This file is a merged representation of the entire codebase, combined into a single document by Repomix.
This section contains a summary of this file.
This file contains a packed representation of the entire repository's contents.
It is designed to be easily consumable by AI systems for analysis, code review,
or other automated processes.
The content is organized as follows:
1. This summary section
2. Repository information
3. Directory structure
4. Repository files, each consisting of:
- File path as an attribute
- Full contents of the file
- This file should be treated as read-only. Any changes should be made to the
original repository files, not this packed version.
- When processing this file, use the file path to distinguish
between different files in the repository.
- Be aware that this file may contain sensitive information. Handle it with
the same level of security as you would the original repository.
- Some files may have been excluded based on .gitignore rules and Repomix's configuration
- Binary files are not included in this packed representation. Please refer to the Repository Structure section for a complete list of file paths, including binary files
- Files matching patterns in .gitignore are excluded
- Files matching default ignore patterns are excluded
- Files are sorted by Git change count (files with more changes are at the bottom)
base/
grouping/
__init__.py
base.py
nb.py
resampling/
__init__.py
base.py
nb.py
__init__.py
accessors.py
chunking.py
combining.py
decorators.py
flex_indexing.py
indexes.py
indexing.py
merging.py
preparing.py
reshaping.py
wrapping.py
data/
custom/
__init__.py
alpaca.py
av.py
bento.py
binance.py
ccxt.py
csv.py
custom.py
db.py
duckdb.py
feather.py
file.py
finpy.py
gbm_ohlc.py
gbm.py
hdf.py
local.py
ndl.py
parquet.py
polygon.py
random_ohlc.py
random.py
remote.py
sql.py
synthetic.py
tv.py
yf.py
__init__.py
base.py
decorators.py
nb.py
saver.py
updater.py
generic/
nb/
__init__.py
apply_reduce.py
base.py
iter_.py
patterns.py
records.py
rolling.py
sim_range.py
splitting/
__init__.py
base.py
decorators.py
nb.py
purged.py
sklearn_.py
__init__.py
accessors.py
analyzable.py
decorators.py
drawdowns.py
enums.py
plots_builder.py
plotting.py
price_records.py
ranges.py
sim_range.py
stats_builder.py
indicators/
custom/
__init__.py
adx.py
atr.py
bbands.py
hurst.py
ma.py
macd.py
msd.py
obv.py
ols.py
patsim.py
pivotinfo.py
rsi.py
sigdet.py
stoch.py
supertrend.py
vwap.py
__init__.py
configs.py
enums.py
expr.py
factory.py
nb.py
talib_.py
labels/
generators/
__init__.py
bolb.py
fixlb.py
fmax.py
fmean.py
fmin.py
fstd.py
meanlb.py
pivotlb.py
trendlb.py
__init__.py
enums.py
nb.py
ohlcv/
__init__.py
accessors.py
enums.py
nb.py
portfolio/
nb/
__init__.py
analysis.py
core.py
ctx_helpers.py
from_order_func.py
from_orders.py
from_signals.py
iter_.py
records.py
pfopt/
__init__.py
base.py
nb.py
records.py
__init__.py
base.py
call_seq.py
chunking.py
decorators.py
enums.py
logs.py
orders.py
preparing.py
trades.py
px/
__init__.py
accessors.py
decorators.py
records/
__init__.py
base.py
chunking.py
col_mapper.py
decorators.py
mapped_array.py
nb.py
registries/
__init__.py
ca_registry.py
ch_registry.py
jit_registry.py
pbar_registry.py
returns/
__init__.py
accessors.py
enums.py
nb.py
qs_adapter.py
signals/
generators/
__init__.py
ohlcstcx.py
ohlcstx.py
rand.py
randnx.py
randx.py
rprob.py
rprobcx.py
rprobnx.py
rprobx.py
stcx.py
stx.py
__init__.py
accessors.py
enums.py
factory.py
nb.py
templates/
dark.json
light.json
seaborn.json
utils/
knowledge/
__init__.py
asset_pipelines.py
base_asset_funcs.py
base_assets.py
chatting.py
custom_asset_funcs.py
custom_assets.py
formatting.py
__init__.py
annotations.py
array_.py
attr_.py
base.py
caching.py
chaining.py
checks.py
chunking.py
colors.py
config.py
cutting.py
datetime_.py
datetime_nb.py
decorators.py
enum_.py
eval_.py
execution.py
figure.py
formatting.py
hashing.py
image_.py
jitting.py
magic_decorators.py
mapping.py
math_.py
merging.py
module_.py
params.py
parsing.py
path_.py
pbar.py
pickling.py
profiling.py
random_.py
requests_.py
schedule_.py
search_.py
selection.py
tagging.py
telegram.py
template.py
warnings_.py
__init__.py
_dtypes.py
_opt_deps.py
_settings.py
_typing.py
_version.py
accessors.py
This section contains the contents of the repository's files.
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with classes and utilities for grouping."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.base.grouping.base import *
from vectorbtpro.base.grouping.nb import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base classes and functions for grouping.
Class `Grouper` stores metadata related to grouping index. It can return, for example,
the number of groups, the start indices of groups, and other information useful for reducing
operations that utilize grouping. It also allows to dynamically enable/disable/modify groups
and checks whether a certain operation is permitted."""
import numpy as np
import pandas as pd
from pandas.core.groupby import GroupBy as PandasGroupBy
from pandas.core.resample import Resampler as PandasResampler
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import indexes
from vectorbtpro.base.grouping import nb
from vectorbtpro.base.indexes import ExceptLevel
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils.array_ import is_sorted
from vectorbtpro.utils.config import Configured
from vectorbtpro.utils.decorators import cached_method
from vectorbtpro.utils.template import CustomTemplate
__all__ = [
"Grouper",
]
GroupByT = tp.Union[None, bool, tp.Index]
GrouperT = tp.TypeVar("GrouperT", bound="Grouper")
class Grouper(Configured):
"""Class that exposes methods to group index.
`group_by` can be:
* boolean (False for no grouping, True for one group),
* integer (level by position),
* string (level by name),
* sequence of integers or strings that is shorter than `index` (multiple levels),
* any other sequence that has the same length as `index` (group per index).
Set `allow_enable` to False to prohibit grouping if `Grouper.group_by` is None.
Set `allow_disable` to False to prohibit disabling of grouping if `Grouper.group_by` is not None.
Set `allow_modify` to False to prohibit modifying groups (you can still change their labels).
All properties are read-only to enable caching."""
@classmethod
def group_by_to_index(
cls,
index: tp.Index,
group_by: tp.GroupByLike,
def_lvl_name: tp.Hashable = "group",
) -> GroupByT:
"""Convert mapper `group_by` to `pd.Index`.
!!! note
Index and mapper must have the same length."""
if group_by is None or group_by is False:
return group_by
if isinstance(group_by, CustomTemplate):
group_by = group_by.substitute(context=dict(index=index), strict=True, eval_id="group_by")
if group_by is True:
group_by = pd.Index(["group"] * len(index), name=def_lvl_name)
elif isinstance(index, pd.MultiIndex) or isinstance(group_by, (ExceptLevel, int, str)):
if isinstance(group_by, ExceptLevel):
except_levels = group_by.value
if isinstance(except_levels, (int, str)):
except_levels = [except_levels]
new_group_by = []
for i, name in enumerate(index.names):
if i not in except_levels and name not in except_levels:
new_group_by.append(name)
if len(new_group_by) == 0:
group_by = pd.Index(["group"] * len(index), name=def_lvl_name)
else:
if len(new_group_by) == 1:
new_group_by = new_group_by[0]
group_by = indexes.select_levels(index, new_group_by)
elif isinstance(group_by, (int, str)):
group_by = indexes.select_levels(index, group_by)
elif (
isinstance(group_by, (tuple, list))
and not isinstance(group_by[0], pd.Index)
and len(group_by) <= len(index.names)
):
try:
group_by = indexes.select_levels(index, group_by)
except (IndexError, KeyError):
pass
if not isinstance(group_by, pd.Index):
if isinstance(group_by[0], pd.Index):
group_by = pd.MultiIndex.from_arrays(group_by)
else:
group_by = pd.Index(group_by, name=def_lvl_name)
if len(group_by) != len(index):
raise ValueError("group_by and index must have the same length")
return group_by
@classmethod
def group_by_to_groups_and_index(
cls,
index: tp.Index,
group_by: tp.GroupByLike,
def_lvl_name: tp.Hashable = "group",
) -> tp.Tuple[tp.Array1d, tp.Index]:
"""Return array of group indices pointing to the original index, and grouped index."""
if group_by is None or group_by is False:
return np.arange(len(index)), index
group_by = cls.group_by_to_index(index, group_by, def_lvl_name)
codes, uniques = pd.factorize(group_by)
if not isinstance(uniques, pd.Index):
new_index = pd.Index(uniques)
else:
new_index = uniques
if isinstance(group_by, pd.MultiIndex):
new_index.names = group_by.names
elif isinstance(group_by, (pd.Index, pd.Series)):
new_index.name = group_by.name
return codes, new_index
@classmethod
def iter_group_lens(cls, group_lens: tp.GroupLens) -> tp.Iterator[tp.GroupIdxs]:
"""Iterate over indices of each group in group lengths."""
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in range(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
yield np.arange(from_col, to_col)
@classmethod
def iter_group_map(cls, group_map: tp.GroupMap) -> tp.Iterator[tp.GroupIdxs]:
"""Iterate over indices of each group in a group map."""
group_idxs, group_lens = group_map
group_start = 0
group_end = 0
for group in range(len(group_lens)):
group_len = group_lens[group]
group_end += group_len
yield group_idxs[group_start:group_end]
group_start += group_len
@classmethod
def from_pd_group_by(
cls: tp.Type[GrouperT],
pd_group_by: tp.PandasGroupByLike,
**kwargs,
) -> GrouperT:
"""Build a `Grouper` instance from a pandas `GroupBy` object.
Indices are stored under `index` and group labels under `group_by`."""
from vectorbtpro.base.merging import concat_arrays
if not isinstance(pd_group_by, (PandasGroupBy, PandasResampler)):
raise TypeError("pd_group_by must be an instance of GroupBy or Resampler")
indices = list(pd_group_by.indices.values())
group_lens = np.asarray(list(map(len, indices)))
groups = np.full(int(np.sum(group_lens)), 0, dtype=int_)
group_start_idxs = np.cumsum(group_lens)[1:] - group_lens[1:]
groups[group_start_idxs] = 1
groups = np.cumsum(groups)
index = pd.Index(concat_arrays(indices))
group_by = pd.Index(list(pd_group_by.indices.keys()), name="group")[groups]
return cls(
index=index,
group_by=group_by,
**kwargs,
)
def __init__(
self,
index: tp.Index,
group_by: tp.GroupByLike = None,
def_lvl_name: tp.Hashable = "group",
allow_enable: bool = True,
allow_disable: bool = True,
allow_modify: bool = True,
**kwargs,
) -> None:
if not isinstance(index, pd.Index):
index = pd.Index(index)
if group_by is None or group_by is False:
group_by = None
else:
group_by = self.group_by_to_index(index, group_by, def_lvl_name=def_lvl_name)
self._index = index
self._group_by = group_by
self._def_lvl_name = def_lvl_name
self._allow_enable = allow_enable
self._allow_disable = allow_disable
self._allow_modify = allow_modify
Configured.__init__(
self,
index=index,
group_by=group_by,
def_lvl_name=def_lvl_name,
allow_enable=allow_enable,
allow_disable=allow_disable,
allow_modify=allow_modify,
**kwargs,
)
@property
def index(self) -> tp.Index:
"""Original index."""
return self._index
@property
def group_by(self) -> GroupByT:
"""Mapper for grouping."""
return self._group_by
@property
def def_lvl_name(self) -> tp.Hashable:
"""Default level name."""
return self._def_lvl_name
@property
def allow_enable(self) -> bool:
"""Whether to allow enabling grouping."""
return self._allow_enable
@property
def allow_disable(self) -> bool:
"""Whether to allow disabling grouping."""
return self._allow_disable
@property
def allow_modify(self) -> bool:
"""Whether to allow changing groups."""
return self._allow_modify
def is_grouped(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether index are grouped."""
if group_by is False:
return False
if group_by is None:
group_by = self.group_by
return group_by is not None
def is_grouping_enabled(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether grouping has been enabled."""
return self.group_by is None and self.is_grouped(group_by=group_by)
def is_grouping_disabled(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether grouping has been disabled."""
return self.group_by is not None and not self.is_grouped(group_by=group_by)
@cached_method(whitelist=True)
def is_grouping_modified(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether grouping has been modified.
Doesn't care if grouping labels have been changed."""
if group_by is None or (group_by is False and self.group_by is None):
return False
group_by = self.group_by_to_index(self.index, group_by, def_lvl_name=self.def_lvl_name)
if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index):
if not pd.Index.equals(group_by, self.group_by):
groups1 = self.group_by_to_groups_and_index(
self.index,
group_by,
def_lvl_name=self.def_lvl_name,
)[0]
groups2 = self.group_by_to_groups_and_index(
self.index,
self.group_by,
def_lvl_name=self.def_lvl_name,
)[0]
if not np.array_equal(groups1, groups2):
return True
return False
return True
@cached_method(whitelist=True)
def is_grouping_changed(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether grouping has been changed in any way."""
if group_by is None or (group_by is False and self.group_by is None):
return False
if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index):
if pd.Index.equals(group_by, self.group_by):
return False
return True
def is_group_count_changed(self, group_by: tp.GroupByLike = None) -> bool:
"""Check whether the number of groups has changed."""
if group_by is None or (group_by is False and self.group_by is None):
return False
if isinstance(group_by, pd.Index) and isinstance(self.group_by, pd.Index):
return len(group_by) != len(self.group_by)
return True
def check_group_by(
self,
group_by: tp.GroupByLike = None,
allow_enable: tp.Optional[bool] = None,
allow_disable: tp.Optional[bool] = None,
allow_modify: tp.Optional[bool] = None,
) -> None:
"""Check passed `group_by` object against restrictions."""
if allow_enable is None:
allow_enable = self.allow_enable
if allow_disable is None:
allow_disable = self.allow_disable
if allow_modify is None:
allow_modify = self.allow_modify
if self.is_grouping_enabled(group_by=group_by):
if not allow_enable:
raise ValueError("Enabling grouping is not allowed")
elif self.is_grouping_disabled(group_by=group_by):
if not allow_disable:
raise ValueError("Disabling grouping is not allowed")
elif self.is_grouping_modified(group_by=group_by):
if not allow_modify:
raise ValueError("Modifying groups is not allowed")
def resolve_group_by(self, group_by: tp.GroupByLike = None, **kwargs) -> GroupByT:
"""Resolve `group_by` from either object variable or keyword argument."""
if group_by is None:
group_by = self.group_by
if group_by is False and self.group_by is None:
group_by = None
self.check_group_by(group_by=group_by, **kwargs)
return self.group_by_to_index(self.index, group_by, def_lvl_name=self.def_lvl_name)
@cached_method(whitelist=True)
def get_groups_and_index(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.Tuple[tp.Array1d, tp.Index]:
"""See `Grouper.group_by_to_groups_and_index`."""
group_by = self.resolve_group_by(group_by=group_by, **kwargs)
return self.group_by_to_groups_and_index(self.index, group_by, def_lvl_name=self.def_lvl_name)
def get_groups(self, **kwargs) -> tp.Array1d:
"""Return groups array."""
return self.get_groups_and_index(**kwargs)[0]
def get_index(self, **kwargs) -> tp.Index:
"""Return grouped index."""
return self.get_groups_and_index(**kwargs)[1]
get_grouped_index = get_index
@property
def grouped_index(self) -> tp.Index:
"""Grouped index."""
return self.get_grouped_index()
def get_stretched_index(self, **kwargs) -> tp.Index:
"""Return stretched index."""
groups, index = self.get_groups_and_index(**kwargs)
return index[groups]
def get_group_count(self, **kwargs) -> int:
"""Get number of groups."""
return len(self.get_index(**kwargs))
@cached_method(whitelist=True)
def is_sorted(self, group_by: tp.GroupByLike = None, **kwargs) -> bool:
"""Return whether groups are monolithic, sorted."""
group_by = self.resolve_group_by(group_by=group_by, **kwargs)
groups = self.get_groups(group_by=group_by)
return is_sorted(groups)
@cached_method(whitelist=True)
def get_group_lens(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, **kwargs) -> tp.GroupLens:
"""See `vectorbtpro.base.grouping.nb.get_group_lens_nb`."""
group_by = self.resolve_group_by(group_by=group_by, **kwargs)
if group_by is None or group_by is False: # no grouping
return np.full(len(self.index), 1)
if not self.is_sorted(group_by=group_by):
raise ValueError("group_by must form monolithic, sorted groups")
groups = self.get_groups(group_by=group_by)
func = jit_reg.resolve_option(nb.get_group_lens_nb, jitted)
return func(groups)
def get_group_start_idxs(self, **kwargs) -> tp.Array1d:
"""Get first index of each group as an array."""
group_lens = self.get_group_lens(**kwargs)
return np.cumsum(group_lens) - group_lens
def get_group_end_idxs(self, **kwargs) -> tp.Array1d:
"""Get end index of each group as an array."""
group_lens = self.get_group_lens(**kwargs)
return np.cumsum(group_lens)
@cached_method(whitelist=True)
def get_group_map(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None, **kwargs) -> tp.GroupMap:
"""See get_group_map_nb."""
group_by = self.resolve_group_by(group_by=group_by, **kwargs)
if group_by is None or group_by is False: # no grouping
return np.arange(len(self.index)), np.full(len(self.index), 1)
groups, new_index = self.get_groups_and_index(group_by=group_by)
func = jit_reg.resolve_option(nb.get_group_map_nb, jitted)
return func(groups, len(new_index))
def iter_group_idxs(self, **kwargs) -> tp.Iterator[tp.GroupIdxs]:
"""Iterate over indices of each group."""
group_map = self.get_group_map(**kwargs)
return self.iter_group_map(group_map)
def iter_groups(
self,
key_as_index: bool = False,
**kwargs,
) -> tp.Iterator[tp.Tuple[tp.Union[tp.Hashable, pd.Index], tp.GroupIdxs]]:
"""Iterate over groups and their indices."""
index = self.get_index(**kwargs)
for group, group_idxs in enumerate(self.iter_group_idxs(**kwargs)):
if key_as_index:
yield index[[group]], group_idxs
else:
yield index[group], group_idxs
def select_groups(self, group_idxs: tp.Array1d, jitted: tp.JittedOption = None) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Select groups.
Returns indices and new group array. Automatically decides whether to use group lengths or group map."""
from vectorbtpro.base.reshaping import to_1d_array
if self.is_sorted():
func = jit_reg.resolve_option(nb.group_lens_select_nb, jitted)
new_group_idxs, new_groups = func(self.get_group_lens(), to_1d_array(group_idxs)) # faster
else:
func = jit_reg.resolve_option(nb.group_map_select_nb, jitted)
new_group_idxs, new_groups = func(self.get_group_map(), to_1d_array(group_idxs)) # more flexible
return new_group_idxs, new_groups
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for grouping."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import register_jitted
__all__ = []
GroupByT = tp.Union[None, bool, tp.Index]
@register_jitted(cache=True)
def get_group_lens_nb(groups: tp.Array1d) -> tp.GroupLens:
"""Return the count per group.
!!! note
Columns must form monolithic, sorted groups. For unsorted groups, use `get_group_map_nb`."""
result = np.empty(groups.shape[0], dtype=int_)
j = 0
last_group = -1
group_len = 0
for i in range(groups.shape[0]):
cur_group = groups[i]
if cur_group < last_group:
raise ValueError("Columns must form monolithic, sorted groups")
if cur_group != last_group:
if last_group != -1:
# Process previous group
result[j] = group_len
j += 1
group_len = 0
last_group = cur_group
group_len += 1
if i == groups.shape[0] - 1:
# Process last group
result[j] = group_len
j += 1
group_len = 0
return result[:j]
@register_jitted(cache=True)
def get_group_map_nb(groups: tp.Array1d, n_groups: int) -> tp.GroupMap:
"""Build the map between groups and indices.
Returns an array with indices segmented by group and an array with group lengths.
Works well for unsorted group arrays."""
group_lens_out = np.full(n_groups, 0, dtype=int_)
for g in range(groups.shape[0]):
group = groups[g]
group_lens_out[group] += 1
group_start_idxs = np.cumsum(group_lens_out) - group_lens_out
group_idxs_out = np.empty((groups.shape[0],), dtype=int_)
group_i = np.full(n_groups, 0, dtype=int_)
for g in range(groups.shape[0]):
group = groups[g]
group_idxs_out[group_start_idxs[group] + group_i[group]] = g
group_i[group] += 1
return group_idxs_out, group_lens_out
@register_jitted(cache=True)
def group_lens_select_nb(group_lens: tp.GroupLens, new_groups: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Perform indexing on a sorted array using group lengths.
Returns indices of elements corresponding to groups in `new_groups` and a new group array."""
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
n_values = np.sum(group_lens[new_groups])
indices_out = np.empty(n_values, dtype=int_)
group_arr_out = np.empty(n_values, dtype=int_)
j = 0
for c in range(new_groups.shape[0]):
from_r = group_start_idxs[new_groups[c]]
to_r = group_end_idxs[new_groups[c]]
if from_r == to_r:
continue
rang = np.arange(from_r, to_r)
indices_out[j : j + rang.shape[0]] = rang
group_arr_out[j : j + rang.shape[0]] = c
j += rang.shape[0]
return indices_out, group_arr_out
@register_jitted(cache=True)
def group_map_select_nb(group_map: tp.GroupMap, new_groups: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Perform indexing using group map."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
total_count = np.sum(group_lens[new_groups])
indices_out = np.empty(total_count, dtype=int_)
group_arr_out = np.empty(total_count, dtype=int_)
j = 0
for new_group_i in range(len(new_groups)):
new_group = new_groups[new_group_i]
group_len = group_lens[new_group]
if group_len == 0:
continue
group_start_idx = group_start_idxs[new_group]
idxs = group_idxs[group_start_idx : group_start_idx + group_len]
indices_out[j : j + group_len] = idxs
group_arr_out[j : j + group_len] = new_group_i
j += group_len
return indices_out, group_arr_out
@register_jitted(cache=True)
def group_by_evenly_nb(n: int, n_splits: int) -> tp.Array1d:
"""Get `group_by` from evenly splitting a space of values."""
out = np.empty(n, dtype=int_)
for i in range(n):
out[i] = i * n_splits // n + n_splits // (2 * n)
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with classes and utilities for resampling."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.base.resampling.base import *
from vectorbtpro.base.resampling.nb import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base classes and functions for resampling."""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexes import repeat_index
from vectorbtpro.base.resampling import nb
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.config import Configured
from vectorbtpro.utils.decorators import cached_property, hybrid_method
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"Resampler",
]
ResamplerT = tp.TypeVar("ResamplerT", bound="Resampler")
class Resampler(Configured):
"""Class that exposes methods to resample index.
Args:
source_index (index_like): Index being resampled.
target_index (index_like): Index resulted from resampling.
source_freq (frequency_like or bool): Frequency or date offset of the source index.
Set to False to force-set the frequency to None.
target_freq (frequency_like or bool): Frequency or date offset of the target index.
Set to False to force-set the frequency to None.
silence_warnings (bool): Whether to silence all warnings."""
def __init__(
self,
source_index: tp.IndexLike,
target_index: tp.IndexLike,
source_freq: tp.Union[None, bool, tp.FrequencyLike] = None,
target_freq: tp.Union[None, bool, tp.FrequencyLike] = None,
silence_warnings: tp.Optional[bool] = None,
**kwargs,
) -> None:
source_index = dt.prepare_dt_index(source_index)
target_index = dt.prepare_dt_index(target_index)
infer_source_freq = True
if isinstance(source_freq, bool):
if not source_freq:
infer_source_freq = False
source_freq = None
infer_target_freq = True
if isinstance(target_freq, bool):
if not target_freq:
infer_target_freq = False
target_freq = None
if infer_source_freq:
source_freq = dt.infer_index_freq(source_index, freq=source_freq)
if infer_target_freq:
target_freq = dt.infer_index_freq(target_index, freq=target_freq)
self._source_index = source_index
self._target_index = target_index
self._source_freq = source_freq
self._target_freq = target_freq
self._silence_warnings = silence_warnings
Configured.__init__(
self,
source_index=source_index,
target_index=target_index,
source_freq=source_freq,
target_freq=target_freq,
silence_warnings=silence_warnings,
**kwargs,
)
@classmethod
def from_pd_resampler(
cls: tp.Type[ResamplerT],
pd_resampler: tp.PandasResampler,
source_freq: tp.Optional[tp.FrequencyLike] = None,
silence_warnings: bool = True,
) -> ResamplerT:
"""Build `Resampler` from
[pandas.core.resample.Resampler](https://pandas.pydata.org/docs/reference/resampling.html).
"""
target_index = pd_resampler.count().index
return cls(
source_index=pd_resampler.obj.index,
target_index=target_index,
source_freq=source_freq,
target_freq=None,
silence_warnings=silence_warnings,
)
@classmethod
def from_pd_resample(
cls: tp.Type[ResamplerT],
source_index: tp.IndexLike,
*args,
source_freq: tp.Optional[tp.FrequencyLike] = None,
silence_warnings: bool = True,
**kwargs,
) -> ResamplerT:
"""Build `Resampler` from
[pandas.DataFrame.resample](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.resample.html).
"""
pd_resampler = pd.Series(index=source_index, dtype=object).resample(*args, **kwargs)
return cls.from_pd_resampler(pd_resampler, source_freq=source_freq, silence_warnings=silence_warnings)
@classmethod
def from_date_range(
cls: tp.Type[ResamplerT],
source_index: tp.IndexLike,
*args,
source_freq: tp.Optional[tp.FrequencyLike] = None,
silence_warnings: tp.Optional[bool] = None,
**kwargs,
) -> ResamplerT:
"""Build `Resampler` from `vectorbtpro.utils.datetime_.date_range`."""
target_index = dt.date_range(*args, **kwargs)
return cls(
source_index=source_index,
target_index=target_index,
source_freq=source_freq,
target_freq=None,
silence_warnings=silence_warnings,
)
@property
def source_index(self) -> tp.Index:
"""Index being resampled."""
return self._source_index
@property
def target_index(self) -> tp.Index:
"""Index resulted from resampling."""
return self._target_index
@property
def source_freq(self) -> tp.AnyPandasFrequency:
"""Frequency or date offset of the source index."""
return self._source_freq
@property
def target_freq(self) -> tp.AnyPandasFrequency:
"""Frequency or date offset of the target index."""
return self._target_freq
@property
def silence_warnings(self) -> bool:
"""Frequency or date offset of the target index."""
from vectorbtpro._settings import settings
resampling_cfg = settings["resampling"]
silence_warnings = self._silence_warnings
if silence_warnings is None:
silence_warnings = resampling_cfg["silence_warnings"]
return silence_warnings
def get_np_source_freq(self, silence_warnings: tp.Optional[bool] = None) -> tp.AnyPandasFrequency:
"""Frequency or date offset of the source index in NumPy format."""
if silence_warnings is None:
silence_warnings = self.silence_warnings
warned = False
source_freq = self.source_freq
if source_freq is not None:
if not isinstance(source_freq, (int, float)):
try:
source_freq = dt.to_timedelta64(source_freq)
except ValueError as e:
if not silence_warnings:
warn(f"Cannot convert {source_freq} to np.timedelta64. Setting to None.")
warned = True
source_freq = None
if source_freq is None:
if not warned and not silence_warnings:
warn("Using right bound of source index without frequency. Set source frequency.")
return source_freq
def get_np_target_freq(self, silence_warnings: tp.Optional[bool] = None) -> tp.AnyPandasFrequency:
"""Frequency or date offset of the target index in NumPy format."""
if silence_warnings is None:
silence_warnings = self.silence_warnings
warned = False
target_freq = self.target_freq
if target_freq is not None:
if not isinstance(target_freq, (int, float)):
try:
target_freq = dt.to_timedelta64(target_freq)
except ValueError as e:
if not silence_warnings:
warn(f"Cannot convert {target_freq} to np.timedelta64. Setting to None.")
warned = True
target_freq = None
if target_freq is None:
if not warned and not silence_warnings:
warn("Using right bound of target index without frequency. Set target frequency.")
return target_freq
@classmethod
def get_lbound_index(cls, index: tp.Index, freq: tp.AnyPandasFrequency = None) -> tp.Index:
"""Get the left bound of a datetime index.
If `freq` is None, calculates the leftmost bound."""
index = dt.prepare_dt_index(index)
checks.assert_instance_of(index, pd.DatetimeIndex)
if freq is not None:
return index.shift(-1, freq=freq) + pd.Timedelta(1, "ns")
min_ts = pd.DatetimeIndex([pd.Timestamp.min.tz_localize(index.tz)])
return (index[:-1] + pd.Timedelta(1, "ns")).append(min_ts)
@classmethod
def get_rbound_index(cls, index: tp.Index, freq: tp.AnyPandasFrequency = None) -> tp.Index:
"""Get the right bound of a datetime index.
If `freq` is None, calculates the rightmost bound."""
index = dt.prepare_dt_index(index)
checks.assert_instance_of(index, pd.DatetimeIndex)
if freq is not None:
return index.shift(1, freq=freq) - pd.Timedelta(1, "ns")
max_ts = pd.DatetimeIndex([pd.Timestamp.max.tz_localize(index.tz)])
return (index[1:] - pd.Timedelta(1, "ns")).append(max_ts)
@cached_property
def source_lbound_index(self) -> tp.Index:
"""Get the left bound of the source datetime index."""
return self.get_lbound_index(self.source_index, freq=self.source_freq)
@cached_property
def source_rbound_index(self) -> tp.Index:
"""Get the right bound of the source datetime index."""
return self.get_rbound_index(self.source_index, freq=self.source_freq)
@cached_property
def target_lbound_index(self) -> tp.Index:
"""Get the left bound of the target datetime index."""
return self.get_lbound_index(self.target_index, freq=self.target_freq)
@cached_property
def target_rbound_index(self) -> tp.Index:
"""Get the right bound of the target datetime index."""
return self.get_rbound_index(self.target_index, freq=self.target_freq)
def map_to_target_index(
self,
before: bool = False,
raise_missing: bool = True,
return_index: bool = True,
jitted: tp.JittedOption = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Union[tp.Array1d, tp.Index]:
"""See `vectorbtpro.base.resampling.nb.map_to_target_index_nb`."""
target_freq = self.get_np_target_freq(silence_warnings=silence_warnings)
func = jit_reg.resolve_option(nb.map_to_target_index_nb, jitted)
mapped_arr = func(
self.source_index.values,
self.target_index.values,
target_freq=target_freq,
before=before,
raise_missing=raise_missing,
)
if return_index:
nan_mask = mapped_arr == -1
if nan_mask.any():
mapped_index = self.source_index.to_series().copy()
mapped_index[nan_mask] = np.nan
mapped_index[~nan_mask] = self.target_index[mapped_arr]
mapped_index = pd.Index(mapped_index)
else:
mapped_index = self.target_index[mapped_arr]
return mapped_index
return mapped_arr
def index_difference(
self,
reverse: bool = False,
return_index: bool = True,
jitted: tp.JittedOption = None,
) -> tp.Union[tp.Array1d, tp.Index]:
"""See `vectorbtpro.base.resampling.nb.index_difference_nb`."""
func = jit_reg.resolve_option(nb.index_difference_nb, jitted)
if reverse:
mapped_arr = func(self.target_index.values, self.source_index.values)
else:
mapped_arr = func(self.source_index.values, self.target_index.values)
if return_index:
return self.target_index[mapped_arr]
return mapped_arr
def map_index_to_source_ranges(
self,
before: bool = False,
jitted: tp.JittedOption = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""See `vectorbtpro.base.resampling.nb.map_index_to_source_ranges_nb`.
If `Resampler.target_freq` is a date offset, sets is to None and gives a warning.
Raises another warning is `target_freq` is None."""
target_freq = self.get_np_target_freq(silence_warnings=silence_warnings)
func = jit_reg.resolve_option(nb.map_index_to_source_ranges_nb, jitted)
return func(
self.source_index.values,
self.target_index.values,
target_freq=target_freq,
before=before,
)
@hybrid_method
def map_bounds_to_source_ranges(
cls_or_self,
source_index: tp.Optional[tp.IndexLike] = None,
target_lbound_index: tp.Optional[tp.IndexLike] = None,
target_rbound_index: tp.Optional[tp.IndexLike] = None,
closed_lbound: bool = True,
closed_rbound: bool = False,
skip_not_found: bool = False,
jitted: tp.JittedOption = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""See `vectorbtpro.base.resampling.nb.map_bounds_to_source_ranges_nb`.
Either `target_lbound_index` or `target_rbound_index` must be set.
Set `target_lbound_index` and `target_rbound_index` to 'pandas' to use
`Resampler.get_lbound_index` and `Resampler.get_rbound_index` respectively.
Also, both allow providing a single datetime string and will automatically broadcast
to the `Resampler.target_index`."""
if not isinstance(cls_or_self, type):
if target_lbound_index is None and target_rbound_index is None:
raise ValueError("Either target_lbound_index or target_rbound_index must be set")
if target_lbound_index is not None:
if isinstance(target_lbound_index, str) and target_lbound_index.lower() == "pandas":
target_lbound_index = cls_or_self.target_lbound_index
else:
target_lbound_index = dt.prepare_dt_index(target_lbound_index)
target_rbound_index = cls_or_self.target_index
if target_rbound_index is not None:
target_lbound_index = cls_or_self.target_index
if isinstance(target_rbound_index, str) and target_rbound_index.lower() == "pandas":
target_rbound_index = cls_or_self.target_rbound_index
else:
target_rbound_index = dt.prepare_dt_index(target_rbound_index)
if len(target_lbound_index) == 1 and len(target_rbound_index) > 1:
target_lbound_index = repeat_index(target_lbound_index, len(target_rbound_index))
elif len(target_lbound_index) > 1 and len(target_rbound_index) == 1:
target_rbound_index = repeat_index(target_rbound_index, len(target_lbound_index))
else:
source_index = dt.prepare_dt_index(source_index)
target_lbound_index = dt.prepare_dt_index(target_lbound_index)
target_rbound_index = dt.prepare_dt_index(target_rbound_index)
checks.assert_len_equal(target_rbound_index, target_lbound_index)
func = jit_reg.resolve_option(nb.map_bounds_to_source_ranges_nb, jitted)
return func(
source_index.values,
target_lbound_index.values,
target_rbound_index.values,
closed_lbound=closed_lbound,
closed_rbound=closed_rbound,
skip_not_found=skip_not_found,
)
def resample_source_mask(
self,
source_mask: tp.ArrayLike,
jitted: tp.JittedOption = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Array1d:
"""See `vectorbtpro.base.resampling.nb.resample_source_mask_nb`."""
from vectorbtpro.base.reshaping import broadcast_array_to
if silence_warnings is None:
silence_warnings = self.silence_warnings
source_mask = broadcast_array_to(source_mask, len(self.source_index))
source_freq = self.get_np_source_freq(silence_warnings=silence_warnings)
target_freq = self.get_np_target_freq(silence_warnings=silence_warnings)
func = jit_reg.resolve_option(nb.resample_source_mask_nb, jitted)
return func(
source_mask,
self.source_index.values,
self.target_index.values,
source_freq,
target_freq,
)
def last_before_target_index(self, incl_source: bool = True, jitted: tp.JittedOption = None) -> tp.Array1d:
"""See `vectorbtpro.base.resampling.nb.last_before_target_index_nb`."""
func = jit_reg.resolve_option(nb.last_before_target_index_nb, jitted)
return func(self.source_index.values, self.target_index.values, incl_source=incl_source)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for resampling."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils.datetime_nb import d_td
__all__ = []
@register_jitted(cache=True)
def date_range_nb(
start: np.datetime64,
end: np.datetime64,
freq: np.timedelta64 = d_td,
incl_left: bool = True,
incl_right: bool = True,
) -> tp.Array1d:
"""Generate a datetime index with nanosecond precision from a date range.
Inspired by [pandas.date_range](https://pandas.pydata.org/docs/reference/api/pandas.date_range.html)."""
values_len = int(np.floor((end - start) / freq)) + 1
values = np.empty(values_len, dtype="datetime64[ns]")
for i in range(values_len):
values[i] = start + i * freq
if start == end:
if not incl_left and not incl_right:
values = values[1:-1]
else:
if not incl_left or not incl_right:
if not incl_left and len(values) and values[0] == start:
values = values[1:]
if not incl_right and len(values) and values[-1] == end:
values = values[:-1]
return values
@register_jitted(cache=True)
def map_to_target_index_nb(
source_index: tp.Array1d,
target_index: tp.Array1d,
target_freq: tp.Optional[tp.Scalar] = None,
before: bool = False,
raise_missing: bool = True,
) -> tp.Array1d:
"""Get the index of each from `source_index` in `target_index`.
If `before` is True, applied on elements that come before and including that index.
Otherwise, applied on elements that come after and including that index.
If `raise_missing` is True, will throw an error if an index cannot be mapped.
Otherwise, the element for that index becomes -1."""
out = np.empty(len(source_index), dtype=int_)
from_j = 0
for i in range(len(source_index)):
if i > 0 and source_index[i] < source_index[i - 1]:
raise ValueError("Source index must be increasing")
if i > 0 and source_index[i] == source_index[i - 1]:
out[i] = out[i - 1]
found = False
for j in range(from_j, len(target_index)):
if j > 0 and target_index[j] <= target_index[j - 1]:
raise ValueError("Target index must be strictly increasing")
if target_freq is None:
if before and source_index[i] <= target_index[j]:
if j == 0 or target_index[j - 1] < source_index[i]:
out[i] = from_j = j
found = True
break
if not before and target_index[j] <= source_index[i]:
if j == len(target_index) - 1 or source_index[i] < target_index[j + 1]:
out[i] = from_j = j
found = True
break
else:
if before and target_index[j] - target_freq < source_index[i] <= target_index[j]:
out[i] = from_j = j
found = True
break
if not before and target_index[j] <= source_index[i] < target_index[j] + target_freq:
out[i] = from_j = j
found = True
break
if not found:
if raise_missing:
raise ValueError("Resampling failed: cannot map some source indices")
out[i] = -1
return out
@register_jitted(cache=True)
def index_difference_nb(
source_index: tp.Array1d,
target_index: tp.Array1d,
) -> tp.Array1d:
"""Get the elements in `source_index` not present in `target_index`."""
out = np.empty(len(source_index), dtype=int_)
from_j = 0
k = 0
for i in range(len(source_index)):
if i > 0 and source_index[i] <= source_index[i - 1]:
raise ValueError("Array index must be strictly increasing")
found = False
for j in range(from_j, len(target_index)):
if j > 0 and target_index[j] <= target_index[j - 1]:
raise ValueError("Target index must be strictly increasing")
if source_index[i] < target_index[j]:
break
if source_index[i] == target_index[j]:
from_j = j
found = True
break
from_j = j
if not found:
out[k] = i
k += 1
return out[:k]
@register_jitted(cache=True)
def map_index_to_source_ranges_nb(
source_index: tp.Array1d,
target_index: tp.Array1d,
target_freq: tp.Optional[tp.Scalar] = None,
before: bool = False,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Get the source bounds that correspond to each target index.
If `target_freq` is not None, the right bound is limited by the frequency in `target_freq`.
Otherwise, the right bound is the next index in `target_index`.
Returns a 2-dim array where the first column is the absolute start index (including) and
the second column is the absolute end index (excluding).
If an element cannot be mapped, the start and end of the range becomes -1.
!!! note
Both index arrays must be increasing. Repeating values are allowed."""
range_starts_out = np.empty(len(target_index), dtype=int_)
range_ends_out = np.empty(len(target_index), dtype=int_)
to_j = 0
for i in range(len(target_index)):
if i > 0 and target_index[i] < target_index[i - 1]:
raise ValueError("Target index must be increasing")
from_j = -1
for j in range(to_j, len(source_index)):
if j > 0 and source_index[j] < source_index[j - 1]:
raise ValueError("Array index must be increasing")
found = False
if target_freq is None:
if before:
if i == 0 and source_index[j] <= target_index[i]:
found = True
elif i > 0 and target_index[i - 1] < source_index[j] <= target_index[i]:
found = True
elif source_index[j] > target_index[i]:
break
else:
if i == len(target_index) - 1 and target_index[i] <= source_index[j]:
found = True
elif i < len(target_index) - 1 and target_index[i] <= source_index[j] < target_index[i + 1]:
found = True
elif i < len(target_index) - 1 and source_index[j] >= target_index[i + 1]:
break
else:
if before:
if target_index[i] - target_freq < source_index[j] <= target_index[i]:
found = True
elif source_index[j] > target_index[i]:
break
else:
if target_index[i] <= source_index[j] < target_index[i] + target_freq:
found = True
elif source_index[j] >= target_index[i] + target_freq:
break
if found:
if from_j == -1:
from_j = j
to_j = j + 1
if from_j == -1:
range_starts_out[i] = -1
range_ends_out[i] = -1
else:
range_starts_out[i] = from_j
range_ends_out[i] = to_j
return range_starts_out, range_ends_out
@register_jitted(cache=True)
def map_bounds_to_source_ranges_nb(
source_index: tp.Array1d,
target_lbound_index: tp.Array1d,
target_rbound_index: tp.Array1d,
closed_lbound: bool = True,
closed_rbound: bool = False,
skip_not_found: bool = False,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Get the source bounds that correspond to the target bounds.
Returns a 2-dim array where the first column is the absolute start index (including) nad
the second column is the absolute end index (excluding).
If an element cannot be mapped, the start and end of the range becomes -1.
!!! note
Both index arrays must be increasing. Repeating values are allowed."""
range_starts_out = np.empty(len(target_lbound_index), dtype=int_)
range_ends_out = np.empty(len(target_lbound_index), dtype=int_)
k = 0
to_j = 0
for i in range(len(target_lbound_index)):
if i > 0 and target_lbound_index[i] < target_lbound_index[i - 1]:
raise ValueError("Target left-bound index must be increasing")
if i > 0 and target_rbound_index[i] < target_rbound_index[i - 1]:
raise ValueError("Target right-bound index must be increasing")
from_j = -1
for j in range(len(source_index)):
if j > 0 and source_index[j] < source_index[j - 1]:
raise ValueError("Array index must be increasing")
found = False
if closed_lbound and closed_rbound:
if target_lbound_index[i] <= source_index[j] <= target_rbound_index[i]:
found = True
elif source_index[j] > target_rbound_index[i]:
break
elif closed_lbound:
if target_lbound_index[i] <= source_index[j] < target_rbound_index[i]:
found = True
elif source_index[j] >= target_rbound_index[i]:
break
elif closed_rbound:
if target_lbound_index[i] < source_index[j] <= target_rbound_index[i]:
found = True
elif source_index[j] > target_rbound_index[i]:
break
else:
if target_lbound_index[i] < source_index[j] < target_rbound_index[i]:
found = True
elif source_index[j] >= target_rbound_index[i]:
break
if found:
if from_j == -1:
from_j = j
to_j = j + 1
if skip_not_found:
if from_j != -1:
range_starts_out[k] = from_j
range_ends_out[k] = to_j
k += 1
else:
if from_j == -1:
range_starts_out[i] = -1
range_ends_out[i] = -1
else:
range_starts_out[i] = from_j
range_ends_out[i] = to_j
if skip_not_found:
return range_starts_out[:k], range_ends_out[:k]
return range_starts_out, range_ends_out
@register_jitted(cache=True)
def resample_source_mask_nb(
source_mask: tp.Array1d,
source_index: tp.Array1d,
target_index: tp.Array1d,
source_freq: tp.Optional[tp.Scalar] = None,
target_freq: tp.Optional[tp.Scalar] = None,
) -> tp.Array1d:
"""Resample a source mask to the target index.
Becomes True only if the target bar is fully contained in the source bar. The source bar
is represented by a non-interrupting sequence of True values in the source mask."""
out = np.full(len(target_index), False, dtype=np.bool_)
from_j = 0
for i in range(len(target_index)):
if i > 0 and target_index[i] < target_index[i - 1]:
raise ValueError("Target index must be increasing")
target_lbound = target_index[i]
if target_freq is None:
if i + 1 < len(target_index):
target_rbound = target_index[i + 1]
else:
target_rbound = None
else:
target_rbound = target_index[i] + target_freq
found_start = False
for j in range(from_j, len(source_index)):
if j > 0 and source_index[j] < source_index[j - 1]:
raise ValueError("Source index must be increasing")
source_lbound = source_index[j]
if source_freq is None:
if j + 1 < len(source_index):
source_rbound = source_index[j + 1]
else:
source_rbound = None
else:
source_rbound = source_index[j] + source_freq
if target_rbound is not None and target_rbound <= source_lbound:
break
if found_start or (
target_lbound >= source_lbound and (source_rbound is None or target_lbound < source_rbound)
):
if not found_start:
from_j = j
found_start = True
if not source_mask[j]:
break
if source_rbound is None or (target_rbound is not None and target_rbound <= source_rbound):
out[i] = True
break
return out
@register_jitted(cache=True)
def last_before_target_index_nb(
source_index: tp.Array1d,
target_index: tp.Array1d,
incl_source: bool = True,
incl_target: bool = False,
) -> tp.Array1d:
"""For each source index, find the position of the last source index between the original
source index and the corresponding target index."""
out = np.empty(len(source_index), dtype=int_)
last_j = -1
for i in range(len(source_index)):
if i > 0 and source_index[i] < source_index[i - 1]:
raise ValueError("Source index must be increasing")
if i > 0 and target_index[i] < target_index[i - 1]:
raise ValueError("Target index must be increasing")
if source_index[i] > target_index[i]:
raise ValueError("Target index must be equal to or greater than source index")
if last_j == -1:
from_i = i + 1
else:
from_i = last_j
if incl_source:
last_j = i
else:
last_j = -1
for j in range(from_i, len(source_index)):
if source_index[j] < target_index[i]:
last_j = j
elif incl_target and source_index[j] == target_index[i]:
last_j = j
else:
break
out[i] = last_j
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with base classes and utilities for pandas objects, such as broadcasting."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.base.grouping import *
from vectorbtpro.base.resampling import *
from vectorbtpro.base.accessors import *
from vectorbtpro.base.chunking import *
from vectorbtpro.base.combining import *
from vectorbtpro.base.decorators import *
from vectorbtpro.base.flex_indexing import *
from vectorbtpro.base.indexes import *
from vectorbtpro.base.indexing import *
from vectorbtpro.base.merging import *
from vectorbtpro.base.preparing import *
from vectorbtpro.base.reshaping import *
from vectorbtpro.base.wrapping import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom Pandas accessors for base operations with Pandas objects."""
import ast
import inspect
import numpy as np
import pandas as pd
from pandas.api.types import is_scalar
from pandas.core.groupby import GroupBy as PandasGroupBy
from pandas.core.resample import Resampler as PandasResampler
from vectorbtpro import _typing as tp
from vectorbtpro.base import combining, reshaping, indexes
from vectorbtpro.base.grouping.base import Grouper
from vectorbtpro.base.indexes import IndexApplier
from vectorbtpro.base.indexing import (
point_idxr_defaults,
range_idxr_defaults,
get_index_points,
get_index_ranges,
)
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.chunking import ChunkMeta, iter_chunk_meta, get_chunk_meta_key, ArraySelector, ArraySlicer
from vectorbtpro.utils.config import merge_dicts, resolve_dict, Configured
from vectorbtpro.utils.decorators import hybrid_property, hybrid_method
from vectorbtpro.utils.execution import Task, execute
from vectorbtpro.utils.eval_ import evaluate
from vectorbtpro.utils.magic_decorators import attach_binary_magic_methods, attach_unary_magic_methods
from vectorbtpro.utils.parsing import get_context_vars, get_func_arg_names
from vectorbtpro.utils.template import substitute_templates
from vectorbtpro.utils.warnings_ import warn
if tp.TYPE_CHECKING:
from vectorbtpro.data.base import Data as DataT
else:
DataT = "Data"
if tp.TYPE_CHECKING:
from vectorbtpro.generic.splitting.base import Splitter as SplitterT
else:
SplitterT = "Splitter"
__all__ = ["BaseIDXAccessor", "BaseAccessor", "BaseSRAccessor", "BaseDFAccessor"]
BaseIDXAccessorT = tp.TypeVar("BaseIDXAccessorT", bound="BaseIDXAccessor")
class BaseIDXAccessor(Configured, IndexApplier):
"""Accessor on top of Index.
Accessible via `pd.Index.vbt` and all child accessors."""
def __init__(self, obj: tp.Index, freq: tp.Optional[tp.FrequencyLike] = None, **kwargs) -> None:
checks.assert_instance_of(obj, pd.Index)
Configured.__init__(self, obj=obj, freq=freq, **kwargs)
self._obj = obj
self._freq = freq
@property
def obj(self) -> tp.Index:
"""Pandas object."""
return self._obj
def get(self) -> tp.Index:
"""Get `IDXAccessor.obj`."""
return self.obj
# ############# Index ############# #
def to_ns(self) -> tp.Array1d:
"""Convert index to an 64-bit integer array.
Timestamps will be converted to nanoseconds."""
return dt.to_ns(self.obj)
def to_period(self, freq: tp.FrequencyLike, shift: bool = False) -> pd.PeriodIndex:
"""Convert index to period."""
index = self.obj
if isinstance(index, pd.DatetimeIndex):
index = index.tz_localize(None).to_period(freq)
if shift:
index = index.shift()
if not isinstance(index, pd.PeriodIndex):
raise TypeError(f"Cannot convert index of type {type(index)} to period")
return index
def to_period_ts(self, *args, **kwargs) -> pd.DatetimeIndex:
"""Convert index to period and then to timestamp."""
new_index = self.to_period(*args, **kwargs).to_timestamp()
if self.obj.tz is not None:
new_index = new_index.tz_localize(self.obj.tz)
return new_index
def to_period_ns(self, *args, **kwargs) -> tp.Array1d:
"""Convert index to period and then to an 64-bit integer array.
Timestamps will be converted to nanoseconds."""
return dt.to_ns(self.to_period_ts(*args, **kwargs))
@classmethod
def from_values(cls, *args, **kwargs) -> tp.Index:
"""See `vectorbtpro.base.indexes.index_from_values`."""
return indexes.index_from_values(*args, **kwargs)
def repeat(self, *args, **kwargs) -> tp.Index:
"""See `vectorbtpro.base.indexes.repeat_index`."""
return indexes.repeat_index(self.obj, *args, **kwargs)
def tile(self, *args, **kwargs) -> tp.Index:
"""See `vectorbtpro.base.indexes.tile_index`."""
return indexes.tile_index(self.obj, *args, **kwargs)
@hybrid_method
def stack(
cls_or_self,
*others: tp.Union[tp.IndexLike, "BaseIDXAccessor"],
on_top: bool = False,
**kwargs,
) -> tp.Index:
"""See `vectorbtpro.base.indexes.stack_indexes`.
Set `on_top` to True to stack the second index on top of this one."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
if on_top:
objs = (*others, cls_or_self.obj)
else:
objs = (cls_or_self.obj, *others)
return indexes.stack_indexes(*objs, **kwargs)
@hybrid_method
def combine(
cls_or_self,
*others: tp.Union[tp.IndexLike, "BaseIDXAccessor"],
on_top: bool = False,
**kwargs,
) -> tp.Index:
"""See `vectorbtpro.base.indexes.combine_indexes`.
Set `on_top` to True to stack the second index on top of this one."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
if on_top:
objs = (*others, cls_or_self.obj)
else:
objs = (cls_or_self.obj, *others)
return indexes.combine_indexes(*objs, **kwargs)
@hybrid_method
def concat(cls_or_self, *others: tp.Union[tp.IndexLike, "BaseIDXAccessor"], **kwargs) -> tp.Index:
"""See `vectorbtpro.base.indexes.concat_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return indexes.concat_indexes(*objs, **kwargs)
def apply_to_index(
self: BaseIDXAccessorT,
apply_func: tp.Callable,
*args,
**kwargs,
) -> tp.Index:
return self.replace(obj=apply_func(self.obj, *args, **kwargs)).obj
def align_to(self, *args, **kwargs) -> tp.IndexSlice:
"""See `vectorbtpro.base.indexes.align_index_to`."""
return indexes.align_index_to(self.obj, *args, **kwargs)
@hybrid_method
def align(
cls_or_self,
*others: tp.Union[tp.IndexLike, "BaseIDXAccessor"],
**kwargs,
) -> tp.Tuple[tp.IndexSlice, ...]:
"""See `vectorbtpro.base.indexes.align_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return indexes.align_indexes(*objs, **kwargs)
def cross_with(self, *args, **kwargs) -> tp.Tuple[tp.IndexSlice, tp.IndexSlice]:
"""See `vectorbtpro.base.indexes.cross_index_with`."""
return indexes.cross_index_with(self.obj, *args, **kwargs)
@hybrid_method
def cross(
cls_or_self,
*others: tp.Union[tp.IndexLike, "BaseIDXAccessor"],
**kwargs,
) -> tp.Tuple[tp.IndexSlice, ...]:
"""See `vectorbtpro.base.indexes.cross_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseIDXAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return indexes.cross_indexes(*objs, **kwargs)
x = cross
def find_first_occurrence(self, *args, **kwargs) -> int:
"""See `vectorbtpro.base.indexes.find_first_occurrence`."""
return indexes.find_first_occurrence(self.obj, *args, **kwargs)
# ############# Frequency ############# #
@hybrid_method
def get_freq(
cls_or_self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> tp.Union[None, float, tp.PandasFrequency]:
"""Index frequency as `pd.Timedelta` or None if it cannot be converted."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if not isinstance(cls_or_self, type):
if index is None:
index = cls_or_self.obj
if freq is None:
freq = cls_or_self._freq
else:
checks.assert_not_none(index, arg_name="index")
if freq is None:
freq = wrapping_cfg["freq"]
try:
return dt.infer_index_freq(index, freq=freq, **kwargs)
except Exception as e:
return None
@property
def freq(self) -> tp.Optional[tp.PandasFrequency]:
"""`BaseIDXAccessor.get_freq` with date offsets and integer frequencies not allowed."""
return self.get_freq(allow_offset=True, allow_numeric=False)
@property
def ns_freq(self) -> tp.Optional[int]:
"""Convert frequency to a 64-bit integer.
Timedelta will be converted to nanoseconds."""
freq = self.get_freq(allow_offset=False, allow_numeric=True)
if freq is not None:
freq = dt.to_ns(dt.to_timedelta64(freq))
return freq
@property
def any_freq(self) -> tp.Union[None, float, tp.PandasFrequency]:
"""Index frequency of any type."""
return self.get_freq()
@hybrid_method
def get_periods(cls_or_self, index: tp.Optional[tp.Index] = None) -> int:
"""Get the number of periods in the index, without taking into account its datetime-like properties."""
if not isinstance(cls_or_self, type):
if index is None:
index = cls_or_self.obj
else:
checks.assert_not_none(index, arg_name="index")
return len(index)
@property
def periods(self) -> int:
"""`BaseIDXAccessor.get_periods` with default arguments."""
return len(self.obj)
@hybrid_method
def get_dt_periods(
cls_or_self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.PandasFrequency] = None,
) -> float:
"""Get the number of periods in the index, taking into account its datetime-like properties."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if not isinstance(cls_or_self, type):
if index is None:
index = cls_or_self.obj
else:
checks.assert_not_none(index, arg_name="index")
if isinstance(index, pd.DatetimeIndex):
freq = cls_or_self.get_freq(index=index, freq=freq, allow_offset=True, allow_numeric=False)
if freq is not None:
if not isinstance(freq, pd.Timedelta):
freq = dt.to_timedelta(freq, approximate=True)
return (index[-1] - index[0]) / freq + 1
if not wrapping_cfg["silence_warnings"]:
warn(
"Couldn't parse the frequency of index. Pass it as `freq` or "
"define it globally under `settings.wrapping`."
)
if checks.is_number(index[0]) and checks.is_number(index[-1]):
freq = cls_or_self.get_freq(index=index, freq=freq, allow_offset=False, allow_numeric=True)
if checks.is_number(freq):
return (index[-1] - index[0]) / freq + 1
return index[-1] - index[0] + 1
if not wrapping_cfg["silence_warnings"]:
warn("Index is neither datetime-like nor integer")
return cls_or_self.get_periods(index=index)
@property
def dt_periods(self) -> float:
"""`BaseIDXAccessor.get_dt_periods` with default arguments."""
return self.get_dt_periods()
def arr_to_timedelta(
self,
a: tp.MaybeArray,
to_pd: bool = False,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Union[pd.Index, tp.MaybeArray]:
"""Convert array to duration using `BaseIDXAccessor.freq`."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if silence_warnings is None:
silence_warnings = wrapping_cfg["silence_warnings"]
freq = self.freq
if freq is None:
if not silence_warnings:
warn(
"Couldn't parse the frequency of index. Pass it as `freq` or "
"define it globally under `settings.wrapping`."
)
return a
if not isinstance(freq, pd.Timedelta):
freq = dt.to_timedelta(freq, approximate=True)
if to_pd:
out = pd.to_timedelta(a * freq)
else:
out = a * freq
return out
# ############# Grouping ############# #
def get_grouper(self, by: tp.AnyGroupByLike, groupby_kwargs: tp.KwargsLike = None, **kwargs) -> Grouper:
"""Get an index grouper of type `vectorbtpro.base.grouping.base.Grouper`.
Argument `by` can be a grouper itself, an instance of Pandas `GroupBy`,
an instance of Pandas `Resampler`, but also any supported input to any of them
such as a frequency or an array of indices.
Keyword arguments `groupby_kwargs` are passed to the Pandas methods `groupby` and `resample`,
while `**kwargs` are passed to initialize `vectorbtpro.base.grouping.base.Grouper`."""
if groupby_kwargs is None:
groupby_kwargs = {}
if isinstance(by, Grouper):
if len(kwargs) > 0:
return by.replace(**kwargs)
return by
if isinstance(by, (PandasGroupBy, PandasResampler)):
return Grouper.from_pd_group_by(by, **kwargs)
try:
return Grouper(index=self.obj, group_by=by, **kwargs)
except Exception as e:
pass
if isinstance(self.obj, pd.DatetimeIndex):
try:
return Grouper(index=self.obj, group_by=self.to_period(dt.to_freq(by)), **kwargs)
except Exception as e:
pass
try:
pd_group_by = pd.Series(index=self.obj, dtype=object).resample(dt.to_freq(by), **groupby_kwargs)
return Grouper.from_pd_group_by(pd_group_by, **kwargs)
except Exception as e:
pass
pd_group_by = pd.Series(index=self.obj, dtype=object).groupby(by, axis=0, **groupby_kwargs)
return Grouper.from_pd_group_by(pd_group_by, **kwargs)
def get_resampler(
self,
rule: tp.AnyRuleLike,
freq: tp.Optional[tp.FrequencyLike] = None,
resample_kwargs: tp.KwargsLike = None,
return_pd_resampler: bool = False,
silence_warnings: tp.Optional[bool] = None,
) -> tp.Union[Resampler, tp.PandasResampler]:
"""Get an index resampler of type `vectorbtpro.base.resampling.base.Resampler`."""
if checks.is_frequency_like(rule):
try:
rule = dt.to_freq(rule)
is_td = True
except Exception as e:
is_td = False
if is_td:
resample_kwargs = merge_dicts(
dict(closed="left", label="left"),
resample_kwargs,
)
rule = pd.Series(index=self.obj, dtype=object).resample(rule, **resolve_dict(resample_kwargs))
if isinstance(rule, PandasResampler):
if return_pd_resampler:
return rule
if silence_warnings is None:
silence_warnings = True
rule = Resampler.from_pd_resampler(rule, source_freq=self.freq, silence_warnings=silence_warnings)
if return_pd_resampler:
raise TypeError("Cannot convert Resampler to Pandas Resampler")
if checks.is_dt_like(rule) or checks.is_iterable(rule):
rule = dt.prepare_dt_index(rule)
rule = Resampler(
source_index=self.obj,
target_index=rule,
source_freq=self.freq,
target_freq=freq,
silence_warnings=silence_warnings,
)
if isinstance(rule, Resampler):
if freq is not None:
rule = rule.replace(target_freq=freq)
return rule
raise ValueError(f"Cannot build Resampler from {rule}")
# ############# Points and ranges ############# #
def get_points(self, *args, **kwargs) -> tp.Array1d:
"""See `vectorbtpro.base.indexing.get_index_points`."""
return get_index_points(self.obj, *args, **kwargs)
def get_ranges(self, *args, **kwargs) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""See `vectorbtpro.base.indexing.get_index_ranges`."""
return get_index_ranges(self.obj, self.any_freq, *args, **kwargs)
# ############# Splitting ############# #
def split(self, *args, splitter_cls: tp.Optional[tp.Type[SplitterT]] = None, **kwargs) -> tp.Any:
"""Split using `vectorbtpro.generic.splitting.base.Splitter.split_and_take`.
!!! note
Splits Pandas object, not accessor!"""
from vectorbtpro.generic.splitting.base import Splitter
if splitter_cls is None:
splitter_cls = Splitter
return splitter_cls.split_and_take(self.obj, self.obj, *args, **kwargs)
def split_apply(
self,
apply_func: tp.Callable,
*args,
splitter_cls: tp.Optional[tp.Type[SplitterT]] = None,
**kwargs,
) -> tp.Any:
"""Split using `vectorbtpro.generic.splitting.base.Splitter.split_and_apply`.
!!! note
Splits Pandas object, not accessor!"""
from vectorbtpro.generic.splitting.base import Splitter, Takeable
if splitter_cls is None:
splitter_cls = Splitter
return splitter_cls.split_and_apply(self.obj, apply_func, Takeable(self.obj), *args, **kwargs)
# ############# Chunking ############# #
def chunk(
self: BaseIDXAccessorT,
min_size: tp.Optional[int] = None,
n_chunks: tp.Union[None, int, str] = None,
chunk_len: tp.Union[None, int, str] = None,
chunk_meta: tp.Optional[tp.Iterable[ChunkMeta]] = None,
select: bool = False,
return_chunk_meta: bool = False,
) -> tp.Iterator[tp.Union[tp.Index, tp.Tuple[ChunkMeta, tp.Index]]]:
"""Chunk this instance.
If `axis` is None, becomes 0 if the instance is one-dimensional and 1 otherwise.
For arguments related to chunking meta, see `vectorbtpro.utils.chunking.iter_chunk_meta`.
!!! note
Splits Pandas object, not accessor!"""
if chunk_meta is None:
chunk_meta = iter_chunk_meta(
size=len(self.obj),
min_size=min_size,
n_chunks=n_chunks,
chunk_len=chunk_len
)
for _chunk_meta in chunk_meta:
if select:
array_taker = ArraySelector()
else:
array_taker = ArraySlicer()
if return_chunk_meta:
yield _chunk_meta, array_taker.take(self.obj, _chunk_meta)
else:
yield array_taker.take(self.obj, _chunk_meta)
def chunk_apply(
self: BaseIDXAccessorT,
apply_func: tp.Union[str, tp.Callable],
*args,
chunk_kwargs: tp.KwargsLike = None,
execute_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MergeableResults:
"""Chunk this instance and apply a function to each chunk.
If `apply_func` is a string, becomes the method name.
For arguments related to chunking, see `Wrapping.chunk`.
!!! note
Splits Pandas object, not accessor!"""
if isinstance(apply_func, str):
apply_func = getattr(type(self), apply_func)
if chunk_kwargs is None:
chunk_arg_names = set(get_func_arg_names(self.chunk))
chunk_kwargs = {}
for k in list(kwargs.keys()):
if k in chunk_arg_names:
chunk_kwargs[k] = kwargs.pop(k)
if execute_kwargs is None:
execute_kwargs = {}
chunks = self.chunk(return_chunk_meta=True, **chunk_kwargs)
tasks = []
keys = []
for _chunk_meta, chunk in chunks:
tasks.append(Task(apply_func, chunk, *args, **kwargs))
keys.append(get_chunk_meta_key(_chunk_meta))
keys = pd.Index(keys, name="chunk_indices")
return execute(tasks, size=len(tasks), keys=keys, **execute_kwargs)
BaseAccessorT = tp.TypeVar("BaseAccessorT", bound="BaseAccessor")
@attach_binary_magic_methods(lambda self, other, np_func: self.combine(other, combine_func=np_func))
@attach_unary_magic_methods(lambda self, np_func: self.apply(apply_func=np_func))
class BaseAccessor(Wrapping):
"""Accessor on top of Series and DataFrames.
Accessible via `pd.Series.vbt` and `pd.DataFrame.vbt`, and all child accessors.
Series is just a DataFrame with one column, hence to avoid defining methods exclusively for 1-dim data,
we will convert any Series to a DataFrame and perform matrix computation on it. Afterwards,
by using `BaseAccessor.wrapper`, we will convert the 2-dim output back to a Series.
`**kwargs` will be passed to `vectorbtpro.base.wrapping.ArrayWrapper`.
!!! note
When using magic methods, ensure that `.vbt` is called on the operand on the left
if the other operand is an array.
Accessors do not utilize caching.
Grouping is only supported by the methods that accept the `group_by` argument.
Usage:
* Build a symmetric matrix:
```pycon
>>> from vectorbtpro import *
>>> # vectorbtpro.base.accessors.BaseAccessor.make_symmetric
>>> pd.Series([1, 2, 3]).vbt.make_symmetric()
0 1 2
0 1.0 2.0 3.0
1 2.0 NaN NaN
2 3.0 NaN NaN
```
* Broadcast pandas objects:
```pycon
>>> sr = pd.Series([1])
>>> df = pd.DataFrame([1, 2, 3])
>>> vbt.base.reshaping.broadcast_to(sr, df)
0
0 1
1 1
2 1
>>> sr.vbt.broadcast_to(df)
0
0 1
1 1
2 1
```
* Many methods such as `BaseAccessor.broadcast` are both class and instance methods:
```pycon
>>> from vectorbtpro.base.accessors import BaseAccessor
>>> # Same as sr.vbt.broadcast(df)
>>> new_sr, new_df = BaseAccessor.broadcast(sr, df)
>>> new_sr
0
0 1
1 1
2 1
>>> new_df
0
0 1
1 2
2 3
```
* Instead of explicitly importing `BaseAccessor` or any other accessor, we can use `pd_acc` instead:
```pycon
>>> vbt.pd_acc.broadcast(sr, df)
>>> new_sr
0
0 1
1 1
2 1
>>> new_df
0
0 1
1 2
2 3
```
* `BaseAccessor` implements arithmetic (such as `+`), comparison (such as `>`) and
logical operators (such as `&`) by forwarding the operation to `BaseAccessor.combine`:
```pycon
>>> sr.vbt + df
0
0 2
1 3
2 4
```
Many interesting use cases can be implemented this way.
* For example, let's compare an array with 3 different thresholds:
```pycon
>>> df.vbt > vbt.Param(np.arange(3), name='threshold')
threshold 0 1 2
a2 b2 c2 a2 b2 c2 a2 b2 c2
x2 True True True False True True False False True
y2 True True True True True True True True True
z2 True True True True True True True True True
```
* The same using the broadcasting mechanism:
```pycon
>>> df.vbt > vbt.Param(np.arange(3), name='threshold')
threshold 0 1 2
a2 b2 c2 a2 b2 c2 a2 b2 c2
x2 True True True False True True False False True
y2 True True True True True True True True True
z2 True True True True True True True True True
```
"""
@classmethod
def resolve_row_stack_kwargs(
cls: tp.Type[BaseAccessorT],
*objs: tp.MaybeTuple[BaseAccessorT],
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `BaseAccessor` after stacking along rows."""
if "obj" not in kwargs:
kwargs["obj"] = kwargs["wrapper"].row_stack_arrs(
*[obj.obj for obj in objs],
group_by=False,
wrap=False,
)
return kwargs
@classmethod
def resolve_column_stack_kwargs(
cls: tp.Type[BaseAccessorT],
*objs: tp.MaybeTuple[BaseAccessorT],
reindex_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `BaseAccessor` after stacking along columns."""
if "obj" not in kwargs:
kwargs["obj"] = kwargs["wrapper"].column_stack_arrs(
*[obj.obj for obj in objs],
reindex_kwargs=reindex_kwargs,
group_by=False,
wrap=False,
)
return kwargs
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[BaseAccessorT],
*objs: tp.MaybeTuple[BaseAccessorT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> BaseAccessorT:
"""Stack multiple `BaseAccessor` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, BaseAccessor):
raise TypeError("Each object to be merged must be an instance of BaseAccessor")
if wrapper_kwargs is None:
wrapper_kwargs = {}
if "wrapper" in kwargs and kwargs["wrapper"] is not None:
wrapper = kwargs["wrapper"]
if len(wrapper_kwargs) > 0:
wrapper = wrapper.replace(**wrapper_kwargs)
else:
wrapper = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
kwargs["wrapper"] = wrapper
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
if kwargs["wrapper"].ndim == 1:
return cls.sr_accessor_cls(**kwargs)
return cls.df_accessor_cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[BaseAccessorT],
*objs: tp.MaybeTuple[BaseAccessorT],
wrapper_kwargs: tp.KwargsLike = None,
reindex_kwargs: tp.KwargsLike = None,
**kwargs,
) -> BaseAccessorT:
"""Stack multiple `BaseAccessor` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, BaseAccessor):
raise TypeError("Each object to be merged must be an instance of BaseAccessor")
if wrapper_kwargs is None:
wrapper_kwargs = {}
if "wrapper" in kwargs and kwargs["wrapper"] is not None:
wrapper = kwargs["wrapper"]
if len(wrapper_kwargs) > 0:
wrapper = wrapper.replace(**wrapper_kwargs)
else:
wrapper = ArrayWrapper.column_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
kwargs["wrapper"] = wrapper
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls.df_accessor_cls(**kwargs)
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
**kwargs,
) -> None:
if len(kwargs) > 0:
wrapper_kwargs, kwargs = ArrayWrapper.extract_init_kwargs(**kwargs)
else:
wrapper_kwargs, kwargs = {}, {}
if not isinstance(wrapper, ArrayWrapper):
if obj is not None:
raise ValueError("Must either provide wrapper and object, or only object")
wrapper, obj = ArrayWrapper.from_obj(wrapper, **wrapper_kwargs), wrapper
else:
if obj is None:
raise ValueError("Must either provide wrapper and object, or only object")
if len(wrapper_kwargs) > 0:
wrapper = wrapper.replace(**wrapper_kwargs)
Wrapping.__init__(self, wrapper, obj=obj, **kwargs)
self._obj = obj
def __call__(self: BaseAccessorT, **kwargs) -> BaseAccessorT:
"""Allows passing arguments to the initializer."""
return self.replace(**kwargs)
@hybrid_property
def sr_accessor_cls(cls_or_self) -> tp.Type["BaseSRAccessor"]:
"""Accessor class for `pd.Series`."""
return BaseSRAccessor
@hybrid_property
def df_accessor_cls(cls_or_self) -> tp.Type["BaseDFAccessor"]:
"""Accessor class for `pd.DataFrame`."""
return BaseDFAccessor
def indexing_func(self: BaseAccessorT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> BaseAccessorT:
"""Perform indexing on `BaseAccessor`."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs)
new_obj = ArrayWrapper.select_from_flex_array(
self._obj,
row_idxs=wrapper_meta["row_idxs"],
col_idxs=wrapper_meta["col_idxs"],
rows_changed=wrapper_meta["rows_changed"],
columns_changed=wrapper_meta["columns_changed"],
)
if checks.is_series(new_obj):
return self.replace(cls_=self.sr_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj)
return self.replace(cls_=self.df_accessor_cls, wrapper=wrapper_meta["new_wrapper"], obj=new_obj)
def indexing_setter_func(self, pd_indexing_setter_func: tp.Callable, **kwargs) -> None:
"""Perform indexing setter on `BaseAccessor`."""
pd_indexing_setter_func(self._obj)
@property
def obj(self) -> tp.SeriesFrame:
"""Pandas object."""
if isinstance(self._obj, (pd.Series, pd.DataFrame)):
if self._obj.shape == self.wrapper.shape:
if self._obj.index is self.wrapper.index:
if isinstance(self._obj, pd.Series) and self._obj.name == self.wrapper.name:
return self._obj
if isinstance(self._obj, pd.DataFrame) and self._obj.columns is self.wrapper.columns:
return self._obj
return self.wrapper.wrap(self._obj, group_by=False)
def get(self, key: tp.Optional[tp.Hashable] = None, default: tp.Optional[tp.Any] = None) -> tp.SeriesFrame:
"""Get `BaseAccessor.obj`."""
if key is None:
return self.obj
return self.obj.get(key, default=default)
@property
def unwrapped(self) -> tp.SeriesFrame:
return self.obj
@hybrid_method
def should_wrap(cls_or_self) -> bool:
return False
@hybrid_property
def ndim(cls_or_self) -> tp.Optional[int]:
"""Number of dimensions in the object.
1 -> Series, 2 -> DataFrame."""
if isinstance(cls_or_self, type):
return None
return cls_or_self.obj.ndim
@hybrid_method
def is_series(cls_or_self) -> bool:
"""Whether the object is a Series."""
if isinstance(cls_or_self, type):
raise NotImplementedError
return isinstance(cls_or_self.obj, pd.Series)
@hybrid_method
def is_frame(cls_or_self) -> bool:
"""Whether the object is a DataFrame."""
if isinstance(cls_or_self, type):
raise NotImplementedError
return isinstance(cls_or_self.obj, pd.DataFrame)
@classmethod
def resolve_shape(cls, shape: tp.ShapeLike) -> tp.Shape:
"""Resolve shape."""
shape_2d = reshaping.to_2d_shape(shape)
try:
if cls.is_series() and shape_2d[1] > 1:
raise ValueError("Use DataFrame accessor")
except NotImplementedError:
pass
return shape_2d
# ############# Creation ############# #
@classmethod
def empty(cls, shape: tp.Shape, fill_value: tp.Scalar = np.nan, **kwargs) -> tp.SeriesFrame:
"""Generate an empty Series/DataFrame of shape `shape` and fill with `fill_value`."""
if not isinstance(shape, tuple) or (isinstance(shape, tuple) and len(shape) == 1):
return pd.Series(np.full(shape, fill_value), **kwargs)
return pd.DataFrame(np.full(shape, fill_value), **kwargs)
@classmethod
def empty_like(cls, other: tp.SeriesFrame, fill_value: tp.Scalar = np.nan, **kwargs) -> tp.SeriesFrame:
"""Generate an empty Series/DataFrame like `other` and fill with `fill_value`."""
if checks.is_series(other):
return cls.empty(other.shape, fill_value=fill_value, index=other.index, name=other.name, **kwargs)
return cls.empty(other.shape, fill_value=fill_value, index=other.index, columns=other.columns, **kwargs)
# ############# Indexes ############# #
def apply_to_index(
self: BaseAccessorT,
*args,
wrap: bool = False,
**kwargs,
) -> tp.Union[BaseAccessorT, tp.SeriesFrame]:
"""See `vectorbtpro.base.wrapping.Wrapping.apply_to_index`.
!!! note
If `wrap` is False, returns Pandas object, not accessor!"""
result = Wrapping.apply_to_index(self, *args, **kwargs)
if wrap:
return result
return result.obj
# ############# Setting ############# #
def set(
self,
value_or_func: tp.Union[tp.MaybeArray, tp.Callable],
*args,
inplace: bool = False,
columns: tp.Optional[tp.MaybeSequence[tp.Hashable]] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Optional[tp.SeriesFrame]:
"""Set value at each index point using `vectorbtpro.base.indexing.get_index_points`.
If `value_or_func` is a function, selects all keyword arguments that were not passed
to the `get_index_points` method, substitutes any templates, and passes everything to the function.
As context uses `kwargs`, `template_context`, and various variables such as `i` (iteration index),
`index_point` (absolute position in the index), `wrapper`, and `obj`."""
if inplace:
obj = self.obj
else:
obj = self.obj.copy()
index_points = get_index_points(self.wrapper.index, **kwargs)
if callable(value_or_func):
func_kwargs = {k: v for k, v in kwargs.items() if k not in point_idxr_defaults}
template_context = merge_dicts(kwargs, template_context)
else:
func_kwargs = None
if callable(value_or_func):
for i in range(len(index_points)):
_template_context = merge_dicts(
dict(
i=i,
index_point=index_points[i],
index_points=index_points,
wrapper=self.wrapper,
obj=self.obj,
columns=columns,
args=args,
kwargs=kwargs,
),
template_context,
)
_func_args = substitute_templates(args, _template_context, eval_id="func_args")
_func_kwargs = substitute_templates(func_kwargs, _template_context, eval_id="func_kwargs")
v = value_or_func(*_func_args, **_func_kwargs)
if self.is_series() or columns is None:
obj.iloc[index_points[i]] = v
elif is_scalar(columns):
obj.iloc[index_points[i], obj.columns.get_indexer([columns])[0]] = v
else:
obj.iloc[index_points[i], obj.columns.get_indexer(columns)] = v
elif checks.is_sequence(value_or_func) and not is_scalar(value_or_func):
if self.is_series():
obj.iloc[index_points] = reshaping.to_1d_array(value_or_func)
elif columns is None:
obj.iloc[index_points] = reshaping.to_2d_array(value_or_func)
elif is_scalar(columns):
obj.iloc[index_points, obj.columns.get_indexer([columns])[0]] = reshaping.to_1d_array(value_or_func)
else:
obj.iloc[index_points, obj.columns.get_indexer(columns)] = reshaping.to_2d_array(value_or_func)
else:
if self.is_series() or columns is None:
obj.iloc[index_points] = value_or_func
elif is_scalar(columns):
obj.iloc[index_points, obj.columns.get_indexer([columns])[0]] = value_or_func
else:
obj.iloc[index_points, obj.columns.get_indexer(columns)] = value_or_func
if inplace:
return None
return obj
def set_between(
self,
value_or_func: tp.Union[tp.MaybeArray, tp.Callable],
*args,
inplace: bool = False,
columns: tp.Optional[tp.MaybeSequence[tp.Hashable]] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Optional[tp.SeriesFrame]:
"""Set value at each index range using `vectorbtpro.base.indexing.get_index_ranges`.
If `value_or_func` is a function, selects all keyword arguments that were not passed
to the `get_index_points` method, substitutes any templates, and passes everything to the function.
As context uses `kwargs`, `template_context`, and various variables such as `i` (iteration index),
`index_slice` (absolute slice of the index), `wrapper`, and `obj`."""
if inplace:
obj = self.obj
else:
obj = self.obj.copy()
index_ranges = get_index_ranges(self.wrapper.index, **kwargs)
if callable(value_or_func):
func_kwargs = {k: v for k, v in kwargs.items() if k not in range_idxr_defaults}
template_context = merge_dicts(kwargs, template_context)
else:
func_kwargs = None
for i in range(len(index_ranges[0])):
if callable(value_or_func):
_template_context = merge_dicts(
dict(
i=i,
index_slice=slice(index_ranges[0][i], index_ranges[1][i]),
range_starts=index_ranges[0],
range_ends=index_ranges[1],
wrapper=self.wrapper,
obj=self.obj,
columns=columns,
args=args,
kwargs=kwargs,
),
template_context,
)
_func_args = substitute_templates(args, _template_context, eval_id="func_args")
_func_kwargs = substitute_templates(func_kwargs, _template_context, eval_id="func_kwargs")
v = value_or_func(*_func_args, **_func_kwargs)
elif checks.is_sequence(value_or_func) and not isinstance(value_or_func, str):
v = value_or_func[i]
else:
v = value_or_func
if self.is_series() or columns is None:
obj.iloc[index_ranges[0][i] : index_ranges[1][i]] = v
elif is_scalar(columns):
obj.iloc[index_ranges[0][i] : index_ranges[1][i], obj.columns.get_indexer([columns])[0]] = v
else:
obj.iloc[index_ranges[0][i] : index_ranges[1][i], obj.columns.get_indexer(columns)] = v
if inplace:
return None
return obj
# ############# Reshaping ############# #
def to_1d_array(self) -> tp.Array1d:
"""See `vectorbtpro.base.reshaping.to_1d` with `raw` set to True."""
return reshaping.to_1d_array(self.obj)
def to_2d_array(self) -> tp.Array2d:
"""See `vectorbtpro.base.reshaping.to_2d` with `raw` set to True."""
return reshaping.to_2d_array(self.obj)
def tile(
self,
n: int,
keys: tp.Optional[tp.IndexLike] = None,
axis: int = 1,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.base.reshaping.tile`.
Set `axis` to 1 for columns and 0 for index.
Use `keys` as the outermost level."""
tiled = reshaping.tile(self.obj, n, axis=axis)
if keys is not None:
if axis == 1:
new_columns = indexes.combine_indexes([keys, self.wrapper.columns])
return ArrayWrapper.from_obj(tiled).wrap(
tiled.values,
**merge_dicts(dict(columns=new_columns), wrap_kwargs),
)
else:
new_index = indexes.combine_indexes([keys, self.wrapper.index])
return ArrayWrapper.from_obj(tiled).wrap(
tiled.values,
**merge_dicts(dict(index=new_index), wrap_kwargs),
)
return tiled
def repeat(
self,
n: int,
keys: tp.Optional[tp.IndexLike] = None,
axis: int = 1,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.base.reshaping.repeat`.
Set `axis` to 1 for columns and 0 for index.
Use `keys` as the outermost level."""
repeated = reshaping.repeat(self.obj, n, axis=axis)
if keys is not None:
if axis == 1:
new_columns = indexes.combine_indexes([self.wrapper.columns, keys])
return ArrayWrapper.from_obj(repeated).wrap(
repeated.values,
**merge_dicts(dict(columns=new_columns), wrap_kwargs),
)
else:
new_index = indexes.combine_indexes([self.wrapper.index, keys])
return ArrayWrapper.from_obj(repeated).wrap(
repeated.values,
**merge_dicts(dict(index=new_index), wrap_kwargs),
)
return repeated
def align_to(self, other: tp.SeriesFrame, wrap_kwargs: tp.KwargsLike = None, **kwargs) -> tp.SeriesFrame:
"""Align to `other` on their axes using `vectorbtpro.base.indexes.align_index_to`.
Usage:
```pycon
>>> df1 = pd.DataFrame(
... [[1, 2], [3, 4]],
... index=['x', 'y'],
... columns=['a', 'b']
... )
>>> df1
a b
x 1 2
y 3 4
>>> df2 = pd.DataFrame(
... [[5, 6, 7, 8], [9, 10, 11, 12]],
... index=['x', 'y'],
... columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']])
... )
>>> df2
1 2
a b a b
x 5 6 7 8
y 9 10 11 12
>>> df1.vbt.align_to(df2)
1 2
a b a b
x 1 2 1 2
y 3 4 3 4
```
"""
checks.assert_instance_of(other, (pd.Series, pd.DataFrame))
obj = reshaping.to_2d(self.obj)
other = reshaping.to_2d(other)
aligned_index = indexes.align_index_to(obj.index, other.index, **kwargs)
aligned_columns = indexes.align_index_to(obj.columns, other.columns, **kwargs)
obj = obj.iloc[aligned_index, aligned_columns]
return self.wrapper.wrap(
obj.values,
group_by=False,
**merge_dicts(dict(index=other.index, columns=other.columns), wrap_kwargs),
)
@hybrid_method
def align(
cls_or_self,
*others: tp.Union[tp.SeriesFrame, "BaseAccessor"],
**kwargs,
) -> tp.Tuple[tp.SeriesFrame, ...]:
"""Align objects using `vectorbtpro.base.indexes.align_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
objs_2d = list(map(reshaping.to_2d, objs))
index_slices, new_index = indexes.align_indexes(
*map(lambda x: x.index, objs_2d),
return_new_index=True,
**kwargs,
)
column_slices, new_columns = indexes.align_indexes(
*map(lambda x: x.columns, objs_2d),
return_new_index=True,
**kwargs,
)
new_objs = []
for i in range(len(objs_2d)):
new_obj = objs_2d[i].iloc[index_slices[i], column_slices[i]].copy(deep=False)
if objs[i].ndim == 1 and new_obj.shape[1] == 1:
new_obj = new_obj.iloc[:, 0].rename(objs[i].name)
new_obj.index = new_index
new_obj.columns = new_columns
new_objs.append(new_obj)
return tuple(new_objs)
def cross_with(self, other: tp.SeriesFrame, wrap_kwargs: tp.KwargsLike = None) -> tp.SeriesFrame:
"""Align to `other` on their axes using `vectorbtpro.base.indexes.cross_index_with`.
Usage:
```pycon
>>> df1 = pd.DataFrame(
... [[1, 2, 3, 4], [5, 6, 7, 8]],
... index=['x', 'y'],
... columns=pd.MultiIndex.from_arrays([[1, 1, 2, 2], ['a', 'b', 'a', 'b']])
... )
>>> df1
1 2
a b a b
x 1 2 3 4
y 5 6 7 8
>>> df2 = pd.DataFrame(
... [[9, 10, 11, 12], [13, 14, 15, 16]],
... index=['x', 'y'],
... columns=pd.MultiIndex.from_arrays([[3, 3, 4, 4], ['a', 'b', 'a', 'b']])
... )
>>> df2
3 4
a b a b
x 9 10 11 12
y 13 14 15 16
>>> df1.vbt.cross_with(df2)
1 2
3 4 3 4
a b a b a b a b
x 1 2 1 2 3 4 3 4
y 5 6 5 6 7 8 7 8
```
"""
checks.assert_instance_of(other, (pd.Series, pd.DataFrame))
obj = reshaping.to_2d(self.obj)
other = reshaping.to_2d(other)
index_slices, new_index = indexes.cross_index_with(
obj.index,
other.index,
return_new_index=True,
)
column_slices, new_columns = indexes.cross_index_with(
obj.columns,
other.columns,
return_new_index=True,
)
obj = obj.iloc[index_slices[0], column_slices[0]]
return self.wrapper.wrap(
obj.values,
group_by=False,
**merge_dicts(dict(index=new_index, columns=new_columns), wrap_kwargs),
)
@hybrid_method
def cross(cls_or_self, *others: tp.Union[tp.SeriesFrame, "BaseAccessor"]) -> tp.Tuple[tp.SeriesFrame, ...]:
"""Align objects using `vectorbtpro.base.indexes.cross_indexes`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
objs_2d = list(map(reshaping.to_2d, objs))
index_slices, new_index = indexes.cross_indexes(
*map(lambda x: x.index, objs_2d),
return_new_index=True,
)
column_slices, new_columns = indexes.cross_indexes(
*map(lambda x: x.columns, objs_2d),
return_new_index=True,
)
new_objs = []
for i in range(len(objs_2d)):
new_obj = objs_2d[i].iloc[index_slices[i], column_slices[i]].copy(deep=False)
if objs[i].ndim == 1 and new_obj.shape[1] == 1:
new_obj = new_obj.iloc[:, 0].rename(objs[i].name)
new_obj.index = new_index
new_obj.columns = new_columns
new_objs.append(new_obj)
return tuple(new_objs)
x = cross
@hybrid_method
def broadcast(cls_or_self, *others: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any:
"""See `vectorbtpro.base.reshaping.broadcast`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return reshaping.broadcast(*objs, **kwargs)
def broadcast_to(self, other: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any:
"""See `vectorbtpro.base.reshaping.broadcast_to`."""
if isinstance(other, BaseAccessor):
other = other.obj
return reshaping.broadcast_to(self.obj, other, **kwargs)
@hybrid_method
def broadcast_combs(cls_or_self, *others: tp.Union[tp.ArrayLike, "BaseAccessor"], **kwargs) -> tp.Any:
"""See `vectorbtpro.base.reshaping.broadcast_combs`."""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj, *others)
return reshaping.broadcast_combs(*objs, **kwargs)
def make_symmetric(self, *args, **kwargs) -> tp.Frame:
"""See `vectorbtpro.base.reshaping.make_symmetric`."""
return reshaping.make_symmetric(self.obj, *args, **kwargs)
def unstack_to_array(self, *args, **kwargs) -> tp.Array:
"""See `vectorbtpro.base.reshaping.unstack_to_array`."""
return reshaping.unstack_to_array(self.obj, *args, **kwargs)
def unstack_to_df(self, *args, **kwargs) -> tp.Frame:
"""See `vectorbtpro.base.reshaping.unstack_to_df`."""
return reshaping.unstack_to_df(self.obj, *args, **kwargs)
def to_dict(self, *args, **kwargs) -> tp.Mapping:
"""See `vectorbtpro.base.reshaping.to_dict`."""
return reshaping.to_dict(self.obj, *args, **kwargs)
# ############# Conversion ############# #
def to_data(
self,
data_cls: tp.Optional[tp.Type[DataT]] = None,
columns_are_symbols: bool = True,
**kwargs,
) -> DataT:
"""Convert to a `vectorbtpro.data.base.Data` instance."""
if data_cls is None:
from vectorbtpro.data.base import Data
data_cls = Data
return data_cls.from_data(self.obj, columns_are_symbols=columns_are_symbols, **kwargs)
# ############# Combining ############# #
def apply(
self,
apply_func: tp.Callable,
*args,
keep_pd: bool = False,
to_2d: bool = False,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Apply a function `apply_func`.
Set `keep_pd` to True to keep inputs as pandas objects, otherwise convert to NumPy arrays.
Set `to_2d` to True to reshape inputs to 2-dim arrays, otherwise keep as-is.
`*args` and `**kwargs` are passed to `apply_func`.
!!! note
The resulted array must have the same shape as the original array.
Usage:
* Using instance method:
```pycon
>>> sr = pd.Series([1, 2], index=['x', 'y'])
>>> sr.vbt.apply(lambda x: x ** 2)
x 1
y 4
dtype: int64
```
* Using class method, templates, and broadcasting:
```pycon
>>> sr.vbt.apply(
... lambda x, y: x + y,
... vbt.Rep('y'),
... broadcast_named_args=dict(
... y=pd.DataFrame([[3, 4]], columns=['a', 'b'])
... )
... )
a b
x 4 5
y 5 6
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
broadcast_named_args = {"obj": self.obj, **broadcast_named_args}
if len(broadcast_named_args) > 1:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
wrapper = self.wrapper
if to_2d:
broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()}
elif not keep_pd:
broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()}
template_context = merge_dicts(broadcast_named_args, template_context)
args = substitute_templates(args, template_context, eval_id="args")
kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs")
out = apply_func(broadcast_named_args["obj"], *args, **kwargs)
return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def concat(
cls_or_self,
*others: tp.ArrayLike,
broadcast_kwargs: tp.KwargsLike = None,
keys: tp.Optional[tp.IndexLike] = None,
) -> tp.Frame:
"""Concatenate with `others` along columns.
Usage:
```pycon
>>> sr = pd.Series([1, 2], index=['x', 'y'])
>>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
>>> sr.vbt.concat(df, keys=['c', 'd'])
c d
a b a b
x 1 1 3 4
y 2 2 5 6
```
"""
others = tuple(map(lambda x: x.obj if isinstance(x, BaseAccessor) else x, others))
if isinstance(cls_or_self, type):
objs = others
else:
objs = (cls_or_self.obj,) + others
if broadcast_kwargs is None:
broadcast_kwargs = {}
broadcasted = reshaping.broadcast(*objs, **broadcast_kwargs)
broadcasted = tuple(map(reshaping.to_2d, broadcasted))
out = pd.concat(broadcasted, axis=1, keys=keys)
if not isinstance(out.columns, pd.MultiIndex) and np.all(out.columns == 0):
out.columns = pd.RangeIndex(start=0, stop=len(out.columns), step=1)
return out
def apply_and_concat(
self,
ntimes: int,
apply_func: tp.Callable,
*args,
keep_pd: bool = False,
to_2d: bool = False,
keys: tp.Optional[tp.IndexLike] = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeTuple[tp.Frame]:
"""Apply `apply_func` `ntimes` times and concatenate the results along columns.
See `vectorbtpro.base.combining.apply_and_concat`.
`ntimes` is the number of times to call `apply_func`, while `n_outputs` is the number of outputs to expect.
`*args` and `**kwargs` are passed to `vectorbtpro.base.combining.apply_and_concat`.
!!! note
The resulted arrays to be concatenated must have the same shape as broadcast input arrays.
Usage:
* Using instance method:
```pycon
>>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
>>> df.vbt.apply_and_concat(
... 3,
... lambda i, a, b: a * b[i],
... [1, 2, 3],
... keys=['c', 'd', 'e']
... )
c d e
a b a b a b
x 3 4 6 8 9 12
y 5 6 10 12 15 18
```
* Using class method, templates, and broadcasting:
```pycon
>>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z'])
>>> sr.vbt.apply_and_concat(
... 3,
... lambda i, a, b: a * b + i,
... vbt.Rep('df'),
... broadcast_named_args=dict(
... df=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
apply_idx 0 1 2
a b c a b c a b c
x 1 2 3 2 3 4 3 4 5
y 2 4 6 3 5 7 4 6 8
z 3 6 9 4 7 10 5 8 11
```
* To change the execution engine or specify other engine-related arguments, use `execute_kwargs`:
```pycon
>>> import time
>>> def apply_func(i, a):
... time.sleep(1)
... return a
>>> sr = pd.Series([1, 2, 3])
>>> %timeit sr.vbt.apply_and_concat(3, apply_func)
3.02 s ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit sr.vbt.apply_and_concat(3, apply_func, execute_kwargs=dict(engine='dask'))
1.02 s ± 927 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
broadcast_named_args = {"obj": self.obj, **broadcast_named_args}
if len(broadcast_named_args) > 1:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
wrapper = self.wrapper
if to_2d:
broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()}
elif not keep_pd:
broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()}
template_context = merge_dicts(broadcast_named_args, dict(ntimes=ntimes), template_context)
args = substitute_templates(args, template_context, eval_id="args")
kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs")
out = combining.apply_and_concat(ntimes, apply_func, broadcast_named_args["obj"], *args, **kwargs)
if keys is not None:
new_columns = indexes.combine_indexes([keys, wrapper.columns])
else:
top_columns = pd.Index(np.arange(ntimes), name="apply_idx")
new_columns = indexes.combine_indexes([top_columns, wrapper.columns])
if out is None:
return None
wrap_kwargs = merge_dicts(dict(columns=new_columns), wrap_kwargs)
if isinstance(out, list):
return tuple(map(lambda x: wrapper.wrap(x, group_by=False, **wrap_kwargs), out))
return wrapper.wrap(out, group_by=False, **wrap_kwargs)
@hybrid_method
def combine(
cls_or_self,
obj: tp.MaybeTupleList[tp.Union[tp.ArrayLike, "BaseAccessor"]],
combine_func: tp.Callable,
*args,
allow_multiple: bool = True,
keep_pd: bool = False,
to_2d: bool = False,
concat: tp.Optional[bool] = None,
keys: tp.Optional[tp.IndexLike] = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Combine with `other` using `combine_func`.
Args:
obj (array_like): Object(s) to combine this array with.
combine_func (callable): Function to combine two arrays.
Can be Numba-compiled.
*args: Variable arguments passed to `combine_func`.
allow_multiple (bool): Whether a tuple/list/Index will be considered as multiple objects in `other`.
Takes effect only when using the instance method.
keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays.
to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is.
concat (bool): Whether to concatenate the results along the column axis.
Otherwise, pairwise combine into a Series/DataFrame of the same shape.
If True, see `vectorbtpro.base.combining.combine_and_concat`.
If False, see `vectorbtpro.base.combining.combine_multiple`.
If None, becomes True if there are multiple objects to combine.
Can only concatenate using the instance method.
keys (index_like): Outermost column level.
broadcast_named_args (dict): Dictionary with arguments to broadcast against each other.
broadcast_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`.
template_context (dict): Context used to substitute templates in `args` and `kwargs`.
wrap_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper.wrap`.
**kwargs: Keyword arguments passed to `combine_func`.
!!! note
If `combine_func` is Numba-compiled, will broadcast using `WRITEABLE` and `C_CONTIGUOUS`
flags, which can lead to an expensive computation overhead if passed objects are large and
have different shape/memory order. You also must ensure that all objects have the same data type.
Also remember to bring each in `*args` to a Numba-compatible format.
Usage:
* Using instance method:
```pycon
>>> sr = pd.Series([1, 2], index=['x', 'y'])
>>> df = pd.DataFrame([[3, 4], [5, 6]], index=['x', 'y'], columns=['a', 'b'])
>>> # using instance method
>>> sr.vbt.combine(df, np.add)
a b
x 4 5
y 7 8
>>> sr.vbt.combine([df, df * 2], np.add, concat=False)
a b
x 10 13
y 17 20
>>> sr.vbt.combine([df, df * 2], np.add)
combine_idx 0 1
a b a b
x 4 5 7 9
y 7 8 12 14
>>> sr.vbt.combine([df, df * 2], np.add, keys=['c', 'd'])
c d
a b a b
x 4 5 7 9
y 7 8 12 14
>>> sr.vbt.combine(vbt.Param([1, 2], name='param'), np.add)
param 1 2
x 2 3
y 3 4
>>> # using class method
>>> sr.vbt.combine([df, df * 2], np.add, concat=False)
a b
x 10 13
y 17 20
```
* Using class method, templates, and broadcasting:
```pycon
>>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z'])
>>> sr.vbt.combine(
... [1, 2, 3],
... lambda x, y, z: x + y + z,
... vbt.Rep('df'),
... broadcast_named_args=dict(
... df=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
combine_idx 0 1 2
a b c a b c a b c
x 3 4 5 4 5 6 5 6 7
y 4 5 6 5 6 7 6 7 8
z 5 6 7 6 7 8 7 8 9
```
* To change the execution engine or specify other engine-related arguments, use `execute_kwargs`:
```pycon
>>> import time
>>> def combine_func(a, b):
... time.sleep(1)
... return a + b
>>> sr = pd.Series([1, 2, 3])
>>> %timeit sr.vbt.combine([1, 1, 1], combine_func)
3.01 s ± 2.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit sr.vbt.combine([1, 1, 1], combine_func, execute_kwargs=dict(engine='dask'))
1.02 s ± 2.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
"""
from vectorbtpro.indicators.factory import IndicatorBase
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(cls_or_self, type):
objs = obj
else:
if allow_multiple and isinstance(obj, (tuple, list)):
objs = obj
if concat is None:
concat = True
else:
objs = (obj,)
new_objs = []
for obj in objs:
if isinstance(obj, BaseAccessor):
obj = obj.obj
elif isinstance(obj, IndicatorBase):
obj = obj.main_output
new_objs.append(obj)
objs = tuple(new_objs)
if not isinstance(cls_or_self, type):
objs = (cls_or_self.obj,) + objs
if checks.is_numba_func(combine_func):
# Numba requires writeable arrays and in the same order
broadcast_kwargs = merge_dicts(dict(require_kwargs=dict(requirements=["W", "C"])), broadcast_kwargs)
# Broadcast and substitute templates
broadcast_named_args = {**{"obj_" + str(i): obj for i, obj in enumerate(objs)}, **broadcast_named_args}
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
if to_2d:
broadcast_named_args = {k: reshaping.to_2d(v, raw=not keep_pd) for k, v in broadcast_named_args.items()}
elif not keep_pd:
broadcast_named_args = {k: np.asarray(v) for k, v in broadcast_named_args.items()}
template_context = merge_dicts(broadcast_named_args, template_context)
args = substitute_templates(args, template_context, eval_id="args")
kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs")
inputs = [broadcast_named_args["obj_" + str(i)] for i in range(len(objs))]
if concat is None:
concat = len(inputs) > 2
if concat:
# Concat the results horizontally
if isinstance(cls_or_self, type):
raise TypeError("Use instance method to concatenate")
out = combining.combine_and_concat(inputs[0], inputs[1:], combine_func, *args, **kwargs)
if keys is not None:
new_columns = indexes.combine_indexes([keys, wrapper.columns])
else:
top_columns = pd.Index(np.arange(len(objs) - 1), name="combine_idx")
new_columns = indexes.combine_indexes([top_columns, wrapper.columns])
return wrapper.wrap(out, **merge_dicts(dict(columns=new_columns, force_2d=True), wrap_kwargs))
else:
# Combine arguments pairwise into one object
out = combining.combine_multiple(inputs, combine_func, *args, **kwargs)
return wrapper.wrap(out, **resolve_dict(wrap_kwargs))
@classmethod
def eval(
cls,
expr: str,
frames_back: int = 1,
use_numexpr: bool = False,
numexpr_kwargs: tp.KwargsLike = None,
local_dict: tp.Optional[tp.Mapping] = None,
global_dict: tp.Optional[tp.Mapping] = None,
broadcast_kwargs: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
):
"""Evaluate a simple array expression element-wise using NumExpr or NumPy.
If NumExpr is enables, only one-line statements are supported. Otherwise, uses
`vectorbtpro.utils.eval_.evaluate`.
!!! note
All required variables will broadcast against each other prior to the evaluation.
Usage:
```pycon
>>> sr = pd.Series([1, 2, 3], index=['x', 'y', 'z'])
>>> df = pd.DataFrame([[4, 5, 6]], index=['x', 'y', 'z'], columns=['a', 'b', 'c'])
>>> vbt.pd_acc.eval('sr + df')
a b c
x 5 6 7
y 6 7 8
z 7 8 9
```
"""
if numexpr_kwargs is None:
numexpr_kwargs = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if wrap_kwargs is None:
wrap_kwargs = {}
expr = inspect.cleandoc(expr)
parsed = ast.parse(expr)
body_nodes = list(parsed.body)
load_vars = set()
store_vars = set()
for body_node in body_nodes:
for child_node in ast.walk(body_node):
if type(child_node) is ast.Name:
if isinstance(child_node.ctx, ast.Load):
if child_node.id not in store_vars:
load_vars.add(child_node.id)
if isinstance(child_node.ctx, ast.Store):
store_vars.add(child_node.id)
load_vars = list(load_vars)
objs = get_context_vars(load_vars, frames_back=frames_back, local_dict=local_dict, global_dict=global_dict)
objs = dict(zip(load_vars, objs))
objs, wrapper = reshaping.broadcast(objs, return_wrapper=True, **broadcast_kwargs)
objs = {k: np.asarray(v) for k, v in objs.items()}
if use_numexpr:
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("numexpr")
import numexpr
out = numexpr.evaluate(expr, local_dict=objs, **numexpr_kwargs)
else:
out = evaluate(expr, context=objs)
return wrapper.wrap(out, **wrap_kwargs)
class BaseSRAccessor(BaseAccessor):
"""Accessor on top of Series.
Accessible via `pd.Series.vbt` and all child accessors."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
if _full_init:
if isinstance(wrapper, ArrayWrapper):
if wrapper.ndim == 2:
if wrapper.shape[1] == 1:
wrapper = wrapper.replace(ndim=1)
else:
raise TypeError("Series accessors work only one one-dimensional data")
BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs)
@hybrid_property
def ndim(cls_or_self) -> int:
return 1
@hybrid_method
def is_series(cls_or_self) -> bool:
return True
@hybrid_method
def is_frame(cls_or_self) -> bool:
return False
class BaseDFAccessor(BaseAccessor):
"""Accessor on top of DataFrames.
Accessible via `pd.DataFrame.vbt` and all child accessors."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
if _full_init:
if isinstance(wrapper, ArrayWrapper):
if wrapper.ndim == 1:
wrapper = wrapper.replace(ndim=2)
BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs)
@hybrid_property
def ndim(cls_or_self) -> int:
return 2
@hybrid_method
def is_series(cls_or_self) -> bool:
return False
@hybrid_method
def is_frame(cls_or_self) -> bool:
return True
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Extensions for chunking of base operations."""
import uuid
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.chunking import (
ArgGetter,
ArgSizer,
ArraySizer,
ChunkMeta,
ChunkMapper,
ChunkSlicer,
ShapeSlicer,
ArraySelector,
ArraySlicer,
Chunked,
)
from vectorbtpro.utils.parsing import Regex
__all__ = [
"GroupLensSizer",
"GroupLensSlicer",
"ChunkedGroupLens",
"GroupLensMapper",
"GroupMapSlicer",
"ChunkedGroupMap",
"GroupIdxsMapper",
"FlexArraySizer",
"FlexArraySelector",
"FlexArraySlicer",
"ChunkedFlexArray",
"shape_gl_slicer",
"flex_1d_array_gl_slicer",
"flex_array_gl_slicer",
"array_gl_slicer",
]
class GroupLensSizer(ArgSizer):
"""Class for getting the size from group lengths.
Argument can be either a group map tuple or a group lengths array."""
@classmethod
def get_obj_size(cls, obj: tp.Union[tp.GroupLens, tp.GroupMap], single_type: tp.Optional[type] = None) -> int:
"""Get size of an object."""
if single_type is not None:
if checks.is_instance_of(obj, single_type):
return 1
if isinstance(obj, tuple):
return len(obj[1])
return len(obj)
def get_size(self, ann_args: tp.AnnArgs, **kwargs) -> int:
return self.get_obj_size(self.get_arg(ann_args), single_type=self.single_type)
class GroupLensSlicer(ChunkSlicer):
"""Class for slicing multiple elements from group lengths based on the chunk range."""
def get_size(self, obj: tp.Union[tp.GroupLens, tp.GroupMap], **kwargs) -> int:
return GroupLensSizer.get_obj_size(obj, single_type=self.single_type)
def take(self, obj: tp.Union[tp.GroupLens, tp.GroupMap], chunk_meta: ChunkMeta, **kwargs) -> tp.GroupMap:
if isinstance(obj, tuple):
return obj[1][chunk_meta.start : chunk_meta.end]
return obj[chunk_meta.start : chunk_meta.end]
class ChunkedGroupLens(Chunked):
"""Class representing chunkable group lengths."""
def resolve_take_spec(self) -> tp.TakeSpec:
if self.take_spec_missing:
if self.select:
raise ValueError("Selection is not supported")
return GroupLensSlicer
return self.take_spec
def get_group_lens_slice(group_lens: tp.GroupLens, chunk_meta: ChunkMeta) -> slice:
"""Get slice of each chunk in group lengths."""
group_lens_cumsum = np.cumsum(group_lens[: chunk_meta.end])
start = group_lens_cumsum[chunk_meta.start] - group_lens[chunk_meta.start]
end = group_lens_cumsum[-1]
return slice(start, end)
@define
class GroupLensMapper(ChunkMapper, ArgGetter, DefineMixin):
"""Class for mapping chunk metadata to per-group column lengths.
Argument can be either a group map tuple or a group lengths array."""
def map(self, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs) -> ChunkMeta:
group_lens = self.get_arg(ann_args)
if isinstance(group_lens, tuple):
group_lens = group_lens[1]
group_lens_slice = get_group_lens_slice(group_lens, chunk_meta)
return ChunkMeta(
uuid=str(uuid.uuid4()),
idx=chunk_meta.idx,
start=group_lens_slice.start,
end=group_lens_slice.stop,
indices=None,
)
group_lens_mapper = GroupLensMapper(arg_query=Regex(r"(group_lens|group_map)"))
"""Default instance of `GroupLensMapper`."""
class GroupMapSlicer(ChunkSlicer):
"""Class for slicing multiple elements from a group map based on the chunk range."""
def get_size(self, obj: tp.GroupMap, **kwargs) -> int:
return GroupLensSizer.get_obj_size(obj, single_type=self.single_type)
def take(self, obj: tp.GroupMap, chunk_meta: ChunkMeta, **kwargs) -> tp.GroupMap:
group_idxs, group_lens = obj
group_lens = group_lens[chunk_meta.start : chunk_meta.end]
return np.arange(np.sum(group_lens)), group_lens
class ChunkedGroupMap(Chunked):
"""Class representing a chunkable group map."""
def resolve_take_spec(self) -> tp.TakeSpec:
if self.take_spec_missing:
if self.select:
raise ValueError("Selection is not supported")
return GroupMapSlicer
return self.take_spec
@define
class GroupIdxsMapper(ChunkMapper, ArgGetter, DefineMixin):
"""Class for mapping chunk metadata to per-group column indices.
Argument must be a group map tuple."""
def map(self, chunk_meta: ChunkMeta, ann_args: tp.Optional[tp.AnnArgs] = None, **kwargs) -> ChunkMeta:
group_map = self.get_arg(ann_args)
group_idxs, group_lens = group_map
group_lens_slice = get_group_lens_slice(group_lens, chunk_meta)
return ChunkMeta(
uuid=str(uuid.uuid4()),
idx=chunk_meta.idx,
start=None,
end=None,
indices=group_idxs[group_lens_slice],
)
group_idxs_mapper = GroupIdxsMapper(arg_query="group_map")
"""Default instance of `GroupIdxsMapper`."""
class FlexArraySizer(ArraySizer):
"""Class for getting the size from the length of an axis in a flexible array."""
@classmethod
def get_obj_size(cls, obj: tp.AnyArray, axis: int, single_type: tp.Optional[type] = None) -> int:
"""Get size of an object."""
if single_type is not None:
if checks.is_instance_of(obj, single_type):
return 1
obj = np.asarray(obj)
if len(obj.shape) == 0:
return 1
if axis is None:
if len(obj.shape) == 1:
axis = 0
checks.assert_not_none(axis, arg_name="axis")
checks.assert_in(axis, (0, 1), arg_name="axis")
if len(obj.shape) == 1:
if axis == 1:
return 1
return obj.shape[0]
if len(obj.shape) == 2:
if axis == 1:
return obj.shape[1]
return obj.shape[0]
raise ValueError(f"FlexArraySizer supports max 2 dimensions, not {len(obj.shape)}")
@define
class FlexArraySelector(ArraySelector, DefineMixin):
"""Class for selecting one element from a NumPy array's axis flexibly based on the chunk index.
The result is intended to be used together with `vectorbtpro.base.flex_indexing.flex_select_1d_nb`
and `vectorbtpro.base.flex_indexing.flex_select_nb`."""
def get_size(self, obj: tp.ArrayLike, **kwargs) -> int:
return FlexArraySizer.get_obj_size(obj, self.axis, single_type=self.single_type)
def suggest_size(self, obj: tp.ArrayLike, **kwargs) -> tp.Optional[int]:
return None
def take(
self,
obj: tp.ArrayLike,
chunk_meta: ChunkMeta,
ann_args: tp.Optional[tp.AnnArgs] = None,
**kwargs,
) -> tp.ArrayLike:
if np.isscalar(obj):
return obj
obj = np.asarray(obj)
if len(obj.shape) == 0:
return obj
axis = self.axis
if axis is None:
if len(obj.shape) == 1:
axis = 0
checks.assert_not_none(axis, arg_name="axis")
checks.assert_in(axis, (0, 1), arg_name="axis")
if len(obj.shape) == 1:
if axis == 1 or obj.shape[0] == 1:
return obj
if self.keep_dims:
return obj[chunk_meta.idx : chunk_meta.idx + 1]
return obj[chunk_meta.idx]
if len(obj.shape) == 2:
if axis == 1:
if obj.shape[1] == 1:
return obj
if self.keep_dims:
return obj[:, chunk_meta.idx : chunk_meta.idx + 1]
return obj[:, chunk_meta.idx]
if obj.shape[0] == 1:
return obj
if self.keep_dims:
return obj[chunk_meta.idx : chunk_meta.idx + 1, :]
return obj[chunk_meta.idx, :]
raise ValueError(f"FlexArraySelector supports max 2 dimensions, not {len(obj.shape)}")
@define
class FlexArraySlicer(ArraySlicer, DefineMixin):
"""Class for selecting one element from a NumPy array's axis flexibly based on the chunk index.
The result is intended to be used together with `vectorbtpro.base.flex_indexing.flex_select_1d_nb`
and `vectorbtpro.base.flex_indexing.flex_select_nb`."""
def get_size(self, obj: tp.ArrayLike, **kwargs) -> int:
return FlexArraySizer.get_obj_size(obj, self.axis, single_type=self.single_type)
def suggest_size(self, obj: tp.ArrayLike, **kwargs) -> tp.Optional[int]:
return None
def take(
self,
obj: tp.ArrayLike,
chunk_meta: ChunkMeta,
ann_args: tp.Optional[tp.AnnArgs] = None,
**kwargs,
) -> tp.ArrayLike:
if np.isscalar(obj):
return obj
obj = np.asarray(obj)
if len(obj.shape) == 0:
return obj
axis = self.axis
if axis is None:
if len(obj.shape) == 1:
axis = 0
checks.assert_not_none(axis, arg_name="axis")
checks.assert_in(axis, (0, 1), arg_name="axis")
if len(obj.shape) == 1:
if axis == 1 or obj.shape[0] == 1:
return obj
return obj[chunk_meta.start : chunk_meta.end]
if len(obj.shape) == 2:
if axis == 1:
if obj.shape[1] == 1:
return obj
return obj[:, chunk_meta.start : chunk_meta.end]
if obj.shape[0] == 1:
return obj
return obj[chunk_meta.start : chunk_meta.end, :]
raise ValueError(f"FlexArraySlicer supports max 2 dimensions, not {len(obj.shape)}")
class ChunkedFlexArray(Chunked):
"""Class representing a chunkable flexible array."""
def resolve_take_spec(self) -> tp.TakeSpec:
if self.take_spec_missing:
if self.select:
return FlexArraySelector
return FlexArraySlicer
return self.take_spec
shape_gl_slicer = ShapeSlicer(axis=1, mapper=group_lens_mapper)
"""Flexible 2-dim shape slicer along the column axis based on group lengths."""
flex_1d_array_gl_slicer = FlexArraySlicer(mapper=group_lens_mapper)
"""Flexible 1-dim array slicer along the column axis based on group lengths."""
flex_array_gl_slicer = FlexArraySlicer(axis=1, mapper=group_lens_mapper)
"""Flexible 2-dim array slicer along the column axis based on group lengths."""
array_gl_slicer = ArraySlicer(axis=1, mapper=group_lens_mapper)
"""2-dim array slicer along the column axis based on group lengths."""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for combining arrays.
Combine functions combine two or more NumPy arrays using a custom function. The emphasis here is
done upon stacking the results into one NumPy array - since vectorbt is all about brute-forcing
large spaces of hyper-parameters, concatenating the results of each hyper-parameter combination into
a single DataFrame is important. All functions are available in both Python and Numba-compiled form."""
import numpy as np
from numba.typed import List
from vectorbtpro import _typing as tp
from vectorbtpro.registries.jit_registry import jit_reg, register_jitted
from vectorbtpro.utils.execution import Task, execute
from vectorbtpro.utils.template import RepFunc
__all__ = []
@register_jitted
def custom_apply_and_concat_none_nb(
indices: tp.Array1d,
apply_func_nb: tp.Callable,
*args,
) -> None:
"""Run `apply_func_nb` that returns nothing for each index.
Meant for in-place outputs."""
for i in indices:
apply_func_nb(i, *args)
@register_jitted
def apply_and_concat_none_nb(
ntimes: int,
apply_func_nb: tp.Callable,
*args,
) -> None:
"""Run `apply_func_nb` that returns nothing number of times.
Uses `custom_apply_and_concat_none_nb`."""
custom_apply_and_concat_none_nb(np.arange(ntimes), apply_func_nb, *args)
@register_jitted
def to_2d_one_nb(a: tp.Array) -> tp.Array2d:
"""Expand the dimensions of the array along the axis 1."""
if a.ndim > 1:
return a
return np.expand_dims(a, axis=1)
@register_jitted
def custom_apply_and_concat_one_nb(
indices: tp.Array1d,
apply_func_nb: tp.Callable,
*args,
) -> tp.Array2d:
"""Run `apply_func_nb` that returns one array for each index."""
output_0 = to_2d_one_nb(apply_func_nb(indices[0], *args))
output = np.empty((output_0.shape[0], len(indices) * output_0.shape[1]), dtype=output_0.dtype)
for i in range(len(indices)):
if i == 0:
outputs_i = output_0
else:
outputs_i = to_2d_one_nb(apply_func_nb(indices[i], *args))
output[:, i * outputs_i.shape[1] : (i + 1) * outputs_i.shape[1]] = outputs_i
return output
@register_jitted
def apply_and_concat_one_nb(
ntimes: int,
apply_func_nb: tp.Callable,
*args,
) -> tp.Array2d:
"""Run `apply_func_nb` that returns one array number of times.
Uses `custom_apply_and_concat_one_nb`."""
return custom_apply_and_concat_one_nb(np.arange(ntimes), apply_func_nb, *args)
@register_jitted
def to_2d_multiple_nb(a: tp.Iterable[tp.Array]) -> tp.List[tp.Array2d]:
"""Expand the dimensions of each array in `a` along axis 1."""
lst = list()
for _a in a:
lst.append(to_2d_one_nb(_a))
return lst
@register_jitted
def custom_apply_and_concat_multiple_nb(
indices: tp.Array1d,
apply_func_nb: tp.Callable,
*args,
) -> tp.List[tp.Array2d]:
"""Run `apply_func_nb` that returns multiple arrays for each index."""
outputs = list()
outputs_0 = to_2d_multiple_nb(apply_func_nb(indices[0], *args))
for j in range(len(outputs_0)):
outputs.append(
np.empty((outputs_0[j].shape[0], len(indices) * outputs_0[j].shape[1]), dtype=outputs_0[j].dtype)
)
for i in range(len(indices)):
if i == 0:
outputs_i = outputs_0
else:
outputs_i = to_2d_multiple_nb(apply_func_nb(indices[i], *args))
for j in range(len(outputs_i)):
outputs[j][:, i * outputs_i[j].shape[1] : (i + 1) * outputs_i[j].shape[1]] = outputs_i[j]
return outputs
@register_jitted
def apply_and_concat_multiple_nb(
ntimes: int,
apply_func_nb: tp.Callable,
*args,
) -> tp.List[tp.Array2d]:
"""Run `apply_func_nb` that returns multiple arrays number of times.
Uses `custom_apply_and_concat_multiple_nb`."""
return custom_apply_and_concat_multiple_nb(np.arange(ntimes), apply_func_nb, *args)
def apply_and_concat_each(
tasks: tp.TasksLike,
n_outputs: tp.Optional[int] = None,
execute_kwargs: tp.KwargsLike = None,
) -> tp.Union[None, tp.Array2d, tp.List[tp.Array2d]]:
"""Apply each function on its own set of positional and keyword arguments.
Executes the function using `vectorbtpro.utils.execution.execute`."""
from vectorbtpro.base.merging import column_stack_arrays
if execute_kwargs is None:
execute_kwargs = {}
out = execute(tasks, **execute_kwargs)
if n_outputs is None:
if out[0] is None:
n_outputs = 0
elif isinstance(out[0], (tuple, list, List)):
n_outputs = len(out[0])
else:
n_outputs = 1
if n_outputs == 0:
return None
if n_outputs == 1:
if isinstance(out[0], (tuple, list, List)) and len(out[0]) == 1:
out = list(map(lambda x: x[0], out))
return column_stack_arrays(out)
return list(map(column_stack_arrays, zip(*out)))
def apply_and_concat(
ntimes: int,
apply_func: tp.Callable,
*args,
n_outputs: tp.Optional[int] = None,
jitted_loop: bool = False,
jitted_warmup: bool = False,
execute_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Union[None, tp.Array2d, tp.List[tp.Array2d]]:
"""Run `apply_func` function a number of times and concatenate the results depending upon how
many array-like objects it generates.
`apply_func` must accept arguments `i`, `*args`, and `**kwargs`.
Set `jitted_loop` to True to use the JIT-compiled version.
All jitted iteration functions are resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`.
!!! note
`n_outputs` must be set when `jitted_loop` is True.
Numba doesn't support variable keyword arguments."""
if jitted_loop:
if n_outputs is None:
raise ValueError("Jitted iteration requires n_outputs")
if n_outputs == 0:
func = jit_reg.resolve(custom_apply_and_concat_none_nb)
elif n_outputs == 1:
func = jit_reg.resolve(custom_apply_and_concat_one_nb)
else:
func = jit_reg.resolve(custom_apply_and_concat_multiple_nb)
if jitted_warmup:
func(np.array([0]), apply_func, *args, **kwargs)
def _tasks_template(chunk_meta):
tasks = []
for _chunk_meta in chunk_meta:
if _chunk_meta.indices is not None:
chunk_indices = np.asarray(_chunk_meta.indices)
else:
if _chunk_meta.start is None or _chunk_meta.end is None:
raise ValueError("Each chunk must have a start and an end index")
chunk_indices = np.arange(_chunk_meta.start, _chunk_meta.end)
tasks.append(Task(func, chunk_indices, apply_func, *args, **kwargs))
return tasks
tasks = RepFunc(_tasks_template)
else:
tasks = [(apply_func, (i, *args), kwargs) for i in range(ntimes)]
if execute_kwargs is None:
execute_kwargs = {}
execute_kwargs["size"] = ntimes
return apply_and_concat_each(
tasks,
n_outputs=n_outputs,
execute_kwargs=execute_kwargs,
)
@register_jitted
def select_and_combine_nb(
i: int,
obj: tp.Any,
others: tp.Sequence,
combine_func_nb: tp.Callable,
*args,
) -> tp.AnyArray:
"""Numba-compiled version of `select_and_combine`."""
return combine_func_nb(obj, others[i], *args)
@register_jitted
def combine_and_concat_nb(
obj: tp.Any,
others: tp.Sequence,
combine_func_nb: tp.Callable,
*args,
) -> tp.Array2d:
"""Numba-compiled version of `combine_and_concat`."""
return apply_and_concat_one_nb(len(others), select_and_combine_nb, obj, others, combine_func_nb, *args)
def select_and_combine(
i: int,
obj: tp.Any,
others: tp.Sequence,
combine_func: tp.Callable,
*args,
**kwargs,
) -> tp.AnyArray:
"""Combine `obj` with an array at position `i` in `others` using `combine_func`."""
return combine_func(obj, others[i], *args, **kwargs)
def combine_and_concat(
obj: tp.Any,
others: tp.Sequence,
combine_func: tp.Callable,
*args,
jitted_loop: bool = False,
**kwargs,
) -> tp.Array2d:
"""Combine `obj` with each in `others` using `combine_func` and concatenate.
`select_and_combine_nb` is resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`."""
if jitted_loop:
apply_func = jit_reg.resolve(select_and_combine_nb)
else:
apply_func = select_and_combine
return apply_and_concat(
len(others),
apply_func,
obj,
others,
combine_func,
*args,
n_outputs=1,
jitted_loop=jitted_loop,
**kwargs,
)
@register_jitted
def combine_multiple_nb(
objs: tp.Sequence,
combine_func_nb: tp.Callable,
*args,
) -> tp.Any:
"""Numba-compiled version of `combine_multiple`."""
result = objs[0]
for i in range(1, len(objs)):
result = combine_func_nb(result, objs[i], *args)
return result
def combine_multiple(
objs: tp.Sequence,
combine_func: tp.Callable,
*args,
jitted_loop: bool = False,
**kwargs,
) -> tp.Any:
"""Combine `objs` pairwise into a single object.
Set `jitted_loop` to True to use the JIT-compiled version.
`combine_multiple_nb` is resolved using `vectorbtpro.registries.jit_registry.JITRegistry.resolve`.
!!! note
Numba doesn't support variable keyword arguments."""
if jitted_loop:
func = jit_reg.resolve(combine_multiple_nb)
return func(objs, combine_func, *args)
result = objs[0]
for i in range(1, len(objs)):
result = combine_func(result, objs[i], *args, **kwargs)
return result
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class decorators for base classes."""
from functools import cached_property as cachedproperty
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import Config, HybridConfig, merge_dicts
__all__ = []
def override_arg_config(config: Config, merge_configs: bool = True) -> tp.ClassWrapper:
"""Class decorator to override the argument config of a class subclassing
`vectorbtpro.base.preparing.BasePreparer`.
Instead of overriding `_arg_config` class attribute, you can pass `config` directly to this decorator.
Disable `merge_configs` to not merge, which will effectively disable field inheritance."""
def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
checks.assert_subclass_of(cls, "BasePreparer")
if merge_configs:
new_config = merge_dicts(cls.arg_config, config)
else:
new_config = config
if not isinstance(new_config, Config):
new_config = HybridConfig(new_config)
setattr(cls, "_arg_config", new_config)
return cls
return wrapper
def attach_arg_properties(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
"""Class decorator to attach properties for arguments defined in the argument config
of a `vectorbtpro.base.preparing.BasePreparer` subclass."""
checks.assert_subclass_of(cls, "BasePreparer")
for arg_name, settings in cls.arg_config.items():
attach = settings.get("attach", None)
broadcast = settings.get("broadcast", False)
substitute_templates = settings.get("substitute_templates", False)
if (isinstance(attach, bool) and attach) or (attach is None and (broadcast or substitute_templates)):
if broadcast:
return_type = tp.ArrayLike
else:
return_type = object
target_pre_name = "_pre_" + arg_name
if not hasattr(cls, target_pre_name):
def pre_arg_prop(self, _arg_name: str = arg_name) -> return_type:
return self.get_arg(_arg_name)
pre_arg_prop.__name__ = target_pre_name
pre_arg_prop.__module__ = cls.__module__
pre_arg_prop.__qualname__ = f"{cls.__name__}.{pre_arg_prop.__name__}"
if broadcast and substitute_templates:
pre_arg_prop.__doc__ = f"Argument `{arg_name}` before broadcasting and template substitution."
elif broadcast:
pre_arg_prop.__doc__ = f"Argument `{arg_name}` before broadcasting."
else:
pre_arg_prop.__doc__ = f"Argument `{arg_name}` before template substitution."
setattr(cls, pre_arg_prop.__name__, cachedproperty(pre_arg_prop))
getattr(cls, pre_arg_prop.__name__).__set_name__(cls, pre_arg_prop.__name__)
target_post_name = "_post_" + arg_name
if not hasattr(cls, target_post_name):
def post_arg_prop(self, _arg_name: str = arg_name) -> return_type:
return self.prepare_post_arg(_arg_name)
post_arg_prop.__name__ = target_post_name
post_arg_prop.__module__ = cls.__module__
post_arg_prop.__qualname__ = f"{cls.__name__}.{post_arg_prop.__name__}"
if broadcast and substitute_templates:
post_arg_prop.__doc__ = f"Argument `{arg_name}` after broadcasting and template substitution."
elif broadcast:
post_arg_prop.__doc__ = f"Argument `{arg_name}` after broadcasting."
else:
post_arg_prop.__doc__ = f"Argument `{arg_name}` after template substitution."
setattr(cls, post_arg_prop.__name__, cachedproperty(post_arg_prop))
getattr(cls, post_arg_prop.__name__).__set_name__(cls, post_arg_prop.__name__)
target_name = arg_name
if not hasattr(cls, target_name):
def arg_prop(self, _target_post_name: str = target_post_name) -> return_type:
return getattr(self, _target_post_name)
arg_prop.__name__ = target_name
arg_prop.__module__ = cls.__module__
arg_prop.__qualname__ = f"{cls.__name__}.{arg_prop.__name__}"
arg_prop.__doc__ = f"Argument `{arg_name}`."
setattr(cls, arg_prop.__name__, cachedproperty(arg_prop))
getattr(cls, arg_prop.__name__).__set_name__(cls, arg_prop.__name__)
elif (isinstance(attach, bool) and attach) or attach is None:
if not hasattr(cls, arg_name):
def arg_prop(self, _arg_name: str = arg_name) -> tp.Any:
return self.get_arg(_arg_name)
arg_prop.__name__ = arg_name
arg_prop.__module__ = cls.__module__
arg_prop.__qualname__ = f"{cls.__name__}.{arg_prop.__name__}"
arg_prop.__doc__ = f"Argument `{arg_name}`."
setattr(cls, arg_prop.__name__, cachedproperty(arg_prop))
getattr(cls, arg_prop.__name__).__set_name__(cls, arg_prop.__name__)
return cls
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes and functions for flexible indexing."""
from vectorbtpro import _typing as tp
from vectorbtpro._settings import settings
from vectorbtpro.registries.jit_registry import register_jitted
__all__ = [
"flex_select_1d_nb",
"flex_select_1d_pr_nb",
"flex_select_1d_pc_nb",
"flex_select_nb",
"flex_select_row_nb",
"flex_select_col_nb",
"flex_select_2d_row_nb",
"flex_select_2d_col_nb",
]
_rotate_rows = settings["indexing"]["rotate_rows"]
_rotate_cols = settings["indexing"]["rotate_cols"]
@register_jitted(cache=True)
def flex_choose_i_1d_nb(arr: tp.FlexArray1d, i: int) -> int:
"""Choose a position in an array as if it has been broadcast against rows or columns.
!!! note
Array must be one-dimensional."""
if arr.shape[0] == 1:
flex_i = 0
else:
flex_i = i
return int(flex_i)
@register_jitted(cache=True)
def flex_select_1d_nb(arr: tp.FlexArray1d, i: int) -> tp.Scalar:
"""Select an element of an array as if it has been broadcast against rows or columns.
!!! note
Array must be one-dimensional."""
flex_i = flex_choose_i_1d_nb(arr, i)
return arr[flex_i]
@register_jitted(cache=True)
def flex_choose_i_pr_1d_nb(arr: tp.FlexArray1d, i: int, rotate_rows: bool = _rotate_rows) -> int:
"""Choose a position in an array as if it has been broadcast against rows.
Can use rotational indexing along rows.
!!! note
Array must be one-dimensional."""
if arr.shape[0] == 1:
flex_i = 0
else:
flex_i = i
if rotate_rows:
return int(flex_i) % arr.shape[0]
return int(flex_i)
@register_jitted(cache=True)
def flex_choose_i_pr_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> int:
"""Choose a position in an array as if it has been broadcast against rows.
Can use rotational indexing along rows.
!!! note
Array must be two-dimensional."""
if arr.shape[0] == 1:
flex_i = 0
else:
flex_i = i
if rotate_rows:
return int(flex_i) % arr.shape[0]
return int(flex_i)
@register_jitted(cache=True)
def flex_select_1d_pr_nb(arr: tp.FlexArray1d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Scalar:
"""Select an element of an array as if it has been broadcast against rows.
Can use rotational indexing along rows.
!!! note
Array must be one-dimensional."""
flex_i = flex_choose_i_pr_1d_nb(arr, i, rotate_rows=rotate_rows)
return arr[flex_i]
@register_jitted(cache=True)
def flex_choose_i_pc_1d_nb(arr: tp.FlexArray1d, col: int, rotate_cols: bool = _rotate_cols) -> int:
"""Choose a position in an array as if it has been broadcast against columns.
Can use rotational indexing along columns.
!!! note
Array must be one-dimensional."""
if arr.shape[0] == 1:
flex_col = 0
else:
flex_col = col
if rotate_cols:
return int(flex_col) % arr.shape[0]
return int(flex_col)
@register_jitted(cache=True)
def flex_choose_i_pc_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> int:
"""Choose a position in an array as if it has been broadcast against columns.
Can use rotational indexing along columns.
!!! note
Array must be two-dimensional."""
if arr.shape[1] == 1:
flex_col = 0
else:
flex_col = col
if rotate_cols:
return int(flex_col) % arr.shape[1]
return int(flex_col)
@register_jitted(cache=True)
def flex_select_1d_pc_nb(arr: tp.FlexArray1d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Scalar:
"""Select an element of an array as if it has been broadcast against columns.
Can use rotational indexing along columns.
!!! note
Array must be one-dimensional."""
flex_col = flex_choose_i_pc_1d_nb(arr, col, rotate_cols=rotate_cols)
return arr[flex_col]
@register_jitted(cache=True)
def flex_choose_i_and_col_nb(
arr: tp.FlexArray2d,
i: int,
col: int,
rotate_rows: bool = _rotate_rows,
rotate_cols: bool = _rotate_cols,
) -> tp.Tuple[int, int]:
"""Choose a position in an array as if it has been broadcast rows and columns.
Can use rotational indexing along rows and columns.
!!! note
Array must be two-dimensional."""
if arr.shape[0] == 1:
flex_i = 0
else:
flex_i = i
if arr.shape[1] == 1:
flex_col = 0
else:
flex_col = col
if rotate_rows and rotate_cols:
return int(flex_i) % arr.shape[0], int(flex_col) % arr.shape[1]
if rotate_rows:
return int(flex_i) % arr.shape[0], int(flex_col)
if rotate_cols:
return int(flex_i), int(flex_col) % arr.shape[1]
return int(flex_i), int(flex_col)
@register_jitted(cache=True)
def flex_select_nb(
arr: tp.FlexArray2d,
i: int,
col: int,
rotate_rows: bool = _rotate_rows,
rotate_cols: bool = _rotate_cols,
) -> tp.Scalar:
"""Select element of an array as if it has been broadcast rows and columns.
Can use rotational indexing along rows and columns.
!!! note
Array must be two-dimensional."""
flex_i, flex_col = flex_choose_i_and_col_nb(
arr,
i,
col,
rotate_rows=rotate_rows,
rotate_cols=rotate_cols,
)
return arr[flex_i, flex_col]
@register_jitted(cache=True)
def flex_select_row_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Array1d:
"""Select a row from a flexible 2-dim array. Returns a 1-dim array.
!!! note
Array must be two-dimensional."""
flex_i = flex_choose_i_pr_nb(arr, i, rotate_rows=rotate_rows)
return arr[flex_i]
@register_jitted(cache=True)
def flex_select_col_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Array1d:
"""Select a column from a flexible 2-dim array. Returns a 1-dim array.
!!! note
Array must be two-dimensional."""
flex_col = flex_choose_i_pc_nb(arr, col, rotate_cols=rotate_cols)
return arr[:, flex_col]
@register_jitted(cache=True)
def flex_select_2d_row_nb(arr: tp.FlexArray2d, i: int, rotate_rows: bool = _rotate_rows) -> tp.Array2d:
"""Select a row from a flexible 2-dim array. Returns a 2-dim array.
!!! note
Array must be two-dimensional."""
flex_i = flex_choose_i_pr_nb(arr, i, rotate_rows=rotate_rows)
return arr[flex_i : flex_i + 1]
@register_jitted(cache=True)
def flex_select_2d_col_nb(arr: tp.FlexArray2d, col: int, rotate_cols: bool = _rotate_cols) -> tp.Array2d:
"""Select a column from a flexible 2-dim array. Returns a 2-dim array.
!!! note
Array must be two-dimensional."""
flex_col = flex_choose_i_pc_nb(arr, col, rotate_cols=rotate_cols)
return arr[:, flex_col : flex_col + 1]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for working with indexes: index and columns.
They perform operations on index objects, such as stacking, combining, and cleansing MultiIndex levels.
!!! note
"Index" in pandas context is referred to both index and columns."""
from datetime import datetime, timedelta
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import jit_reg, register_jitted
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.base import Base
__all__ = [
"ExceptLevel",
"repeat_index",
"tile_index",
"stack_indexes",
"combine_indexes",
]
@define
class ExceptLevel(DefineMixin):
"""Class for grouping except one or more levels."""
value: tp.MaybeLevelSequence = define.field()
"""One or more level positions or names."""
def to_any_index(index_like: tp.IndexLike) -> tp.Index:
"""Convert any index-like object to an index.
Index objects are kept as-is."""
if checks.is_np_array(index_like) and index_like.ndim == 0:
index_like = index_like[None]
if not checks.is_index(index_like):
return pd.Index(index_like)
return index_like
def get_index(obj: tp.SeriesFrame, axis: int) -> tp.Index:
"""Get index of `obj` by `axis`."""
checks.assert_instance_of(obj, (pd.Series, pd.DataFrame))
checks.assert_in(axis, (0, 1))
if axis == 0:
return obj.index
else:
if checks.is_series(obj):
if obj.name is not None:
return pd.Index([obj.name])
return pd.Index([0]) # same as how pandas does it
else:
return obj.columns
def index_from_values(
values: tp.Sequence,
single_value: bool = False,
name: tp.Optional[tp.Hashable] = None,
) -> tp.Index:
"""Create a new `pd.Index` with `name` by parsing an iterable `values`.
Each in `values` will correspond to an element in the new index."""
scalar_types = (int, float, complex, str, bool, datetime, timedelta, np.generic)
type_id_number = {}
value_names = []
if len(values) == 1:
single_value = True
for i in range(len(values)):
if i > 0 and single_value:
break
v = values[i]
if v is None or isinstance(v, scalar_types):
value_names.append(v)
elif isinstance(v, np.ndarray):
all_same = False
if np.issubdtype(v.dtype, np.floating):
if np.isclose(v, v.item(0), equal_nan=True).all():
all_same = True
elif v.dtype.names is not None:
all_same = False
else:
if np.equal(v, v.item(0)).all():
all_same = True
if all_same:
value_names.append(v.item(0))
else:
if single_value:
value_names.append("array")
else:
if "array" not in type_id_number:
type_id_number["array"] = {}
if id(v) not in type_id_number["array"]:
type_id_number["array"][id(v)] = len(type_id_number["array"])
value_names.append("array_%d" % (type_id_number["array"][id(v)]))
else:
type_name = str(type(v).__name__)
if single_value:
value_names.append("%s" % type_name)
else:
if type_name not in type_id_number:
type_id_number[type_name] = {}
if id(v) not in type_id_number[type_name]:
type_id_number[type_name][id(v)] = len(type_id_number[type_name])
value_names.append("%s_%d" % (type_name, type_id_number[type_name][id(v)]))
if single_value and len(values) > 1:
value_names *= len(values)
return pd.Index(value_names, name=name)
def repeat_index(index: tp.IndexLike, n: int, ignore_ranges: tp.Optional[bool] = None) -> tp.Index:
"""Repeat each element in `index` `n` times.
Set `ignore_ranges` to True to ignore indexes of type `pd.RangeIndex`."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if ignore_ranges is None:
ignore_ranges = broadcasting_cfg["ignore_ranges"]
index = to_any_index(index)
if n == 1:
return index
if checks.is_default_index(index) and ignore_ranges: # ignore simple ranges without name
return pd.RangeIndex(start=0, stop=len(index) * n, step=1)
return index.repeat(n)
def tile_index(index: tp.IndexLike, n: int, ignore_ranges: tp.Optional[bool] = None) -> tp.Index:
"""Tile the whole `index` `n` times.
Set `ignore_ranges` to True to ignore indexes of type `pd.RangeIndex`."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if ignore_ranges is None:
ignore_ranges = broadcasting_cfg["ignore_ranges"]
index = to_any_index(index)
if n == 1:
return index
if checks.is_default_index(index) and ignore_ranges: # ignore simple ranges without name
return pd.RangeIndex(start=0, stop=len(index) * n, step=1)
if isinstance(index, pd.MultiIndex):
return pd.MultiIndex.from_tuples(np.tile(index, n), names=index.names)
return pd.Index(np.tile(index, n), name=index.name)
def clean_index(
index: tp.IndexLike,
drop_duplicates: tp.Optional[bool] = None,
keep: tp.Optional[str] = None,
drop_redundant: tp.Optional[bool] = None,
) -> tp.Index:
"""Clean index.
Set `drop_duplicates` to True to remove duplicate levels.
For details on `keep`, see `drop_duplicate_levels`.
Set `drop_redundant` to True to use `drop_redundant_levels`."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if drop_duplicates is None:
drop_duplicates = broadcasting_cfg["drop_duplicates"]
if keep is None:
keep = broadcasting_cfg["keep"]
if drop_redundant is None:
drop_redundant = broadcasting_cfg["drop_redundant"]
index = to_any_index(index)
if drop_duplicates:
index = drop_duplicate_levels(index, keep=keep)
if drop_redundant:
index = drop_redundant_levels(index)
return index
def stack_indexes(*indexes: tp.MaybeTuple[tp.IndexLike], **clean_index_kwargs) -> tp.Index:
"""Stack each index in `indexes` on top of each other, from top to bottom."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
levels = []
for i in range(len(indexes)):
index = indexes[i]
if not isinstance(index, pd.MultiIndex):
levels.append(to_any_index(index))
else:
for j in range(index.nlevels):
levels.append(index.get_level_values(j))
max_len = max(map(len, levels))
for i in range(len(levels)):
if len(levels[i]) < max_len:
if len(levels[i]) != 1:
raise ValueError(f"Index at level {i} could not be broadcast to shape ({max_len},) ")
levels[i] = repeat_index(levels[i], max_len, ignore_ranges=False)
new_index = pd.MultiIndex.from_arrays(levels)
return clean_index(new_index, **clean_index_kwargs)
def combine_indexes(*indexes: tp.MaybeTuple[tp.IndexLike], **kwargs) -> tp.Index:
"""Combine each index in `indexes` using Cartesian product.
Keyword arguments will be passed to `stack_indexes`."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
new_index = to_any_index(indexes[0])
for i in range(1, len(indexes)):
index1, index2 = new_index, to_any_index(indexes[i])
new_index1 = repeat_index(index1, len(index2), ignore_ranges=False)
new_index2 = tile_index(index2, len(index1), ignore_ranges=False)
new_index = stack_indexes([new_index1, new_index2], **kwargs)
return new_index
def combine_index_with_keys(index: tp.IndexLike, keys: tp.IndexLike, lens: tp.Sequence[int], **kwargs) -> tp.Index:
"""Build keys based on index lengths."""
if not isinstance(index, pd.Index):
index = pd.Index(index)
if not isinstance(keys, pd.Index):
keys = pd.Index(keys)
new_index = None
new_keys = None
start_idx = 0
for i in range(len(keys)):
_index = index[start_idx : start_idx + lens[i]]
if new_index is None:
new_index = _index
else:
new_index = new_index.append(_index)
start_idx += lens[i]
new_key = keys[[i]].repeat(lens[i])
if new_keys is None:
new_keys = new_key
else:
new_keys = new_keys.append(new_key)
return stack_indexes([new_keys, new_index], **kwargs)
def concat_indexes(
*indexes: tp.MaybeTuple[tp.IndexLike],
index_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append",
keys: tp.Optional[tp.IndexLike] = None,
clean_index_kwargs: tp.KwargsLike = None,
verify_integrity: bool = True,
axis: int = 1,
) -> tp.Index:
"""Concatenate indexes.
The following index concatenation methods are supported:
* 'append': append one index to another
* 'union': build a union of indexes
* 'pd_concat': convert indexes to Pandas Series or DataFrames and use `pd.concat`
* 'factorize': factorize the concatenated index
* 'factorize_each': factorize each index and concatenate while keeping numbers unique
* 'reset': reset the concatenated index without applying `keys`
* Callable: a custom callable that takes the indexes and returns the concatenated index
Argument `index_concat_method` also accepts a tuple of two options: the second option gets applied
if the first one fails.
Use `keys` as an index with the same number of elements as there are indexes to add
another index level on top of the concatenated indexes.
If `verify_integrity` is True and `keys` is None, performs various checks depending on the axis."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
if clean_index_kwargs is None:
clean_index_kwargs = {}
if axis == 0:
factorized_name = "row_idx"
elif axis == 1:
factorized_name = "col_idx"
else:
factorized_name = "group_idx"
if keys is None:
all_ranges = True
for index in indexes:
if not checks.is_default_index(index):
all_ranges = False
break
if all_ranges:
return pd.RangeIndex(stop=sum(map(len, indexes)))
if isinstance(index_concat_method, tuple):
try:
return concat_indexes(
*indexes,
index_concat_method=index_concat_method[0],
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=axis,
)
except Exception as e:
return concat_indexes(
*indexes,
index_concat_method=index_concat_method[1],
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=axis,
)
if not isinstance(index_concat_method, str):
new_index = index_concat_method(indexes)
elif index_concat_method.lower() == "append":
new_index = None
for index in indexes:
if new_index is None:
new_index = index
else:
new_index = new_index.append(index)
elif index_concat_method.lower() == "union":
if keys is not None:
raise ValueError("Cannot apply keys after concatenating indexes through union")
new_index = None
for index in indexes:
if new_index is None:
new_index = index
else:
new_index = new_index.union(index)
elif index_concat_method.lower() == "pd_concat":
new_index = None
for index in indexes:
if isinstance(index, pd.MultiIndex):
index = index.to_frame().reset_index(drop=True)
else:
index = index.to_series().reset_index(drop=True)
if new_index is None:
new_index = index
else:
if isinstance(new_index, pd.DataFrame):
if isinstance(index, pd.Series):
index = index.to_frame()
elif isinstance(index, pd.Series):
if isinstance(new_index, pd.DataFrame):
new_index = new_index.to_frame()
new_index = pd.concat((new_index, index), ignore_index=True)
if isinstance(new_index, pd.Series):
new_index = pd.Index(new_index)
else:
new_index = pd.MultiIndex.from_frame(new_index)
elif index_concat_method.lower() == "factorize":
new_index = concat_indexes(
*indexes,
index_concat_method="append",
clean_index_kwargs=clean_index_kwargs,
verify_integrity=False,
axis=axis,
)
new_index = pd.Index(pd.factorize(new_index)[0], name=factorized_name)
elif index_concat_method.lower() == "factorize_each":
new_index = None
for index in indexes:
index = pd.Index(pd.factorize(index)[0], name=factorized_name)
if new_index is None:
new_index = index
next_min = index.max() + 1
else:
new_index = new_index.append(index + next_min)
next_min = index.max() + 1 + next_min
elif index_concat_method.lower() == "reset":
return pd.RangeIndex(stop=sum(map(len, indexes)))
else:
if axis == 0:
raise ValueError(f"Invalid index concatenation method: '{index_concat_method}'")
elif axis == 1:
raise ValueError(f"Invalid column concatenation method: '{index_concat_method}'")
else:
raise ValueError(f"Invalid group concatenation method: '{index_concat_method}'")
if keys is not None:
if isinstance(keys[0], pd.Index):
keys = concat_indexes(
*keys,
index_concat_method="append",
clean_index_kwargs=clean_index_kwargs,
verify_integrity=False,
axis=axis,
)
new_index = stack_indexes((keys, new_index), **clean_index_kwargs)
keys = None
elif not isinstance(keys, pd.Index):
keys = pd.Index(keys)
if keys is not None:
top_index = None
for i, index in enumerate(indexes):
repeated_index = repeat_index(keys[[i]], len(index))
if top_index is None:
top_index = repeated_index
else:
top_index = top_index.append(repeated_index)
new_index = stack_indexes((top_index, new_index), **clean_index_kwargs)
if verify_integrity:
if keys is None:
if axis == 0:
if not new_index.is_monotonic_increasing:
raise ValueError("Concatenated index is not monotonically increasing")
if "mixed" in new_index.inferred_type:
raise ValueError("Concatenated index is mixed")
if new_index.has_duplicates:
raise ValueError("Concatenated index contains duplicates")
if axis == 1:
if new_index.has_duplicates:
raise ValueError("Concatenated columns contain duplicates")
if axis == 2:
if new_index.has_duplicates:
len_sum = 0
for index in indexes:
if len_sum > 0:
prev_index = new_index[:len_sum]
this_index = new_index[len_sum : len_sum + len(index)]
if len(prev_index.intersection(this_index)) > 0:
raise ValueError("Concatenated groups contain duplicates")
len_sum += len(index)
return new_index
def drop_levels(
index: tp.Index,
levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence],
strict: bool = True,
) -> tp.Index:
"""Drop `levels` in `index` by their name(s)/position(s).
Provide `levels` as an instance of `ExceptLevel` to drop everything apart from the specified levels."""
if not isinstance(index, pd.MultiIndex):
if strict:
raise TypeError("Index must be a multi-index")
return index
if isinstance(levels, ExceptLevel):
levels = levels.value
except_mode = True
else:
except_mode = False
levels_to_drop = set()
if isinstance(levels, str) or not checks.is_sequence(levels):
levels = [levels]
for level in levels:
if level in index.names:
for level_pos in [i for i, x in enumerate(index.names) if x == level]:
levels_to_drop.add(level_pos)
elif checks.is_int(level):
if level < 0:
new_level = index.nlevels + level
if new_level < 0:
raise KeyError(f"Level at position {level} not found")
level = new_level
if 0 <= level < index.nlevels:
levels_to_drop.add(level)
else:
raise KeyError(f"Level at position {level} not found")
elif strict:
raise KeyError(f"Level '{level}' not found")
if except_mode:
levels_to_drop = set(range(index.nlevels)).difference(levels_to_drop)
if len(levels_to_drop) == 0:
if strict:
raise ValueError("No levels to drop")
return index
if len(levels_to_drop) >= index.nlevels:
if strict:
raise ValueError(
f"Cannot remove {len(levels_to_drop)} levels from an index with {index.nlevels} levels: "
"at least one level must be left"
)
return index
return index.droplevel(list(levels_to_drop))
def rename_levels(index: tp.Index, mapper: tp.MaybeMappingSequence[tp.Level], strict: bool = True) -> tp.Index:
"""Rename levels in `index` by `mapper`.
Mapper can be a single or multiple levels to rename to, or a dictionary that maps
old level names to new level names."""
if isinstance(index, pd.MultiIndex):
nlevels = index.nlevels
if isinstance(mapper, (int, str)):
mapper = dict(zip(index.names, [mapper]))
elif checks.is_complex_sequence(mapper):
mapper = dict(zip(index.names, mapper))
else:
nlevels = 1
if isinstance(mapper, (int, str)):
mapper = dict(zip([index.name], [mapper]))
elif checks.is_complex_sequence(mapper):
mapper = dict(zip([index.name], mapper))
for k, v in mapper.items():
if k in index.names:
if isinstance(index, pd.MultiIndex):
index = index.rename(v, level=k)
else:
index = index.rename(v)
elif checks.is_int(k):
if k < 0:
new_k = nlevels + k
if new_k < 0:
raise KeyError(f"Level at position {k} not found")
k = new_k
if 0 <= k < nlevels:
if isinstance(index, pd.MultiIndex):
index = index.rename(v, level=k)
else:
index = index.rename(v)
else:
raise KeyError(f"Level at position {k} not found")
elif strict:
raise KeyError(f"Level '{k}' not found")
return index
def select_levels(
index: tp.Index,
levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence],
strict: bool = True,
) -> tp.Index:
"""Build a new index by selecting one or multiple `levels` from `index`.
Provide `levels` as an instance of `ExceptLevel` to select everything apart from the specified levels."""
was_multiindex = True
if not isinstance(index, pd.MultiIndex):
was_multiindex = False
index = pd.MultiIndex.from_arrays([index])
if isinstance(levels, ExceptLevel):
levels = levels.value
except_mode = True
else:
except_mode = False
levels_to_select = list()
if isinstance(levels, str) or not checks.is_sequence(levels):
levels = [levels]
single_mode = True
else:
single_mode = False
for level in levels:
if level in index.names:
for level_pos in [i for i, x in enumerate(index.names) if x == level]:
if level_pos not in levels_to_select:
levels_to_select.append(level_pos)
elif checks.is_int(level):
if level < 0:
new_level = index.nlevels + level
if new_level < 0:
raise KeyError(f"Level at position {level} not found")
level = new_level
if 0 <= level < index.nlevels:
if level not in levels_to_select:
levels_to_select.append(level)
else:
raise KeyError(f"Level at position {level} not found")
elif strict:
raise KeyError(f"Level '{level}' not found")
if except_mode:
levels_to_select = list(set(range(index.nlevels)).difference(levels_to_select))
if len(levels_to_select) == 0:
if strict:
raise ValueError("No levels to select")
if not was_multiindex:
return index.get_level_values(0)
return index
if len(levels_to_select) == 1 and single_mode:
return index.get_level_values(levels_to_select[0])
levels = [index.get_level_values(level) for level in levels_to_select]
return pd.MultiIndex.from_arrays(levels)
def drop_redundant_levels(index: tp.Index) -> tp.Index:
"""Drop levels in `index` that either have a single unnamed value or a range from 0 to n."""
if not isinstance(index, pd.MultiIndex):
return index
levels_to_drop = []
for i in range(index.nlevels):
if len(index.levels[i]) == 1 and index.levels[i].name is None:
levels_to_drop.append(i)
elif checks.is_default_index(index.get_level_values(i)):
levels_to_drop.append(i)
if len(levels_to_drop) < index.nlevels:
return index.droplevel(levels_to_drop)
return index
def drop_duplicate_levels(index: tp.Index, keep: tp.Optional[str] = None) -> tp.Index:
"""Drop levels in `index` with the same name and values.
Set `keep` to 'last' to keep last levels, otherwise 'first'.
Set `keep` to None to use the default."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if keep is None:
keep = broadcasting_cfg["keep"]
if not isinstance(index, pd.MultiIndex):
return index
checks.assert_in(keep.lower(), ["first", "last"])
levels_to_drop = set()
level_values = [index.get_level_values(i) for i in range(index.nlevels)]
for i in range(index.nlevels):
level1 = level_values[i]
for j in range(i + 1, index.nlevels):
level2 = level_values[j]
if level1.name is None or level2.name is None or level1.name == level2.name:
if checks.is_index_equal(level1, level2, check_names=False):
if level1.name is None and level2.name is not None:
levels_to_drop.add(i)
elif level1.name is not None and level2.name is None:
levels_to_drop.add(j)
else:
if keep.lower() == "first":
levels_to_drop.add(j)
else:
levels_to_drop.add(i)
return index.droplevel(list(levels_to_drop))
@register_jitted(cache=True)
def align_arr_indices_nb(a: tp.Array1d, b: tp.Array1d) -> tp.Array1d:
"""Return indices required to align `a` to `b`."""
idxs = np.empty(b.shape[0], dtype=int_)
g = 0
for i in range(b.shape[0]):
for j in range(a.shape[0]):
if b[i] == a[j]:
idxs[g] = j
g += 1
break
return idxs
def align_index_to(index1: tp.Index, index2: tp.Index, jitted: tp.JittedOption = None) -> tp.IndexSlice:
"""Align `index1` to have the same shape as `index2` if they have any levels in common.
Returns index slice for the aligning."""
if not isinstance(index1, pd.MultiIndex):
index1 = pd.MultiIndex.from_arrays([index1])
if not isinstance(index2, pd.MultiIndex):
index2 = pd.MultiIndex.from_arrays([index2])
if checks.is_index_equal(index1, index2):
return pd.IndexSlice[:]
if len(index1) > len(index2):
raise ValueError("Longer index cannot be aligned to shorter index")
mapper = {}
for i in range(index1.nlevels):
name1 = index1.names[i]
for j in range(index2.nlevels):
name2 = index2.names[j]
if name1 is None or name2 is None or name1 == name2:
if set(index2.levels[j]).issubset(set(index1.levels[i])):
if i in mapper:
raise ValueError(f"There are multiple candidate levels with name {name1} in second index")
mapper[i] = j
continue
if name1 == name2 and name1 is not None:
raise ValueError(f"Level {name1} in second index contains values not in first index")
if len(mapper) == 0:
if len(index1) == len(index2):
return pd.IndexSlice[:]
raise ValueError("Cannot find common levels to align indexes")
factorized = []
for k, v in mapper.items():
factorized.append(
pd.factorize(
pd.concat(
(
index1.get_level_values(k).to_series(),
index2.get_level_values(v).to_series(),
)
)
)[0],
)
stacked = np.transpose(np.stack(factorized))
indices1 = stacked[: len(index1)]
indices2 = stacked[len(index1) :]
if len(indices1) < len(indices2):
if len(np.unique(indices1, axis=0)) != len(indices1):
raise ValueError("Cannot align indexes")
if len(index2) % len(index1) == 0:
tile_times = len(index2) // len(index1)
index1_tiled = np.tile(indices1, (tile_times, 1))
if np.array_equal(index1_tiled, indices2):
return pd.IndexSlice[np.tile(np.arange(len(index1)), tile_times)]
unique_indices = np.unique(stacked, axis=0, return_inverse=True)[1]
unique1 = unique_indices[: len(index1)]
unique2 = unique_indices[len(index1) :]
if len(indices1) == len(indices2):
if np.array_equal(unique1, unique2):
return pd.IndexSlice[:]
func = jit_reg.resolve_option(align_arr_indices_nb, jitted)
return pd.IndexSlice[func(unique1, unique2)]
def align_indexes(
*indexes: tp.MaybeTuple[tp.Index],
return_new_index: bool = False,
**kwargs,
) -> tp.Union[tp.Tuple[tp.IndexSlice, ...], tp.Tuple[tp.Tuple[tp.IndexSlice, ...], tp.Index]]:
"""Align multiple indexes to each other with `align_index_to`."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
index_items = sorted([(i, indexes[i]) for i in range(len(indexes))], key=lambda x: len(x[1]))
index_slices = []
for i in range(len(index_items)):
index_slice = align_index_to(index_items[i][1], index_items[-1][1], **kwargs)
index_slices.append((index_items[i][0], index_slice))
index_slices = list(map(lambda x: x[1], sorted(index_slices, key=lambda x: x[0])))
if return_new_index:
new_index = stack_indexes(
*[indexes[i][index_slices[i]] for i in range(len(indexes))],
drop_duplicates=True,
)
return tuple(index_slices), new_index
return tuple(index_slices)
@register_jitted(cache=True)
def block_index_product_nb(
block_group_map1: tp.GroupMap,
block_group_map2: tp.GroupMap,
factorized1: tp.Array1d,
factorized2: tp.Array1d,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Return indices required for building a block-wise Cartesian product of two factorized indexes."""
group_idxs1, group_lens1 = block_group_map1
group_idxs2, group_lens2 = block_group_map2
group_start_idxs1 = np.cumsum(group_lens1) - group_lens1
group_start_idxs2 = np.cumsum(group_lens2) - group_lens2
matched1 = np.empty(len(factorized1), dtype=np.bool_)
matched2 = np.empty(len(factorized2), dtype=np.bool_)
indices1 = np.empty(len(factorized1) * len(factorized2), dtype=int_)
indices2 = np.empty(len(factorized1) * len(factorized2), dtype=int_)
k1 = 0
k2 = 0
for g1 in range(len(group_lens1)):
group_len1 = group_lens1[g1]
group_start1 = group_start_idxs1[g1]
for g2 in range(len(group_lens2)):
group_len2 = group_lens2[g2]
group_start2 = group_start_idxs2[g2]
for c1 in range(group_len1):
i = group_idxs1[group_start1 + c1]
for c2 in range(group_len2):
j = group_idxs2[group_start2 + c2]
if factorized1[i] == factorized2[j]:
matched1[i] = True
matched2[j] = True
indices1[k1] = i
indices2[k2] = j
k1 += 1
k2 += 1
if not np.all(matched1) or not np.all(matched2):
raise ValueError("Cannot match some block level values")
return indices1[:k1], indices2[:k2]
def cross_index_with(
index1: tp.Index,
index2: tp.Index,
return_new_index: bool = False,
) -> tp.Union[tp.Tuple[tp.IndexSlice, tp.IndexSlice], tp.Tuple[tp.Tuple[tp.IndexSlice, tp.IndexSlice], tp.Index]]:
"""Build a Cartesian product of one index with another while taking into account levels they have in common.
Returns index slices for the aligning."""
from vectorbtpro.base.grouping.nb import get_group_map_nb
index1_default = checks.is_default_index(index1, check_names=True)
index2_default = checks.is_default_index(index2, check_names=True)
if not isinstance(index1, pd.MultiIndex):
index1 = pd.MultiIndex.from_arrays([index1])
if not isinstance(index2, pd.MultiIndex):
index2 = pd.MultiIndex.from_arrays([index2])
if not index1_default and not index2_default and checks.is_index_equal(index1, index2):
if return_new_index:
new_index = stack_indexes(index1, index2, drop_duplicates=True)
return (pd.IndexSlice[:], pd.IndexSlice[:]), new_index
return pd.IndexSlice[:], pd.IndexSlice[:]
levels1 = []
levels2 = []
for i in range(index1.nlevels):
if checks.is_default_index(index1.get_level_values(i), check_names=True):
continue
for j in range(index2.nlevels):
if checks.is_default_index(index2.get_level_values(j), check_names=True):
continue
name1 = index1.names[i]
name2 = index2.names[j]
if name1 == name2:
if set(index2.levels[j]) == set(index1.levels[i]):
if i in levels1 or j in levels2:
raise ValueError(f"There are multiple candidate block levels with name {name1}")
levels1.append(i)
levels2.append(j)
continue
if name1 is not None:
raise ValueError(f"Candidate block level {name1} in both indexes has different values")
if len(levels1) == 0:
# Regular index product
indices1 = np.repeat(np.arange(len(index1)), len(index2))
indices2 = np.tile(np.arange(len(index2)), len(index1))
else:
# Block index product
index_levels1 = select_levels(index1, levels1)
index_levels2 = select_levels(index2, levels2)
block_levels1 = list(set(range(index1.nlevels)).difference(levels1))
block_levels2 = list(set(range(index2.nlevels)).difference(levels2))
if len(block_levels1) > 0:
index_block_levels1 = select_levels(index1, block_levels1)
else:
index_block_levels1 = pd.Index(np.full(len(index1), 0))
if len(block_levels2) > 0:
index_block_levels2 = select_levels(index2, block_levels2)
else:
index_block_levels2 = pd.Index(np.full(len(index2), 0))
factorized = pd.factorize(pd.concat((index_levels1.to_series(), index_levels2.to_series())))[0]
factorized1 = factorized[: len(index_levels1)]
factorized2 = factorized[len(index_levels1) :]
block_factorized1, block_unique1 = pd.factorize(index_block_levels1)
block_factorized2, block_unique2 = pd.factorize(index_block_levels2)
block_group_map1 = get_group_map_nb(block_factorized1, len(block_unique1))
block_group_map2 = get_group_map_nb(block_factorized2, len(block_unique2))
indices1, indices2 = block_index_product_nb(
block_group_map1,
block_group_map2,
factorized1,
factorized2,
)
if return_new_index:
new_index = stack_indexes(index1[indices1], index2[indices2], drop_duplicates=True)
return (pd.IndexSlice[indices1], pd.IndexSlice[indices2]), new_index
return pd.IndexSlice[indices1], pd.IndexSlice[indices2]
def cross_indexes(
*indexes: tp.MaybeTuple[tp.Index],
return_new_index: bool = False,
) -> tp.Union[tp.Tuple[tp.IndexSlice, ...], tp.Tuple[tp.Tuple[tp.IndexSlice, ...], tp.Index]]:
"""Cross multiple indexes with `cross_index_with`."""
if len(indexes) == 1:
indexes = indexes[0]
indexes = list(indexes)
if len(indexes) == 2:
return cross_index_with(indexes[0], indexes[1], return_new_index=return_new_index)
index = None
index_slices = []
for i in range(len(indexes) - 2, -1, -1):
index1 = indexes[i]
if i == len(indexes) - 2:
index2 = indexes[i + 1]
else:
index2 = index
(index_slice1, index_slice2), index = cross_index_with(index1, index2, return_new_index=True)
if i == len(indexes) - 2:
index_slices.append(index_slice2)
else:
for j in range(len(index_slices)):
if isinstance(index_slices[j], slice):
index_slices[j] = np.arange(len(index2))[index_slices[j]]
index_slices[j] = index_slices[j][index_slice2]
index_slices.append(index_slice1)
if return_new_index:
return tuple(index_slices[::-1]), index
return tuple(index_slices[::-1])
OptionalLevelSequence = tp.Optional[tp.Sequence[tp.Union[None, tp.Level]]]
def pick_levels(
index: tp.Index,
required_levels: OptionalLevelSequence = None,
optional_levels: OptionalLevelSequence = None,
) -> tp.Tuple[tp.List[int], tp.List[int]]:
"""Pick optional and required levels and return their indices.
Raises an exception if index has less or more levels than expected."""
if required_levels is None:
required_levels = []
if optional_levels is None:
optional_levels = []
checks.assert_instance_of(index, pd.MultiIndex)
n_opt_set = len(list(filter(lambda x: x is not None, optional_levels)))
n_req_set = len(list(filter(lambda x: x is not None, required_levels)))
n_levels_left = index.nlevels - n_opt_set
if n_req_set < len(required_levels):
if n_levels_left != len(required_levels):
n_expected = len(required_levels) + n_opt_set
raise ValueError(f"Expected {n_expected} levels, found {index.nlevels}")
levels_left = list(range(index.nlevels))
_optional_levels = []
for level in optional_levels:
level_pos = None
if level is not None:
checks.assert_instance_of(level, (int, str))
if isinstance(level, str):
level_pos = index.names.index(level)
else:
level_pos = level
if level_pos < 0:
level_pos = index.nlevels + level_pos
levels_left.remove(level_pos)
_optional_levels.append(level_pos)
_required_levels = []
for level in required_levels:
level_pos = None
if level is not None:
checks.assert_instance_of(level, (int, str))
if isinstance(level, str):
level_pos = index.names.index(level)
else:
level_pos = level
if level_pos < 0:
level_pos = index.nlevels + level_pos
levels_left.remove(level_pos)
_required_levels.append(level_pos)
for i, level in enumerate(_required_levels):
if level is None:
_required_levels[i] = levels_left.pop(0)
return _required_levels, _optional_levels
def find_first_occurrence(index_value: tp.Any, index: tp.Index) -> int:
"""Return index of the first occurrence in `index`."""
loc = index.get_loc(index_value)
if isinstance(loc, slice):
return loc.start
elif isinstance(loc, list):
return loc[0]
elif isinstance(loc, np.ndarray):
return np.flatnonzero(loc)[0]
return loc
IndexApplierT = tp.TypeVar("IndexApplierT", bound="IndexApplier")
class IndexApplier(Base):
"""Abstract class that can apply a function on an index."""
def apply_to_index(self: IndexApplierT, apply_func: tp.Callable, *args, **kwargs) -> IndexApplierT:
"""Apply function `apply_func` on the index of the instance and return a new instance."""
raise NotImplementedError
def add_levels(
self: IndexApplierT,
*indexes: tp.Index,
on_top: bool = True,
drop_duplicates: tp.Optional[bool] = None,
keep: tp.Optional[str] = None,
drop_redundant: tp.Optional[bool] = None,
**kwargs,
) -> IndexApplierT:
"""Append or prepend levels using `stack_indexes`.
Set `on_top` to False to stack at bottom.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
if on_top:
return stack_indexes(
[*indexes, index],
drop_duplicates=drop_duplicates,
keep=keep,
drop_redundant=drop_redundant,
)
return stack_indexes(
[index, *indexes],
drop_duplicates=drop_duplicates,
keep=keep,
drop_redundant=drop_redundant,
)
return self.apply_to_index(_apply_func, **kwargs)
def drop_levels(
self: IndexApplierT,
levels: tp.Union[ExceptLevel, tp.MaybeLevelSequence],
strict: bool = True,
**kwargs,
) -> IndexApplierT:
"""Drop levels using `drop_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return drop_levels(index, levels, strict=strict)
return self.apply_to_index(_apply_func, **kwargs)
def rename_levels(
self: IndexApplierT,
mapper: tp.MaybeMappingSequence[tp.Level],
strict: bool = True,
**kwargs,
) -> IndexApplierT:
"""Rename levels using `rename_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return rename_levels(index, mapper, strict=strict)
return self.apply_to_index(_apply_func, **kwargs)
def select_levels(
self: IndexApplierT,
level_names: tp.Union[ExceptLevel, tp.MaybeLevelSequence],
strict: bool = True,
**kwargs,
) -> IndexApplierT:
"""Select levels using `select_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return select_levels(index, level_names, strict=strict)
return self.apply_to_index(_apply_func, **kwargs)
def drop_redundant_levels(self: IndexApplierT, **kwargs) -> IndexApplierT:
"""Drop any redundant levels using `drop_redundant_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return drop_redundant_levels(index)
return self.apply_to_index(_apply_func, **kwargs)
def drop_duplicate_levels(self: IndexApplierT, keep: tp.Optional[str] = None, **kwargs) -> IndexApplierT:
"""Drop any duplicate levels using `drop_duplicate_levels`.
See `IndexApplier.apply_to_index` for other keyword arguments."""
def _apply_func(index):
return drop_duplicate_levels(index, keep=keep)
return self.apply_to_index(_apply_func, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes and functions for indexing."""
import functools
from datetime import time
from functools import partial
import numpy as np
import pandas as pd
from pandas.tseries.offsets import BaseOffset
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks, datetime_ as dt, datetime_nb as dt_nb
from vectorbtpro.utils.attr_ import DefineMixin, define, MISSING
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import hdict, merge_dicts
from vectorbtpro.utils.mapping import to_field_mapping
from vectorbtpro.utils.pickling import pdict
from vectorbtpro.utils.selection import PosSel, LabelSel
from vectorbtpro.utils.template import CustomTemplate
__all__ = [
"PandasIndexer",
"ExtPandasIndexer",
"hslice",
"get_index_points",
"get_index_ranges",
"get_idxs",
"index_dict",
"IdxSetter",
"IdxSetterFactory",
"IdxDict",
"IdxSeries",
"IdxFrame",
"IdxRecords",
"posidx",
"maskidx",
"lbidx",
"dtidx",
"dtcidx",
"pointidx",
"rangeidx",
"autoidx",
"rowidx",
"colidx",
"idx",
]
__pdoc__ = {}
class IndexingError(Exception):
"""Exception raised when an indexing error has occurred."""
IndexingBaseT = tp.TypeVar("IndexingBaseT", bound="IndexingBase")
class IndexingBase(Base):
"""Class that supports indexing through `IndexingBase.indexing_func`."""
def indexing_func(self: IndexingBaseT, pd_indexing_func: tp.Callable, **kwargs) -> IndexingBaseT:
"""Apply `pd_indexing_func` on all pandas objects in question and return a new instance of the class.
Should be overridden."""
raise NotImplementedError
def indexing_setter_func(self, pd_indexing_setter_func: tp.Callable, **kwargs) -> None:
"""Apply `pd_indexing_setter_func` on all pandas objects in question.
Should be overridden."""
raise NotImplementedError
class LocBase(Base):
"""Class that implements location-based indexing."""
def __init__(
self,
indexing_func: tp.Callable,
indexing_setter_func: tp.Optional[tp.Callable] = None,
**kwargs,
) -> None:
self._indexing_func = indexing_func
self._indexing_setter_func = indexing_setter_func
self._indexing_kwargs = kwargs
@property
def indexing_func(self) -> tp.Callable:
"""Indexing function."""
return self._indexing_func
@property
def indexing_setter_func(self) -> tp.Optional[tp.Callable]:
"""Indexing setter function."""
return self._indexing_setter_func
@property
def indexing_kwargs(self) -> dict:
"""Keyword arguments passed to `LocBase.indexing_func`."""
return self._indexing_kwargs
def __getitem__(self, key: tp.Any) -> tp.Any:
raise NotImplementedError
def __setitem__(self, key: tp.Any, value: tp.Any) -> None:
raise NotImplementedError
def __iter__(self):
raise TypeError(f"'{type(self).__name__}' object is not iterable")
class pdLoc(LocBase):
"""Forwards a Pandas-like indexing operation to each Series/DataFrame and returns a new class instance."""
@classmethod
def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame:
"""Pandas-like indexing operation."""
raise NotImplementedError
@classmethod
def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None:
"""Pandas-like indexing setter operation."""
raise NotImplementedError
def __getitem__(self, key: tp.Any) -> tp.Any:
return self.indexing_func(partial(self.pd_indexing_func, key=key), **self.indexing_kwargs)
def __setitem__(self, key: tp.Any, value: tp.Any) -> None:
self.indexing_setter_func(partial(self.pd_indexing_setter_func, key=key, value=value), **self.indexing_kwargs)
class iLoc(pdLoc):
"""Forwards `pd.Series.iloc`/`pd.DataFrame.iloc` operation to each
Series/DataFrame and returns a new class instance."""
@classmethod
def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame:
return obj.iloc.__getitem__(key)
@classmethod
def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None:
obj.iloc.__setitem__(key, value)
class Loc(pdLoc):
"""Forwards `pd.Series.loc`/`pd.DataFrame.loc` operation to each
Series/DataFrame and returns a new class instance."""
@classmethod
def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame:
return obj.loc.__getitem__(key)
@classmethod
def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None:
obj.loc.__setitem__(key, value)
PandasIndexerT = tp.TypeVar("PandasIndexerT", bound="PandasIndexer")
class PandasIndexer(IndexingBase):
"""Implements indexing using `iloc`, `loc`, `xs` and `__getitem__`.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.indexing import PandasIndexer
>>> class C(PandasIndexer):
... def __init__(self, df1, df2):
... self.df1 = df1
... self.df2 = df2
... super().__init__()
...
... def indexing_func(self, pd_indexing_func):
... return type(self)(
... pd_indexing_func(self.df1),
... pd_indexing_func(self.df2)
... )
>>> df1 = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
>>> df2 = pd.DataFrame({'a': [5, 6], 'b': [7, 8]})
>>> c = C(df1, df2)
>>> c.iloc[:, 0]
<__main__.C object at 0x1a1cacbbe0>
>>> c.iloc[:, 0].df1
0 1
1 2
Name: a, dtype: int64
>>> c.iloc[:, 0].df2
0 5
1 6
Name: a, dtype: int64
```
"""
def __init__(self, **kwargs) -> None:
self._iloc = iLoc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs)
self._loc = Loc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs)
self._indexing_kwargs = kwargs
@property
def indexing_kwargs(self) -> dict:
"""Indexing keyword arguments."""
return self._indexing_kwargs
@property
def iloc(self) -> iLoc:
"""Purely integer-location based indexing for selection by position."""
return self._iloc
iloc.__doc__ = iLoc.__doc__
@property
def loc(self) -> Loc:
"""Purely label-location based indexer for selection by label."""
return self._loc
loc.__doc__ = Loc.__doc__
def xs(self: PandasIndexerT, *args, **kwargs) -> PandasIndexerT:
"""Forwards `pd.Series.xs`/`pd.DataFrame.xs`
operation to each Series/DataFrame and returns a new class instance."""
return self.indexing_func(lambda x: x.xs(*args, **kwargs), **self.indexing_kwargs)
def __getitem__(self: PandasIndexerT, key: tp.Any) -> PandasIndexerT:
def __getitem__func(x, _key=key):
return x.__getitem__(_key)
return self.indexing_func(__getitem__func, **self.indexing_kwargs)
def __setitem__(self, key: tp.Any, value: tp.Any) -> None:
def __setitem__func(x, _key=key, _value=value):
return x.__setitem__(_key, _value)
self.indexing_setter_func(__setitem__func, **self.indexing_kwargs)
def __iter__(self):
raise TypeError(f"'{type(self).__name__}' object is not iterable")
class xLoc(iLoc):
"""Subclass of `iLoc` that transforms an `Idxr`-based operation with
`get_idxs` to an `iLoc` operation."""
@classmethod
def pd_indexing_func(cls, obj: tp.SeriesFrame, key: tp.Any) -> tp.MaybeSeriesFrame:
from vectorbtpro.base.indexes import get_index
if isinstance(key, tuple):
key = Idxr(*key)
index = get_index(obj, 0)
columns = get_index(obj, 1)
freq = dt.infer_index_freq(index)
row_idxs, col_idxs = get_idxs(key, index=index, columns=columns, freq=freq)
if isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2:
row_idxs = normalize_idxs(row_idxs, target_len=len(index))
if isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2:
col_idxs = normalize_idxs(col_idxs, target_len=len(columns))
if isinstance(obj, pd.Series):
if not isinstance(col_idxs, (slice, hslice)) or (
col_idxs.start is not None or col_idxs.stop is not None or col_idxs.step is not None
):
raise IndexingError("Too many indexers")
return obj.iloc.__getitem__(row_idxs)
return obj.iloc.__getitem__((row_idxs, col_idxs))
@classmethod
def pd_indexing_setter_func(cls, obj: tp.SeriesFrame, key: tp.Any, value: tp.Any) -> None:
IdxSetter([(key, value)]).set_pd(obj)
class ExtPandasIndexer(PandasIndexer):
"""Extension of `PandasIndexer` that also implements indexing using `xLoc`."""
def __init__(self, **kwargs) -> None:
self._xloc = xLoc(self.indexing_func, indexing_setter_func=self.indexing_setter_func, **kwargs)
PandasIndexer.__init__(self, **kwargs)
@property
def xloc(self) -> xLoc:
"""`Idxr`-based indexing."""
return self._xloc
xloc.__doc__ = xLoc.__doc__
class ParamLoc(LocBase):
"""Access a group of columns by parameter using `pd.Series.loc`.
Uses `mapper` to establish link between columns and parameter values."""
@classmethod
def encode_key(cls, key: tp.Any):
"""Encode key."""
if isinstance(key, tuple):
return str(tuple(map(lambda k: k.item() if isinstance(k, np.generic) else k, key)))
key_str = str(key)
return str(key.item()) if isinstance(key, np.generic) else key_str
def __init__(
self,
mapper: tp.Series,
indexing_func: tp.Callable,
indexing_setter_func: tp.Optional[tp.Callable] = None,
level_name: tp.Level = None,
**kwargs,
) -> None:
checks.assert_instance_of(mapper, pd.Series)
if mapper.dtype == "O":
if isinstance(mapper.iloc[0], tuple):
mapper = mapper.apply(self.encode_key)
else:
mapper = mapper.astype(str)
self._mapper = mapper
self._level_name = level_name
LocBase.__init__(self, indexing_func, indexing_setter_func=indexing_setter_func, **kwargs)
@property
def mapper(self) -> tp.Series:
"""Mapper."""
return self._mapper
@property
def level_name(self) -> tp.Level:
"""Level name."""
return self._level_name
def get_idxs(self, key: tp.Any) -> tp.Array1d:
"""Get array of indices affected by this key."""
if self.mapper.dtype == "O":
if isinstance(key, (slice, hslice)):
start = self.encode_key(key.start) if key.start is not None else None
stop = self.encode_key(key.stop) if key.stop is not None else None
key = slice(start, stop, key.step)
elif isinstance(key, (list, np.ndarray)):
key = list(map(self.encode_key, key))
else:
key = self.encode_key(key)
mapper = pd.Series(np.arange(len(self.mapper.index)), index=self.mapper.values)
idxs = mapper.loc.__getitem__(key)
if isinstance(idxs, pd.Series):
idxs = idxs.values
return idxs
def __getitem__(self, key: tp.Any) -> tp.Any:
idxs = self.get_idxs(key)
is_multiple = isinstance(key, (slice, hslice, list, np.ndarray))
def pd_indexing_func(obj: tp.SeriesFrame) -> tp.MaybeSeriesFrame:
from vectorbtpro.base.indexes import drop_levels
new_obj = obj.iloc[:, idxs]
if not is_multiple:
if self.level_name is not None:
if checks.is_frame(new_obj):
if isinstance(new_obj.columns, pd.MultiIndex):
new_obj.columns = drop_levels(new_obj.columns, self.level_name)
return new_obj
return self.indexing_func(pd_indexing_func, **self.indexing_kwargs)
def __setitem__(self, key: tp.Any, value: tp.Any) -> None:
idxs = self.get_idxs(key)
def pd_indexing_setter_func(obj: tp.SeriesFrame) -> None:
obj.iloc[:, idxs] = value
return self.indexing_setter_func(pd_indexing_setter_func, **self.indexing_kwargs)
def indexing_on_mapper(
mapper: tp.Series,
ref_obj: tp.SeriesFrame,
pd_indexing_func: tp.Callable,
) -> tp.Optional[tp.Series]:
"""Broadcast `mapper` Series to `ref_obj` and perform pandas indexing using `pd_indexing_func`."""
from vectorbtpro.base.reshaping import broadcast_to
checks.assert_instance_of(mapper, pd.Series)
checks.assert_instance_of(ref_obj, (pd.Series, pd.DataFrame))
if isinstance(ref_obj, pd.Series):
range_mapper = broadcast_to(0, ref_obj)
else:
range_mapper = broadcast_to(np.arange(len(mapper.index))[None], ref_obj)
loced_range_mapper = pd_indexing_func(range_mapper)
new_mapper = mapper.iloc[loced_range_mapper.values[0]]
if checks.is_frame(loced_range_mapper):
return pd.Series(new_mapper.values, index=loced_range_mapper.columns, name=mapper.name)
elif checks.is_series(loced_range_mapper):
return pd.Series([new_mapper], index=[loced_range_mapper.name], name=mapper.name)
return None
def build_param_indexer(
param_names: tp.Sequence[str],
class_name: str = "ParamIndexer",
module_name: tp.Optional[str] = None,
) -> tp.Type[IndexingBase]:
"""A factory to create a class with parameter indexing.
Parameter indexer enables accessing a group of rows and columns by a parameter array (similar to `loc`).
This way, one can query index/columns by another Series called a parameter mapper, which is just a
`pd.Series` that maps columns (its index) to params (its values).
Parameter indexing is important, since querying by column/index labels alone is not always the best option.
For example, `pandas` doesn't let you query by list at a specific index/column level.
Args:
param_names (list of str): Names of the parameters.
class_name (str): Name of the generated class.
module_name (str): Name of the module to which the class should be bound.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.indexing import build_param_indexer, indexing_on_mapper
>>> MyParamIndexer = build_param_indexer(['my_param'])
>>> class C(MyParamIndexer):
... def __init__(self, df, param_mapper):
... self.df = df
... self._my_param_mapper = param_mapper
... super().__init__([param_mapper])
...
... def indexing_func(self, pd_indexing_func):
... return type(self)(
... pd_indexing_func(self.df),
... indexing_on_mapper(self._my_param_mapper, self.df, pd_indexing_func)
... )
>>> df = pd.DataFrame({'a': [1, 2], 'b': [3, 4]})
>>> param_mapper = pd.Series(['First', 'Second'], index=['a', 'b'])
>>> c = C(df, param_mapper)
>>> c.my_param_loc['First'].df
0 1
1 2
Name: a, dtype: int64
>>> c.my_param_loc['Second'].df
0 3
1 4
Name: b, dtype: int64
>>> c.my_param_loc[['First', 'First', 'Second', 'Second']].df
a b
0 1 1 3 3
1 2 2 4 4
```
"""
class ParamIndexer(IndexingBase):
"""Class with parameter indexing."""
def __init__(
self,
param_mappers: tp.Sequence[tp.Series],
level_names: tp.Optional[tp.LevelSequence] = None,
**kwargs,
) -> None:
checks.assert_len_equal(param_names, param_mappers)
for i, param_name in enumerate(param_names):
level_name = level_names[i] if level_names is not None else None
_param_loc = ParamLoc(param_mappers[i], self.indexing_func, level_name=level_name, **kwargs)
setattr(self, f"_{param_name}_loc", _param_loc)
for i, param_name in enumerate(param_names):
def param_loc(self, _param_name=param_name) -> ParamLoc:
return getattr(self, f"_{_param_name}_loc")
param_loc.__doc__ = f"""Access a group of columns by parameter `{param_name}` using `pd.Series.loc`.
Forwards this operation to each Series/DataFrame and returns a new class instance.
"""
setattr(ParamIndexer, param_name + "_loc", property(param_loc))
ParamIndexer.__name__ = class_name
ParamIndexer.__qualname__ = ParamIndexer.__name__
if module_name is not None:
ParamIndexer.__module__ = module_name
return ParamIndexer
hsliceT = tp.TypeVar("hsliceT", bound="hslice")
@define
class hslice(DefineMixin):
"""Hashable slice."""
start: object = define.field()
"""Start."""
stop: object = define.field()
"""Stop."""
step: object = define.field()
"""Step."""
def __init__(self, start: object = MISSING, stop: object = MISSING, step: object = MISSING) -> None:
if start is not MISSING and stop is MISSING and step is MISSING:
stop = start
start, step = None, None
else:
if start is MISSING:
start = None
if stop is MISSING:
stop = None
if step is MISSING:
step = None
DefineMixin.__init__(self, start=start, stop=stop, step=step)
@classmethod
def from_slice(cls: tp.Type[hsliceT], slice_: slice) -> hsliceT:
"""Construct from a slice."""
return cls(slice_.start, slice_.stop, slice_.step)
def to_slice(self) -> slice:
"""Convert to a slice."""
return slice(self.start, self.stop, self.step)
class IdxrBase(Base):
"""Abstract class for resolving indices."""
def get(self, *args, **kwargs) -> tp.Any:
"""Get indices."""
raise NotImplementedError
@classmethod
def slice_indexer(
cls,
index: tp.Index,
slice_: tp.Slice,
closed_start: bool = True,
closed_end: bool = False,
) -> slice:
"""Compute the slice indexer for input labels and step."""
start = slice_.start
end = slice_.stop
if start is not None:
left_start = index.get_slice_bound(start, side="left")
right_start = index.get_slice_bound(start, side="right")
if left_start == right_start or not closed_start:
start = right_start
else:
start = left_start
if end is not None:
left_end = index.get_slice_bound(end, side="left")
right_end = index.get_slice_bound(end, side="right")
if left_end == right_end or closed_end:
end = right_end
else:
end = left_end
return slice(start, end, slice_.step)
def check_idxs(self, idxs: tp.MaybeIndexArray, check_minus_one: bool = False) -> None:
"""Check indices after resolving them."""
if isinstance(idxs, slice):
if idxs.start is not None and not checks.is_int(idxs.start):
raise TypeError("Start of a returned index slice must be an integer or None")
if idxs.stop is not None and not checks.is_int(idxs.stop):
raise TypeError("Stop of a returned index slice must be an integer or None")
if idxs.step is not None and not checks.is_int(idxs.step):
raise TypeError("Step of a returned index slice must be an integer or None")
if check_minus_one and idxs.start == -1:
raise ValueError("Range start index couldn't be matched")
elif check_minus_one and idxs.stop == -1:
raise ValueError("Range end index couldn't be matched")
elif checks.is_int(idxs):
if check_minus_one and idxs == -1:
raise ValueError("Index couldn't be matched")
elif checks.is_sequence(idxs) and not np.isscalar(idxs):
if len(idxs) == 0:
raise ValueError("No indices could be matched")
if not isinstance(idxs, np.ndarray):
raise ValueError(f"Indices must be a NumPy array, not {type(idxs)}")
if not np.issubdtype(idxs.dtype, np.integer) or np.issubdtype(idxs.dtype, np.bool_):
raise ValueError(f"Indices must be of integer data type, not {idxs.dtype}")
if check_minus_one and -1 in idxs:
raise ValueError("Some indices couldn't be matched")
if idxs.ndim not in (1, 2):
raise ValueError("Indices array must have either 1 or 2 dimensions")
if idxs.ndim == 2 and idxs.shape[1] != 2:
raise ValueError("Indices array provided as ranges must have exactly two columns")
else:
raise TypeError(
f"Indices must be an integer, a slice, a NumPy array, or a tuple of two NumPy arrays, not {type(idxs)}"
)
def normalize_idxs(idxs: tp.MaybeIndexArray, target_len: int) -> tp.Array1d:
"""Normalize indexes into a 1-dim integer array."""
if isinstance(idxs, hslice):
idxs = idxs.to_slice()
if isinstance(idxs, slice):
idxs = np.arange(target_len)[idxs]
if checks.is_int(idxs):
idxs = np.array([idxs])
if idxs.ndim == 2:
from vectorbtpro.base.merging import concat_arrays
idxs = concat_arrays(tuple(map(lambda x: np.arange(x[0], x[1]), idxs)))
if (idxs < 0).any():
idxs = np.where(idxs >= 0, idxs, target_len + idxs)
return idxs
class UniIdxr(IdxrBase):
"""Abstract class for resolving indices based on a single index."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
raise NotImplementedError
def __invert__(self):
def _op_func(x, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
idxs = np.setdiff1d(np.arange(len(index)), x)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self)
def __and__(self, other):
def _op_func(x, y, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
y = normalize_idxs(y, len(index))
idxs = np.intersect1d(x, y)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __or__(self, other):
def _op_func(x, y, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
y = normalize_idxs(y, len(index))
idxs = np.union1d(x, y)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __sub__(self, other):
def _op_func(x, y, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
y = normalize_idxs(y, len(index))
idxs = np.setdiff1d(x, y)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __xor__(self, other):
def _op_func(x, y, index=None, freq=None):
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
y = normalize_idxs(y, len(index))
idxs = np.setxor1d(x, y)
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __lshift__(self, other):
def _op_func(x, y, index=None, freq=None):
if not checks.is_int(y):
raise TypeError("Second operand in __lshift__ must be an integer")
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
shifted = x - y
idxs = shifted[shifted >= 0]
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
def __rshift__(self, other):
def _op_func(x, y, index=None, freq=None):
if not checks.is_int(y):
raise TypeError("Second operand in __rshift__ must be an integer")
if index is None:
raise ValueError("Index is required")
x = normalize_idxs(x, len(index))
shifted = x + y
idxs = shifted[shifted >= 0]
self.check_idxs(idxs)
return idxs
return UniIdxrOp(_op_func, self, other)
@define
class UniIdxrOp(UniIdxr, DefineMixin):
"""Class for applying an operation to one or more indexers.
Produces a single set of indices."""
op_func: tp.Callable = define.field()
"""Operation function that takes the indices of each indexer (as `*args`), `index` (keyword argument),
and `freq` (keyword argument), and returns new indices."""
idxrs: tp.Tuple[object, ...] = define.field()
"""A tuple of one or more indexers."""
def __init__(self, op_func: tp.Callable, *idxrs) -> None:
if len(idxrs) == 1 and checks.is_iterable(idxrs[0]):
idxrs = idxrs[0]
DefineMixin.__init__(self, op_func=op_func, idxrs=idxrs)
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
idxr_indices = []
for idxr in self.idxrs:
if isinstance(idxr, IdxrBase):
checks.assert_instance_of(idxr, UniIdxr)
idxr_indices.append(idxr.get(index=index, freq=freq))
else:
idxr_indices.append(idxr)
return self.op_func(*idxr_indices, index=index, freq=freq)
@define
class PosIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as integer positions."""
value: tp.Union[None, tp.MaybeSequence[tp.MaybeSequence[int]], tp.Slice] = define.field()
"""One or more integer positions."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
idxs = self.value
if checks.is_sequence(idxs) and not np.isscalar(idxs):
idxs = np.asarray(idxs)
if isinstance(idxs, hslice):
idxs = idxs.to_slice()
self.check_idxs(idxs)
return idxs
@define
class MaskIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as a mask."""
value: tp.Union[None, tp.Sequence[bool]] = define.field()
"""Mask."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
idxs = np.flatnonzero(self.value)
self.check_idxs(idxs)
return idxs
@define
class LabelIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as labels."""
value: tp.Union[None, tp.MaybeSequence[tp.Label], tp.Slice] = define.field()
"""One or more labels."""
closed_start: bool = define.field(default=True)
"""Whether slice start should be inclusive."""
closed_end: bool = define.field(default=True)
"""Whether slice end should be inclusive."""
level: tp.MaybeLevelSequence = define.field(default=None)
"""One or more levels."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
if index is None:
raise ValueError("Index is required")
if self.level is not None:
from vectorbtpro.base.indexes import select_levels
index = select_levels(index, self.level)
if isinstance(self.value, (slice, hslice)):
idxs = self.slice_indexer(
index,
self.value,
closed_start=self.closed_start,
closed_end=self.closed_end,
)
elif (checks.is_sequence(self.value) and not np.isscalar(self.value)) and (
not isinstance(index, pd.MultiIndex)
or (isinstance(index, pd.MultiIndex) and isinstance(self.value[0], tuple))
):
idxs = index.get_indexer_for(self.value)
else:
idxs = index.get_loc(self.value)
if isinstance(idxs, np.ndarray) and np.issubdtype(idxs.dtype, np.bool_):
idxs = np.flatnonzero(idxs)
self.check_idxs(idxs, check_minus_one=True)
return idxs
@define
class DatetimeIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as datetime-like objects."""
value: tp.Union[None, tp.MaybeSequence[tp.DatetimeLike], tp.Slice] = define.field()
"""One or more datetime-like objects."""
closed_start: bool = define.field(default=True)
"""Whether slice start should be inclusive."""
closed_end: bool = define.field(default=False)
"""Whether slice end should be inclusive."""
indexer_method: tp.Optional[str] = define.field(default="bfill")
"""Method for `pd.Index.get_indexer`.
Allows two additional values: "before" and "after"."""
below_to_zero: bool = define.field(default=False)
"""Whether to place 0 instead of -1 if `DatetimeIdxr.value` is below the first index."""
above_to_len: bool = define.field(default=False)
"""Whether to place `len(index)` instead of -1 if `DatetimeIdxr.value` is above the last index."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
if index is None:
raise ValueError("Index is required")
index = dt.prepare_dt_index(index)
checks.assert_instance_of(index, pd.DatetimeIndex)
if not index.is_unique:
raise ValueError("Datetime index must be unique")
if not index.is_monotonic_increasing:
raise ValueError("Datetime index must be monotonically increasing")
if isinstance(self.value, (slice, hslice)):
start = dt.try_align_dt_to_index(self.value.start, index)
stop = dt.try_align_dt_to_index(self.value.stop, index)
new_value = slice(start, stop, self.value.step)
idxs = self.slice_indexer(index, new_value, closed_start=self.closed_start, closed_end=self.closed_end)
elif checks.is_sequence(self.value) and not np.isscalar(self.value):
new_value = dt.try_align_to_dt_index(self.value, index)
idxs = index.get_indexer(new_value, method=self.indexer_method)
if self.below_to_zero:
idxs = np.where(new_value < index[0], 0, idxs)
if self.above_to_len:
idxs = np.where(new_value > index[-1], len(index), idxs)
else:
new_value = dt.try_align_dt_to_index(self.value, index)
if new_value < index[0] and self.below_to_zero:
idxs = 0
elif new_value > index[-1] and self.above_to_len:
idxs = len(index)
else:
if self.indexer_method is None or new_value in index:
idxs = index.get_loc(new_value)
if isinstance(idxs, np.ndarray) and np.issubdtype(idxs.dtype, np.bool_):
idxs = np.flatnonzero(idxs)
else:
indexer_method = self.indexer_method
if indexer_method is not None:
indexer_method = indexer_method.lower()
if indexer_method == "before":
new_value = new_value - pd.Timedelta(1, "ns")
indexer_method = "ffill"
elif indexer_method == "after":
new_value = new_value + pd.Timedelta(1, "ns")
indexer_method = "bfill"
idxs = index.get_indexer([new_value], method=indexer_method)[0]
self.check_idxs(idxs, check_minus_one=True)
return idxs
@define
class DTCIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices provided as datetime-like components."""
value: tp.Union[None, tp.MaybeSequence[tp.DTCLike], tp.Slice] = define.field()
"""One or more datetime-like components."""
parse_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `vectorbtpro.utils.datetime_.DTC.parse`."""
closed_start: bool = define.field(default=True)
"""Whether slice start should be inclusive."""
closed_end: bool = define.field(default=False)
"""Whether slice end should be inclusive."""
jitted: tp.JittedOption = define.field(default=None)
"""Jitting option passed to `vectorbtpro.utils.datetime_nb.index_matches_dtc_nb`
and `vectorbtpro.utils.datetime_nb.index_within_dtc_range_nb`."""
@staticmethod
def get_dtc_namedtuple(value: tp.Optional[tp.DTCLike] = None, **parse_kwargs) -> dt.DTCNT:
"""Convert a value to a `vectorbtpro.utils.datetime_.DTCNT` instance."""
if value is None:
return dt.DTC().to_namedtuple()
if isinstance(value, dt.DTC):
return value.to_namedtuple()
if isinstance(value, dt.DTCNT):
return value
return dt.DTC.parse(value, **parse_kwargs).to_namedtuple()
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
parse_kwargs = self.parse_kwargs
if parse_kwargs is None:
parse_kwargs = {}
if index is None:
raise ValueError("Index is required")
index = dt.prepare_dt_index(index)
ns_index = dt.to_ns(index)
checks.assert_instance_of(index, pd.DatetimeIndex)
if not index.is_unique:
raise ValueError("Datetime index must be unique")
if not index.is_monotonic_increasing:
raise ValueError("Datetime index must be monotonically increasing")
if isinstance(self.value, (slice, hslice)):
if self.value.step is not None:
raise ValueError("Step must be None")
if self.value.start is None and self.value.stop is None:
return slice(None, None, None)
start_dtc = self.get_dtc_namedtuple(self.value.start, **parse_kwargs)
end_dtc = self.get_dtc_namedtuple(self.value.stop, **parse_kwargs)
func = jit_reg.resolve_option(dt_nb.index_within_dtc_range_nb, self.jitted)
mask = func(ns_index, start_dtc, end_dtc, closed_start=self.closed_start, closed_end=self.closed_end)
elif checks.is_sequence(self.value) and not np.isscalar(self.value):
func = jit_reg.resolve_option(dt_nb.index_matches_dtc_nb, self.jitted)
dtcs = map(lambda x: self.get_dtc_namedtuple(x, **parse_kwargs), self.value)
masks = map(lambda x: func(ns_index, x), dtcs)
mask = functools.reduce(np.logical_or, masks)
else:
dtc = self.get_dtc_namedtuple(self.value, **parse_kwargs)
func = jit_reg.resolve_option(dt_nb.index_matches_dtc_nb, self.jitted)
mask = func(ns_index, dtc)
return MaskIdxr(mask).get(index=index, freq=freq)
@define
class PointIdxr(UniIdxr, DefineMixin):
"""Class for resolving index points."""
every: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Frequency either as an integer or timedelta.
Gets translated into `on` array by creating a range. If integer, an index sequence from `start` to `end`
(exclusive) is created and 'indices' as `kind` is used. If timedelta-like, a date sequence from
`start` to `end` (inclusive) is created and 'labels' as `kind` is used.
If `at_time` is not None and `every` and `on` are None, `every` defaults to one day."""
normalize_every: bool = define.field(default=False)
"""Normalize start/end dates to midnight before generating date range."""
at_time: tp.Optional[tp.TimeLike] = define.field(default=None)
"""Time of the day either as a (human-readable) string or `datetime.time`.
Every datetime in `on` gets floored to the daily frequency, while `at_time` gets converted into
a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_delta`.
Index must be datetime-like."""
start: tp.Optional[tp.Union[int, tp.DatetimeLike]] = define.field(default=None)
"""Start index/date.
If (human-readable) string, gets converted into a datetime.
If `every` is None, gets used to filter the final index array."""
end: tp.Optional[tp.Union[int, tp.DatetimeLike]] = define.field(default=None)
"""End index/date.
If (human-readable) string, gets converted into a datetime.
If `every` is None, gets used to filter the final index array."""
exact_start: bool = define.field(default=False)
"""Whether the first index should be exactly `start`.
Depending on `every`, the first index picked by `pd.date_range` may happen after `start`.
In such a case, `start` gets injected before the first index generated by `pd.date_range`."""
on: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None)
"""Index/label or a sequence of such.
Gets converted into datetime format whenever possible."""
add_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Offset to be added to each in `on`.
Gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`."""
kind: tp.Optional[str] = define.field(default=None)
"""Kind of data in `on`: indices or labels.
If None, gets assigned to `indices` if `on` contains integer data, otherwise to `labels`.
If `kind` is 'labels', `on` gets converted into indices using `pd.Index.get_indexer`.
Prior to this, gets its timezone aligned to the timezone of the index. If `kind` is 'indices',
`on` gets wrapped with NumPy."""
indexer_method: str = define.field(default="bfill")
"""Method for `pd.Index.get_indexer`.
Allows two additional values: "before" and "after"."""
indexer_tolerance: tp.Optional[tp.Union[int, tp.TimedeltaLike, tp.IndexLike]] = define.field(default=None)
"""Tolerance for `pd.Index.get_indexer`.
If `at_time` is set and `indexer_method` is neither exact nor nearest, `indexer_tolerance`
becomes such that the next element must be within the current day."""
skip_not_found: bool = define.field(default=True)
"""Whether to drop indices that are -1 (not found)."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if index is None:
raise ValueError("Index is required")
idxs = get_index_points(index, **self.asdict())
self.check_idxs(idxs, check_minus_one=True)
return idxs
point_idxr_defaults = {a.name: a.default for a in PointIdxr.fields}
def get_index_points(
index: tp.Index,
every: tp.Optional[tp.FrequencyLike] = point_idxr_defaults["every"],
normalize_every: bool = point_idxr_defaults["normalize_every"],
at_time: tp.Optional[tp.TimeLike] = point_idxr_defaults["at_time"],
start: tp.Optional[tp.Union[int, tp.DatetimeLike]] = point_idxr_defaults["start"],
end: tp.Optional[tp.Union[int, tp.DatetimeLike]] = point_idxr_defaults["end"],
exact_start: bool = point_idxr_defaults["exact_start"],
on: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = point_idxr_defaults["on"],
add_delta: tp.Optional[tp.FrequencyLike] = point_idxr_defaults["add_delta"],
kind: tp.Optional[str] = point_idxr_defaults["kind"],
indexer_method: str = point_idxr_defaults["indexer_method"],
indexer_tolerance: str = point_idxr_defaults["indexer_tolerance"],
skip_not_found: bool = point_idxr_defaults["skip_not_found"],
) -> tp.Array1d:
"""Translate indices or labels into index points.
See `PointIdxr` for arguments.
Usage:
* Provide nothing to generate at the beginning:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020-01", "2020-02", freq="1d")
>>> vbt.get_index_points(index)
array([0])
```
* Provide `every` as an integer frequency to generate index points using NumPy:
```pycon
>>> # Generate a point every five rows
>>> vbt.get_index_points(index, every=5)
array([ 0, 5, 10, 15, 20, 25, 30])
>>> # Generate a point every five rows starting at 6th row
>>> vbt.get_index_points(index, every=5, start=5)
array([ 5, 10, 15, 20, 25, 30])
>>> # Generate a point every five rows from 6th to 16th row
>>> vbt.get_index_points(index, every=5, start=5, end=15)
array([ 5, 10])
```
* Provide `every` as a time delta frequency to generate index points using Pandas:
```pycon
>>> # Generate a point every week
>>> vbt.get_index_points(index, every="W")
array([ 4, 11, 18, 25])
>>> # Generate a point every second day of the week
>>> vbt.get_index_points(index, every="W", add_delta="2d")
array([ 6, 13, 20, 27])
>>> # Generate a point every week, starting at 11th row
>>> vbt.get_index_points(index, every="W", start=10)
array([11, 18, 25])
>>> # Generate a point every week, starting exactly at 11th row
>>> vbt.get_index_points(index, every="W", start=10, exact_start=True)
array([10, 11, 18, 25])
>>> # Generate a point every week, starting at 2020-01-10
>>> vbt.get_index_points(index, every="W", start="2020-01-10")
array([11, 18, 25])
```
* Instead of using `every`, provide indices explicitly:
```pycon
>>> # Generate one point
>>> vbt.get_index_points(index, on="2020-01-07")
array([6])
>>> # Generate multiple points
>>> vbt.get_index_points(index, on=["2020-01-07", "2020-01-14"])
array([ 6, 13])
```
"""
index = dt.prepare_dt_index(index)
if on is not None and isinstance(on, str):
on = dt.try_align_dt_to_index(on, index)
if start is not None and isinstance(start, str):
start = dt.try_align_dt_to_index(start, index)
if end is not None and isinstance(end, str):
end = dt.try_align_dt_to_index(end, index)
if every is not None and not checks.is_int(every):
every = dt.to_freq(every)
start_used = False
end_used = False
if at_time is not None and every is None and on is None:
every = pd.Timedelta(days=1)
if every is not None:
start_used = True
end_used = True
if checks.is_int(every):
if start is None:
start = 0
if end is None:
end = len(index)
on = np.arange(start, end, every)
kind = "indices"
else:
if start is None:
start = 0
if checks.is_int(start):
start_date = index[start]
else:
start_date = start
if end is None:
end = len(index) - 1
if checks.is_int(end):
end_date = index[end]
else:
end_date = end
on = dt.date_range(
start_date,
end_date,
freq=every,
tz=index.tz,
normalize=normalize_every,
inclusive="both",
)
if exact_start and on[0] > start_date:
on = on.insert(0, start_date)
kind = "labels"
if kind is None:
if on is None:
if start is not None:
if checks.is_int(start):
kind = "indices"
else:
kind = "labels"
else:
kind = "indices"
else:
on = dt.prepare_dt_index(on)
if pd.api.types.is_integer_dtype(on):
kind = "indices"
else:
kind = "labels"
checks.assert_in(kind, ("indices", "labels"))
if on is None:
if start is not None:
on = start
start_used = True
else:
if kind.lower() in ("labels",):
on = index
else:
on = np.arange(len(index))
on = dt.prepare_dt_index(on)
if at_time is not None:
checks.assert_instance_of(on, pd.DatetimeIndex)
on = on.floor("D")
add_time_delta = dt.time_to_timedelta(at_time)
if indexer_tolerance is None:
indexer_method = indexer_method.lower()
if indexer_method in ("pad", "ffill"):
indexer_tolerance = add_time_delta
elif indexer_method in ("backfill", "bfill"):
indexer_tolerance = pd.Timedelta(days=1) - pd.Timedelta(1, "ns") - add_time_delta
if add_delta is None:
add_delta = add_time_delta
else:
add_delta += add_time_delta
if add_delta is not None:
on += dt.to_freq(add_delta)
if kind.lower() == "labels":
on = dt.try_align_to_dt_index(on, index)
if indexer_method is not None:
indexer_method = indexer_method.lower()
if indexer_method == "before":
on = on - pd.Timedelta(1, "ns")
indexer_method = "ffill"
elif indexer_method == "after":
on = on + pd.Timedelta(1, "ns")
indexer_method = "bfill"
index_points = index.get_indexer(on, method=indexer_method, tolerance=indexer_tolerance)
else:
index_points = np.asarray(on)
if start is not None and not start_used:
if not checks.is_int(start):
start = index.get_indexer([start], method="bfill").item(0)
index_points = index_points[index_points >= start]
if end is not None and not end_used:
if not checks.is_int(end):
end = index.get_indexer([end], method="ffill").item(0)
index_points = index_points[index_points <= end]
else:
index_points = index_points[index_points < end]
if skip_not_found:
index_points = index_points[index_points != -1]
return index_points
@define
class RangeIdxr(UniIdxr, DefineMixin):
"""Class for resolving index ranges."""
every: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Frequency either as an integer or timedelta.
Gets translated into `start` and `end` arrays by creating a range. If integer, an index sequence from `start`
to `end` (exclusive) is created and 'indices' as `kind` is used. If timedelta-like, a date sequence
from `start` to `end` (inclusive) is created and 'bounds' as `kind` is used.
If `start_time` and `end_time` are not None and `every`, `start`, and `end` are None,
`every` defaults to one day."""
normalize_every: bool = define.field(default=False)
"""Normalize start/end dates to midnight before generating date range."""
split_every: bool = define.field(default=True)
"""Whether to split the sequence generated using `every` into `start` and `end` arrays.
After creation, and if `split_every` is True, an index range is created from each pair of elements in
the generated sequence. Otherwise, the entire sequence is assigned to `start` and `end`, and only time
and delta instructions can be used to further differentiate between them.
Forced to False if `every`, `start_time`, and `end_time` are not None and `fixed_start` is False."""
start_time: tp.Optional[tp.TimeLike] = define.field(default=None)
"""Start time of the day either as a (human-readable) string or `datetime.time`.
Every datetime in `start` gets floored to the daily frequency, while `start_time` gets converted into
a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_start_delta`.
Index must be datetime-like."""
end_time: tp.Optional[tp.TimeLike] = define.field(default=None)
"""End time of the day either as a (human-readable) string or `datetime.time`.
Every datetime in `end` gets floored to the daily frequency, while `end_time` gets converted into
a timedelta using `vectorbtpro.utils.datetime_.time_to_timedelta` and added to `add_end_delta`.
Index must be datetime-like."""
lookback_period: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Lookback period either as an integer or offset.
If `lookback_period` is set, `start` becomes `end-lookback_period`. If `every` is not None,
the sequence is generated from `start+lookback_period` to `end` and then assigned to `end`.
If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`.
If integer, gets multiplied by the frequency of the index if the index is not integer."""
start: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None)
"""Start index/label or a sequence of such.
Gets converted into datetime format whenever possible.
Gets broadcasted together with `end`."""
end: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = define.field(default=None)
"""End index/label or a sequence of such.
Gets converted into datetime format whenever possible.
Gets broadcasted together with `start`."""
exact_start: bool = define.field(default=False)
"""Whether the first index in the `start` array should be exactly `start`.
Depending on `every`, the first index picked by `pd.date_range` may happen after `start`.
In such a case, `start` gets injected before the first index generated by `pd.date_range`.
Cannot be used together with `lookback_period`."""
fixed_start: bool = define.field(default=False)
"""Whether all indices in the `start` array should be exactly `start`.
Works only together with `every`.
Cannot be used together with `lookback_period`."""
closed_start: bool = define.field(default=True)
"""Whether `start` should be inclusive."""
closed_end: bool = define.field(default=False)
"""Whether `end` should be inclusive."""
add_start_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Offset to be added to each in `start`.
If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`."""
add_end_delta: tp.Optional[tp.FrequencyLike] = define.field(default=None)
"""Offset to be added to each in `end`.
If string, gets converted to a proper offset/timedelta using `vectorbtpro.utils.datetime_.to_freq`."""
kind: tp.Optional[str] = define.field(default=None)
"""Kind of data in `on`: indices, labels or bounds.
If None, gets assigned to `indices` if `start` and `end` contain integer data, to `bounds`
if `start`, `end`, and index are datetime-like, otherwise to `labels`.
If `kind` is 'labels', `start` and `end` get converted into indices using `pd.Index.get_indexer`.
Prior to this, get their timezone aligned to the timezone of the index. If `kind` is 'indices',
`start` and `end` get wrapped with NumPy. If kind` is 'bounds',
`vectorbtpro.base.resampling.base.Resampler.map_bounds_to_source_ranges` is used."""
skip_not_found: bool = define.field(default=True)
"""Whether to drop indices that are -1 (not found)."""
jitted: tp.JittedOption = define.field(default=None)
"""Jitting option passed to `vectorbtpro.base.resampling.base.Resampler.map_bounds_to_source_ranges`."""
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if index is None:
raise ValueError("Index is required")
from vectorbtpro.base.merging import column_stack_arrays
start_idxs, end_idxs = get_index_ranges(index, index_freq=freq, **self.asdict())
idxs = column_stack_arrays((start_idxs, end_idxs))
self.check_idxs(idxs, check_minus_one=True)
return idxs
range_idxr_defaults = {a.name: a.default for a in RangeIdxr.fields}
def get_index_ranges(
index: tp.Index,
index_freq: tp.Optional[tp.FrequencyLike] = None,
every: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["every"],
normalize_every: bool = range_idxr_defaults["normalize_every"],
split_every: bool = range_idxr_defaults["split_every"],
start_time: tp.Optional[tp.TimeLike] = range_idxr_defaults["start_time"],
end_time: tp.Optional[tp.TimeLike] = range_idxr_defaults["end_time"],
lookback_period: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["lookback_period"],
start: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = range_idxr_defaults["start"],
end: tp.Optional[tp.Union[int, tp.DatetimeLike, tp.IndexLike]] = range_idxr_defaults["end"],
exact_start: bool = range_idxr_defaults["exact_start"],
fixed_start: bool = range_idxr_defaults["fixed_start"],
closed_start: bool = range_idxr_defaults["closed_start"],
closed_end: bool = range_idxr_defaults["closed_end"],
add_start_delta: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["add_start_delta"],
add_end_delta: tp.Optional[tp.FrequencyLike] = range_idxr_defaults["add_end_delta"],
kind: tp.Optional[str] = range_idxr_defaults["kind"],
skip_not_found: bool = range_idxr_defaults["skip_not_found"],
jitted: tp.JittedOption = range_idxr_defaults["jitted"],
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Translate indices, labels, or bounds into index ranges.
See `RangeIdxr` for arguments.
Usage:
* Provide nothing to generate one largest index range:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020-01", "2020-02", freq="1d")
>>> np.column_stack(vbt.get_index_ranges(index))
array([[ 0, 32]])
```
* Provide `every` as an integer frequency to generate index ranges using NumPy:
```pycon
>>> # Generate a range every five rows
>>> np.column_stack(vbt.get_index_ranges(index, every=5))
array([[ 0, 5],
[ 5, 10],
[10, 15],
[15, 20],
[20, 25],
[25, 30]])
>>> # Generate a range every five rows, starting at 6th row
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every=5,
... start=5
... ))
array([[ 5, 10],
[10, 15],
[15, 20],
[20, 25],
[25, 30]])
>>> # Generate a range every five rows from 6th to 16th row
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every=5,
... start=5,
... end=15
... ))
array([[ 5, 10],
[10, 15]])
```
* Provide `every` as a time delta frequency to generate index ranges using Pandas:
```pycon
>>> # Generate a range every week
>>> np.column_stack(vbt.get_index_ranges(index, every="W"))
array([[ 4, 11],
[11, 18],
[18, 25]])
>>> # Generate a range every second day of the week
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... add_start_delta="2d"
... ))
array([[ 6, 11],
[13, 18],
[20, 25]])
>>> # Generate a range every week, starting at 11th row
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start=10
... ))
array([[11, 18],
[18, 25]])
>>> # Generate a range every week, starting exactly at 11th row
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start=10,
... exact_start=True
... ))
array([[10, 11],
[11, 18],
[18, 25]])
>>> # Generate a range every week, starting at 2020-01-10
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start="2020-01-10"
... ))
array([[11, 18],
[18, 25]])
>>> # Generate a range every week, each starting at 2020-01-10
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start="2020-01-10",
... fixed_start=True
... ))
array([[11, 18],
[11, 25]])
>>> # Generate an expanding range that increments by week
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... start=0,
... exact_start=True,
... fixed_start=True
... ))
array([[ 0, 4],
[ 0, 11],
[ 0, 18],
[ 0, 25]])
```
* Use a look-back period (instead of an end index):
```pycon
>>> # Generate a range every week, looking 5 days back
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... lookback_period=5
... ))
array([[ 6, 11],
[13, 18],
[20, 25]])
>>> # Generate a range every week, looking 2 weeks back
>>> np.column_stack(vbt.get_index_ranges(
... index,
... every="W",
... lookback_period="2W"
... ))
array([[ 0, 11],
[ 4, 18],
[11, 25]])
```
* Instead of using `every`, provide start and end indices explicitly:
```pycon
>>> # Generate one range
>>> np.column_stack(vbt.get_index_ranges(
... index,
... start="2020-01-01",
... end="2020-01-07"
... ))
array([[0, 6]])
>>> # Generate ranges between multiple dates
>>> np.column_stack(vbt.get_index_ranges(
... index,
... start=["2020-01-01", "2020-01-07"],
... end=["2020-01-07", "2020-01-14"]
... ))
array([[ 0, 6],
[ 6, 13]])
>>> # Generate ranges with a fixed start
>>> np.column_stack(vbt.get_index_ranges(
... index,
... start="2020-01-01",
... end=["2020-01-07", "2020-01-14"]
... ))
array([[ 0, 6],
[ 0, 13]])
```
* Use `closed_start` and `closed_end` to exclude any of the bounds:
```pycon
>>> # Generate ranges between multiple dates
>>> # by excluding the start date and including the end date
>>> np.column_stack(vbt.get_index_ranges(
... index,
... start=["2020-01-01", "2020-01-07"],
... end=["2020-01-07", "2020-01-14"],
... closed_start=False,
... closed_end=True
... ))
array([[ 1, 7],
[ 7, 14]])
```
"""
from vectorbtpro.base.indexes import repeat_index
from vectorbtpro.base.resampling.base import Resampler
index = dt.prepare_dt_index(index)
if isinstance(index, pd.DatetimeIndex):
if start is not None:
start = dt.try_align_to_dt_index(start, index)
if isinstance(start, pd.DatetimeIndex):
start = start.tz_localize(None)
if end is not None:
end = dt.try_align_to_dt_index(end, index)
if isinstance(end, pd.DatetimeIndex):
end = end.tz_localize(None)
naive_index = index.tz_localize(None)
else:
if start is not None:
if not isinstance(start, pd.Index):
try:
start = pd.Index(start)
except Exception as e:
start = pd.Index([start])
if end is not None:
if not isinstance(end, pd.Index):
try:
end = pd.Index(end)
except Exception as e:
end = pd.Index([end])
naive_index = index
if every is not None and not checks.is_int(every):
every = dt.to_freq(every)
if lookback_period is not None and not checks.is_int(lookback_period):
lookback_period = dt.to_freq(lookback_period)
if fixed_start and lookback_period is not None:
raise ValueError("Cannot use fixed_start and lookback_period together")
if exact_start and lookback_period is not None:
raise ValueError("Cannot use exact_start and lookback_period together")
if start_time is not None or end_time is not None:
if every is None and start is None and end is None:
every = pd.Timedelta(days=1)
if every is not None:
if not fixed_start:
if start_time is None and end_time is not None:
start_time = time(0, 0, 0, 0)
closed_start = True
if start_time is not None and end_time is None:
end_time = time(0, 0, 0, 0)
closed_end = False
if start_time is not None and end_time is not None and not fixed_start:
split_every = False
if checks.is_int(every):
if start is None:
start = 0
else:
start = start[0]
if end is None:
end = len(naive_index)
else:
end = end[-1]
if closed_end:
end -= 1
if lookback_period is None:
new_index = np.arange(start, end + 1, every)
if not split_every:
start = end = new_index
else:
if fixed_start:
start = np.full(len(new_index) - 1, new_index[0])
else:
start = new_index[:-1]
end = new_index[1:]
else:
end = np.arange(start + lookback_period, end + 1, every)
start = end - lookback_period
kind = "indices"
lookback_period = None
else:
if start is None:
start = 0
else:
start = start[0]
if checks.is_int(start):
start_date = naive_index[start]
else:
start_date = start
if end is None:
end = len(naive_index) - 1
else:
end = end[-1]
if checks.is_int(end):
end_date = naive_index[end]
else:
end_date = end
if lookback_period is None:
new_index = dt.date_range(
start_date,
end_date,
freq=every,
normalize=normalize_every,
inclusive="both",
)
if exact_start and new_index[0] > start_date:
new_index = new_index.insert(0, start_date)
if not split_every:
start = end = new_index
else:
if fixed_start:
start = repeat_index(new_index[[0]], len(new_index) - 1)
else:
start = new_index[:-1]
end = new_index[1:]
else:
if checks.is_int(lookback_period):
lookback_period *= dt.infer_index_freq(naive_index, freq=index_freq)
if isinstance(lookback_period, BaseOffset):
end = dt.date_range(
start_date,
end_date,
freq=every,
normalize=normalize_every,
inclusive="both",
)
start = end - lookback_period
start_mask = start >= start_date
start = start[start_mask]
end = end[start_mask]
else:
end = dt.date_range(
start_date + lookback_period,
end_date,
freq=every,
normalize=normalize_every,
inclusive="both",
)
start = end - lookback_period
kind = "bounds"
lookback_period = None
if kind is None:
if start is None and end is None:
kind = "indices"
else:
if start is not None:
ref_index = start
if end is not None:
ref_index = end
if pd.api.types.is_integer_dtype(ref_index):
kind = "indices"
elif isinstance(ref_index, pd.DatetimeIndex) and isinstance(naive_index, pd.DatetimeIndex):
kind = "bounds"
else:
kind = "labels"
checks.assert_in(kind, ("indices", "labels", "bounds"))
if end is None:
if kind.lower() in ("labels", "bounds"):
end = pd.Index([naive_index[-1]])
else:
end = pd.Index([len(naive_index)])
if start is not None and lookback_period is not None:
raise ValueError("Cannot use start and lookback_period together")
if start is None:
if lookback_period is None:
if kind.lower() in ("labels", "bounds"):
start = pd.Index([naive_index[0]])
else:
start = pd.Index([0])
else:
if checks.is_int(lookback_period) and not pd.api.types.is_integer_dtype(end):
lookback_period *= dt.infer_index_freq(naive_index, freq=index_freq)
start = end - lookback_period
if len(start) == 1 and len(end) > 1:
start = repeat_index(start, len(end))
elif len(start) > 1 and len(end) == 1:
end = repeat_index(end, len(start))
checks.assert_len_equal(start, end)
if start_time is not None:
checks.assert_instance_of(start, pd.DatetimeIndex)
start = start.floor("D")
add_start_time_delta = dt.time_to_timedelta(start_time)
if add_start_delta is None:
add_start_delta = add_start_time_delta
else:
add_start_delta += add_start_time_delta
else:
add_start_time_delta = None
if end_time is not None:
checks.assert_instance_of(end, pd.DatetimeIndex)
end = end.floor("D")
add_end_time_delta = dt.time_to_timedelta(end_time)
if add_start_time_delta is not None:
if add_end_time_delta < add_start_delta:
add_end_time_delta += pd.Timedelta(days=1)
if add_end_delta is None:
add_end_delta = add_end_time_delta
else:
add_end_delta += add_end_time_delta
if add_start_delta is not None:
start += dt.to_freq(add_start_delta)
if add_end_delta is not None:
end += dt.to_freq(add_end_delta)
if kind.lower() == "bounds":
range_starts, range_ends = Resampler.map_bounds_to_source_ranges(
source_index=naive_index.values,
target_lbound_index=start.values,
target_rbound_index=end.values,
closed_lbound=closed_start,
closed_rbound=closed_end,
skip_not_found=skip_not_found,
jitted=jitted,
)
else:
if kind.lower() == "labels":
range_starts = np.empty(len(start), dtype=int_)
range_ends = np.empty(len(end), dtype=int_)
range_index = pd.Series(np.arange(len(naive_index)), index=naive_index)
for i in range(len(range_starts)):
selected_range = range_index[start[i] : end[i]]
if len(selected_range) > 0 and not closed_start and selected_range.index[0] == start[i]:
selected_range = selected_range.iloc[1:]
if len(selected_range) > 0 and not closed_end and selected_range.index[-1] == end[i]:
selected_range = selected_range.iloc[:-1]
if len(selected_range) > 0:
range_starts[i] = selected_range.iloc[0]
range_ends[i] = selected_range.iloc[-1]
else:
range_starts[i] = -1
range_ends[i] = -1
else:
if not closed_start:
start = start + 1
if closed_end:
end = end + 1
range_starts = np.asarray(start)
range_ends = np.asarray(end)
if skip_not_found:
valid_mask = (range_starts != -1) & (range_ends != -1)
range_starts = range_starts[valid_mask]
range_ends = range_ends[valid_mask]
if np.any(range_starts >= range_ends):
raise ValueError("Some start indices are equal to or higher than end indices")
return range_starts, range_ends
@define
class AutoIdxr(UniIdxr, DefineMixin):
"""Class for resolving indices, datetime-like objects, frequency-like objects, and labels for one axis."""
value: tp.Union[
None,
tp.PosSel,
tp.LabelSel,
tp.MaybeSequence[tp.MaybeSequence[int]],
tp.MaybeSequence[tp.Label],
tp.MaybeSequence[tp.DatetimeLike],
tp.MaybeSequence[tp.DTCLike],
tp.FrequencyLike,
tp.Slice,
] = define.field()
"""One or more integer indices, datetime-like objects, frequency-like objects, or labels.
Can also be an instance of `vectorbtpro.utils.selection.PosSel` holding position(s)
and `vectorbtpro.utils.selection.LabelSel` holding label(s)."""
closed_start: bool = define.optional_field()
"""Whether slice start should be inclusive."""
closed_end: bool = define.optional_field()
"""Whether slice end should be inclusive."""
indexer_method: tp.Optional[str] = define.optional_field()
"""Method for `pd.Index.get_indexer`."""
below_to_zero: bool = define.optional_field()
"""Whether to place 0 instead of -1 if `AutoIdxr.value` is below the first index."""
above_to_len: bool = define.optional_field()
"""Whether to place `len(index)` instead of -1 if `AutoIdxr.value` is above the last index."""
level: tp.MaybeLevelSequence = define.field(default=None)
"""One or more levels.
If `level` is not None and `kind` is None, `kind` becomes "labels"."""
kind: tp.Optional[str] = define.field(default=None)
"""Kind of value.
Allowed are
* "position(s)" for `PosIdxr`
* "mask" for `MaskIdxr`
* "label(s)" for `LabelIdxr`
* "datetime" for `DatetimeIdxr`
* "dtc": for `DTCIdxr`
* "frequency" for `PointIdxr`
If None, will (try to) determine automatically based on the type of indices."""
idxr_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to the selected indexer."""
def __init__(self, *args, **kwargs) -> None:
idxr_kwargs = kwargs.pop("idxr_kwargs", None)
if idxr_kwargs is None:
idxr_kwargs = {}
else:
idxr_kwargs = dict(idxr_kwargs)
builtin_keys = {a.name for a in self.fields}
for k in list(kwargs.keys()):
if k not in builtin_keys:
idxr_kwargs[k] = kwargs.pop(k)
DefineMixin.__init__(self, *args, idxr_kwargs=idxr_kwargs, **kwargs)
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.MaybeIndexArray:
if self.value is None:
return slice(None, None, None)
value = self.value
kind = self.kind
if self.level is not None:
from vectorbtpro.base.indexes import select_levels
if index is None:
raise ValueError("Index is required")
index = select_levels(index, self.level)
if kind is None:
kind = "labels"
if self.idxr_kwargs is None:
idxr_kwargs = self.idxr_kwargs
else:
idxr_kwargs = None
if idxr_kwargs is None:
idxr_kwargs = {}
def _dtc_check_func(dtc):
return (
not dtc.has_full_datetime()
and self.indexer_method in (MISSING, None)
and self.below_to_zero is MISSING
and self.above_to_len is MISSING
)
if kind is None:
if isinstance(value, PosSel):
kind = "positions"
value = value.value
elif isinstance(value, LabelSel):
kind = "labels"
value = value.value
elif isinstance(value, (slice, hslice)):
if checks.is_int(value.start) or checks.is_int(value.stop):
kind = "positions"
elif value.start is None and value.stop is None:
kind = "positions"
else:
if index is None:
raise ValueError("Index is required")
if isinstance(index, pd.DatetimeIndex):
if dt.DTC.is_parsable(value.start, check_func=_dtc_check_func) or dt.DTC.is_parsable(
value.stop, check_func=_dtc_check_func
):
kind = "dtc"
else:
kind = "datetime"
else:
kind = "labels"
elif (checks.is_sequence(value) and not np.isscalar(value)) and (
index is None
or (
not isinstance(index, pd.MultiIndex)
or (isinstance(index, pd.MultiIndex) and isinstance(value[0], tuple))
)
):
if checks.is_bool(value[0]):
kind = "mask"
elif checks.is_int(value[0]):
kind = "positions"
elif (
(index is None or not isinstance(index, pd.MultiIndex) or not isinstance(value[0], tuple))
and checks.is_sequence(value[0])
and len(value[0]) == 2
and checks.is_int(value[0][0])
and checks.is_int(value[0][1])
):
kind = "positions"
else:
if index is None:
raise ValueError("Index is required")
elif isinstance(index, pd.DatetimeIndex):
if dt.DTC.is_parsable(value[0], check_func=_dtc_check_func):
kind = "dtc"
else:
kind = "datetime"
else:
kind = "labels"
else:
if checks.is_bool(value):
kind = "mask"
elif checks.is_int(value):
kind = "positions"
else:
if index is None:
raise ValueError("Index is required")
if isinstance(index, pd.DatetimeIndex):
if dt.DTC.is_parsable(value, check_func=_dtc_check_func):
kind = "dtc"
elif isinstance(value, str):
try:
if not value.isupper() and not value.islower():
raise Exception # "2020" shouldn't be a frequency
_ = dt.to_freq(value)
kind = "frequency"
except Exception as e:
try:
_ = dt.to_timestamp(value)
kind = "datetime"
except Exception as e:
raise ValueError(f"'{value}' is neither a frequency nor a datetime")
elif checks.is_frequency(value):
kind = "frequency"
else:
kind = "datetime"
else:
kind = "labels"
def _expand_target_kwargs(target_cls, **target_kwargs):
source_arg_names = {a.name for a in self.fields if a.default is MISSING}
target_arg_names = {a.name for a in target_cls.fields}
for arg_name in source_arg_names:
if arg_name in target_arg_names:
arg_value = getattr(self, arg_name)
if arg_value is not MISSING:
target_kwargs[arg_name] = arg_value
return target_kwargs
if kind.lower() in ("position", "positions"):
idx = PosIdxr(value, **_expand_target_kwargs(PosIdxr, **idxr_kwargs))
elif kind.lower() == "mask":
idx = MaskIdxr(value, **_expand_target_kwargs(MaskIdxr, **idxr_kwargs))
elif kind.lower() in ("label", "labels"):
idx = LabelIdxr(value, **_expand_target_kwargs(LabelIdxr, **idxr_kwargs))
elif kind.lower() == "datetime":
idx = DatetimeIdxr(value, **_expand_target_kwargs(DatetimeIdxr, **idxr_kwargs))
elif kind.lower() == "dtc":
idx = DTCIdxr(value, **_expand_target_kwargs(DTCIdxr, **idxr_kwargs))
elif kind.lower() == "frequency":
idx = PointIdxr(every=value, **_expand_target_kwargs(PointIdxr, **idxr_kwargs))
else:
raise ValueError(f"Invalid kind: '{kind}'")
return idx.get(index=index, freq=freq)
@define
class RowIdxr(IdxrBase, DefineMixin):
"""Class for resolving row indices."""
idxr: object = define.field()
"""Indexer.
Can be an instance of `UniIdxr`, a custom template, or a value to be wrapped with `AutoIdxr`."""
idxr_kwargs: tp.KwargsLike = define.field()
"""Keyword arguments passed to `AutoIdxr`."""
def __init__(self, idxr: object, **idxr_kwargs) -> None:
DefineMixin.__init__(self, idxr=idxr, idxr_kwargs=hdict(idxr_kwargs))
def get(
self,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
) -> tp.MaybeIndexArray:
idxr = self.idxr
if isinstance(idxr, CustomTemplate):
_template_context = merge_dicts(dict(index=index, freq=freq), template_context)
idxr = idxr.substitute(_template_context, eval_id="idxr")
if not isinstance(idxr, UniIdxr):
if isinstance(idxr, IdxrBase):
raise TypeError(f"Indexer of {type(self)} must be an instance of UniIdxr")
idxr = AutoIdxr(idxr, **self.idxr_kwargs)
return idxr.get(index=index, freq=freq)
@define
class ColIdxr(IdxrBase, DefineMixin):
"""Class for resolving column indices."""
idxr: object = define.field()
"""Indexer.
Can be an instance of `UniIdxr`, a custom template, or a value to be wrapped with `AutoIdxr`."""
idxr_kwargs: tp.KwargsLike = define.field()
"""Keyword arguments passed to `AutoIdxr`."""
def __init__(self, idxr: object, **idxr_kwargs) -> None:
DefineMixin.__init__(self, idxr=idxr, idxr_kwargs=hdict(idxr_kwargs))
def get(
self,
columns: tp.Optional[tp.Index] = None,
template_context: tp.KwargsLike = None,
) -> tp.MaybeIndexArray:
idxr = self.idxr
if isinstance(idxr, CustomTemplate):
_template_context = merge_dicts(dict(columns=columns), template_context)
idxr = idxr.substitute(_template_context, eval_id="idxr")
if not isinstance(idxr, UniIdxr):
if isinstance(idxr, IdxrBase):
raise TypeError(f"Indexer of {type(self)} must be an instance of UniIdxr")
idxr = AutoIdxr(idxr, **self.idxr_kwargs)
return idxr.get(index=columns)
@define
class Idxr(IdxrBase, DefineMixin):
"""Class for resolving indices."""
idxrs: tp.Tuple[object, ...] = define.field()
"""A tuple of one or more indexers.
If one indexer is provided, can be an instance of `RowIdxr` or `ColIdxr`,
a custom template, or a value to wrapped with `RowIdxr`.
If two indexers are provided, can be an instance of `RowIdxr` and `ColIdxr` respectively,
or a value to wrapped with `RowIdxr` and `ColIdxr` respectively."""
idxr_kwargs: tp.KwargsLike = define.field()
"""Keyword arguments passed to `RowIdxr` and `ColIdxr`."""
def __init__(self, *idxrs: object, **idxr_kwargs) -> None:
DefineMixin.__init__(self, idxrs=idxrs, idxr_kwargs=hdict(idxr_kwargs))
def get(
self,
index: tp.Optional[tp.Index] = None,
columns: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
) -> tp.Tuple[tp.MaybeIndexArray, tp.MaybeIndexArray]:
if len(self.idxrs) == 0:
raise ValueError("Must provide at least one indexer")
elif len(self.idxrs) == 1:
idxr = self.idxrs[0]
if isinstance(idxr, CustomTemplate):
_template_context = merge_dicts(dict(index=index, columns=columns, freq=freq), template_context)
idxr = idxr.substitute(_template_context, eval_id="idxr")
if isinstance(idxr, tuple):
return type(self)(*idxr).get(
index=index,
columns=columns,
freq=freq,
template_context=template_context,
)
return type(self)(idxr).get(
index=index,
columns=columns,
freq=freq,
template_context=template_context,
)
if isinstance(idxr, ColIdxr):
row_idxr = None
col_idxr = idxr
else:
row_idxr = idxr
col_idxr = None
elif len(self.idxrs) == 2:
row_idxr = self.idxrs[0]
col_idxr = self.idxrs[1]
else:
raise ValueError("Must provide at most two indexers")
if not isinstance(row_idxr, RowIdxr):
if isinstance(row_idxr, (ColIdxr, Idxr)):
raise TypeError(f"Indexer {type(row_idxr)} not supported as a row indexer")
row_idxr = RowIdxr(row_idxr, **self.idxr_kwargs)
row_idxs = row_idxr.get(index=index, freq=freq, template_context=template_context)
if not isinstance(col_idxr, ColIdxr):
if isinstance(col_idxr, (RowIdxr, Idxr)):
raise TypeError(f"Indexer {type(col_idxr)} not supported as a column indexer")
col_idxr = ColIdxr(col_idxr, **self.idxr_kwargs)
col_idxs = col_idxr.get(columns=columns, template_context=template_context)
return row_idxs, col_idxs
def get_idxs(
idxr: object,
index: tp.Optional[tp.Index] = None,
columns: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Tuple[tp.MaybeIndexArray, tp.MaybeIndexArray]:
"""Translate indexer to row and column indices.
If `idxr` is not an indexer class, wraps it with `Idxr`.
Keyword arguments are passed when constructing a new `Idxr`."""
if not isinstance(idxr, Idxr):
idxr = Idxr(idxr, **kwargs)
return idxr.get(index=index, columns=columns, freq=freq, template_context=template_context)
class index_dict(pdict):
"""Dict that contains indexer objects as keys and values to be set as values.
Each indexer object must be hashable. To make a slice hashable, use `hslice`.
To make an array hashable, convert it into a tuple.
To set a default value, use the `_def` key (case-sensitive!)."""
pass
IdxSetterT = tp.TypeVar("IdxSetterT", bound="IdxSetter")
@define
class IdxSetter(DefineMixin):
"""Class for setting values based on indexing."""
idx_items: tp.List[tp.Tuple[object, tp.ArrayLike]] = define.field()
"""Items where the first element is an indexer and the second element is a value to be set."""
@classmethod
def set_row_idxs(cls, arr: tp.Array, idxs: tp.MaybeIndexArray, v: tp.Any) -> None:
"""Set row indices in an array."""
from vectorbtpro.base.reshaping import broadcast_array_to
if not isinstance(v, np.ndarray):
v = np.asarray(v)
single_v = v.size == 1 or (v.ndim == 2 and v.shape[0] == 1)
if arr.ndim == 2:
single_row = not isinstance(idxs, slice) and (np.isscalar(idxs) or idxs.size == 1)
if not single_row:
if v.ndim == 1 and v.size > 1:
v = v[:, None]
if isinstance(idxs, np.ndarray) and idxs.ndim == 2:
if not single_v:
if arr.ndim == 2:
v = broadcast_array_to(v, (len(idxs), arr.shape[1]))
else:
v = broadcast_array_to(v, (len(idxs),))
for i in range(len(idxs)):
_slice = slice(idxs[i, 0], idxs[i, 1])
if not single_v:
cls.set_row_idxs(arr, _slice, v[[i]])
else:
cls.set_row_idxs(arr, _slice, v)
else:
arr[idxs] = v
@classmethod
def set_col_idxs(cls, arr: tp.Array, idxs: tp.MaybeIndexArray, v: tp.Any) -> None:
"""Set column indices in an array."""
from vectorbtpro.base.reshaping import broadcast_array_to
if not isinstance(v, np.ndarray):
v = np.asarray(v)
single_v = v.size == 1 or (v.ndim == 2 and v.shape[1] == 1)
if isinstance(idxs, np.ndarray) and idxs.ndim == 2:
if not single_v:
v = broadcast_array_to(v, (arr.shape[0], len(idxs)))
for j in range(len(idxs)):
_slice = slice(idxs[j, 0], idxs[j, 1])
if not single_v:
cls.set_col_idxs(arr, _slice, v[:, [j]])
else:
cls.set_col_idxs(arr, _slice, v)
else:
arr[:, idxs] = v
@classmethod
def set_row_and_col_idxs(
cls,
arr: tp.Array,
row_idxs: tp.MaybeIndexArray,
col_idxs: tp.MaybeIndexArray,
v: tp.Any,
) -> None:
"""Set row and column indices in an array."""
from vectorbtpro.base.reshaping import broadcast_array_to
if not isinstance(v, np.ndarray):
v = np.asarray(v)
single_v = v.size == 1
if (
isinstance(row_idxs, np.ndarray)
and row_idxs.ndim == 2
and isinstance(col_idxs, np.ndarray)
and col_idxs.ndim == 2
):
if not single_v:
v = broadcast_array_to(v, (len(row_idxs), len(col_idxs)))
for i in range(len(row_idxs)):
for j in range(len(col_idxs)):
row_slice = slice(row_idxs[i, 0], row_idxs[i, 1])
col_slice = slice(col_idxs[j, 0], col_idxs[j, 1])
if not single_v:
cls.set_row_and_col_idxs(arr, row_slice, col_slice, v[i, j])
else:
cls.set_row_and_col_idxs(arr, row_slice, col_slice, v)
elif isinstance(row_idxs, np.ndarray) and row_idxs.ndim == 2:
if not single_v:
if isinstance(col_idxs, slice):
col_idxs = np.arange(arr.shape[1])[col_idxs]
v = broadcast_array_to(v, (len(row_idxs), len(col_idxs)))
for i in range(len(row_idxs)):
row_slice = slice(row_idxs[i, 0], row_idxs[i, 1])
if not single_v:
cls.set_row_and_col_idxs(arr, row_slice, col_idxs, v[[i]])
else:
cls.set_row_and_col_idxs(arr, row_slice, col_idxs, v)
elif isinstance(col_idxs, np.ndarray) and col_idxs.ndim == 2:
if not single_v:
if isinstance(row_idxs, slice):
row_idxs = np.arange(arr.shape[0])[row_idxs]
v = broadcast_array_to(v, (len(row_idxs), len(col_idxs)))
for j in range(len(col_idxs)):
col_slice = slice(col_idxs[j, 0], col_idxs[j, 1])
if not single_v:
cls.set_row_and_col_idxs(arr, row_idxs, col_slice, v[:, [j]])
else:
cls.set_row_and_col_idxs(arr, row_idxs, col_slice, v)
else:
if np.isscalar(row_idxs) or np.isscalar(col_idxs):
arr[row_idxs, col_idxs] = v
elif np.isscalar(v) and (isinstance(row_idxs, slice) or isinstance(col_idxs, slice)):
arr[row_idxs, col_idxs] = v
elif np.isscalar(v):
arr[np.ix_(row_idxs, col_idxs)] = v
else:
if isinstance(row_idxs, slice):
row_idxs = np.arange(arr.shape[0])[row_idxs]
if isinstance(col_idxs, slice):
col_idxs = np.arange(arr.shape[1])[col_idxs]
v = broadcast_array_to(v, (len(row_idxs), len(col_idxs)))
arr[np.ix_(row_idxs, col_idxs)] = v
def get_set_meta(
self,
shape: tp.ShapeLike,
index: tp.Optional[tp.Index] = None,
columns: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
) -> tp.Kwargs:
"""Get meta of setting operations in `IdxSetter.idx_items`."""
from vectorbtpro.base.reshaping import to_tuple_shape
shape = to_tuple_shape(shape)
rows_changed = False
cols_changed = False
set_funcs = []
default = None
for idxr, v in self.idx_items:
if isinstance(idxr, str) and idxr == "_def":
if default is None:
default = v
continue
row_idxs, col_idxs = get_idxs(
idxr,
index=index,
columns=columns,
freq=freq,
template_context=template_context,
)
if isinstance(v, CustomTemplate):
_template_context = merge_dicts(
dict(
idxr=idxr,
row_idxs=row_idxs,
col_idxs=col_idxs,
),
template_context,
)
v = v.substitute(_template_context, eval_id="set")
if not isinstance(v, np.ndarray):
v = np.asarray(v)
def _check_use_idxs(idxs):
use_idxs = True
if isinstance(idxs, slice):
if idxs.start is None and idxs.stop is None and idxs.step is None:
use_idxs = False
if isinstance(idxs, np.ndarray):
if idxs.size == 0:
use_idxs = False
return use_idxs
use_row_idxs = _check_use_idxs(row_idxs)
use_col_idxs = _check_use_idxs(col_idxs)
if use_row_idxs and use_col_idxs:
set_funcs.append(partial(self.set_row_and_col_idxs, row_idxs=row_idxs, col_idxs=col_idxs, v=v))
rows_changed = True
cols_changed = True
elif use_col_idxs:
set_funcs.append(partial(self.set_col_idxs, idxs=col_idxs, v=v))
if checks.is_int(col_idxs):
if v.size > 1:
rows_changed = True
else:
if v.ndim == 2:
if v.shape[0] > 1:
rows_changed = True
cols_changed = True
else:
set_funcs.append(partial(self.set_row_idxs, idxs=row_idxs, v=v))
if use_row_idxs:
rows_changed = True
if len(shape) == 2:
if checks.is_int(row_idxs):
if v.size > 1:
cols_changed = True
else:
if v.ndim == 2:
if v.shape[1] > 1:
cols_changed = True
return dict(
default=default,
set_funcs=set_funcs,
rows_changed=rows_changed,
cols_changed=cols_changed,
)
def set(self, arr: tp.Array, set_funcs: tp.Optional[tp.Sequence[tp.Callable]] = None, **kwargs) -> None:
"""Set values of a NumPy array based on `IdxSetter.get_set_meta`."""
if set_funcs is None:
set_meta = self.get_set_meta(arr.shape, **kwargs)
set_funcs = set_meta["set_funcs"]
for set_op in set_funcs:
set_op(arr)
def set_pd(self, pd_arr: tp.SeriesFrame, **kwargs) -> None:
"""Set values of a Pandas array based on `IdxSetter.get_set_meta`."""
from vectorbtpro.base.indexes import get_index
index = get_index(pd_arr, 0)
columns = get_index(pd_arr, 1)
freq = dt.infer_index_freq(index)
self.set(pd_arr.values, index=index, columns=columns, freq=freq, **kwargs)
def fill_and_set(
self,
shape: tp.ShapeLike,
keep_flex: bool = False,
fill_value: tp.Scalar = np.nan,
**kwargs,
) -> tp.Array:
"""Fill a new array and set its values based on `IdxSetter.get_set_meta`.
If `keep_flex` is True, will return the most memory-efficient array representation
capable of flexible indexing.
If `fill_value` is None, will search for the `_def` key in `IdxSetter.idx_items`.
If there's none, will be set to NaN."""
set_meta = self.get_set_meta(shape, **kwargs)
if set_meta["default"] is not None:
fill_value = set_meta["default"]
if isinstance(fill_value, str):
dtype = object
else:
dtype = None
if keep_flex and not set_meta["cols_changed"] and not set_meta["rows_changed"]:
arr = np.full((1,) if len(shape) == 1 else (1, 1), fill_value, dtype=dtype)
elif keep_flex and not set_meta["cols_changed"]:
arr = np.full(shape if len(shape) == 1 else (shape[0], 1), fill_value, dtype=dtype)
elif keep_flex and not set_meta["rows_changed"]:
arr = np.full((1, shape[1]), fill_value, dtype=dtype)
else:
arr = np.full(shape, fill_value, dtype=dtype)
self.set(arr, set_funcs=set_meta["set_funcs"])
return arr
class IdxSetterFactory(Base):
"""Class for building index setters."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
"""Get an instance of `IdxSetter` or a dict of such instances - one per array name."""
raise NotImplementedError
@define
class IdxDict(IdxSetterFactory, DefineMixin):
"""Class for building an index setter from a dict."""
index_dct: dict = define.field()
"""Dict that contains indexer objects as keys and values to be set as values."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
return IdxSetter(list(self.index_dct.items()))
@define
class IdxSeries(IdxSetterFactory, DefineMixin):
"""Class for building an index setter from a Series."""
sr: tp.AnyArray1d = define.field()
"""Series or any array-like object to create the Series from."""
split: bool = define.field(default=False)
"""Whether to split the setting operation.
If False, will set all values using a single operation.
Otherwise, will do one operation per element."""
idx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `idx` if the indexer isn't an instance of `Idxr`."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
sr = self.sr
split = self.split
idx_kwargs = self.idx_kwargs
if idx_kwargs is None:
idx_kwargs = {}
if not isinstance(sr, pd.Series):
sr = pd.Series(sr)
if split:
idx_items = list(sr.items())
else:
idx_items = [(sr.index, sr.values)]
new_idx_items = []
for idxr, v in idx_items:
if idxr is None:
raise ValueError("Indexer cannot be None")
if not isinstance(idxr, Idxr):
idxr = idx(idxr, **idx_kwargs)
new_idx_items.append((idxr, v))
return IdxSetter(new_idx_items)
@define
class IdxFrame(IdxSetterFactory, DefineMixin):
"""Class for building an index setter from a DataFrame."""
df: tp.AnyArray2d = define.field()
"""DataFrame or any array-like object to create the DataFrame from."""
split: tp.Union[bool, str] = define.field(default=False)
"""Whether to split the setting operation.
If False, will set all values using a single operation.
Otherwise, the following options are supported:
* 'columns': one operation per column
* 'rows': one operation per row
* True or 'elements': one operation per element"""
rowidx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `rowidx` if the indexer isn't an instance of `RowIdxr`."""
colidx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `colidx` if the indexer isn't an instance of `ColIdxr`."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
df = self.df
split = self.split
rowidx_kwargs = self.rowidx_kwargs
colidx_kwargs = self.colidx_kwargs
if rowidx_kwargs is None:
rowidx_kwargs = {}
if colidx_kwargs is None:
colidx_kwargs = {}
if not isinstance(df, pd.DataFrame):
df = pd.DataFrame(df)
if isinstance(split, bool):
if split:
split = "elements"
else:
split = None
if split is not None:
if split.lower() == "columns":
idx_items = []
for col, sr in df.items():
idx_items.append((sr.index, col, sr.values))
elif split.lower() == "rows":
idx_items = []
for row, sr in df.iterrows():
idx_items.append((row, df.columns, sr.values))
elif split.lower() == "elements":
idx_items = []
for col, sr in df.items():
for row, v in sr.items():
idx_items.append((row, col, v))
else:
raise ValueError(f"Invalid split: '{split}'")
else:
idx_items = [(df.index, df.columns, df.values)]
new_idx_items = []
for row_idxr, col_idxr, v in idx_items:
if row_idxr is None:
raise ValueError("Row indexer cannot be None")
if col_idxr is None:
raise ValueError("Column indexer cannot be None")
if row_idxr is not None and not isinstance(row_idxr, RowIdxr):
row_idxr = rowidx(row_idxr, **rowidx_kwargs)
if col_idxr is not None and not isinstance(col_idxr, ColIdxr):
col_idxr = colidx(col_idxr, **colidx_kwargs)
new_idx_items.append((idx(row_idxr, col_idxr), v))
return IdxSetter(new_idx_items)
@define
class IdxRecords(IdxSetterFactory, DefineMixin):
"""Class for building index setters from records - one per field."""
records: tp.RecordsLike = define.field()
"""Series, DataFrame, or any sequence of mapping-like objects.
If a Series or DataFrame and the index is not a default range, the index will become a row field.
If a custom row field is provided, the index will be ignored."""
row_field: tp.Union[None, bool, tp.Label] = define.field(default=None)
"""Row field.
If None or True, will search for "row", "index", "open time", and "date" (case-insensitive).
If `IdxRecords.records` is a Series or DataFrame, will also include the index name
if the index is not a default range.
If a record doesn't have a row field, all rows will be set.
If there's no row and column field, the field value will become the default of the entire array."""
col_field: tp.Union[None, bool, tp.Label] = define.field(default=None)
"""Column field.
If None or True, will search for "col", "column", and "symbol" (case-insensitive).
If a record doesn't have a column field, all columns will be set.
If there's no row and column field, the field value will become the default of the entire array."""
rowidx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `rowidx` if the indexer isn't an instance of `RowIdxr`."""
colidx_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `colidx` if the indexer isn't an instance of `ColIdxr`."""
def get(self) -> tp.Union[IdxSetter, tp.Dict[tp.Label, IdxSetter]]:
records = self.records
row_field = self.row_field
col_field = self.col_field
rowidx_kwargs = self.rowidx_kwargs
colidx_kwargs = self.colidx_kwargs
if rowidx_kwargs is None:
rowidx_kwargs = {}
if colidx_kwargs is None:
colidx_kwargs = {}
default_index = False
index_field = None
if isinstance(records, pd.Series):
records = records.to_frame()
if isinstance(records, pd.DataFrame):
records = records
if checks.is_default_index(records.index):
default_index = True
records = records.reset_index(drop=default_index)
if not default_index:
index_field = records.columns[0]
records = records.itertuples(index=False)
def _resolve_field_meta(fields):
_row_field = row_field
_row_kind = None
_col_field = col_field
_col_kind = None
row_fields = set()
col_fields = set()
for field in fields:
if isinstance(field, str) and index_field is not None and field == index_field:
row_fields.add((field, None))
if isinstance(field, str) and field.lower() in ("row", "index"):
row_fields.add((field, None))
if isinstance(field, str) and field.lower() in ("open time", "date", "datetime"):
if (field, None) in row_fields:
row_fields.remove((field, None))
row_fields.add((field, "datetime"))
if isinstance(field, str) and field.lower() in ("col", "column"):
col_fields.add((field, None))
if isinstance(field, str) and field.lower() == "symbol":
if (field, None) in col_fields:
col_fields.remove((field, None))
col_fields.add((field, "labels"))
if _row_field in (None, True):
if len(row_fields) == 0:
if _row_field is True:
raise ValueError("Cannot find row field")
_row_field = None
elif len(row_fields) == 1:
_row_field, _row_kind = row_fields.pop()
else:
raise ValueError("Multiple row field candidates")
elif _row_field is False:
_row_field = None
if _col_field in (None, True):
if len(col_fields) == 0:
if _col_field is True:
raise ValueError("Cannot find column field")
_col_field = None
elif len(col_fields) == 1:
_col_field, _col_kind = col_fields.pop()
else:
raise ValueError("Multiple column field candidates")
elif _col_field is False:
_col_field = None
field_meta = dict()
field_meta["row_field"] = _row_field
field_meta["row_kind"] = _row_kind
field_meta["col_field"] = _col_field
field_meta["col_kind"] = _col_kind
return field_meta
idx_items = dict()
for r in records:
r = to_field_mapping(r)
field_meta = _resolve_field_meta(r.keys())
if field_meta["row_field"] is None:
row_idxr = None
else:
row_idxr = r.get(field_meta["row_field"], None)
if row_idxr == "_def":
row_idxr = None
if row_idxr is not None and not isinstance(row_idxr, RowIdxr):
_rowidx_kwargs = dict(rowidx_kwargs)
if field_meta["row_kind"] is not None and "kind" not in _rowidx_kwargs:
_rowidx_kwargs["kind"] = field_meta["row_kind"]
row_idxr = rowidx(row_idxr, **_rowidx_kwargs)
if field_meta["col_field"] is None:
col_idxr = None
else:
col_idxr = r.get(field_meta["col_field"], None)
if col_idxr is not None and not isinstance(col_idxr, ColIdxr):
_colidx_kwargs = dict(colidx_kwargs)
if field_meta["col_kind"] is not None and "kind" not in _colidx_kwargs:
_colidx_kwargs["kind"] = field_meta["col_kind"]
col_idxr = colidx(col_idxr, **_colidx_kwargs)
if isinstance(col_idxr, str) and col_idxr == "_def":
col_idxr = None
item_produced = False
for k, v in r.items():
if index_field is not None and k == index_field:
continue
if field_meta["row_field"] is not None and k == field_meta["row_field"]:
continue
if field_meta["col_field"] is not None and k == field_meta["col_field"]:
continue
if k not in idx_items:
idx_items[k] = []
if row_idxr is None and col_idxr is None:
idx_items[k].append(("_def", v))
else:
idx_items[k].append((idx(row_idxr, col_idxr), v))
item_produced = True
if not item_produced:
raise ValueError(f"Record {r} has no fields to set")
idx_setters = dict()
for k, v in idx_items.items():
idx_setters[k] = IdxSetter(v)
return idx_setters
posidx = PosIdxr
"""Shortcut for `PosIdxr`."""
__pdoc__["posidx"] = False
maskidx = MaskIdxr
"""Shortcut for `MaskIdxr`."""
__pdoc__["maskidx"] = False
lbidx = LabelIdxr
"""Shortcut for `LabelIdxr`."""
__pdoc__["lbidx"] = False
dtidx = DatetimeIdxr
"""Shortcut for `DatetimeIdxr`."""
__pdoc__["dtidx"] = False
dtcidx = DTCIdxr
"""Shortcut for `DTCIdxr`."""
__pdoc__["dtcidx"] = False
pointidx = PointIdxr
"""Shortcut for `PointIdxr`."""
__pdoc__["pointidx"] = False
rangeidx = RangeIdxr
"""Shortcut for `RangeIdxr`."""
__pdoc__["rangeidx"] = False
autoidx = AutoIdxr
"""Shortcut for `AutoIdxr`."""
__pdoc__["autoidx"] = False
rowidx = RowIdxr
"""Shortcut for `RowIdxr`."""
__pdoc__["rowidx"] = False
colidx = ColIdxr
"""Shortcut for `ColIdxr`."""
__pdoc__["colidx"] = False
idx = Idxr
"""Shortcut for `Idxr`."""
__pdoc__["idx"] = False
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for merging arrays."""
from functools import partial
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexes import stack_indexes, concat_indexes, clean_index
from vectorbtpro.base.reshaping import to_1d_array, to_2d_array
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import resolve_dict, merge_dicts, HybridConfig
from vectorbtpro.utils.execution import NoResult, NoResultsException, filter_out_no_results
from vectorbtpro.utils.merging import MergeFunc
__all__ = [
"concat_arrays",
"row_stack_arrays",
"column_stack_arrays",
"concat_merge",
"row_stack_merge",
"column_stack_merge",
"imageio_merge",
"mixed_merge",
]
__pdoc__ = {}
def concat_arrays(*arrs: tp.MaybeSequence[tp.AnyArray]) -> tp.Array1d:
"""Concatenate arrays."""
if len(arrs) == 1:
arrs = arrs[0]
arrs = list(arrs)
arrs = list(map(to_1d_array, arrs))
return np.concatenate(arrs)
def row_stack_arrays(*arrs: tp.MaybeSequence[tp.AnyArray], expand_axis: int = 1) -> tp.Array2d:
"""Stack arrays along rows."""
if len(arrs) == 1:
arrs = arrs[0]
arrs = list(arrs)
arrs = list(map(partial(to_2d_array, expand_axis=expand_axis), arrs))
return np.concatenate(arrs, axis=0)
def column_stack_arrays(*arrs: tp.MaybeSequence[tp.AnyArray], expand_axis: int = 1) -> tp.Array2d:
"""Stack arrays along columns."""
if len(arrs) == 1:
arrs = arrs[0]
arrs = list(arrs)
arrs = list(map(partial(to_2d_array, expand_axis=expand_axis), arrs))
common_shape = None
can_concatenate = True
for arr in arrs:
if common_shape is None:
common_shape = arr.shape
if arr.shape != common_shape:
can_concatenate = False
continue
if not (arr.ndim == 1 or (arr.ndim == 2 and arr.shape[1] == 1)):
can_concatenate = False
continue
if can_concatenate:
return np.concatenate(arrs, axis=0).reshape((len(arrs), common_shape[0])).T
return np.concatenate(arrs, axis=1)
def concat_merge(
*objs,
keys: tp.Optional[tp.Index] = None,
filter_results: bool = True,
raise_no_results: bool = True,
wrap: tp.Optional[bool] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLikeSequence = None,
clean_index_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeTuple[tp.AnyArray]:
"""Merge multiple array-like objects through concatenation.
Supports a sequence of tuples.
If `wrap` is None, it will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None.
If `wrap` is True, each array will be wrapped with Pandas Series and merged using `pd.concat`.
Otherwise, arrays will be kept as-is and merged using `concat_arrays`.
`wrap_kwargs` can be a dictionary or a list of dictionaries.
If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap_reduced`.
Keyword arguments `**kwargs` are passed to `pd.concat` only.
!!! note
All arrays are assumed to have the same type and dimensionality."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if isinstance(objs[0], tuple):
if len(objs[0]) == 1:
out_tuple = (
concat_merge(
list(map(lambda x: x[0], objs)),
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
)
else:
out_tuple = tuple(
map(
lambda x: concat_merge(
x,
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
zip(*objs),
)
)
if checks.is_namedtuple(objs[0]):
return type(objs[0])(*out_tuple)
return type(objs[0])(out_tuple)
if filter_results:
try:
objs, keys = filter_out_no_results(objs, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
if isinstance(objs[0], Wrapping):
raise TypeError("Concatenating Wrapping instances is not supported")
if wrap_kwargs is None:
wrap_kwargs = {}
if wrap is None:
wrap = isinstance(objs[0], pd.Series) or wrapper is not None or keys is not None or len(wrap_kwargs) > 0
if not checks.is_complex_iterable(objs[0]):
if wrap:
if keys is not None and isinstance(keys[0], pd.Index):
if len(keys) == 1:
keys = keys[0]
else:
keys = concat_indexes(
*keys,
index_concat_method="append",
clean_index_kwargs=clean_index_kwargs,
verify_integrity=False,
axis=0,
)
wrap_kwargs = merge_dicts(dict(index=keys), wrap_kwargs)
return pd.Series(objs, **wrap_kwargs)
return np.asarray(objs)
if isinstance(objs[0], pd.Index):
objs = list(map(lambda x: x.to_series(), objs))
default_index = True
if not isinstance(objs[0], pd.Series):
if isinstance(objs[0], pd.DataFrame):
raise ValueError("Use row stacking for concatenating DataFrames")
if wrap:
new_objs = []
for i, obj in enumerate(objs):
_wrap_kwargs = resolve_dict(wrap_kwargs, i)
if wrapper is not None:
if "force_1d" not in _wrap_kwargs:
_wrap_kwargs["force_1d"] = True
new_objs.append(wrapper.wrap_reduced(obj, **_wrap_kwargs))
else:
new_objs.append(pd.Series(obj, **_wrap_kwargs))
if default_index and not checks.is_default_index(new_objs[-1].index, check_names=True):
default_index = False
objs = new_objs
if not wrap:
return concat_arrays(objs)
if keys is not None and isinstance(keys[0], pd.Index):
new_obj = pd.concat(objs, axis=0, **kwargs)
if len(keys) == 1:
keys = keys[0]
else:
keys = concat_indexes(
*keys,
index_concat_method="append",
verify_integrity=False,
axis=0,
)
if default_index:
new_obj.index = keys
else:
new_obj.index = stack_indexes((keys, new_obj.index))
else:
new_obj = pd.concat(objs, axis=0, keys=keys, **kwargs)
if clean_index_kwargs is None:
clean_index_kwargs = {}
new_obj.index = clean_index(new_obj.index, **clean_index_kwargs)
return new_obj
def row_stack_merge(
*objs,
keys: tp.Optional[tp.Index] = None,
filter_results: bool = True,
raise_no_results: bool = True,
wrap: tp.Union[None, str, bool] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLikeSequence = None,
clean_index_kwargs: tp.KwargsLikeSequence = None,
**kwargs,
) -> tp.MaybeTuple[tp.AnyArray]:
"""Merge multiple array-like or `vectorbtpro.base.wrapping.Wrapping` objects through row stacking.
Supports a sequence of tuples.
Argument `wrap` supports the following options:
* None: will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None
* True: each array will be wrapped with Pandas Series/DataFrame (depending on dimensions)
* 'sr', 'series': each array will be wrapped with Pandas Series
* 'df', 'frame', 'dataframe': each array will be wrapped with Pandas DataFrame
Without wrapping, arrays will be kept as-is and merged using `row_stack_arrays`.
Argument `wrap_kwargs` can be a dictionary or a list of dictionaries.
If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap`.
Keyword arguments `**kwargs` are passed to `pd.concat` and
`vectorbtpro.base.wrapping.Wrapping.row_stack` only.
!!! note
All arrays are assumed to have the same type and dimensionality."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if isinstance(objs[0], tuple):
if len(objs[0]) == 1:
out_tuple = (
row_stack_merge(
list(map(lambda x: x[0], objs)),
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
)
else:
out_tuple = tuple(
map(
lambda x: row_stack_merge(
x,
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
zip(*objs),
)
)
if checks.is_namedtuple(objs[0]):
return type(objs[0])(*out_tuple)
return type(objs[0])(out_tuple)
if filter_results:
try:
objs, keys = filter_out_no_results(objs, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
if isinstance(objs[0], Wrapping):
kwargs = merge_dicts(dict(wrapper_kwargs=dict(keys=keys)), kwargs)
return type(objs[0]).row_stack(objs, **kwargs)
if wrap_kwargs is None:
wrap_kwargs = {}
if wrap is None:
wrap = (
isinstance(objs[0], (pd.Series, pd.DataFrame))
or wrapper is not None
or keys is not None
or len(wrap_kwargs) > 0
)
if isinstance(objs[0], pd.Index):
objs = list(map(lambda x: x.to_series(), objs))
default_index = True
if not isinstance(objs[0], (pd.Series, pd.DataFrame)):
if isinstance(wrap, str) or wrap:
new_objs = []
for i, obj in enumerate(objs):
_wrap_kwargs = resolve_dict(wrap_kwargs, i)
if wrapper is not None:
new_objs.append(wrapper.wrap(obj, **_wrap_kwargs))
else:
if not isinstance(wrap, str):
if isinstance(obj, np.ndarray):
ndim = obj.ndim
else:
ndim = np.asarray(obj).ndim
if ndim == 1:
wrap = "series"
else:
wrap = "frame"
if isinstance(wrap, str):
if wrap.lower() in ("sr", "series"):
new_objs.append(pd.Series(obj, **_wrap_kwargs))
elif wrap.lower() in ("df", "frame", "dataframe"):
new_objs.append(pd.DataFrame(obj, **_wrap_kwargs))
else:
raise ValueError(f"Invalid wrapping option: '{wrap}'")
if default_index and not checks.is_default_index(new_objs[-1].index, check_names=True):
default_index = False
objs = new_objs
if not wrap:
return row_stack_arrays(objs)
if keys is not None and isinstance(keys[0], pd.Index):
new_obj = pd.concat(objs, axis=0, **kwargs)
if len(keys) == 1:
keys = keys[0]
else:
keys = concat_indexes(
*keys,
index_concat_method="append",
verify_integrity=False,
axis=0,
)
if default_index:
new_obj.index = keys
else:
new_obj.index = stack_indexes((keys, new_obj.index))
else:
new_obj = pd.concat(objs, axis=0, keys=keys, **kwargs)
if clean_index_kwargs is None:
clean_index_kwargs = {}
new_obj.index = clean_index(new_obj.index, **clean_index_kwargs)
return new_obj
def column_stack_merge(
*objs,
reset_index: tp.Union[None, bool, str] = None,
fill_value: tp.Scalar = np.nan,
keys: tp.Optional[tp.Index] = None,
filter_results: bool = True,
raise_no_results: bool = True,
wrap: tp.Union[None, str, bool] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLikeSequence = None,
clean_index_kwargs: tp.KwargsLikeSequence = None,
**kwargs,
) -> tp.MaybeTuple[tp.AnyArray]:
"""Merge multiple array-like or `vectorbtpro.base.wrapping.Wrapping` objects through column stacking.
Supports a sequence of tuples.
Argument `wrap` supports the following options:
* None: will become True if `wrapper`, `keys`, or `wrap_kwargs` are not None
* True: each array will be wrapped with Pandas Series/DataFrame (depending on dimensions)
* 'sr', 'series': each array will be wrapped with Pandas Series
* 'df', 'frame', 'dataframe': each array will be wrapped with Pandas DataFrame
Without wrapping, arrays will be kept as-is and merged using `column_stack_arrays`.
Argument `wrap_kwargs` can be a dictionary or a list of dictionaries.
If `wrapper` is provided, will use `vectorbtpro.base.wrapping.ArrayWrapper.wrap`.
Keyword arguments `**kwargs` are passed to `pd.concat` and
`vectorbtpro.base.wrapping.Wrapping.column_stack` only.
Argument `reset_index` supports the following options:
* False or None: Keep original index of each object
* True or 'from_start': Reset index of each object and align them at start
* 'from_end': Reset index of each object and align them at end
Options above work on Pandas, NumPy, and `vectorbtpro.base.wrapping.Wrapping` instances.
!!! note
All arrays are assumed to have the same type and dimensionality."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if isinstance(reset_index, bool):
if reset_index:
reset_index = "from_start"
else:
reset_index = None
if isinstance(objs[0], tuple):
if len(objs[0]) == 1:
out_tuple = (
column_stack_merge(
list(map(lambda x: x[0], objs)),
reset_index=reset_index,
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
)
else:
out_tuple = tuple(
map(
lambda x: column_stack_merge(
x,
reset_index=reset_index,
keys=keys,
wrap=wrap,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
),
zip(*objs),
)
)
if checks.is_namedtuple(objs[0]):
return type(objs[0])(*out_tuple)
return type(objs[0])(out_tuple)
if filter_results:
try:
objs, keys = filter_out_no_results(objs, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
if isinstance(objs[0], Wrapping):
if reset_index is not None:
max_length = max(map(lambda x: x.wrapper.shape[0], objs))
new_objs = []
for obj in objs:
if isinstance(reset_index, str) and reset_index.lower() == "from_start":
new_index = pd.RangeIndex(stop=obj.wrapper.shape[0])
new_obj = obj.replace(wrapper=obj.wrapper.replace(index=new_index))
elif isinstance(reset_index, str) and reset_index.lower() == "from_end":
new_index = pd.RangeIndex(start=max_length - obj.wrapper.shape[0], stop=max_length)
new_obj = obj.replace(wrapper=obj.wrapper.replace(index=new_index))
else:
raise ValueError(f"Invalid index resetting option: '{reset_index}'")
new_objs.append(new_obj)
objs = new_objs
kwargs = merge_dicts(dict(wrapper_kwargs=dict(keys=keys)), kwargs)
return type(objs[0]).column_stack(objs, **kwargs)
if wrap_kwargs is None:
wrap_kwargs = {}
if wrap is None:
wrap = (
isinstance(objs[0], (pd.Series, pd.DataFrame))
or wrapper is not None
or keys is not None
or len(wrap_kwargs) > 0
)
if isinstance(objs[0], pd.Index):
objs = list(map(lambda x: x.to_series(), objs))
default_columns = True
if not isinstance(objs[0], (pd.Series, pd.DataFrame)):
if isinstance(wrap, str) or wrap:
new_objs = []
for i, obj in enumerate(objs):
_wrap_kwargs = resolve_dict(wrap_kwargs, i)
if wrapper is not None:
new_objs.append(wrapper.wrap(obj, **_wrap_kwargs))
else:
if not isinstance(wrap, str):
if isinstance(obj, np.ndarray):
ndim = obj.ndim
else:
ndim = np.asarray(obj).ndim
if ndim == 1:
wrap = "series"
else:
wrap = "frame"
if isinstance(wrap, str):
if wrap.lower() in ("sr", "series"):
new_objs.append(pd.Series(obj, **_wrap_kwargs))
elif wrap.lower() in ("df", "frame", "dataframe"):
new_objs.append(pd.DataFrame(obj, **_wrap_kwargs))
else:
raise ValueError(f"Invalid wrapping option: '{wrap}'")
if (
default_columns
and isinstance(new_objs[-1], pd.DataFrame)
and not checks.is_default_index(new_objs[-1].columns, check_names=True)
):
default_columns = False
objs = new_objs
if not wrap:
if reset_index is not None:
min_n_rows = None
max_n_rows = None
n_cols = 0
new_objs = []
for obj in objs:
new_obj = to_2d_array(obj)
new_objs.append(new_obj)
if min_n_rows is None or new_obj.shape[0] < min_n_rows:
min_n_rows = new_obj.shape[0]
if max_n_rows is None or new_obj.shape[0] > min_n_rows:
max_n_rows = new_obj.shape[0]
n_cols += new_obj.shape[1]
if min_n_rows == max_n_rows:
return column_stack_arrays(new_objs)
new_obj = np.full((max_n_rows, n_cols), fill_value)
start_col = 0
for obj in new_objs:
end_col = start_col + obj.shape[1]
if isinstance(reset_index, str) and reset_index.lower() == "from_start":
new_obj[: len(obj), start_col:end_col] = obj
elif isinstance(reset_index, str) and reset_index.lower() == "from_end":
new_obj[-len(obj) :, start_col:end_col] = obj
else:
raise ValueError(f"Invalid index resetting option: '{reset_index}'")
start_col = end_col
return new_obj
return column_stack_arrays(objs)
if reset_index is not None:
max_length = max(map(len, objs))
new_objs = []
for obj in objs:
new_obj = obj.copy(deep=False)
if isinstance(reset_index, str) and reset_index.lower() == "from_start":
new_obj.index = pd.RangeIndex(stop=len(new_obj))
elif isinstance(reset_index, str) and reset_index.lower() == "from_end":
new_obj.index = pd.RangeIndex(start=max_length - len(new_obj), stop=max_length)
else:
raise ValueError(f"Invalid index resetting option: '{reset_index}'")
new_objs.append(new_obj)
objs = new_objs
kwargs = merge_dicts(dict(sort=True), kwargs)
if keys is not None and isinstance(keys[0], pd.Index):
new_obj = pd.concat(objs, axis=1, **kwargs)
if len(keys) == 1:
keys = keys[0]
else:
keys = concat_indexes(
*keys,
index_concat_method="append",
verify_integrity=False,
axis=1,
)
if default_columns:
new_obj.columns = keys
else:
new_obj.columns = stack_indexes((keys, new_obj.columns))
else:
new_obj = pd.concat(objs, axis=1, keys=keys, **kwargs)
if clean_index_kwargs is None:
clean_index_kwargs = {}
new_obj.columns = clean_index(new_obj.columns, **clean_index_kwargs)
return new_obj
def imageio_merge(
*objs,
keys: tp.Optional[tp.Index] = None,
filter_results: bool = True,
raise_no_results: bool = True,
to_image_kwargs: tp.KwargsLike = None,
imread_kwargs: tp.KwargsLike = None,
**imwrite_kwargs,
) -> tp.MaybeTuple[tp.Union[None, bytes]]:
"""Merge multiple figure-like objects by writing them with `imageio`.
Keyword arguments `to_image_kwargs` are passed to `plotly.graph_objects.Figure.to_image`.
Keyword arguments `imread_kwargs` and `**imwrite_kwargs` are passed to
`imageio.imread` and `imageio.imwrite` respectively.
Keys are not used in any way."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
import imageio.v3 as iio
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if isinstance(objs[0], tuple):
if len(objs[0]) == 1:
out_tuple = (
imageio_merge(
list(map(lambda x: x[0], objs)),
keys=keys,
imread_kwargs=imread_kwargs,
to_image_kwargs=to_image_kwargs,
**imwrite_kwargs,
),
)
else:
out_tuple = tuple(
map(
lambda x: imageio_merge(
x,
keys=keys,
imread_kwargs=imread_kwargs,
to_image_kwargs=to_image_kwargs,
**imwrite_kwargs,
),
zip(*objs),
)
)
if checks.is_namedtuple(objs[0]):
return type(objs[0])(*out_tuple)
return type(objs[0])(out_tuple)
if filter_results:
try:
objs, keys = filter_out_no_results(objs, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
if imread_kwargs is None:
imread_kwargs = {}
if to_image_kwargs is None:
to_image_kwargs = {}
frames = []
for obj in objs:
if obj is None:
continue
if isinstance(obj, (go.Figure, go.FigureWidget)):
obj = obj.to_image(**to_image_kwargs)
if not isinstance(obj, np.ndarray):
obj = iio.imread(obj, **imread_kwargs)
frames.append(obj)
return iio.imwrite(image=frames, **imwrite_kwargs)
def mixed_merge(
*objs,
merge_funcs: tp.Optional[tp.MergeFuncLike] = None,
mixed_kwargs: tp.Optional[tp.Sequence[tp.KwargsLike]] = None,
**kwargs,
) -> tp.MaybeTuple[tp.AnyArray]:
"""Merge objects of mixed types."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if len(objs) == 0:
raise ValueError("No objects to be merged")
if merge_funcs is None:
raise ValueError("Merging functions or their names are required")
if not isinstance(objs[0], tuple):
raise ValueError("Mixed merging must be applied on tuples")
outputs = []
for i, output_objs in enumerate(zip(*objs)):
output_objs = list(output_objs)
merge_func = resolve_merge_func(merge_funcs[i])
if merge_func is None:
outputs.append(output_objs)
else:
if mixed_kwargs is None:
_kwargs = kwargs
else:
_kwargs = merge_dicts(kwargs, mixed_kwargs[i])
output = merge_func(output_objs, **_kwargs)
outputs.append(output)
return tuple(outputs)
merge_func_config = HybridConfig(
dict(
concat=concat_merge,
row_stack=row_stack_merge,
column_stack=column_stack_merge,
reset_column_stack=partial(column_stack_merge, reset_index=True),
from_start_column_stack=partial(column_stack_merge, reset_index="from_start"),
from_end_column_stack=partial(column_stack_merge, reset_index="from_end"),
imageio=imageio_merge,
)
)
"""_"""
__pdoc__[
"merge_func_config"
] = f"""Config for merging functions.
```python
{merge_func_config.prettify()}
```
"""
def resolve_merge_func(merge_func: tp.MergeFuncLike) -> tp.Optional[tp.Callable]:
"""Resolve a merging function into a callable.
If a string, looks up into `merge_func_config`. If a sequence, uses `mixed_merge` with
`merge_funcs=merge_func`. If an instance of `vectorbtpro.utils.merging.MergeFunc`, calls
`vectorbtpro.utils.merging.MergeFunc.resolve_merge_func` to get the actual callable."""
if merge_func is None:
return None
if isinstance(merge_func, str):
if merge_func.lower() not in merge_func_config:
raise ValueError(f"Invalid merging function name: '{merge_func}'")
return merge_func_config[merge_func.lower()]
if checks.is_sequence(merge_func):
return partial(mixed_merge, merge_funcs=merge_func)
if isinstance(merge_func, MergeFunc):
return merge_func.resolve_merge_func()
return merge_func
def is_merge_func_from_config(merge_func: tp.MergeFuncLike) -> bool:
"""Return whether the merging function can be found in `merge_func_config`."""
if merge_func is None:
return False
if isinstance(merge_func, str):
return merge_func.lower() in merge_func_config
if checks.is_sequence(merge_func):
return all(map(is_merge_func_from_config, merge_func))
if isinstance(merge_func, MergeFunc):
return is_merge_func_from_config(merge_func.merge_func)
return False
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for preparing arguments."""
import inspect
import string
from collections import defaultdict
from datetime import timedelta, time
from functools import cached_property as cachedproperty
from pathlib import Path
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.decorators import override_arg_config, attach_arg_properties
from vectorbtpro.base.indexes import repeat_index
from vectorbtpro.base.indexing import index_dict, IdxSetter, IdxSetterFactory, IdxRecords
from vectorbtpro.base.merging import concat_arrays, column_stack_arrays
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.reshaping import BCO, Default, Ref, broadcast
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.attr_ import get_dict_attr
from vectorbtpro.utils.config import Configured
from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig
from vectorbtpro.utils.cutting import suggest_module_path, cut_and_save_func
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.module_ import import_module_from_path
from vectorbtpro.utils.params import Param
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.path_ import remove_dir
from vectorbtpro.utils.random_ import set_seed
from vectorbtpro.utils.template import CustomTemplate, RepFunc, substitute_templates
__all__ = [
"BasePreparer",
]
__pdoc__ = {}
base_arg_config = ReadonlyConfig(
dict(
broadcast_named_args=dict(is_dict=True),
broadcast_kwargs=dict(is_dict=True),
template_context=dict(is_dict=True),
seed=dict(),
jitted=dict(),
chunked=dict(),
staticized=dict(),
records=dict(),
)
)
"""_"""
__pdoc__[
"base_arg_config"
] = f"""Argument config for `BasePreparer`.
```python
{base_arg_config.prettify()}
```
"""
class MetaBasePreparer(type(Configured)):
"""Metaclass for `BasePreparer`."""
@property
def arg_config(cls) -> Config:
"""Argument config."""
return cls._arg_config
@attach_arg_properties
@override_arg_config(base_arg_config)
class BasePreparer(Configured, metaclass=MetaBasePreparer):
"""Base class for preparing target functions and arguments.
!!! warning
Most properties are force-cached - create a new instance to override any attribute."""
_expected_keys_mode: tp.ExpectedKeysMode = "disable"
_writeable_attrs: tp.WriteableAttrs = {"_arg_config"}
_settings_path: tp.SettingsPath = None
def __init__(self, arg_config: tp.KwargsLike = None, **kwargs) -> None:
Configured.__init__(self, arg_config=arg_config, **kwargs)
# Copy writeable attrs
self._arg_config = type(self)._arg_config.copy()
if arg_config is not None:
self._arg_config = merge_dicts(self._arg_config, arg_config)
_arg_config: tp.ClassVar[Config] = HybridConfig()
@property
def arg_config(self) -> Config:
"""Argument config of `${cls_name}`.
```python
${arg_config}
```
"""
return self._arg_config
@classmethod
def map_enum_value(cls, value: tp.ArrayLike, look_for_type: tp.Optional[type] = None, **kwargs) -> tp.ArrayLike:
"""Map enumerated value(s)."""
if look_for_type is not None:
if isinstance(value, look_for_type):
return map_enum_fields(value, **kwargs)
return value
if isinstance(value, (CustomTemplate, Ref)):
return value
if isinstance(value, (Param, BCO, Default)):
attr_dct = value.asdict()
if isinstance(value, Param) and attr_dct["map_template"] is None:
attr_dct["map_template"] = RepFunc(lambda values: cls.map_enum_value(values, **kwargs))
elif not isinstance(value, Param):
attr_dct["value"] = cls.map_enum_value(attr_dct["value"], **kwargs)
return type(value)(**attr_dct)
if isinstance(value, index_dict):
return index_dict({k: cls.map_enum_value(v, **kwargs) for k, v in value.items()})
if isinstance(value, IdxSetterFactory):
value = value.get()
if not isinstance(value, IdxSetter):
raise ValueError("Index setter factory must return exactly one index setter")
if isinstance(value, IdxSetter):
return IdxSetter([(k, cls.map_enum_value(v, **kwargs)) for k, v in value.idx_items])
return map_enum_fields(value, **kwargs)
@classmethod
def prepare_td_obj(cls, td_obj: object, old_as_keys: bool = True) -> object:
"""Prepare a timedelta object for broadcasting."""
if isinstance(td_obj, Param):
return td_obj.map_value(cls.prepare_td_obj, old_as_keys=old_as_keys)
if isinstance(td_obj, (str, timedelta, pd.DateOffset, pd.Timedelta)):
td_obj = dt.to_timedelta64(td_obj)
elif isinstance(td_obj, pd.Index):
td_obj = td_obj.values
return td_obj
@classmethod
def prepare_dt_obj(
cls,
dt_obj: object,
old_as_keys: bool = True,
last_before: tp.Optional[bool] = None,
) -> object:
"""Prepare a datetime object for broadcasting."""
if isinstance(dt_obj, Param):
return dt_obj.map_value(cls.prepare_dt_obj, old_as_keys=old_as_keys)
if isinstance(dt_obj, (str, time, timedelta, pd.DateOffset, pd.Timedelta)):
def _apply_last_before(source_index, target_index, source_freq):
resampler = Resampler(source_index, target_index, source_freq=source_freq)
last_indices = resampler.last_before_target_index(incl_source=False)
source_rbound_ns = resampler.source_rbound_index.vbt.to_ns()
return np.where(last_indices != -1, source_rbound_ns[last_indices], -1)
def _to_dt(wrapper, _dt_obj=dt_obj, _last_before=last_before):
if _last_before is None:
_last_before = False
_dt_obj = dt.try_align_dt_to_index(_dt_obj, wrapper.index)
source_index = wrapper.index[wrapper.index < _dt_obj]
target_index = repeat_index(pd.Index([_dt_obj]), len(source_index))
if _last_before:
target_ns = _apply_last_before(source_index, target_index, wrapper.freq)
else:
target_ns = target_index.vbt.to_ns()
if len(target_ns) < len(wrapper.index):
target_ns = concat_arrays((target_ns, np.full(len(wrapper.index) - len(target_ns), -1)))
return target_ns
def _to_td(wrapper, _dt_obj=dt_obj, _last_before=last_before):
if _last_before is None:
_last_before = True
target_index = wrapper.index.vbt.to_period_ts(dt.to_freq(_dt_obj), shift=True)
if _last_before:
return _apply_last_before(wrapper.index, target_index, wrapper.freq)
return target_index.vbt.to_ns()
def _to_time(wrapper, _dt_obj=dt_obj, _last_before=last_before):
if _last_before is None:
_last_before = False
floor_index = wrapper.index.floor("1d") + dt.time_to_timedelta(_dt_obj)
target_index = floor_index.where(wrapper.index < floor_index, floor_index + pd.Timedelta(days=1))
if _last_before:
return _apply_last_before(wrapper.index, target_index, wrapper.freq)
return target_index.vbt.to_ns()
dt_obj_dt_template = RepFunc(_to_dt)
dt_obj_td_template = RepFunc(_to_td)
dt_obj_time_template = RepFunc(_to_time)
if isinstance(dt_obj, str):
try:
time.fromisoformat(dt_obj)
dt_obj = dt_obj_time_template
except Exception as e:
try:
dt.to_freq(dt_obj)
dt_obj = dt_obj_td_template
except Exception as e:
dt_obj = dt_obj_dt_template
elif isinstance(dt_obj, time):
dt_obj = dt_obj_time_template
else:
dt_obj = dt_obj_td_template
elif isinstance(dt_obj, pd.Index):
dt_obj = dt_obj.values
return dt_obj
def get_raw_arg_default(self, arg_name: str, is_dict: bool = False) -> tp.Any:
"""Get raw argument default."""
if self._settings_path is None:
if is_dict:
return {}
return None
value = self.get_setting(arg_name)
if is_dict and value is None:
return {}
return value
def get_raw_arg(self, arg_name: str, is_dict: bool = False, has_default: bool = True) -> tp.Any:
"""Get raw argument."""
value = self.config.get(arg_name, None)
if is_dict:
if has_default:
return merge_dicts(self.get_raw_arg_default(arg_name), value)
if value is None:
return {}
return value
if value is None and has_default:
return self.get_raw_arg_default(arg_name)
return value
@cachedproperty
def idx_setters(self) -> tp.Optional[tp.Dict[tp.Label, IdxSetter]]:
"""Index setters from resolving the argument `records`."""
arg_config = self.arg_config["records"]
records = self.get_raw_arg(
"records",
is_dict=arg_config.get("is_dict", False),
has_default=arg_config.get("has_default", True),
)
if records is None:
return None
if not isinstance(records, IdxRecords):
records = IdxRecords(records)
idx_setters = records.get()
for k in idx_setters:
if k in self.arg_config and not self.arg_config[k].get("broadcast", False):
raise ValueError(f"Field {k} is not broadcastable and cannot be included in records")
rename_fields = arg_config.get("rename_fields", {})
new_idx_setters = {}
for k, v in idx_setters.items():
if k in rename_fields:
k = rename_fields[k]
new_idx_setters[k] = v
return new_idx_setters
def get_arg_default(self, arg_name: str) -> tp.Any:
"""Get argument default according to the argument config."""
arg_config = self.arg_config[arg_name]
arg = self.get_raw_arg_default(
arg_name,
is_dict=arg_config.get("is_dict", False),
)
if arg is not None:
if len(arg_config.get("map_enum_kwargs", {})) > 0:
arg = self.map_enum_value(arg, **arg_config["map_enum_kwargs"])
if arg_config.get("is_td", False):
arg = self.prepare_td_obj(
arg,
old_as_keys=arg_config.get("old_as_keys", True),
)
if arg_config.get("is_dt", False):
arg = self.prepare_dt_obj(
arg,
old_as_keys=arg_config.get("old_as_keys", True),
last_before=arg_config.get("last_before", None),
)
return arg
def get_arg(self, arg_name: str, use_idx_setter: bool = True, use_default: bool = True) -> tp.Any:
"""Get mapped argument according to the argument config."""
arg_config = self.arg_config[arg_name]
if use_idx_setter and self.idx_setters is not None and arg_name in self.idx_setters:
arg = self.idx_setters[arg_name]
else:
arg = self.get_raw_arg(
arg_name,
is_dict=arg_config.get("is_dict", False),
has_default=arg_config.get("has_default", True) if use_default else False,
)
if arg is not None:
if len(arg_config.get("map_enum_kwargs", {})) > 0:
arg = self.map_enum_value(arg, **arg_config["map_enum_kwargs"])
if arg_config.get("is_td", False):
arg = self.prepare_td_obj(arg)
if arg_config.get("is_dt", False):
arg = self.prepare_dt_obj(arg, last_before=arg_config.get("last_before", None))
return arg
def __getitem__(self, arg_name) -> tp.Any:
return self.get_arg(arg_name)
def __iter__(self):
raise TypeError(f"'{type(self).__name__}' object is not iterable")
@classmethod
def prepare_td_arr(cls, td_arr: tp.ArrayLike) -> tp.ArrayLike:
"""Prepare a timedelta array."""
if td_arr.dtype == object:
if td_arr.ndim in (0, 1):
td_arr = pd.to_timedelta(td_arr)
if isinstance(td_arr, pd.Timedelta):
td_arr = td_arr.to_timedelta64()
else:
td_arr = td_arr.values
else:
td_arr_cols = []
for col in range(td_arr.shape[1]):
td_arr_col = pd.to_timedelta(td_arr[:, col])
td_arr_cols.append(td_arr_col.values)
td_arr = column_stack_arrays(td_arr_cols)
return td_arr
@classmethod
def prepare_dt_arr(cls, dt_arr: tp.ArrayLike) -> tp.ArrayLike:
"""Prepare a datetime array."""
if dt_arr.dtype == object:
if dt_arr.ndim in (0, 1):
dt_arr = pd.to_datetime(dt_arr).tz_localize(None)
if isinstance(dt_arr, pd.Timestamp):
dt_arr = dt_arr.to_datetime64()
else:
dt_arr = dt_arr.values
else:
dt_arr_cols = []
for col in range(dt_arr.shape[1]):
dt_arr_col = pd.to_datetime(dt_arr[:, col]).tz_localize(None)
dt_arr_cols.append(dt_arr_col.values)
dt_arr = column_stack_arrays(dt_arr_cols)
return dt_arr
@classmethod
def td_arr_to_ns(cls, td_arr: tp.ArrayLike) -> tp.ArrayLike:
"""Prepare a timedelta array and convert it to nanoseconds."""
return dt.to_ns(cls.prepare_td_arr(td_arr))
@classmethod
def dt_arr_to_ns(cls, dt_arr: tp.ArrayLike) -> tp.ArrayLike:
"""Prepare a datetime array and convert it to nanoseconds."""
return dt.to_ns(cls.prepare_dt_arr(dt_arr))
def prepare_post_arg(self, arg_name: str, value: tp.Optional[tp.ArrayLike] = None) -> object:
"""Prepare an argument after broadcasting and/or template substitution."""
if value is None:
if arg_name in self.post_args:
arg = self.post_args[arg_name]
else:
arg = getattr(self, "_pre_" + arg_name)
else:
arg = value
if arg is not None:
arg_config = self.arg_config[arg_name]
if arg_config.get("substitute_templates", False):
arg = substitute_templates(arg, self.template_context, eval_id=arg_name)
if "map_enum_kwargs" in arg_config:
arg = map_enum_fields(arg, **arg_config["map_enum_kwargs"])
if arg_config.get("is_td", False):
arg = self.td_arr_to_ns(arg)
if arg_config.get("is_dt", False):
arg = self.dt_arr_to_ns(arg)
if "type" in arg_config:
checks.assert_instance_of(arg, arg_config["type"], arg_name=arg_name)
if "subdtype" in arg_config:
checks.assert_subdtype(arg, arg_config["subdtype"], arg_name=arg_name)
return arg
@classmethod
def adapt_staticized_to_udf(cls, staticized: tp.Kwargs, func: tp.Union[str, tp.Callable], func_name: str) -> None:
"""Adapt `staticized` dictionary to a UDF."""
target_func_module = inspect.getmodule(staticized["func"])
if isinstance(func, tuple):
func, actual_func_name = func
else:
actual_func_name = None
if isinstance(func, (str, Path)):
if actual_func_name is None:
actual_func_name = func_name
if isinstance(func, str) and not func.endswith(".py") and hasattr(target_func_module, func):
staticized[f"{func_name}_block"] = func
return None
func = Path(func)
module_path = func.resolve()
else:
if actual_func_name is None:
actual_func_name = func.__name__
if inspect.getmodule(func) == target_func_module:
staticized[f"{func_name}_block"] = actual_func_name
return None
module = inspect.getmodule(func)
if not hasattr(module, "__file__"):
raise TypeError(f"{func_name} must be defined in a Python file")
module_path = Path(module.__file__).resolve()
if "import_lines" not in staticized:
staticized["import_lines"] = []
reload = staticized.get("reload", False)
staticized["import_lines"].extend(
[
f'{func_name}_path = r"{module_path}"',
f"globals().update(vbt.import_module_from_path({func_name}_path).__dict__, reload={reload})",
]
)
if actual_func_name != func_name:
staticized["import_lines"].append(f"{func_name} = {actual_func_name}")
@classmethod
def find_target_func(cls, target_func_name: str) -> tp.Callable:
"""Find target function by its name."""
raise NotImplementedError
@classmethod
def resolve_dynamic_target_func(cls, target_func_name: str, staticized: tp.KwargsLike) -> tp.Callable:
"""Resolve a dynamic target function."""
if staticized is None:
func = cls.find_target_func(target_func_name)
else:
if isinstance(staticized, dict):
staticized = dict(staticized)
module_path = suggest_module_path(
staticized.get("suggest_fname", target_func_name),
path=staticized.pop("path", None),
mkdir_kwargs=staticized.get("mkdir_kwargs", None),
)
if "new_func_name" not in staticized:
staticized["new_func_name"] = target_func_name
if staticized.pop("override", False) or not module_path.exists():
if "skip_func" not in staticized:
def _skip_func(out_lines, func_name):
to_skip = lambda x: f"def {func_name}" in x or x.startswith(f"{func_name}_path =")
return any(map(to_skip, out_lines))
staticized["skip_func"] = _skip_func
module_path = cut_and_save_func(path=module_path, **staticized)
if staticized.get("clear_cache", True):
remove_dir(module_path.parent / "__pycache__", with_contents=True, missing_ok=True)
reload = staticized.pop("reload", False)
module = import_module_from_path(module_path, reload=reload)
func = getattr(module, staticized["new_func_name"])
else:
func = staticized
return func
def set_seed(self) -> None:
"""Set seed."""
seed = self.seed
if seed is not None:
set_seed(seed)
# ############# Before broadcasting ############# #
@cachedproperty
def _pre_template_context(self) -> tp.Kwargs:
"""Argument `template_context` before broadcasting."""
return merge_dicts(dict(preparer=self), self["template_context"])
# ############# Broadcasting ############# #
@cachedproperty
def pre_args(self) -> tp.Kwargs:
"""Arguments before broadcasting."""
pre_args = dict()
for k, v in self.arg_config.items():
if v.get("broadcast", False):
pre_args[k] = getattr(self, "_pre_" + k)
return pre_args
@cachedproperty
def args_to_broadcast(self) -> dict:
"""Arguments to broadcast."""
return merge_dicts(self.idx_setters, self.pre_args, self.broadcast_named_args)
@cachedproperty
def def_broadcast_kwargs(self) -> tp.Kwargs:
"""Default keyword arguments for broadcasting."""
return dict(
to_pd=False,
keep_flex=dict(cash_earnings=self.keep_inout_flex, _def=True),
wrapper_kwargs=dict(
freq=self._pre_freq,
group_by=self.group_by,
),
return_wrapper=True,
template_context=self._pre_template_context,
)
@cachedproperty
def broadcast_kwargs(self) -> tp.Kwargs:
"""Argument `broadcast_kwargs`."""
arg_broadcast_kwargs = defaultdict(dict)
for k, v in self.arg_config.items():
if v.get("broadcast", False):
broadcast_kwargs = v.get("broadcast_kwargs", None)
if broadcast_kwargs is None:
broadcast_kwargs = {}
for k2, v2 in broadcast_kwargs.items():
arg_broadcast_kwargs[k2][k] = v2
for k in self.args_to_broadcast:
new_fill_value = None
if k in self.pre_args:
fill_default = self.arg_config[k].get("fill_default", True)
if self.idx_setters is not None and k in self.idx_setters:
new_fill_value = self.get_arg(k, use_idx_setter=False, use_default=fill_default)
elif fill_default and self.arg_config[k].get("has_default", True):
new_fill_value = self.get_arg_default(k)
elif k in self.broadcast_named_args:
if self.idx_setters is not None and k in self.idx_setters:
new_fill_value = self.broadcast_named_args[k]
if new_fill_value is not None:
if not np.isscalar(new_fill_value):
raise TypeError(f"Argument '{k}' (and its default) must be a scalar when also provided via records")
if "reindex_kwargs" not in arg_broadcast_kwargs:
arg_broadcast_kwargs["reindex_kwargs"] = {}
if k not in arg_broadcast_kwargs["reindex_kwargs"]:
arg_broadcast_kwargs["reindex_kwargs"][k] = {}
arg_broadcast_kwargs["reindex_kwargs"][k]["fill_value"] = new_fill_value
return merge_dicts(
self.def_broadcast_kwargs,
dict(arg_broadcast_kwargs),
self["broadcast_kwargs"],
)
@cachedproperty
def broadcast_result(self) -> tp.Any:
"""Result of broadcasting."""
return broadcast(self.args_to_broadcast, **self.broadcast_kwargs)
@cachedproperty
def post_args(self) -> tp.Kwargs:
"""Arguments after broadcasting."""
return self.broadcast_result[0]
@cachedproperty
def post_broadcast_named_args(self) -> tp.Kwargs:
"""Custom arguments after broadcasting."""
if self.broadcast_named_args is None:
return dict()
post_broadcast_named_args = dict()
for k, v in self.post_args.items():
if k in self.broadcast_named_args:
post_broadcast_named_args[k] = v
elif self.idx_setters is not None and k in self.idx_setters and k not in self.pre_args:
post_broadcast_named_args[k] = v
return post_broadcast_named_args
@cachedproperty
def wrapper(self) -> ArrayWrapper:
"""Array wrapper."""
return self.broadcast_result[1]
@cachedproperty
def target_shape(self) -> tp.Shape:
"""Target shape."""
return self.wrapper.shape_2d
@cachedproperty
def index(self) -> tp.Array1d:
"""Index in nanosecond format."""
return self.wrapper.ns_index
@cachedproperty
def freq(self) -> int:
"""Frequency in nanosecond format."""
return self.wrapper.ns_freq
# ############# Template substitution ############# #
@cachedproperty
def template_context(self) -> tp.Kwargs:
"""Argument `template_context`."""
builtin_args = {}
for k, v in self.arg_config.items():
if v.get("broadcast", False):
builtin_args[k] = getattr(self, k)
return merge_dicts(
dict(
wrapper=self.wrapper,
target_shape=self.target_shape,
index=self.index,
freq=self.freq,
),
builtin_args,
self.post_broadcast_named_args,
self._pre_template_context,
)
# ############# Result ############# #
@cachedproperty
def target_func(self) -> tp.Optional[tp.Callable]:
"""Target function."""
return None
@cachedproperty
def target_arg_map(self) -> tp.Kwargs:
"""Map of the target arguments to the preparer attributes."""
return dict()
@cachedproperty
def target_args(self) -> tp.Optional[tp.Kwargs]:
"""Arguments to be passed to the target function."""
if self.target_func is not None:
target_arg_map = self.target_arg_map
func_arg_names = get_func_arg_names(self.target_func)
target_args = {}
for k in func_arg_names:
arg_attr = target_arg_map.get(k, k)
if arg_attr is not None and hasattr(self, arg_attr):
target_args[k] = getattr(self, arg_attr)
return target_args
return None
# ############# Docs ############# #
@classmethod
def build_arg_config_doc(cls, source_cls: tp.Optional[type] = None) -> str:
"""Build argument config documentation."""
if source_cls is None:
source_cls = BasePreparer
return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "arg_config").__doc__)).substitute(
{"arg_config": cls.arg_config.prettify(), "cls_name": cls.__name__},
)
@classmethod
def override_arg_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None:
"""Call this method on each subclass that overrides `BasePreparer.arg_config`."""
__pdoc__[cls.__name__ + ".arg_config"] = cls.build_arg_config_doc(source_cls=source_cls)
BasePreparer.override_arg_config_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for reshaping arrays.
Reshape functions transform a Pandas object/NumPy array in some way."""
import functools
import itertools
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base import indexes, wrapping, indexing
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.config import resolve_dict, merge_dicts
from vectorbtpro.utils.params import combine_params, Param
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.template import CustomTemplate
__all__ = [
"to_1d_shape",
"to_2d_shape",
"repeat_shape",
"tile_shape",
"to_1d_array",
"to_2d_array",
"to_2d_pr_array",
"to_2d_pc_array",
"to_1d_array_nb",
"to_2d_array_nb",
"to_2d_pr_array_nb",
"to_2d_pc_array_nb",
"broadcast_shapes",
"broadcast_array_to",
"broadcast_arrays",
"repeat",
"tile",
"align_pd_arrays",
"BCO",
"Default",
"Ref",
"broadcast",
"broadcast_to",
]
def to_tuple_shape(shape: tp.ShapeLike) -> tp.Shape:
"""Convert a shape-like object to a tuple."""
if checks.is_int(shape):
return (int(shape),)
return tuple(shape)
def to_1d_shape(shape: tp.ShapeLike) -> tp.Shape:
"""Convert a shape-like object to a 1-dim shape."""
shape = to_tuple_shape(shape)
if len(shape) == 0:
return (1,)
if len(shape) == 1:
return shape
if len(shape) == 2 and shape[1] == 1:
return (shape[0],)
raise ValueError(f"Cannot reshape a {len(shape)}-dimensional shape to 1 dimension")
def to_2d_shape(shape: tp.ShapeLike, expand_axis: int = 1) -> tp.Shape:
"""Convert a shape-like object to a 2-dim shape."""
shape = to_tuple_shape(shape)
if len(shape) == 0:
return 1, 1
if len(shape) == 1:
if expand_axis == 1:
return shape[0], 1
else:
return shape[0], 0
if len(shape) == 2:
return shape
raise ValueError(f"Cannot reshape a {len(shape)}-dimensional shape to 2 dimensions")
def repeat_shape(shape: tp.ShapeLike, n: int, axis: int = 1) -> tp.Shape:
"""Repeat shape `n` times along the specified axis."""
shape = to_tuple_shape(shape)
if len(shape) <= axis:
shape = tuple([shape[i] if i < len(shape) else 1 for i in range(axis + 1)])
return *shape[:axis], shape[axis] * n, *shape[axis + 1 :]
def tile_shape(shape: tp.ShapeLike, n: int, axis: int = 1) -> tp.Shape:
"""Tile shape `n` times along the specified axis.
Identical to `repeat_shape`. Exists purely for naming consistency."""
return repeat_shape(shape, n, axis=axis)
def index_to_series(obj: tp.Index, reset_index: bool = False) -> tp.Series:
"""Convert Index to Series."""
if reset_index:
return obj.to_series(index=pd.RangeIndex(stop=len(obj)))
return obj.to_series()
def index_to_frame(obj: tp.Index, reset_index: bool = False) -> tp.Frame:
"""Convert Index to DataFrame."""
if not isinstance(obj, pd.MultiIndex):
return index_to_series(obj, reset_index=reset_index).to_frame()
return obj.to_frame(index=not reset_index)
def mapping_to_series(obj: tp.MappingLike) -> tp.Series:
"""Convert a mapping-like object to Series."""
if checks.is_namedtuple(obj):
obj = obj._asdict()
return pd.Series(obj)
def to_any_array(obj: tp.ArrayLike, raw: bool = False, convert_index: bool = True) -> tp.AnyArray:
"""Convert any array-like object to an array.
Pandas objects are kept as-is unless `raw` is True."""
from vectorbtpro.indicators.factory import IndicatorBase
if isinstance(obj, IndicatorBase):
obj = obj.main_output
if not raw:
if checks.is_any_array(obj):
if convert_index and checks.is_index(obj):
return index_to_series(obj)
return obj
if checks.is_mapping_like(obj):
return mapping_to_series(obj)
return np.asarray(obj)
def to_pd_array(obj: tp.ArrayLike, convert_index: bool = True) -> tp.PandasArray:
"""Convert any array-like object to a Pandas object."""
from vectorbtpro.indicators.factory import IndicatorBase
if isinstance(obj, IndicatorBase):
obj = obj.main_output
if checks.is_pandas(obj):
if convert_index and checks.is_index(obj):
return index_to_series(obj)
return obj
if checks.is_mapping_like(obj):
return mapping_to_series(obj)
obj = np.asarray(obj)
if obj.ndim == 0:
obj = obj[None]
if obj.ndim == 1:
return pd.Series(obj)
if obj.ndim == 2:
return pd.DataFrame(obj)
raise ValueError("Wrong number of dimensions: cannot convert to Series or DataFrame")
def soft_to_ndim(obj: tp.ArrayLike, ndim: int, raw: bool = False) -> tp.AnyArray:
"""Try to softly bring `obj` to the specified number of dimensions `ndim` (max 2)."""
obj = to_any_array(obj, raw=raw)
if ndim == 1:
if obj.ndim == 2:
if obj.shape[1] == 1:
if checks.is_frame(obj):
return obj.iloc[:, 0]
return obj[:, 0] # downgrade
if ndim == 2:
if obj.ndim == 1:
if checks.is_series(obj):
return obj.to_frame()
return obj[:, None] # upgrade
return obj # do nothing
def to_1d(obj: tp.ArrayLike, raw: bool = False) -> tp.AnyArray1d:
"""Reshape argument to one dimension.
If `raw` is True, returns NumPy array.
If 2-dim, will collapse along axis 1 (i.e., DataFrame with one column to Series)."""
obj = to_any_array(obj, raw=raw)
if obj.ndim == 2:
if obj.shape[1] == 1:
if checks.is_frame(obj):
return obj.iloc[:, 0]
return obj[:, 0]
if obj.ndim == 1:
return obj
elif obj.ndim == 0:
return obj.reshape((1,))
raise ValueError(f"Cannot reshape a {obj.ndim}-dimensional array to 1 dimension")
to_1d_array = functools.partial(to_1d, raw=True)
"""`to_1d` with `raw` enabled."""
def to_2d(obj: tp.ArrayLike, raw: bool = False, expand_axis: int = 1) -> tp.AnyArray2d:
"""Reshape argument to two dimensions.
If `raw` is True, returns NumPy array.
If 1-dim, will expand along axis 1 (i.e., Series to DataFrame with one column)."""
obj = to_any_array(obj, raw=raw)
if obj.ndim == 2:
return obj
elif obj.ndim == 1:
if checks.is_series(obj):
if expand_axis == 0:
return pd.DataFrame(obj.values[None, :], columns=obj.index)
elif expand_axis == 1:
return obj.to_frame()
return np.expand_dims(obj, expand_axis)
elif obj.ndim == 0:
return obj.reshape((1, 1))
raise ValueError(f"Cannot reshape a {obj.ndim}-dimensional array to 2 dimensions")
to_2d_array = functools.partial(to_2d, raw=True)
"""`to_2d` with `raw` enabled."""
to_2d_pr_array = functools.partial(to_2d_array, expand_axis=1)
"""`to_2d_array` with `expand_axis=1`."""
to_2d_pc_array = functools.partial(to_2d_array, expand_axis=0)
"""`to_2d_array` with `expand_axis=0`."""
@register_jitted(cache=True)
def to_1d_array_nb(obj: tp.Array) -> tp.Array1d:
"""Resize array to one dimension."""
if obj.ndim == 0:
return np.expand_dims(obj, axis=0)
if obj.ndim == 1:
return obj
if obj.ndim == 2 and obj.shape[1] == 1:
return obj[:, 0]
raise ValueError("Array cannot be resized to one dimension")
@register_jitted(cache=True)
def to_2d_array_nb(obj: tp.Array, expand_axis: int = 1) -> tp.Array2d:
"""Resize array to two dimensions."""
if obj.ndim == 0:
return np.expand_dims(np.expand_dims(obj, axis=0), axis=0)
if obj.ndim == 1:
return np.expand_dims(obj, axis=expand_axis)
if obj.ndim == 2:
return obj
raise ValueError("Array cannot be resized to two dimensions")
@register_jitted(cache=True)
def to_2d_pr_array_nb(obj: tp.Array) -> tp.Array2d:
"""`to_2d_array_nb` with `expand_axis=1`."""
return to_2d_array_nb(obj, expand_axis=1)
@register_jitted(cache=True)
def to_2d_pc_array_nb(obj: tp.Array) -> tp.Array2d:
"""`to_2d_array_nb` with `expand_axis=0`."""
return to_2d_array_nb(obj, expand_axis=0)
def to_dict(obj: tp.ArrayLike, orient: str = "dict") -> dict:
"""Convert object to dict."""
obj = to_pd_array(obj)
if orient == "index_series":
return {obj.index[i]: obj.iloc[i] for i in range(len(obj.index))}
return obj.to_dict(orient)
def repeat(
obj: tp.ArrayLike,
n: int,
axis: int = 1,
raw: bool = False,
ignore_ranges: tp.Optional[bool] = None,
) -> tp.AnyArray:
"""Repeat `obj` `n` times along the specified axis."""
obj = to_any_array(obj, raw=raw)
if axis == 0:
if checks.is_pandas(obj):
new_index = indexes.repeat_index(obj.index, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.repeat(obj.values, n, axis=0), index=new_index)
return np.repeat(obj, n, axis=0)
elif axis == 1:
obj = to_2d(obj)
if checks.is_pandas(obj):
new_columns = indexes.repeat_index(obj.columns, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.repeat(obj.values, n, axis=1), columns=new_columns)
return np.repeat(obj, n, axis=1)
else:
raise ValueError(f"Only axes 0 and 1 are supported, not {axis}")
def tile(
obj: tp.ArrayLike,
n: int,
axis: int = 1,
raw: bool = False,
ignore_ranges: tp.Optional[bool] = None,
) -> tp.AnyArray:
"""Tile `obj` `n` times along the specified axis."""
obj = to_any_array(obj, raw=raw)
if axis == 0:
if obj.ndim == 2:
if checks.is_pandas(obj):
new_index = indexes.tile_index(obj.index, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, (n, 1)), index=new_index)
return np.tile(obj, (n, 1))
if checks.is_pandas(obj):
new_index = indexes.tile_index(obj.index, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, n), index=new_index)
return np.tile(obj, n)
elif axis == 1:
obj = to_2d(obj)
if checks.is_pandas(obj):
new_columns = indexes.tile_index(obj.columns, n, ignore_ranges=ignore_ranges)
return wrapping.ArrayWrapper.from_obj(obj).wrap(np.tile(obj.values, (1, n)), columns=new_columns)
return np.tile(obj, (1, n))
else:
raise ValueError(f"Only axes 0 and 1 are supported, not {axis}")
def broadcast_shapes(
*shapes: tp.ArrayLike,
axis: tp.Optional[tp.MaybeSequence[int]] = None,
expand_axis: tp.Optional[tp.MaybeSequence[int]] = None,
) -> tp.Tuple[tp.Shape, ...]:
"""Broadcast shape-like objects using vectorbt's broadcasting rules."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if expand_axis is None:
expand_axis = broadcasting_cfg["expand_axis"]
is_2d = False
for i, shape in enumerate(shapes):
shape = to_tuple_shape(shape)
if len(shape) == 2:
is_2d = True
break
new_shapes = []
for i, shape in enumerate(shapes):
shape = to_tuple_shape(shape)
if is_2d:
if checks.is_sequence(expand_axis):
_expand_axis = expand_axis[i]
else:
_expand_axis = expand_axis
new_shape = to_2d_shape(shape, expand_axis=_expand_axis)
else:
new_shape = to_1d_shape(shape)
if axis is not None:
if checks.is_sequence(axis):
_axis = axis[i]
else:
_axis = axis
if _axis is not None:
if _axis == 0:
if is_2d:
new_shape = (new_shape[0], 1)
else:
new_shape = (new_shape[0],)
elif _axis == 1:
if is_2d:
new_shape = (1, new_shape[1])
else:
new_shape = (1,)
else:
raise ValueError(f"Only axes 0 and 1 are supported, not {_axis}")
new_shapes.append(new_shape)
return tuple(np.broadcast_shapes(*new_shapes))
def broadcast_array_to(
arr: tp.ArrayLike,
target_shape: tp.ShapeLike,
axis: tp.Optional[int] = None,
expand_axis: tp.Optional[int] = None,
) -> tp.Array:
"""Broadcast an array-like object to a target shape using vectorbt's broadcasting rules."""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if expand_axis is None:
expand_axis = broadcasting_cfg["expand_axis"]
arr = np.asarray(arr)
target_shape = to_tuple_shape(target_shape)
if len(target_shape) not in (1, 2):
raise ValueError(f"Target shape must have either 1 or 2 dimensions, not {len(target_shape)}")
if len(target_shape) == 2:
new_arr = to_2d_array(arr, expand_axis=expand_axis)
else:
new_arr = to_1d_array(arr)
if axis is not None:
if axis == 0:
if len(target_shape) == 2:
target_shape = (target_shape[0], new_arr.shape[1])
else:
target_shape = (target_shape[0],)
elif axis == 1:
target_shape = (new_arr.shape[0], target_shape[1])
else:
raise ValueError(f"Only axes 0 and 1 are supported, not {axis}")
return np.broadcast_to(new_arr, target_shape)
def broadcast_arrays(
*arrs: tp.ArrayLike,
target_shape: tp.Optional[tp.ShapeLike] = None,
axis: tp.Optional[tp.MaybeSequence[int]] = None,
expand_axis: tp.Optional[tp.MaybeSequence[int]] = None,
) -> tp.Tuple[tp.Array, ...]:
"""Broadcast array-like objects using vectorbt's broadcasting rules.
Optionally to a target shape."""
if target_shape is None:
shapes = []
for arr in arrs:
shapes.append(np.asarray(arr).shape)
target_shape = broadcast_shapes(*shapes, axis=axis, expand_axis=expand_axis)
new_arrs = []
for i, arr in enumerate(arrs):
if axis is not None:
if checks.is_sequence(axis):
_axis = axis[i]
else:
_axis = axis
else:
_axis = None
if expand_axis is not None:
if checks.is_sequence(expand_axis):
_expand_axis = expand_axis[i]
else:
_expand_axis = expand_axis
else:
_expand_axis = None
new_arr = broadcast_array_to(arr, target_shape, axis=_axis, expand_axis=_expand_axis)
new_arrs.append(new_arr)
return tuple(new_arrs)
IndexFromLike = tp.Union[None, str, int, tp.Any]
"""Any object that can be coerced into a `index_from` argument."""
def broadcast_index(
objs: tp.Sequence[tp.AnyArray],
to_shape: tp.Shape,
index_from: IndexFromLike = None,
axis: int = 0,
ignore_sr_names: tp.Optional[bool] = None,
ignore_ranges: tp.Optional[bool] = None,
check_index_names: tp.Optional[bool] = None,
**clean_index_kwargs,
) -> tp.Optional[tp.Index]:
"""Produce a broadcast index/columns.
Args:
objs (iterable of array_like): Array-like objects.
to_shape (tuple of int): Target shape.
index_from (any): Broadcasting rule for this index/these columns.
Accepts the following values:
* 'keep' or None - keep the original index/columns of the objects in `objs`
* 'stack' - stack different indexes/columns using `vectorbtpro.base.indexes.stack_indexes`
* 'strict' - ensure that all Pandas objects have the same index/columns
* 'reset' - reset any index/columns (they become a simple range)
* integer - use the index/columns of the i-th object in `objs`
* everything else will be converted to `pd.Index`
axis (int): Set to 0 for index and 1 for columns.
ignore_sr_names (bool): Whether to ignore Series names if they are in conflict.
Conflicting Series names are those that are different but not None.
ignore_ranges (bool): Whether to ignore indexes of type `pd.RangeIndex`.
check_index_names (bool): See `vectorbtpro.utils.checks.is_index_equal`.
**clean_index_kwargs: Keyword arguments passed to `vectorbtpro.base.indexes.clean_index`.
For defaults, see `vectorbtpro._settings.broadcasting`.
!!! note
Series names are treated as columns with a single element but without a name.
If a column level without a name loses its meaning, better to convert Series to DataFrames
with one column prior to broadcasting. If the name of a Series is not that important,
better to drop it altogether by setting it to None.
"""
from vectorbtpro._settings import settings
broadcasting_cfg = settings["broadcasting"]
if ignore_sr_names is None:
ignore_sr_names = broadcasting_cfg["ignore_sr_names"]
if check_index_names is None:
check_index_names = broadcasting_cfg["check_index_names"]
index_str = "columns" if axis == 1 else "index"
to_shape_2d = (to_shape[0], 1) if len(to_shape) == 1 else to_shape
maxlen = to_shape_2d[1] if axis == 1 else to_shape_2d[0]
new_index = None
objs = list(objs)
if index_from is None or (isinstance(index_from, str) and index_from.lower() == "keep"):
return None
if isinstance(index_from, int):
if not checks.is_pandas(objs[index_from]):
raise TypeError(f"Argument under index {index_from} must be a pandas object")
new_index = indexes.get_index(objs[index_from], axis)
elif isinstance(index_from, str):
if index_from.lower() == "reset":
new_index = pd.RangeIndex(start=0, stop=maxlen, step=1)
elif index_from.lower() in ("stack", "strict"):
last_index = None
index_conflict = False
for obj in objs:
if checks.is_pandas(obj):
index = indexes.get_index(obj, axis)
if last_index is not None:
if not checks.is_index_equal(index, last_index, check_names=check_index_names):
index_conflict = True
last_index = index
continue
if not index_conflict:
new_index = last_index
else:
for obj in objs:
if checks.is_pandas(obj):
index = indexes.get_index(obj, axis)
if axis == 1 and checks.is_series(obj) and ignore_sr_names:
continue
if checks.is_default_index(index):
continue
if new_index is None:
new_index = index
else:
if checks.is_index_equal(index, new_index, check_names=check_index_names):
continue
if index_from.lower() == "strict":
raise ValueError(
f"Arrays have different index. Broadcasting {index_str} "
f"is not allowed when {index_str}_from=strict"
)
if len(index) != len(new_index):
if len(index) > 1 and len(new_index) > 1:
raise ValueError("Indexes could not be broadcast together")
if len(index) > len(new_index):
new_index = indexes.repeat_index(new_index, len(index), ignore_ranges=ignore_ranges)
elif len(index) < len(new_index):
index = indexes.repeat_index(index, len(new_index), ignore_ranges=ignore_ranges)
new_index = indexes.stack_indexes([new_index, index], **clean_index_kwargs)
else:
raise ValueError(f"Invalid value '{index_from}' for {'columns' if axis == 1 else 'index'}_from")
else:
if not isinstance(index_from, pd.Index):
index_from = pd.Index(index_from)
new_index = index_from
if new_index is not None:
if maxlen > len(new_index):
if isinstance(index_from, str) and index_from.lower() == "strict":
raise ValueError(f"Broadcasting {index_str} is not allowed when {index_str}_from=strict")
if maxlen > 1 and len(new_index) > 1:
raise ValueError("Indexes could not be broadcast together")
new_index = indexes.repeat_index(new_index, maxlen, ignore_ranges=ignore_ranges)
else:
new_index = pd.RangeIndex(start=0, stop=maxlen, step=1)
return new_index
def wrap_broadcasted(
new_obj: tp.Array,
old_obj: tp.Optional[tp.AnyArray] = None,
axis: tp.Optional[int] = None,
is_pd: bool = False,
new_index: tp.Optional[tp.Index] = None,
new_columns: tp.Optional[tp.Index] = None,
ignore_ranges: tp.Optional[bool] = None,
) -> tp.AnyArray:
"""If the newly brodcasted array was originally a Pandas object, make it Pandas object again
and assign it the newly broadcast index/columns."""
if is_pd:
if axis == 0:
new_columns = None
elif axis == 1:
new_index = None
if old_obj is not None and checks.is_pandas(old_obj):
if new_index is None:
old_index = indexes.get_index(old_obj, 0)
if old_obj.shape[0] == new_obj.shape[0]:
new_index = old_index
else:
new_index = indexes.repeat_index(old_index, new_obj.shape[0], ignore_ranges=ignore_ranges)
if new_columns is None:
old_columns = indexes.get_index(old_obj, 1)
new_ncols = new_obj.shape[1] if new_obj.ndim == 2 else 1
if len(old_columns) == new_ncols:
new_columns = old_columns
else:
new_columns = indexes.repeat_index(old_columns, new_ncols, ignore_ranges=ignore_ranges)
if new_obj.ndim == 2:
return pd.DataFrame(new_obj, index=new_index, columns=new_columns)
if new_columns is not None and len(new_columns) == 1:
name = new_columns[0]
if name == 0:
name = None
else:
name = None
return pd.Series(new_obj, index=new_index, name=name)
return new_obj
def align_pd_arrays(
*objs: tp.AnyArray,
align_index: bool = True,
align_columns: bool = True,
to_index: tp.Optional[tp.Index] = None,
to_columns: tp.Optional[tp.Index] = None,
axis: tp.Optional[tp.MaybeSequence[int]] = None,
reindex_kwargs: tp.KwargsLikeSequence = None,
) -> tp.MaybeTuple[tp.ArrayLike]:
"""Align Pandas arrays against common index and/or column levels using reindexing
and `vectorbtpro.base.indexes.align_indexes` respectively."""
objs = list(objs)
if align_index:
indexes_to_align = []
for i in range(len(objs)):
if axis is not None:
if checks.is_sequence(axis):
_axis = axis[i]
else:
_axis = axis
else:
_axis = None
if _axis in (None, 0):
if checks.is_pandas(objs[i]):
if not checks.is_default_index(objs[i].index):
indexes_to_align.append(i)
if (len(indexes_to_align) > 0 and to_index is not None) or len(indexes_to_align) > 1:
if to_index is None:
new_index = None
index_changed = False
for i in indexes_to_align:
arg_index = objs[i].index
if new_index is None:
new_index = arg_index
else:
if not checks.is_index_equal(new_index, arg_index):
if new_index.dtype != arg_index.dtype:
raise ValueError("Indexes to be aligned must have the same data type")
new_index = new_index.union(arg_index)
index_changed = True
else:
new_index = to_index
index_changed = True
if index_changed:
for i in indexes_to_align:
if to_index is None or not checks.is_index_equal(objs[i].index, to_index):
if objs[i].index.has_duplicates:
raise ValueError(f"Index at position {i} contains duplicates")
if not objs[i].index.is_monotonic_increasing:
raise ValueError(f"Index at position {i} is not monotonically increasing")
_reindex_kwargs = resolve_dict(reindex_kwargs, i=i)
was_bool = (isinstance(objs[i], pd.Series) and objs[i].dtype == "bool") or (
isinstance(objs[i], pd.DataFrame) and (objs[i].dtypes == "bool").all()
)
objs[i] = objs[i].reindex(new_index, **_reindex_kwargs)
is_object = (isinstance(objs[i], pd.Series) and objs[i].dtype == "object") or (
isinstance(objs[i], pd.DataFrame) and (objs[i].dtypes == "object").all()
)
if was_bool and is_object:
objs[i] = objs[i].astype(None)
if align_columns:
columns_to_align = []
for i in range(len(objs)):
if axis is not None:
if checks.is_sequence(axis):
_axis = axis[i]
else:
_axis = axis
else:
_axis = None
if _axis in (None, 1):
if checks.is_frame(objs[i]) and len(objs[i].columns) > 1:
if not checks.is_default_index(objs[i].columns):
columns_to_align.append(i)
if (len(columns_to_align) > 0 and to_columns is not None) or len(columns_to_align) > 1:
indexes_ = [objs[i].columns for i in columns_to_align]
if to_columns is not None:
indexes_.append(to_columns)
if len(set(map(len, indexes_))) > 1:
col_indices = indexes.align_indexes(*indexes_)
for i in columns_to_align:
objs[i] = objs[i].iloc[:, col_indices[columns_to_align.index(i)]]
if len(objs) == 1:
return objs[0]
return tuple(objs)
@define
class BCO(DefineMixin):
"""Class that represents an object passed to `broadcast`.
If any value is None, mostly defaults to the global value passed to `broadcast`."""
value: tp.Any = define.field()
"""Value of the object."""
axis: tp.Optional[int] = define.field(default=None)
"""Axis to broadcast.
Set to None to broadcast all axes."""
to_pd: tp.Optional[bool] = define.field(default=None)
"""Whether to convert the output array to a Pandas object."""
keep_flex: tp.Optional[bool] = define.field(default=None)
"""Whether to keep the raw version of the output for flexible indexing.
Only makes sure that the array can broadcast to the target shape."""
min_ndim: tp.Optional[int] = define.field(default=None)
"""Minimum number of dimensions."""
expand_axis: tp.Optional[int] = define.field(default=None)
"""Axis to expand if the array is 1-dim but the target shape is 2-dim."""
post_func: tp.Optional[tp.Callable] = define.field(default=None)
"""Function to post-process the output array."""
require_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None)
"""Keyword arguments passed to `np.require`."""
reindex_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None)
"""Keyword arguments passed to `pd.DataFrame.reindex`."""
merge_kwargs: tp.Optional[tp.Kwargs] = define.field(default=None)
"""Keyword arguments passed to `vectorbtpro.base.merging.column_stack_merge`."""
context: tp.KwargsLike = define.field(default=None)
"""Context used in evaluation of templates.
Will be merged over `template_context`."""
@define
class Default(DefineMixin):
"""Class for wrapping default values."""
value: tp.Any = define.field()
"""Default value."""
@define
class Ref(DefineMixin):
"""Class for wrapping references to other values."""
key: tp.Hashable = define.field()
"""Reference to another key."""
def resolve_ref(dct: dict, k: tp.Hashable, inside_bco: bool = False, keep_wrap_default: bool = False) -> tp.Any:
"""Resolve a potential reference."""
v = dct[k]
is_default = False
if isinstance(v, Default):
v = v.value
is_default = True
if isinstance(v, Ref):
new_v = resolve_ref(dct, v.key, inside_bco=inside_bco)
if keep_wrap_default and is_default:
return Default(new_v)
return new_v
if isinstance(v, BCO) and inside_bco:
v = v.value
is_default = False
if isinstance(v, Default):
v = v.value
is_default = True
if isinstance(v, Ref):
new_v = resolve_ref(dct, v.key, inside_bco=inside_bco)
if keep_wrap_default and is_default:
return Default(new_v)
return new_v
return v
def broadcast(
*objs,
to_shape: tp.Optional[tp.ShapeLike] = None,
align_index: tp.Optional[bool] = None,
align_columns: tp.Optional[bool] = None,
index_from: tp.Optional[IndexFromLike] = None,
columns_from: tp.Optional[IndexFromLike] = None,
to_frame: tp.Optional[bool] = None,
axis: tp.Optional[tp.MaybeMappingSequence[int]] = None,
to_pd: tp.Optional[tp.MaybeMappingSequence[bool]] = None,
keep_flex: tp.MaybeMappingSequence[tp.Optional[bool]] = None,
min_ndim: tp.MaybeMappingSequence[tp.Optional[int]] = None,
expand_axis: tp.MaybeMappingSequence[tp.Optional[int]] = None,
post_func: tp.MaybeMappingSequence[tp.Optional[tp.Callable]] = None,
require_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None,
reindex_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None,
merge_kwargs: tp.MaybeMappingSequence[tp.Optional[tp.Kwargs]] = None,
tile: tp.Union[None, int, tp.IndexLike] = None,
random_subset: tp.Optional[int] = None,
seed: tp.Optional[int] = None,
keep_wrap_default: tp.Optional[bool] = None,
return_wrapper: bool = False,
wrapper_kwargs: tp.KwargsLike = None,
ignore_sr_names: tp.Optional[bool] = None,
ignore_ranges: tp.Optional[bool] = None,
check_index_names: tp.Optional[bool] = None,
clean_index_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
) -> tp.Any:
"""Bring any array-like object in `objs` to the same shape by using NumPy-like broadcasting.
See [Broadcasting](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
!!! important
The major difference to NumPy is that one-dimensional arrays will always broadcast against the row axis!
Can broadcast Pandas objects by broadcasting their index/columns with `broadcast_index`.
Args:
*objs: Objects to broadcast.
If the first and only argument is a mapping, will return a dict.
Allows using `BCO`, `Ref`, `Default`, `vectorbtpro.utils.params.Param`,
`vectorbtpro.base.indexing.index_dict`, `vectorbtpro.base.indexing.IdxSetter`,
`vectorbtpro.base.indexing.IdxSetterFactory`, and templates.
If an index dictionary, fills using `vectorbtpro.base.wrapping.ArrayWrapper.fill_and_set`.
to_shape (tuple of int): Target shape. If set, will broadcast every object in `objs` to `to_shape`.
align_index (bool): Whether to align index of Pandas objects using union.
Pass None to use the default.
align_columns (bool): Whether to align columns of Pandas objects using multi-index.
Pass None to use the default.
index_from (any): Broadcasting rule for index.
Pass None to use the default.
columns_from (any): Broadcasting rule for columns.
Pass None to use the default.
to_frame (bool): Whether to convert all Series to DataFrames.
axis (int, sequence or mapping): See `BCO.axis`.
to_pd (bool, sequence or mapping): See `BCO.to_pd`.
If None, converts only if there is at least one Pandas object among them.
keep_flex (bool, sequence or mapping): See `BCO.keep_flex`.
min_ndim (int, sequence or mapping): See `BCO.min_ndim`.
If None, becomes 2 if `keep_flex` is True, otherwise 1.
expand_axis (int, sequence or mapping): See `BCO.expand_axis`.
post_func (callable, sequence or mapping): See `BCO.post_func`.
Applied only when `keep_flex` is False.
require_kwargs (dict, sequence or mapping): See `BCO.require_kwargs`.
This key will be merged with any argument-specific dict. If the mapping contains all keys in
`np.require`, it will be applied to all objects.
reindex_kwargs (dict, sequence or mapping): See `BCO.reindex_kwargs`.
This key will be merged with any argument-specific dict. If the mapping contains all keys in
`pd.DataFrame.reindex`, it will be applied to all objects.
merge_kwargs (dict, sequence or mapping): See `BCO.merge_kwargs`.
This key will be merged with any argument-specific dict. If the mapping contains all keys in
`pd.DataFrame.merge`, it will be applied to all objects.
tile (int or index_like): Tile the final object by the number of times or index.
random_subset (int): Select a random subset of parameter values.
Seed can be set using NumPy before calling this function.
seed (int): Seed to make output deterministic.
keep_wrap_default (bool): Whether to keep wrapping with `vectorbtpro.base.reshaping.Default`.
return_wrapper (bool): Whether to also return the wrapper associated with the operation.
wrapper_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper`.
ignore_sr_names (bool): See `broadcast_index`.
ignore_ranges (bool): See `broadcast_index`.
check_index_names (bool): See `broadcast_index`.
clean_index_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.indexes.clean_index`.
template_context (dict): Context used to substitute templates.
For defaults, see `vectorbtpro._settings.broadcasting`.
Any keyword argument that can be associated with an object can be passed as
* a const that is applied to all objects,
* a sequence with value per object, and
* a mapping with value per object name and the special key `_def` denoting the default value.
Additionally, any object can be passed wrapped with `BCO`, which ibutes will override
any of the above arguments if not None.
Usage:
* Without broadcasting index and columns:
```pycon
>>> from vectorbtpro import *
>>> v = 0
>>> a = np.array([1, 2, 3])
>>> sr = pd.Series([1, 2, 3], index=pd.Index(['x', 'y', 'z']), name='a')
>>> df = pd.DataFrame(
... [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
... index=pd.Index(['x2', 'y2', 'z2']),
... columns=pd.Index(['a2', 'b2', 'c2']),
... )
>>> for i in vbt.broadcast(
... v, a, sr, df,
... index_from='keep',
... columns_from='keep',
... align_index=False
... ): print(i)
0 1 2
0 0 0 0
1 0 0 0
2 0 0 0
0 1 2
0 1 2 3
1 1 2 3
2 1 2 3
a a a
x 1 1 1
y 2 2 2
z 3 3 3
a2 b2 c2
x2 1 2 3
y2 4 5 6
z2 7 8 9
```
* Take index and columns from the argument at specific position:
```pycon
>>> for i in vbt.broadcast(
... v, a, sr, df,
... index_from=2,
... columns_from=3,
... align_index=False
... ): print(i)
a2 b2 c2
x 0 0 0
y 0 0 0
z 0 0 0
a2 b2 c2
x 1 2 3
y 1 2 3
z 1 2 3
a2 b2 c2
x 1 1 1
y 2 2 2
z 3 3 3
a2 b2 c2
x 1 2 3
y 4 5 6
z 7 8 9
```
* Broadcast index and columns through stacking:
```pycon
>>> for i in vbt.broadcast(
... v, a, sr, df,
... index_from='stack',
... columns_from='stack',
... align_index=False
... ): print(i)
a2 b2 c2
x x2 0 0 0
y y2 0 0 0
z z2 0 0 0
a2 b2 c2
x x2 1 2 3
y y2 1 2 3
z z2 1 2 3
a2 b2 c2
x x2 1 1 1
y y2 2 2 2
z z2 3 3 3
a2 b2 c2
x x2 1 2 3
y y2 4 5 6
z z2 7 8 9
```
* Set index and columns manually:
```pycon
>>> for i in vbt.broadcast(
... v, a, sr, df,
... index_from=['a', 'b', 'c'],
... columns_from=['d', 'e', 'f'],
... align_index=False
... ): print(i)
d e f
a 0 0 0
b 0 0 0
c 0 0 0
d e f
a 1 2 3
b 1 2 3
c 1 2 3
d e f
a 1 1 1
b 2 2 2
c 3 3 3
d e f
a 1 2 3
b 4 5 6
c 7 8 9
```
* Pass arguments as a mapping returns a mapping:
```pycon
>>> vbt.broadcast(
... dict(v=v, a=a, sr=sr, df=df),
... index_from='stack',
... align_index=False
... )
{'v': a2 b2 c2
x x2 0 0 0
y y2 0 0 0
z z2 0 0 0,
'a': a2 b2 c2
x x2 1 2 3
y y2 1 2 3
z z2 1 2 3,
'sr': a2 b2 c2
x x2 1 1 1
y y2 2 2 2
z z2 3 3 3,
'df': a2 b2 c2
x x2 1 2 3
y y2 4 5 6
z z2 7 8 9}
```
* Keep all results in a format suitable for flexible indexing apart from one:
```pycon
>>> vbt.broadcast(
... dict(v=v, a=a, sr=sr, df=df),
... index_from='stack',
... keep_flex=dict(_def=True, df=False),
... require_kwargs=dict(df=dict(dtype=float)),
... align_index=False
... )
{'v': array([0]),
'a': array([1, 2, 3]),
'sr': array([[1],
[2],
[3]]),
'df': a2 b2 c2
x x2 1.0 2.0 3.0
y y2 4.0 5.0 6.0
z z2 7.0 8.0 9.0}
```
* Specify arguments per object using `BCO`:
```pycon
>>> df_bco = vbt.BCO(df, keep_flex=False, require_kwargs=dict(dtype=float))
>>> vbt.broadcast(
... dict(v=v, a=a, sr=sr, df=df_bco),
... index_from='stack',
... keep_flex=True,
... align_index=False
... )
{'v': array([0]),
'a': array([1, 2, 3]),
'sr': array([[1],
[2],
[3]]),
'df': a2 b2 c2
x x2 1.0 2.0 3.0
y y2 4.0 5.0 6.0
z z2 7.0 8.0 9.0}
```
* Introduce a parameter that should build a Cartesian product of its values and other objects:
```pycon
>>> df_bco = vbt.BCO(df, keep_flex=False, require_kwargs=dict(dtype=float))
>>> p_bco = vbt.BCO(pd.Param([1, 2, 3], name='my_p'))
>>> vbt.broadcast(
... dict(v=v, a=a, sr=sr, df=df_bco, p=p_bco),
... index_from='stack',
... keep_flex=True,
... align_index=False
... )
{'v': array([0]),
'a': array([1, 2, 3, 1, 2, 3, 1, 2, 3]),
'sr': array([[1],
[2],
[3]]),
'df': my_p 1 2 3
a2 b2 c2 a2 b2 c2 a2 b2 c2
x x2 1.0 2.0 3.0 1.0 2.0 3.0 1.0 2.0 3.0
y y2 4.0 5.0 6.0 4.0 5.0 6.0 4.0 5.0 6.0
z z2 7.0 8.0 9.0 7.0 8.0 9.0 7.0 8.0 9.0,
'p': array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
[1, 1, 1, 2, 2, 2, 3, 3, 3],
[1, 1, 1, 2, 2, 2, 3, 3, 3]])}
```
* Build a Cartesian product of all parameters:
```pycon
>>> vbt.broadcast(
... dict(
... a=vbt.Param([1, 2, 3]),
... b=vbt.Param(['x', 'y']),
... c=vbt.Param([False, True])
... )
... )
{'a': array([[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]]),
'b': array([['x', 'x', 'y', 'y', 'x', 'x', 'y', 'y', 'x', 'x', 'y', 'y']], dtype='>> vbt.broadcast(
... dict(
... a=vbt.Param([1, 2, 3], level=0),
... b=vbt.Param(['x', 'y'], level=1),
... d=vbt.Param([100., 200., 300.], level=0),
... c=vbt.Param([False, True], level=1)
... )
... )
{'a': array([[1, 1, 2, 2, 3, 3]]),
'b': array([['x', 'y', 'x', 'y', 'x', 'y']], dtype='>> vbt.broadcast(
... dict(
... a=vbt.Param([1, 2, 3]),
... b=vbt.Param(['x', 'y']),
... c=vbt.Param([False, True])
... ),
... random_subset=5,
... seed=42
... )
{'a': array([[1, 2, 3, 3, 3]]),
'b': array([['x', 'x', 'x', 'x', 'y']], dtype=' 1:
raise ValueError("Only one argument is allowed when passing a mapping")
all_keys = list(dict(objs[0]).keys())
objs = list(objs[0].values())
return_dict = True
else:
objs = list(objs)
all_keys = list(range(len(objs)))
return_dict = False
def _resolve_arg(obj: tp.Any, arg_name: str, global_value: tp.Any, default_value: tp.Any) -> tp.Any:
if isinstance(obj, BCO) and getattr(obj, arg_name) is not None:
return getattr(obj, arg_name)
if checks.is_mapping(global_value):
return global_value.get(k, global_value.get("_def", default_value))
if checks.is_sequence(global_value):
return global_value[i]
return global_value
# Build BCO instances
none_keys = set()
default_keys = set()
param_keys = set()
special_keys = set()
bco_instances = {}
pool = dict(zip(all_keys, objs))
for i, k in enumerate(all_keys):
obj = objs[i]
if isinstance(obj, Default):
obj = obj.value
default_keys.add(k)
if isinstance(obj, Ref):
obj = resolve_ref(pool, k)
if isinstance(obj, BCO):
value = obj.value
else:
value = obj
if isinstance(value, Default):
value = value.value
default_keys.add(k)
if isinstance(value, Ref):
value = resolve_ref(pool, k, inside_bco=True)
if value is None:
none_keys.add(k)
continue
_axis = _resolve_arg(obj, "axis", axis, None)
_to_pd = _resolve_arg(obj, "to_pd", to_pd, None)
_keep_flex = _resolve_arg(obj, "keep_flex", keep_flex, None)
if _keep_flex is None:
_keep_flex = broadcasting_cfg["keep_flex"]
_min_ndim = _resolve_arg(obj, "min_ndim", min_ndim, None)
if _min_ndim is None:
_min_ndim = broadcasting_cfg["min_ndim"]
_expand_axis = _resolve_arg(obj, "expand_axis", expand_axis, None)
if _expand_axis is None:
_expand_axis = broadcasting_cfg["expand_axis"]
_post_func = _resolve_arg(obj, "post_func", post_func, None)
if isinstance(obj, BCO) and obj.require_kwargs is not None:
_require_kwargs = obj.require_kwargs
else:
_require_kwargs = None
if checks.is_mapping(require_kwargs) and require_kwargs_per_obj:
_require_kwargs = merge_dicts(
require_kwargs.get("_def", None),
require_kwargs.get(k, None),
_require_kwargs,
)
elif checks.is_sequence(require_kwargs) and require_kwargs_per_obj:
_require_kwargs = merge_dicts(require_kwargs[i], _require_kwargs)
else:
_require_kwargs = merge_dicts(require_kwargs, _require_kwargs)
if isinstance(obj, BCO) and obj.reindex_kwargs is not None:
_reindex_kwargs = obj.reindex_kwargs
else:
_reindex_kwargs = None
if checks.is_mapping(reindex_kwargs) and reindex_kwargs_per_obj:
_reindex_kwargs = merge_dicts(
reindex_kwargs.get("_def", None),
reindex_kwargs.get(k, None),
_reindex_kwargs,
)
elif checks.is_sequence(reindex_kwargs) and reindex_kwargs_per_obj:
_reindex_kwargs = merge_dicts(reindex_kwargs[i], _reindex_kwargs)
else:
_reindex_kwargs = merge_dicts(reindex_kwargs, _reindex_kwargs)
if isinstance(obj, BCO) and obj.merge_kwargs is not None:
_merge_kwargs = obj.merge_kwargs
else:
_merge_kwargs = None
if checks.is_mapping(merge_kwargs) and merge_kwargs_per_obj:
_merge_kwargs = merge_dicts(
merge_kwargs.get("_def", None),
merge_kwargs.get(k, None),
_merge_kwargs,
)
elif checks.is_sequence(merge_kwargs) and merge_kwargs_per_obj:
_merge_kwargs = merge_dicts(merge_kwargs[i], _merge_kwargs)
else:
_merge_kwargs = merge_dicts(merge_kwargs, _merge_kwargs)
if isinstance(obj, BCO):
_context = merge_dicts(template_context, obj.context)
else:
_context = template_context
if isinstance(value, Param):
param_keys.add(k)
elif isinstance(value, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory, CustomTemplate)):
special_keys.add(k)
else:
value = to_any_array(value)
bco_instances[k] = BCO(
value,
axis=_axis,
to_pd=_to_pd,
keep_flex=_keep_flex,
min_ndim=_min_ndim,
expand_axis=_expand_axis,
post_func=_post_func,
require_kwargs=_require_kwargs,
reindex_kwargs=_reindex_kwargs,
merge_kwargs=_merge_kwargs,
context=_context,
)
# Check whether we should broadcast Pandas metadata and work on 2-dim data
is_pd = False
is_2d = False
old_objs = {}
obj_axis = {}
obj_reindex_kwargs = {}
for k, bco_obj in bco_instances.items():
if k in none_keys or k in param_keys or k in special_keys:
continue
obj = bco_obj.value
if obj.ndim > 1:
is_2d = True
if checks.is_pandas(obj):
is_pd = True
if bco_obj.to_pd is not None and bco_obj.to_pd:
is_pd = True
old_objs[k] = obj
obj_axis[k] = bco_obj.axis
obj_reindex_kwargs[k] = bco_obj.reindex_kwargs
if to_shape is not None:
if isinstance(to_shape, int):
to_shape = (to_shape,)
if len(to_shape) > 1:
is_2d = True
if to_frame is not None:
is_2d = to_frame
if to_pd is not None:
is_pd = to_pd or (return_wrapper and is_pd)
# Align pandas arrays
if index_from is not None and not isinstance(index_from, (int, str, pd.Index)):
index_from = pd.Index(index_from)
if columns_from is not None and not isinstance(columns_from, (int, str, pd.Index)):
columns_from = pd.Index(columns_from)
aligned_objs = align_pd_arrays(
*old_objs.values(),
align_index=align_index,
align_columns=align_columns,
to_index=index_from if isinstance(index_from, pd.Index) else None,
to_columns=columns_from if isinstance(columns_from, pd.Index) else None,
axis=list(obj_axis.values()),
reindex_kwargs=list(obj_reindex_kwargs.values()),
)
if not isinstance(aligned_objs, tuple):
aligned_objs = (aligned_objs,)
aligned_objs = dict(zip(old_objs.keys(), aligned_objs))
# Convert to NumPy
ready_objs = {}
for k, obj in aligned_objs.items():
_expand_axis = bco_instances[k].expand_axis
new_obj = np.asarray(obj)
if is_2d and new_obj.ndim == 1:
if isinstance(obj, pd.Series):
new_obj = new_obj[:, None]
else:
new_obj = np.expand_dims(new_obj, _expand_axis)
ready_objs[k] = new_obj
# Get final shape
if to_shape is None:
try:
to_shape = broadcast_shapes(
*map(lambda x: x.shape, ready_objs.values()),
axis=list(obj_axis.values()),
)
except ValueError:
arr_shapes = {}
for i, k in enumerate(bco_instances):
if k in none_keys or k in param_keys or k in special_keys:
continue
if len(ready_objs[k].shape) > 0:
arr_shapes[k] = ready_objs[k].shape
raise ValueError("Could not broadcast shapes: %s" % str(arr_shapes))
if not isinstance(to_shape, tuple):
to_shape = (to_shape,)
if len(to_shape) == 0:
to_shape = (1,)
to_shape_2d = to_shape if len(to_shape) > 1 else (*to_shape, 1)
if is_pd:
# Decide on index and columns
# NOTE: Important to pass aligned_objs, not ready_objs, to preserve original shape info
new_index = broadcast_index(
[v for k, v in aligned_objs.items() if obj_axis[k] in (None, 0)],
to_shape,
index_from=index_from,
axis=0,
ignore_sr_names=ignore_sr_names,
ignore_ranges=ignore_ranges,
check_index_names=check_index_names,
**clean_index_kwargs,
)
new_columns = broadcast_index(
[v for k, v in aligned_objs.items() if obj_axis[k] in (None, 1)],
to_shape,
index_from=columns_from,
axis=1,
ignore_sr_names=ignore_sr_names,
ignore_ranges=ignore_ranges,
check_index_names=check_index_names,
**clean_index_kwargs,
)
else:
new_index = pd.RangeIndex(stop=to_shape_2d[0])
new_columns = pd.RangeIndex(stop=to_shape_2d[1])
# Build a product
param_product = None
param_columns = None
n_params = 0
if len(param_keys) > 0:
# Combine parameters
param_dct = {}
for k, bco_obj in bco_instances.items():
if k not in param_keys:
continue
param_dct[k] = bco_obj.value
param_product, param_columns = combine_params(
param_dct,
random_subset=random_subset,
seed=seed,
clean_index_kwargs=clean_index_kwargs,
)
n_params = len(param_columns)
# Combine parameter columns with new columns
if param_columns is not None and new_columns is not None:
new_columns = indexes.combine_indexes([param_columns, new_columns], **clean_index_kwargs)
# Tile
if tile is not None:
if isinstance(tile, int):
if new_columns is not None:
new_columns = indexes.tile_index(new_columns, tile)
else:
if new_columns is not None:
new_columns = indexes.combine_indexes([tile, new_columns], **clean_index_kwargs)
tile = len(tile)
n_params = max(n_params, 1) * tile
# Build wrapper
if n_params == 0:
new_shape = to_shape
else:
new_shape = (to_shape_2d[0], to_shape_2d[1] * n_params)
wrapper = wrapping.ArrayWrapper.from_shape(
new_shape,
**merge_dicts(
dict(
index=new_index,
columns=new_columns,
),
wrapper_kwargs,
),
)
def _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis):
if _min_ndim is None:
if _keep_flex:
_min_ndim = 2
else:
_min_ndim = 1
if _min_ndim not in (1, 2):
raise ValueError("Argument min_ndim must be either 1 or 2")
if _min_ndim in (1, 2) and new_obj.ndim == 0:
new_obj = new_obj[None]
if _min_ndim == 2 and new_obj.ndim == 1:
if len(to_shape) == 1:
new_obj = new_obj[:, None]
else:
new_obj = np.expand_dims(new_obj, _expand_axis)
return new_obj
# Perform broadcasting
aligned_objs2 = {}
new_objs = {}
for i, k in enumerate(all_keys):
if k in none_keys or k in special_keys:
continue
_keep_flex = bco_instances[k].keep_flex
_min_ndim = bco_instances[k].min_ndim
_axis = bco_instances[k].axis
_expand_axis = bco_instances[k].expand_axis
_merge_kwargs = bco_instances[k].merge_kwargs
_context = bco_instances[k].context
must_reset_index = _merge_kwargs.get("reset_index", None) not in (None, False)
_reindex_kwargs = resolve_dict(bco_instances[k].reindex_kwargs)
_fill_value = _reindex_kwargs.get("fill_value", np.nan)
if k in param_keys:
# Broadcast parameters
from vectorbtpro.base.merging import column_stack_merge
if _axis == 0:
raise ValueError("Parameters do not support broadcasting with axis=0")
obj = param_product[k]
new_obj = []
any_needs_broadcasting = False
all_forced_broadcast = True
for o in obj:
if isinstance(o, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory)):
o = wrapper.fill_and_set(
o,
fill_value=_fill_value,
keep_flex=_keep_flex,
)
elif isinstance(o, CustomTemplate):
context = merge_dicts(
dict(
bco_instances=bco_instances,
wrapper=wrapper,
obj_name=k,
bco=bco_instances[k],
),
_context,
)
o = o.substitute(context, eval_id="broadcast")
o = to_2d_array(o)
if not _keep_flex:
needs_broadcasting = True
elif o.shape[0] > 1:
needs_broadcasting = True
elif o.shape[1] > 1 and o.shape[1] != to_shape_2d[1]:
needs_broadcasting = True
else:
needs_broadcasting = False
if needs_broadcasting:
any_needs_broadcasting = True
o = broadcast_array_to(o, to_shape_2d, axis=_axis)
elif o.size == 1:
all_forced_broadcast = False
o = np.repeat(o, to_shape_2d[1], axis=1)
else:
all_forced_broadcast = False
new_obj.append(o)
if any_needs_broadcasting and not all_forced_broadcast:
new_obj2 = []
for o in new_obj:
if o.shape[1] != to_shape_2d[1] or (not must_reset_index and o.shape[0] != to_shape_2d[0]):
o = broadcast_array_to(o, to_shape_2d, axis=_axis)
new_obj2.append(o)
new_obj = new_obj2
obj = column_stack_merge(new_obj, **_merge_kwargs)
if tile is not None:
obj = np.tile(obj, (1, tile))
old_obj = obj
new_obj = obj
else:
# Broadcast regular objects
old_obj = aligned_objs[k]
new_obj = ready_objs[k]
if _axis in (None, 0) and new_obj.ndim >= 1 and new_obj.shape[0] > 1 and new_obj.shape[0] != to_shape[0]:
raise ValueError(f"Could not broadcast argument {k} of shape {new_obj.shape} to {to_shape}")
if _axis in (None, 1) and new_obj.ndim == 2 and new_obj.shape[1] > 1 and new_obj.shape[1] != to_shape[1]:
raise ValueError(f"Could not broadcast argument {k} of shape {new_obj.shape} to {to_shape}")
if _keep_flex:
if n_params > 0 and _axis in (None, 1):
if len(to_shape) == 1:
if new_obj.ndim == 1 and new_obj.shape[0] > 1:
new_obj = new_obj[:, None] # product changes is_2d behavior
else:
if new_obj.ndim == 1 and new_obj.shape[0] > 1:
new_obj = np.tile(new_obj, n_params)
elif new_obj.ndim == 2 and new_obj.shape[1] > 1:
new_obj = np.tile(new_obj, (1, n_params))
else:
new_obj = broadcast_array_to(new_obj, to_shape, axis=_axis)
if n_params > 0 and _axis in (None, 1):
if new_obj.ndim == 1:
new_obj = new_obj[:, None] # product changes is_2d behavior
new_obj = np.tile(new_obj, (1, n_params))
new_obj = _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis)
aligned_objs2[k] = old_obj
new_objs[k] = new_obj
# Resolve special objects
new_objs2 = {}
for i, k in enumerate(all_keys):
if k in none_keys:
continue
if k in special_keys:
bco = bco_instances[k]
if isinstance(bco.value, (indexing.index_dict, indexing.IdxSetter, indexing.IdxSetterFactory)):
_is_pd = bco.to_pd
if _is_pd is None:
_is_pd = is_pd
_keep_flex = bco.keep_flex
_min_ndim = bco.min_ndim
_expand_axis = bco.expand_axis
_reindex_kwargs = resolve_dict(bco.reindex_kwargs)
_fill_value = _reindex_kwargs.get("fill_value", np.nan)
new_obj = wrapper.fill_and_set(
bco.value,
fill_value=_fill_value,
keep_flex=_keep_flex,
)
if not _is_pd and not _keep_flex:
new_obj = new_obj.values
new_obj = _adjust_dims(new_obj, _keep_flex, _min_ndim, _expand_axis)
elif isinstance(bco.value, CustomTemplate):
context = merge_dicts(
dict(
bco_instances=bco_instances,
new_objs=new_objs,
wrapper=wrapper,
obj_name=k,
bco=bco,
),
bco.context,
)
new_obj = bco.value.substitute(context, eval_id="broadcast")
else:
raise TypeError(f"Special type {type(bco.value)} is not supported")
else:
new_obj = new_objs[k]
# Force to match requirements
new_obj = np.require(new_obj, **resolve_dict(bco_instances[k].require_kwargs))
new_objs2[k] = new_obj
# Perform wrapping and post-processing
new_objs3 = {}
for i, k in enumerate(all_keys):
if k in none_keys:
continue
new_obj = new_objs2[k]
_axis = bco_instances[k].axis
_keep_flex = bco_instances[k].keep_flex
if not _keep_flex:
# Wrap array
_is_pd = bco_instances[k].to_pd
if _is_pd is None:
_is_pd = is_pd
new_obj = wrap_broadcasted(
new_obj,
old_obj=aligned_objs2[k] if k not in special_keys else None,
axis=_axis,
is_pd=_is_pd,
new_index=new_index,
new_columns=new_columns,
ignore_ranges=ignore_ranges,
)
# Post-process array
_post_func = bco_instances[k].post_func
if _post_func is not None:
new_obj = _post_func(new_obj)
new_objs3[k] = new_obj
# Prepare outputs
return_objs = []
for k in all_keys:
if k not in none_keys:
if k in default_keys and keep_wrap_default:
return_objs.append(Default(new_objs3[k]))
else:
return_objs.append(new_objs3[k])
else:
if k in default_keys and keep_wrap_default:
return_objs.append(Default(None))
else:
return_objs.append(None)
if return_dict:
return_objs = dict(zip(all_keys, return_objs))
else:
return_objs = tuple(return_objs)
if len(return_objs) > 1 or return_dict:
if return_wrapper:
return return_objs, wrapper
return return_objs
if return_wrapper:
return return_objs[0], wrapper
return return_objs[0]
def broadcast_to(
arg1: tp.ArrayLike,
arg2: tp.Union[tp.ArrayLike, tp.ShapeLike, wrapping.ArrayWrapper],
to_pd: tp.Optional[bool] = None,
index_from: tp.Optional[IndexFromLike] = None,
columns_from: tp.Optional[IndexFromLike] = None,
**kwargs,
) -> tp.Any:
"""Broadcast `arg1` to `arg2`.
Argument `arg2` can be a shape, an instance of `vectorbtpro.base.wrapping.ArrayWrapper`,
or any array-like object.
Pass None to `index_from`/`columns_from` to use index/columns of the second argument.
Keyword arguments `**kwargs` are passed to `broadcast`.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import broadcast_to
>>> a = np.array([1, 2, 3])
>>> sr = pd.Series([4, 5, 6], index=pd.Index(['x', 'y', 'z']), name='a')
>>> broadcast_to(a, sr)
x 1
y 2
z 3
Name: a, dtype: int64
>>> broadcast_to(sr, a)
array([4, 5, 6])
```
"""
if checks.is_int(arg2) or isinstance(arg2, tuple):
arg2 = to_tuple_shape(arg2)
if isinstance(arg2, tuple):
to_shape = arg2
elif isinstance(arg2, wrapping.ArrayWrapper):
to_pd = True
if index_from is None:
index_from = arg2.index
if columns_from is None:
columns_from = arg2.columns
to_shape = arg2.shape
else:
arg2 = to_any_array(arg2)
if to_pd is None:
to_pd = checks.is_pandas(arg2)
if to_pd:
# Take index and columns from arg2
if index_from is None:
index_from = indexes.get_index(arg2, 0)
if columns_from is None:
columns_from = indexes.get_index(arg2, 1)
to_shape = arg2.shape
return broadcast(
arg1,
to_shape=to_shape,
to_pd=to_pd,
index_from=index_from,
columns_from=columns_from,
**kwargs,
)
def broadcast_to_array_of(arg1: tp.ArrayLike, arg2: tp.ArrayLike) -> tp.Array:
"""Broadcast `arg1` to the shape `(1, *arg2.shape)`.
`arg1` must be either a scalar, a 1-dim array, or have 1 dimension more than `arg2`.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import broadcast_to_array_of
>>> broadcast_to_array_of([0.1, 0.2], np.empty((2, 2)))
[[[0.1 0.1]
[0.1 0.1]]
[[0.2 0.2]
[0.2 0.2]]]
```
"""
arg1 = np.asarray(arg1)
arg2 = np.asarray(arg2)
if arg1.ndim == arg2.ndim + 1:
if arg1.shape[1:] == arg2.shape:
return arg1
# From here on arg1 can be only a 1-dim array
if arg1.ndim == 0:
arg1 = to_1d(arg1)
checks.assert_ndim(arg1, 1)
if arg2.ndim == 0:
return arg1
for i in range(arg2.ndim):
arg1 = np.expand_dims(arg1, axis=-1)
return np.tile(arg1, (1, *arg2.shape))
def broadcast_to_axis_of(
arg1: tp.ArrayLike,
arg2: tp.ArrayLike,
axis: int,
require_kwargs: tp.KwargsLike = None,
) -> tp.Array:
"""Broadcast `arg1` to an axis of `arg2`.
If `arg2` has less dimensions than requested, will broadcast `arg1` to a single number.
For other keyword arguments, see `broadcast`."""
if require_kwargs is None:
require_kwargs = {}
arg2 = to_any_array(arg2)
if arg2.ndim < axis + 1:
return broadcast_array_to(arg1, (1,))[0] # to a single number
arg1 = broadcast_array_to(arg1, (arg2.shape[axis],))
arg1 = np.require(arg1, **require_kwargs)
return arg1
def broadcast_combs(
*objs: tp.ArrayLike,
axis: int = 1,
comb_func: tp.Callable = itertools.product,
**broadcast_kwargs,
) -> tp.Any:
"""Align an axis of each array using a combinatoric function and broadcast their indexes.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import broadcast_combs
>>> df = pd.DataFrame([[1, 2, 3], [3, 4, 5]], columns=pd.Index(['a', 'b', 'c'], name='df_param'))
>>> df2 = pd.DataFrame([[6, 7], [8, 9]], columns=pd.Index(['d', 'e'], name='df2_param'))
>>> sr = pd.Series([10, 11], name='f')
>>> new_df, new_df2, new_sr = broadcast_combs((df, df2, sr))
>>> new_df
df_param a b c
df2_param d e d e d e
0 1 1 2 2 3 3
1 3 3 4 4 5 5
>>> new_df2
df_param a b c
df2_param d e d e d e
0 6 7 6 7 6 7
1 8 9 8 9 8 9
>>> new_sr
df_param a b c
df2_param d e d e d e
0 10 10 10 10 10 10
1 11 11 11 11 11 11
```
"""
if broadcast_kwargs is None:
broadcast_kwargs = {}
objs = list(objs)
if len(objs) < 2:
raise ValueError("At least two arguments are required")
for i in range(len(objs)):
obj = to_any_array(objs[i])
if axis == 1:
obj = to_2d(obj)
objs[i] = obj
indices = []
for obj in objs:
indices.append(np.arange(len(indexes.get_index(to_pd_array(obj), axis))))
new_indices = list(map(list, zip(*list(comb_func(*indices)))))
results = []
for i, obj in enumerate(objs):
if axis == 1:
if checks.is_pandas(obj):
results.append(obj.iloc[:, new_indices[i]])
else:
results.append(obj[:, new_indices[i]])
else:
if checks.is_pandas(obj):
results.append(obj.iloc[new_indices[i]])
else:
results.append(obj[new_indices[i]])
if axis == 1:
broadcast_kwargs = merge_dicts(dict(columns_from="stack"), broadcast_kwargs)
else:
broadcast_kwargs = merge_dicts(dict(index_from="stack"), broadcast_kwargs)
return broadcast(*results, **broadcast_kwargs)
def get_multiindex_series(obj: tp.SeriesFrame) -> tp.Series:
"""Get Series with a multi-index.
If DataFrame has been passed, must at maximum have one row or column."""
checks.assert_instance_of(obj, (pd.Series, pd.DataFrame))
if checks.is_frame(obj):
if obj.shape[0] == 1:
obj = obj.iloc[0, :]
elif obj.shape[1] == 1:
obj = obj.iloc[:, 0]
else:
raise ValueError("Supported are either Series or DataFrame with one column/row")
checks.assert_instance_of(obj.index, pd.MultiIndex)
return obj
def unstack_to_array(
obj: tp.SeriesFrame,
levels: tp.Optional[tp.MaybeLevelSequence] = None,
sort: bool = True,
return_indexes: bool = False,
) -> tp.Union[tp.Array, tp.Tuple[tp.Array, tp.List[tp.Index]]]:
"""Reshape `obj` based on its multi-index into a multi-dimensional array.
Use `levels` to specify what index levels to unstack and in which order.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import unstack_to_array
>>> index = pd.MultiIndex.from_arrays(
... [[1, 1, 2, 2], [3, 4, 3, 4], ['a', 'b', 'c', 'd']])
>>> sr = pd.Series([1, 2, 3, 4], index=index)
>>> unstack_to_array(sr).shape
(2, 2, 4)
>>> unstack_to_array(sr)
[[[ 1. nan nan nan]
[nan 2. nan nan]]
[[nan nan 3. nan]
[nan nan nan 4.]]]
>>> unstack_to_array(sr, levels=(2, 0))
[[ 1. nan]
[ 2. nan]
[nan 3.]
[nan 4.]]
```
"""
sr = get_multiindex_series(obj)
if sr.index.duplicated().any():
raise ValueError("Index contains duplicate entries, cannot reshape")
new_index_list = []
value_indices_list = []
if levels is None:
levels = range(sr.index.nlevels)
if isinstance(levels, (int, str)):
levels = (levels,)
for level in levels:
level_values = indexes.select_levels(sr.index, level)
new_index = level_values.unique()
if sort:
new_index = new_index.sort_values()
new_index_list.append(new_index)
index_map = pd.Series(range(len(new_index)), index=new_index)
value_indices = index_map.loc[level_values]
value_indices_list.append(value_indices)
a = np.full(list(map(len, new_index_list)), np.nan)
a[tuple(zip(value_indices_list))] = sr.values
if return_indexes:
return a, new_index_list
return a
def make_symmetric(obj: tp.SeriesFrame, sort: bool = True) -> tp.Frame:
"""Make `obj` symmetric.
The index and columns of the resulting DataFrame will be identical.
Requires the index and columns to have the same number of levels.
Pass `sort=False` if index and columns should not be sorted, but concatenated
and get duplicates removed.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import make_symmetric
>>> df = pd.DataFrame([[1, 2], [3, 4]], index=['a', 'b'], columns=['c', 'd'])
>>> make_symmetric(df)
a b c d
a NaN NaN 1.0 2.0
b NaN NaN 3.0 4.0
c 1.0 3.0 NaN NaN
d 2.0 4.0 NaN NaN
```
"""
from vectorbtpro.base.merging import concat_arrays
checks.assert_instance_of(obj, (pd.Series, pd.DataFrame))
df = to_2d(obj)
if isinstance(df.index, pd.MultiIndex) or isinstance(df.columns, pd.MultiIndex):
checks.assert_instance_of(df.index, pd.MultiIndex)
checks.assert_instance_of(df.columns, pd.MultiIndex)
checks.assert_array_equal(df.index.nlevels, df.columns.nlevels)
names1, names2 = tuple(df.index.names), tuple(df.columns.names)
else:
names1, names2 = df.index.name, df.columns.name
if names1 == names2:
new_name = names1
else:
if isinstance(df.index, pd.MultiIndex):
new_name = tuple(zip(*[names1, names2]))
else:
new_name = (names1, names2)
if sort:
idx_vals = np.unique(concat_arrays((df.index, df.columns))).tolist()
else:
idx_vals = list(dict.fromkeys(concat_arrays((df.index, df.columns))))
df_index = df.index.copy()
df_columns = df.columns.copy()
if isinstance(df.index, pd.MultiIndex):
unique_index = pd.MultiIndex.from_tuples(idx_vals, names=new_name)
df_index.names = new_name
df_columns.names = new_name
else:
unique_index = pd.Index(idx_vals, name=new_name)
df_index.name = new_name
df_columns.name = new_name
df = df.copy(deep=False)
df.index = df_index
df.columns = df_columns
df_out_dtype = np.promote_types(df.values.dtype, np.min_scalar_type(np.nan))
df_out = pd.DataFrame(index=unique_index, columns=unique_index, dtype=df_out_dtype)
df_out.loc[:, :] = df
df_out[df_out.isnull()] = df.transpose()
return df_out
def unstack_to_df(
obj: tp.SeriesFrame,
index_levels: tp.Optional[tp.MaybeLevelSequence] = None,
column_levels: tp.Optional[tp.MaybeLevelSequence] = None,
symmetric: bool = False,
sort: bool = True,
) -> tp.Frame:
"""Reshape `obj` based on its multi-index into a DataFrame.
Use `index_levels` to specify what index levels will form new index, and `column_levels`
for new columns. Set `symmetric` to True to make DataFrame symmetric.
Usage:
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.base.reshaping import unstack_to_df
>>> index = pd.MultiIndex.from_arrays(
... [[1, 1, 2, 2], [3, 4, 3, 4], ['a', 'b', 'c', 'd']],
... names=['x', 'y', 'z'])
>>> sr = pd.Series([1, 2, 3, 4], index=index)
>>> unstack_to_df(sr, index_levels=(0, 1), column_levels=2)
z a b c d
x y
1 3 1.0 NaN NaN NaN
1 4 NaN 2.0 NaN NaN
2 3 NaN NaN 3.0 NaN
2 4 NaN NaN NaN 4.0
```
"""
sr = get_multiindex_series(obj)
if sr.index.nlevels > 2:
if index_levels is None:
raise ValueError("index_levels must be specified")
if column_levels is None:
raise ValueError("column_levels must be specified")
else:
if index_levels is None:
index_levels = 0
if column_levels is None:
column_levels = 1
unstacked, (new_index, new_columns) = unstack_to_array(
sr,
levels=(index_levels, column_levels),
sort=sort,
return_indexes=True,
)
df = pd.DataFrame(unstacked, index=new_index, columns=new_columns)
if symmetric:
return make_symmetric(df, sort=sort)
return df
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for wrapping NumPy arrays into Series/DataFrames."""
import numpy as np
import pandas as pd
from pandas.core.groupby import GroupBy as PandasGroupBy
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import indexes, reshaping
from vectorbtpro.base.grouping.base import Grouper
from vectorbtpro.base.indexes import stack_indexes, concat_indexes, IndexApplier
from vectorbtpro.base.indexing import IndexingError, ExtPandasIndexer, index_dict, IdxSetter, IdxSetterFactory, IdxDict
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.array_ import is_range, cast_to_min_precision, cast_to_max_precision
from vectorbtpro.utils.attr_ import AttrResolverMixin, AttrResolverMixinT
from vectorbtpro.utils.chunking import ChunkMeta, iter_chunk_meta, get_chunk_meta_key, ArraySelector, ArraySlicer
from vectorbtpro.utils.config import Configured, merge_dicts, resolve_dict
from vectorbtpro.utils.decorators import hybrid_method, cached_method, cached_property
from vectorbtpro.utils.execution import Task, execute
from vectorbtpro.utils.params import ItemParamable
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.warnings_ import warn
if tp.TYPE_CHECKING:
from vectorbtpro.base.accessors import BaseIDXAccessor as BaseIDXAccessorT
from vectorbtpro.generic.splitting.base import Splitter as SplitterT
else:
BaseIDXAccessorT = "BaseIDXAccessor"
SplitterT = "Splitter"
__all__ = [
"ArrayWrapper",
"Wrapping",
]
HasWrapperT = tp.TypeVar("HasWrapperT", bound="HasWrapper")
class HasWrapper(ExtPandasIndexer, ItemParamable):
"""Abstract class that manages a wrapper."""
@property
def unwrapped(self) -> tp.Any:
"""Unwrapped object."""
raise NotImplementedError
@hybrid_method
def should_wrap(cls_or_self) -> bool:
"""Whether to wrap where applicable."""
return True
@property
def wrapper(self) -> "ArrayWrapper":
"""Array wrapper of the type `ArrayWrapper`."""
raise NotImplementedError
@property
def column_only_select(self) -> bool:
"""Whether to perform indexing on columns only."""
raise NotImplementedError
@property
def range_only_select(self) -> bool:
"""Whether to perform indexing on rows using slices only."""
raise NotImplementedError
@property
def group_select(self) -> bool:
"""Whether to allow indexing on groups."""
raise NotImplementedError
def regroup(self: HasWrapperT, group_by: tp.GroupByLike, **kwargs) -> HasWrapperT:
"""Regroup this instance."""
raise NotImplementedError
def ungroup(self: HasWrapperT, **kwargs) -> HasWrapperT:
"""Ungroup this instance."""
return self.regroup(False, **kwargs)
# ############# Selection ############# #
def select_col(
self: HasWrapperT,
column: tp.Any = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> HasWrapperT:
"""Select one column/group.
`column` can be a label-based position as well as an integer position (if label fails)."""
_self = self.regroup(group_by, **kwargs)
def _check_out_dim(out: HasWrapperT) -> HasWrapperT:
if out.wrapper.get_ndim() == 2:
if out.wrapper.get_shape_2d()[1] == 1:
if out.column_only_select:
return out.iloc[0]
return out.iloc[:, 0]
if _self.wrapper.grouper.is_grouped():
raise TypeError("Could not select one group: multiple groups returned")
else:
raise TypeError("Could not select one column: multiple columns returned")
return out
if column is None:
if _self.wrapper.get_ndim() == 2 and _self.wrapper.get_shape_2d()[1] == 1:
column = 0
if column is not None:
if _self.wrapper.grouper.is_grouped():
if _self.wrapper.grouped_ndim == 1:
raise TypeError("This instance already contains one group of data")
if column not in _self.wrapper.get_columns():
if isinstance(column, int):
if _self.column_only_select:
return _check_out_dim(_self.iloc[column])
return _check_out_dim(_self.iloc[:, column])
raise KeyError(f"Group '{column}' not found")
else:
if _self.wrapper.ndim == 1:
raise TypeError("This instance already contains one column of data")
if column not in _self.wrapper.columns:
if isinstance(column, int):
if _self.column_only_select:
return _check_out_dim(_self.iloc[column])
return _check_out_dim(_self.iloc[:, column])
raise KeyError(f"Column '{column}' not found")
return _check_out_dim(_self[column])
if _self.wrapper.grouper.is_grouped():
if _self.wrapper.grouped_ndim == 1:
return _self
raise TypeError("Only one group is allowed. Use indexing or column argument.")
if _self.wrapper.ndim == 1:
return _self
raise TypeError("Only one column is allowed. Use indexing or column argument.")
@hybrid_method
def select_col_from_obj(
cls_or_self,
obj: tp.Optional[tp.SeriesFrame],
column: tp.Any = None,
obj_ungrouped: bool = False,
group_by: tp.GroupByLike = None,
wrapper: tp.Optional["ArrayWrapper"] = None,
**kwargs,
) -> tp.MaybeSeries:
"""Select one column/group from a Pandas object.
`column` can be a label-based position as well as an integer position (if label fails)."""
if obj is None:
return None
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
_wrapper = wrapper.regroup(group_by, **kwargs)
def _check_out_dim(out: tp.SeriesFrame, from_df: bool) -> tp.Series:
bad_shape = False
if from_df and isinstance(out, pd.DataFrame):
if len(out.columns) == 1:
return out.iloc[:, 0]
bad_shape = True
if not from_df and isinstance(out, pd.Series):
if len(out) == 1:
return out.iloc[0]
bad_shape = True
if bad_shape:
if _wrapper.grouper.is_grouped():
raise TypeError("Could not select one group: multiple groups returned")
else:
raise TypeError("Could not select one column: multiple columns returned")
return out
if column is None:
if _wrapper.get_ndim() == 2 and _wrapper.get_shape_2d()[1] == 1:
column = 0
if column is not None:
if _wrapper.grouper.is_grouped():
if _wrapper.grouped_ndim == 1:
raise TypeError("This instance already contains one group of data")
if obj_ungrouped:
mask = _wrapper.grouper.group_by == column
if not mask.any():
raise KeyError(f"Group '{column}' not found")
if isinstance(obj, pd.DataFrame):
return obj.loc[:, mask]
return obj.loc[mask]
else:
if column not in _wrapper.get_columns():
if isinstance(column, int):
if isinstance(obj, pd.DataFrame):
return _check_out_dim(obj.iloc[:, column], True)
return _check_out_dim(obj.iloc[column], False)
raise KeyError(f"Group '{column}' not found")
else:
if _wrapper.ndim == 1:
raise TypeError("This instance already contains one column of data")
if column not in _wrapper.columns:
if isinstance(column, int):
if isinstance(obj, pd.DataFrame):
return _check_out_dim(obj.iloc[:, column], True)
return _check_out_dim(obj.iloc[column], False)
raise KeyError(f"Column '{column}' not found")
if isinstance(obj, pd.DataFrame):
return _check_out_dim(obj[column], True)
return _check_out_dim(obj[column], False)
if not _wrapper.grouper.is_grouped():
if _wrapper.ndim == 1:
return obj
raise TypeError("Only one column is allowed. Use indexing or column argument.")
if _wrapper.grouped_ndim == 1:
return obj
raise TypeError("Only one group is allowed. Use indexing or column argument.")
# ############# Splitting ############# #
def split(
self,
*args,
splitter_cls: tp.Optional[tp.Type[SplitterT]] = None,
wrap: tp.Optional[bool] = None,
**kwargs,
) -> tp.Any:
"""Split this instance.
Uses `vectorbtpro.generic.splitting.base.Splitter.split_and_take`."""
from vectorbtpro.generic.splitting.base import Splitter
if splitter_cls is None:
splitter_cls = Splitter
if wrap is None:
wrap = self.should_wrap()
wrapped_self = self if wrap else self.unwrapped
return splitter_cls.split_and_take(self.wrapper.index, wrapped_self, *args, **kwargs)
def split_apply(
self,
apply_func: tp.Union[str, tp.Callable],
*args,
splitter_cls: tp.Optional[tp.Type[SplitterT]] = None,
wrap: tp.Optional[bool] = None,
**kwargs,
) -> tp.Any:
"""Split this instance and apply a function to each split.
Uses `vectorbtpro.generic.splitting.base.Splitter.split_and_apply`."""
from vectorbtpro.generic.splitting.base import Splitter, Takeable
if isinstance(apply_func, str):
apply_func = getattr(type(self), apply_func)
if splitter_cls is None:
splitter_cls = Splitter
if wrap is None:
wrap = self.should_wrap()
wrapped_self = self if wrap else self.unwrapped
return splitter_cls.split_and_apply(self.wrapper.index, apply_func, Takeable(wrapped_self), *args, **kwargs)
# ############# Chunking ############# #
def chunk(
self: HasWrapperT,
axis: tp.Optional[int] = None,
min_size: tp.Optional[int] = None,
n_chunks: tp.Union[None, int, str] = None,
chunk_len: tp.Union[None, int, str] = None,
chunk_meta: tp.Optional[tp.Iterable[ChunkMeta]] = None,
select: bool = False,
wrap: tp.Optional[bool] = None,
return_chunk_meta: bool = False,
) -> tp.Iterator[tp.Union[HasWrapperT, tp.Tuple[ChunkMeta, HasWrapperT]]]:
"""Chunk this instance.
If `axis` is None, becomes 0 if the instance is one-dimensional and 1 otherwise.
For arguments related to chunking meta, see `vectorbtpro.utils.chunking.iter_chunk_meta`."""
if axis is None:
axis = 0 if self.wrapper.ndim == 1 else 1
if self.wrapper.ndim == 1 and axis == 1:
raise TypeError("Axis 1 is not supported for one dimension")
checks.assert_in(axis, (0, 1))
size = self.wrapper.shape_2d[axis]
if wrap is None:
wrap = self.should_wrap()
wrapped_self = self if wrap else self.unwrapped
if chunk_meta is None:
chunk_meta = iter_chunk_meta(
size=size,
min_size=min_size,
n_chunks=n_chunks,
chunk_len=chunk_len,
)
for _chunk_meta in chunk_meta:
if select:
array_taker = ArraySelector(axis=axis)
else:
array_taker = ArraySlicer(axis=axis)
if return_chunk_meta:
yield _chunk_meta, array_taker.take(wrapped_self, _chunk_meta)
else:
yield array_taker.take(wrapped_self, _chunk_meta)
def chunk_apply(
self: HasWrapperT,
apply_func: tp.Union[str, tp.Callable],
*args,
chunk_kwargs: tp.KwargsLike = None,
execute_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MergeableResults:
"""Chunk this instance and apply a function to each chunk.
If `apply_func` is a string, becomes the method name.
For arguments related to chunking, see `Wrapping.chunk`."""
if isinstance(apply_func, str):
apply_func = getattr(type(self), apply_func)
if chunk_kwargs is None:
chunk_arg_names = set(get_func_arg_names(self.chunk))
chunk_kwargs = {}
for k in list(kwargs.keys()):
if k in chunk_arg_names:
chunk_kwargs[k] = kwargs.pop(k)
if execute_kwargs is None:
execute_kwargs = {}
chunks = self.chunk(return_chunk_meta=True, **chunk_kwargs)
tasks = []
keys = []
for _chunk_meta, chunk in chunks:
tasks.append(Task(apply_func, chunk, *args, **kwargs))
keys.append(get_chunk_meta_key(_chunk_meta))
keys = pd.Index(keys, name="chunk_indices")
return execute(tasks, size=len(tasks), keys=keys, **execute_kwargs)
# ############# Iteration ############# #
def get_item_keys(self, group_by: tp.GroupByLike = None) -> tp.Index:
"""Get keys for `Wrapping.items`."""
_self = self.regroup(group_by=group_by)
if _self.group_select and _self.wrapper.grouper.is_grouped():
return _self.wrapper.get_columns()
return _self.wrapper.columns
def items(
self,
group_by: tp.GroupByLike = None,
apply_group_by: bool = False,
keep_2d: bool = False,
key_as_index: bool = False,
wrap: tp.Optional[bool] = None,
) -> tp.Items:
"""Iterate over columns or groups (if grouped and `Wrapping.group_select` is True).
If `apply_group_by` is False, `group_by` becomes a grouping instruction for the iteration,
not for the final object. In this case, will raise an error if the instance is grouped
and that grouping must be changed."""
if wrap is None:
wrap = self.should_wrap()
def _resolve_v(self):
return self if wrap else self.unwrapped
if group_by is None or apply_group_by:
_self = self.regroup(group_by=group_by)
if _self.group_select and _self.wrapper.grouper.is_grouped():
columns = _self.wrapper.get_columns()
ndim = _self.wrapper.get_ndim()
else:
columns = _self.wrapper.columns
ndim = _self.wrapper.ndim
if ndim == 1:
if key_as_index:
yield columns, _resolve_v(_self)
else:
yield columns[0], _resolve_v(_self)
else:
for i in range(len(columns)):
if key_as_index:
key = columns[[i]]
else:
key = columns[i]
if _self.column_only_select:
if keep_2d:
yield key, _resolve_v(_self.iloc[i : i + 1])
else:
yield key, _resolve_v(_self.iloc[i])
else:
if keep_2d:
yield key, _resolve_v(_self.iloc[:, i : i + 1])
else:
yield key, _resolve_v(_self.iloc[:, i])
else:
if self.group_select and self.wrapper.grouper.is_grouped():
raise ValueError("Cannot change grouping")
wrapper = self.wrapper.regroup(group_by=group_by)
if wrapper.get_ndim() == 1:
if key_as_index:
yield wrapper.get_columns(), _resolve_v(self)
else:
yield wrapper.get_columns()[0], _resolve_v(self)
else:
for group, group_idxs in wrapper.grouper.iter_groups(key_as_index=key_as_index):
if self.column_only_select:
if keep_2d or len(group_idxs) > 1:
yield group, _resolve_v(self.iloc[group_idxs])
else:
yield group, _resolve_v(self.iloc[group_idxs[0]])
else:
if keep_2d or len(group_idxs) > 1:
yield group, _resolve_v(self.iloc[:, group_idxs])
else:
yield group, _resolve_v(self.iloc[:, group_idxs[0]])
ArrayWrapperT = tp.TypeVar("ArrayWrapperT", bound="ArrayWrapper")
class ArrayWrapper(Configured, HasWrapper, IndexApplier):
"""Class that stores index, columns, and shape metadata for wrapping NumPy arrays.
Tightly integrated with `vectorbtpro.base.grouping.base.Grouper` for grouping columns.
If the underlying object is a Series, pass `[sr.name]` as `columns`.
`**kwargs` are passed to `vectorbtpro.base.grouping.base.Grouper`.
!!! note
This class is meant to be immutable. To change any attribute, use `ArrayWrapper.replace`.
Use methods that begin with `get_` to get group-aware results."""
@classmethod
def from_obj(cls: tp.Type[ArrayWrapperT], obj: tp.ArrayLike, **kwargs) -> ArrayWrapperT:
"""Derive metadata from an object."""
from vectorbtpro.base.reshaping import to_pd_array
from vectorbtpro.data.base import Data
if isinstance(obj, Data):
obj = obj.symbol_wrapper
if isinstance(obj, Wrapping):
obj = obj.wrapper
if isinstance(obj, ArrayWrapper):
return obj.replace(**kwargs)
pd_obj = to_pd_array(obj)
index = indexes.get_index(pd_obj, 0)
columns = indexes.get_index(pd_obj, 1)
ndim = pd_obj.ndim
kwargs.pop("index", None)
kwargs.pop("columns", None)
kwargs.pop("ndim", None)
return cls(index, columns, ndim, **kwargs)
@classmethod
def from_shape(
cls: tp.Type[ArrayWrapperT],
shape: tp.ShapeLike,
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
ndim: tp.Optional[int] = None,
*args,
**kwargs,
) -> ArrayWrapperT:
"""Derive metadata from shape."""
shape = reshaping.to_tuple_shape(shape)
if index is None:
index = pd.RangeIndex(stop=shape[0])
if columns is None:
columns = pd.RangeIndex(stop=shape[1] if len(shape) > 1 else 1)
if ndim is None:
ndim = len(shape)
return cls(index, columns, ndim, *args, **kwargs)
@staticmethod
def extract_init_kwargs(**kwargs) -> tp.Tuple[tp.Kwargs, tp.Kwargs]:
"""Extract keyword arguments that can be passed to `ArrayWrapper` or `Grouper`."""
wrapper_arg_names = get_func_arg_names(ArrayWrapper.__init__)
grouper_arg_names = get_func_arg_names(Grouper.__init__)
init_kwargs = dict()
for k in list(kwargs.keys()):
if k in wrapper_arg_names or k in grouper_arg_names:
init_kwargs[k] = kwargs.pop(k)
return init_kwargs, kwargs
@classmethod
def resolve_stack_kwargs(cls, *wrappers: tp.MaybeTuple[ArrayWrapperT], **kwargs) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `ArrayWrapper` after stacking."""
if len(wrappers) == 1:
wrappers = wrappers[0]
wrappers = list(wrappers)
common_keys = set()
for wrapper in wrappers:
common_keys = common_keys.union(set(wrapper.config.keys()))
if "grouper" not in kwargs:
common_keys = common_keys.union(set(wrapper.grouper.config.keys()))
common_keys.remove("grouper")
init_wrapper = wrappers[0]
for i in range(1, len(wrappers)):
wrapper = wrappers[i]
for k in common_keys:
if k not in kwargs:
same_k = True
try:
if k in wrapper.config:
if not checks.is_deep_equal(init_wrapper.config[k], wrapper.config[k]):
same_k = False
elif "grouper" not in kwargs and k in wrapper.grouper.config:
if not checks.is_deep_equal(init_wrapper.grouper.config[k], wrapper.grouper.config[k]):
same_k = False
else:
same_k = False
except KeyError as e:
same_k = False
if not same_k:
raise ValueError(f"Objects to be merged must have compatible '{k}'. Pass to override.")
for k in common_keys:
if k not in kwargs:
if k in init_wrapper.config:
kwargs[k] = init_wrapper.config[k]
elif "grouper" not in kwargs and k in init_wrapper.grouper.config:
kwargs[k] = init_wrapper.grouper.config[k]
else:
raise ValueError(f"Objects to be merged must have compatible '{k}'. Pass to override.")
return kwargs
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[ArrayWrapperT],
*wrappers: tp.MaybeTuple[ArrayWrapperT],
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
group_by: tp.GroupByLike = None,
stack_columns: bool = True,
index_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append",
keys: tp.Optional[tp.IndexLike] = None,
clean_index_kwargs: tp.KwargsLike = None,
verify_integrity: bool = True,
**kwargs,
) -> ArrayWrapperT:
"""Stack multiple `ArrayWrapper` instances along rows.
Concatenates indexes using `vectorbtpro.base.indexes.concat_indexes`.
Frequency must be the same across all indexes. A custom frequency can be provided via `freq`.
If column levels in some instances differ, they will be stacked upon each other.
Custom columns can be provided via `columns`.
If `group_by` is None, all instances must be either grouped or not, and they must
contain the same group values and labels.
All instances must contain the same keys and values in their configs and configs of their
grouper instances, apart from those arguments provided explicitly via `kwargs`."""
if not isinstance(cls_or_self, type):
wrappers = (cls_or_self, *wrappers)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(wrappers) == 1:
wrappers = wrappers[0]
wrappers = list(wrappers)
for wrapper in wrappers:
if not checks.is_instance_of(wrapper, ArrayWrapper):
raise TypeError("Each object to be merged must be an instance of ArrayWrapper")
if index is None:
index = concat_indexes(
[wrapper.index for wrapper in wrappers],
index_concat_method=index_concat_method,
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=0,
)
elif not isinstance(index, pd.Index):
index = pd.Index(index)
kwargs["index"] = index
if freq is None:
new_freq = None
for wrapper in wrappers:
if new_freq is None:
new_freq = wrapper.freq
else:
if new_freq is not None and wrapper.freq is not None and new_freq != wrapper.freq:
raise ValueError("Objects to be merged must have the same frequency")
freq = new_freq
kwargs["freq"] = freq
if columns is None:
new_columns = None
for wrapper in wrappers:
if new_columns is None:
new_columns = wrapper.columns
else:
if not checks.is_index_equal(new_columns, wrapper.columns):
if not stack_columns:
raise ValueError("Objects to be merged must have the same columns")
new_columns = stack_indexes(
(new_columns, wrapper.columns),
**resolve_dict(clean_index_kwargs),
)
columns = new_columns
elif not isinstance(columns, pd.Index):
columns = pd.Index(columns)
kwargs["columns"] = columns
if "grouper" in kwargs:
if not checks.is_index_equal(columns, kwargs["grouper"].index):
raise ValueError("Columns and grouper index must match")
if group_by is not None:
kwargs["group_by"] = group_by
else:
if group_by is None:
grouped = None
for wrapper in wrappers:
wrapper_grouped = wrapper.grouper.is_grouped()
if grouped is None:
grouped = wrapper_grouped
else:
if grouped is not wrapper_grouped:
raise ValueError("Objects to be merged must be either grouped or not")
if grouped:
new_group_by = None
for wrapper in wrappers:
wrapper_groups, wrapper_grouped_index = wrapper.grouper.get_groups_and_index()
wrapper_group_by = wrapper_grouped_index[wrapper_groups]
if new_group_by is None:
new_group_by = wrapper_group_by
else:
if not checks.is_index_equal(new_group_by, wrapper_group_by):
raise ValueError("Objects to be merged must have the same groups")
group_by = new_group_by
else:
group_by = False
kwargs["group_by"] = group_by
if "ndim" not in kwargs:
ndim = None
for wrapper in wrappers:
if ndim is None or wrapper.ndim > 1:
ndim = wrapper.ndim
kwargs["ndim"] = ndim
return cls(**ArrayWrapper.resolve_stack_kwargs(*wrappers, **kwargs))
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[ArrayWrapperT],
*wrappers: tp.MaybeTuple[ArrayWrapperT],
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
group_by: tp.GroupByLike = None,
union_index: bool = True,
col_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = "append",
group_concat_method: tp.MaybeTuple[tp.Union[str, tp.Callable]] = ("append", "factorize_each"),
keys: tp.Optional[tp.IndexLike] = None,
clean_index_kwargs: tp.KwargsLike = None,
verify_integrity: bool = True,
**kwargs,
) -> ArrayWrapperT:
"""Stack multiple `ArrayWrapper` instances along columns.
If indexes are the same in each wrapper index, will use that index. If indexes differ and
`union_index` is True, they will be merged into a single one by the set union operation.
Otherwise, an error will be raised. The merged index must have no duplicates or mixed data,
and must be monotonically increasing. A custom index can be provided via `index`.
Frequency must be the same across all indexes. A custom frequency can be provided via `freq`.
Concatenates columns and groups using `vectorbtpro.base.indexes.concat_indexes`.
If any of the instances has `column_only_select` being enabled, the final wrapper will also enable it.
If any of the instances has `group_select` or other grouping-related flags being disabled, the final
wrapper will also disable them.
All instances must contain the same keys and values in their configs and configs of their
grouper instances, apart from those arguments provided explicitly via `kwargs`."""
if not isinstance(cls_or_self, type):
wrappers = (cls_or_self, *wrappers)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(wrappers) == 1:
wrappers = wrappers[0]
wrappers = list(wrappers)
for wrapper in wrappers:
if not checks.is_instance_of(wrapper, ArrayWrapper):
raise TypeError("Each object to be merged must be an instance of ArrayWrapper")
for wrapper in wrappers:
if wrapper.index.has_duplicates:
raise ValueError("Index of some objects to be merged contains duplicates")
if index is None:
new_index = None
for wrapper in wrappers:
if new_index is None:
new_index = wrapper.index
else:
if not checks.is_index_equal(new_index, wrapper.index):
if not union_index:
raise ValueError(
"Objects to be merged must have the same index. "
"Use union_index=True to merge index as well."
)
else:
if new_index.dtype != wrapper.index.dtype:
raise ValueError("Indexes to be merged must have the same data type")
new_index = new_index.union(wrapper.index)
if not new_index.is_monotonic_increasing:
raise ValueError("Merged index must be monotonically increasing")
index = new_index
elif not isinstance(index, pd.Index):
index = pd.Index(index)
kwargs["index"] = index
if freq is None:
new_freq = None
for wrapper in wrappers:
if new_freq is None:
new_freq = wrapper.freq
else:
if new_freq is not None and wrapper.freq is not None and new_freq != wrapper.freq:
raise ValueError("Objects to be merged must have the same frequency")
freq = new_freq
kwargs["freq"] = freq
if columns is None:
columns = concat_indexes(
[wrapper.columns for wrapper in wrappers],
index_concat_method=col_concat_method,
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=1,
)
elif not isinstance(columns, pd.Index):
columns = pd.Index(columns)
kwargs["columns"] = columns
if "grouper" in kwargs:
if not checks.is_index_equal(columns, kwargs["grouper"].index):
raise ValueError("Columns and grouper index must match")
if group_by is not None:
kwargs["group_by"] = group_by
else:
if group_by is None:
any_grouped = False
for wrapper in wrappers:
if wrapper.grouper.is_grouped():
any_grouped = True
break
if any_grouped:
group_by = concat_indexes(
[wrapper.grouper.get_stretched_index() for wrapper in wrappers],
index_concat_method=group_concat_method,
keys=keys,
clean_index_kwargs=clean_index_kwargs,
verify_integrity=verify_integrity,
axis=2,
)
else:
group_by = False
kwargs["group_by"] = group_by
if "ndim" not in kwargs:
kwargs["ndim"] = 2
if "grouped_ndim" not in kwargs:
kwargs["grouped_ndim"] = None
if "column_only_select" not in kwargs:
column_only_select = None
for wrapper in wrappers:
if column_only_select is None or wrapper.column_only_select:
column_only_select = wrapper.column_only_select
kwargs["column_only_select"] = column_only_select
if "range_only_select" not in kwargs:
range_only_select = None
for wrapper in wrappers:
if range_only_select is None or wrapper.range_only_select:
range_only_select = wrapper.range_only_select
kwargs["range_only_select"] = range_only_select
if "group_select" not in kwargs:
group_select = None
for wrapper in wrappers:
if group_select is None or not wrapper.group_select:
group_select = wrapper.group_select
kwargs["group_select"] = group_select
if "grouper" not in kwargs:
if "allow_enable" not in kwargs:
allow_enable = None
for wrapper in wrappers:
if allow_enable is None or not wrapper.grouper.allow_enable:
allow_enable = wrapper.grouper.allow_enable
kwargs["allow_enable"] = allow_enable
if "allow_disable" not in kwargs:
allow_disable = None
for wrapper in wrappers:
if allow_disable is None or not wrapper.grouper.allow_disable:
allow_disable = wrapper.grouper.allow_disable
kwargs["allow_disable"] = allow_disable
if "allow_modify" not in kwargs:
allow_modify = None
for wrapper in wrappers:
if allow_modify is None or not wrapper.grouper.allow_modify:
allow_modify = wrapper.grouper.allow_modify
kwargs["allow_modify"] = allow_modify
return cls(**ArrayWrapper.resolve_stack_kwargs(*wrappers, **kwargs))
def __init__(
self,
index: tp.IndexLike,
columns: tp.Optional[tp.IndexLike] = None,
ndim: tp.Optional[int] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
parse_index: tp.Optional[bool] = None,
column_only_select: tp.Optional[bool] = None,
range_only_select: tp.Optional[bool] = None,
group_select: tp.Optional[bool] = None,
grouped_ndim: tp.Optional[int] = None,
grouper: tp.Optional[Grouper] = None,
**kwargs,
) -> None:
checks.assert_not_none(index, arg_name="index")
index = dt.prepare_dt_index(index, parse_index=parse_index)
if columns is None:
columns = [None]
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
if ndim is None:
if len(columns) == 1 and not isinstance(columns, pd.MultiIndex):
ndim = 1
else:
ndim = 2
else:
if len(columns) > 1:
ndim = 2
grouper_arg_names = get_func_arg_names(Grouper.__init__)
grouper_kwargs = dict()
for k in list(kwargs.keys()):
if k in grouper_arg_names:
grouper_kwargs[k] = kwargs.pop(k)
if grouper is None:
grouper = Grouper(columns, **grouper_kwargs)
elif not checks.is_index_equal(columns, grouper.index) or len(grouper_kwargs) > 0:
grouper = grouper.replace(index=columns, **grouper_kwargs)
HasWrapper.__init__(self)
Configured.__init__(
self,
index=index,
columns=columns,
ndim=ndim,
freq=freq,
parse_index=parse_index,
column_only_select=column_only_select,
range_only_select=range_only_select,
group_select=group_select,
grouped_ndim=grouped_ndim,
grouper=grouper,
**kwargs,
)
self._index = index
self._columns = columns
self._ndim = ndim
self._freq = freq
self._parse_index = parse_index
self._column_only_select = column_only_select
self._range_only_select = range_only_select
self._group_select = group_select
self._grouper = grouper
self._grouped_ndim = grouped_ndim
def indexing_func_meta(
self: ArrayWrapperT,
pd_indexing_func: tp.PandasIndexingFunc,
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
column_only_select: tp.Optional[bool] = None,
range_only_select: tp.Optional[bool] = None,
group_select: tp.Optional[bool] = None,
return_slices: bool = True,
return_none_slices: bool = True,
return_scalars: bool = True,
group_by: tp.GroupByLike = None,
wrapper_kwargs: tp.KwargsLike = None,
) -> dict:
"""Perform indexing on `ArrayWrapper` and also return metadata.
Takes into account column grouping.
Flipping rows and columns is not allowed. If one row is selected, the result will still be
a Series when indexing a Series and a DataFrame when indexing a DataFrame.
Set `column_only_select` to True to index the array wrapper as a Series of columns/groups.
This way, selection of index (axis 0) can be avoided. Set `range_only_select` to True to
allow selection of rows only using slices. Set `group_select` to True to allow selection of groups.
Otherwise, indexing is performed on columns, even if grouping is enabled. Takes effect only if
grouping is enabled.
Returns the new array wrapper, row indices, column indices, and group indices.
If `return_slices` is True (default), indices will be returned as a slice if they were
identified as a range. If `return_none_slices` is True (default), indices will be returned as a slice
`(None, None, None)` if the axis hasn't been changed.
!!! note
If `column_only_select` is True, make sure to index the array wrapper
as a Series of columns rather than a DataFrame. For example, the operation
`.iloc[:, :2]` should become `.iloc[:2]`. Operations are not allowed if the
object is already a Series and thus has only one column/group."""
if column_only_select is None:
column_only_select = self.column_only_select
if range_only_select is None:
range_only_select = self.range_only_select
if group_select is None:
group_select = self.group_select
if wrapper_kwargs is None:
wrapper_kwargs = {}
_self = self.regroup(group_by)
group_select = group_select and _self.grouper.is_grouped()
if index is None:
index = _self.index
if not isinstance(index, pd.Index):
index = pd.Index(index)
if columns is None:
if group_select:
columns = _self.get_columns()
else:
columns = _self.columns
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
if group_select:
# Groups as columns
i_wrapper = ArrayWrapper(index, columns, _self.get_ndim())
else:
# Columns as columns
i_wrapper = ArrayWrapper(index, columns, _self.ndim)
n_rows = len(index)
n_cols = len(columns)
def _resolve_arr(arr, n):
if checks.is_np_array(arr) and is_range(arr):
if arr[0] == 0 and arr[-1] == n - 1:
if return_none_slices:
return slice(None, None, None), False
return arr, False
if return_slices:
return slice(arr[0], arr[-1] + 1, None), True
return arr, True
if isinstance(arr, np.integer):
arr = arr.item()
columns_changed = True
if isinstance(arr, int):
if arr == 0 and n == 1:
columns_changed = False
if not return_scalars:
arr = np.array([arr])
return arr, columns_changed
if column_only_select:
if i_wrapper.ndim == 1:
raise IndexingError("Columns only: This instance already contains one column of data")
try:
col_mapper = pd_indexing_func(i_wrapper.wrap_reduced(np.arange(n_cols), columns=columns))
except pd.core.indexing.IndexingError as e:
warn("Columns only: Make sure to treat this instance as a Series of columns rather than a DataFrame")
raise e
if checks.is_series(col_mapper):
new_columns = col_mapper.index
col_idxs = col_mapper.values
new_ndim = 2
else:
new_columns = columns[[col_mapper]]
col_idxs = col_mapper
new_ndim = 1
new_index = index
row_idxs = np.arange(len(index))
else:
init_row_mapper_values = reshaping.broadcast_array_to(np.arange(n_rows)[:, None], (n_rows, n_cols))
init_row_mapper = i_wrapper.wrap(init_row_mapper_values, index=index, columns=columns)
row_mapper = pd_indexing_func(init_row_mapper)
if i_wrapper.ndim == 1:
if not checks.is_series(row_mapper):
row_idxs = np.array([row_mapper])
new_index = index[row_idxs]
else:
row_idxs = row_mapper.values
new_index = indexes.get_index(row_mapper, 0)
col_idxs = 0
new_columns = columns
new_ndim = 1
else:
init_col_mapper_values = reshaping.broadcast_array_to(np.arange(n_cols)[None], (n_rows, n_cols))
init_col_mapper = i_wrapper.wrap(init_col_mapper_values, index=index, columns=columns)
col_mapper = pd_indexing_func(init_col_mapper)
if checks.is_frame(col_mapper):
# Multiple rows and columns selected
row_idxs = row_mapper.values[:, 0]
col_idxs = col_mapper.values[0]
new_index = indexes.get_index(row_mapper, 0)
new_columns = indexes.get_index(col_mapper, 1)
new_ndim = 2
elif checks.is_series(col_mapper):
multi_index = isinstance(index, pd.MultiIndex)
multi_columns = isinstance(columns, pd.MultiIndex)
multi_name = isinstance(col_mapper.name, tuple)
if multi_index and multi_name and col_mapper.name in index:
one_row = True
elif not multi_index and not multi_name and col_mapper.name in index:
one_row = True
else:
one_row = False
if multi_columns and multi_name and col_mapper.name in columns:
one_col = True
elif not multi_columns and not multi_name and col_mapper.name in columns:
one_col = True
else:
one_col = False
if (one_row and one_col) or (not one_row and not one_col):
one_row = np.all(row_mapper.values == row_mapper.values.item(0))
one_col = np.all(col_mapper.values == col_mapper.values.item(0))
if (one_row and one_col) or (not one_row and not one_col):
raise IndexingError("Could not parse indexing operation")
if one_row:
# One row selected
row_idxs = row_mapper.values[[0]]
col_idxs = col_mapper.values
new_index = index[row_idxs]
new_columns = indexes.get_index(col_mapper, 0)
new_ndim = 2
else:
# One column selected
row_idxs = row_mapper.values
col_idxs = col_mapper.values[0]
new_index = indexes.get_index(row_mapper, 0)
new_columns = columns[[col_idxs]]
new_ndim = 1
else:
# One row and column selected
row_idxs = np.array([row_mapper])
col_idxs = col_mapper
new_index = index[row_idxs]
new_columns = columns[[col_idxs]]
new_ndim = 1
if _self.grouper.is_grouped():
# Grouping enabled
if np.asarray(row_idxs).ndim == 0:
raise IndexingError("Flipping index and columns is not allowed")
if group_select:
# Selection based on groups
# Get indices of columns corresponding to selected groups
group_idxs = col_idxs
col_idxs, new_groups = _self.grouper.select_groups(group_idxs)
ungrouped_columns = _self.columns[col_idxs]
if new_ndim == 1 and len(ungrouped_columns) == 1:
ungrouped_ndim = 1
col_idxs = col_idxs[0]
else:
ungrouped_ndim = 2
row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0])
if range_only_select and rows_changed:
if not isinstance(row_idxs, slice):
raise ValueError("Rows can be selected only by slicing")
if row_idxs.step not in (1, None):
raise ValueError("Slice for selecting rows must have a step of 1 or None")
col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1])
group_idxs, groups_changed = _resolve_arr(group_idxs, _self.get_shape_2d()[1])
return dict(
new_wrapper=_self.replace(
**merge_dicts(
dict(
index=new_index,
columns=ungrouped_columns,
ndim=ungrouped_ndim,
grouped_ndim=new_ndim,
group_by=new_columns[new_groups],
),
wrapper_kwargs,
)
),
row_idxs=row_idxs,
rows_changed=rows_changed,
col_idxs=col_idxs,
columns_changed=columns_changed,
group_idxs=group_idxs,
groups_changed=groups_changed,
)
# Selection based on columns
group_idxs = _self.grouper.get_groups()[col_idxs]
new_group_by = _self.grouper.group_by[reshaping.to_1d_array(col_idxs)]
row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0])
if range_only_select and rows_changed:
if not isinstance(row_idxs, slice):
raise ValueError("Rows can be selected only by slicing")
if row_idxs.step not in (1, None):
raise ValueError("Slice for selecting rows must have a step of 1 or None")
col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1])
group_idxs, groups_changed = _resolve_arr(group_idxs, _self.get_shape_2d()[1])
return dict(
new_wrapper=_self.replace(
**merge_dicts(
dict(
index=new_index,
columns=new_columns,
ndim=new_ndim,
grouped_ndim=None,
group_by=new_group_by,
),
wrapper_kwargs,
)
),
row_idxs=row_idxs,
rows_changed=rows_changed,
col_idxs=col_idxs,
columns_changed=columns_changed,
group_idxs=group_idxs,
groups_changed=groups_changed,
)
# Grouping disabled
row_idxs, rows_changed = _resolve_arr(row_idxs, _self.shape[0])
if range_only_select and rows_changed:
if not isinstance(row_idxs, slice):
raise ValueError("Rows can be selected only by slicing")
if row_idxs.step not in (1, None):
raise ValueError("Slice for selecting rows must have a step of 1 or None")
col_idxs, columns_changed = _resolve_arr(col_idxs, _self.shape_2d[1])
return dict(
new_wrapper=_self.replace(
**merge_dicts(
dict(
index=new_index,
columns=new_columns,
ndim=new_ndim,
grouped_ndim=None,
group_by=None,
),
wrapper_kwargs,
)
),
row_idxs=row_idxs,
rows_changed=rows_changed,
col_idxs=col_idxs,
columns_changed=columns_changed,
group_idxs=col_idxs,
groups_changed=columns_changed,
)
def indexing_func(self: ArrayWrapperT, *args, **kwargs) -> ArrayWrapperT:
"""Perform indexing on `ArrayWrapper`."""
return self.indexing_func_meta(*args, **kwargs)["new_wrapper"]
@staticmethod
def select_from_flex_array(
arr: tp.ArrayLike,
row_idxs: tp.Union[int, tp.Array1d, slice] = None,
col_idxs: tp.Union[int, tp.Array1d, slice] = None,
rows_changed: bool = True,
columns_changed: bool = True,
rotate_rows: bool = False,
rotate_cols: bool = True,
) -> tp.Array2d:
"""Select rows and columns from a flexible array.
Always returns a 2-dim NumPy array."""
new_arr = arr_2d = reshaping.to_2d_array(arr)
if row_idxs is not None and rows_changed:
if arr_2d.shape[0] > 1:
if isinstance(row_idxs, slice):
max_idx = row_idxs.stop - 1
else:
row_idxs = reshaping.to_1d_array(row_idxs)
max_idx = np.max(row_idxs)
if arr_2d.shape[0] <= max_idx:
if rotate_rows and not isinstance(row_idxs, slice):
new_arr = new_arr[row_idxs % arr_2d.shape[0], :]
else:
new_arr = new_arr[row_idxs, :]
else:
new_arr = new_arr[row_idxs, :]
if col_idxs is not None and columns_changed:
if arr_2d.shape[1] > 1:
if isinstance(col_idxs, slice):
max_idx = col_idxs.stop - 1
else:
col_idxs = reshaping.to_1d_array(col_idxs)
max_idx = np.max(col_idxs)
if arr_2d.shape[1] <= max_idx:
if rotate_cols and not isinstance(col_idxs, slice):
new_arr = new_arr[:, col_idxs % arr_2d.shape[1]]
else:
new_arr = new_arr[:, col_idxs]
else:
new_arr = new_arr[:, col_idxs]
return new_arr
def get_resampler(self, *args, **kwargs) -> tp.Union[Resampler, tp.PandasResampler]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_resampler`."""
return self.index_acc.get_resampler(*args, **kwargs)
def resample_meta(self: ArrayWrapperT, *args, wrapper_kwargs: tp.KwargsLike = None, **kwargs) -> dict:
"""Perform resampling on `ArrayWrapper` and also return metadata.
`*args` and `**kwargs` are passed to `ArrayWrapper.get_resampler`."""
resampler = self.get_resampler(*args, **kwargs)
if isinstance(resampler, Resampler):
_resampler = resampler
else:
_resampler = Resampler.from_pd_resampler(resampler)
if wrapper_kwargs is None:
wrapper_kwargs = {}
if "index" not in wrapper_kwargs:
wrapper_kwargs["index"] = _resampler.target_index
if "freq" not in wrapper_kwargs:
wrapper_kwargs["freq"] = dt.infer_index_freq(wrapper_kwargs["index"], freq=_resampler.target_freq)
new_wrapper = self.replace(**wrapper_kwargs)
return dict(resampler=resampler, new_wrapper=new_wrapper)
def resample(self: ArrayWrapperT, *args, **kwargs) -> ArrayWrapperT:
"""Perform resampling on `ArrayWrapper`.
Uses `ArrayWrapper.resample_meta`."""
return self.resample_meta(*args, **kwargs)["new_wrapper"]
@property
def wrapper(self) -> "ArrayWrapper":
return self
@property
def index(self) -> tp.Index:
"""Index."""
return self._index
@cached_property(whitelist=True)
def index_acc(self) -> BaseIDXAccessorT:
"""Get index accessor of the type `vectorbtpro.base.accessors.BaseIDXAccessor`."""
from vectorbtpro.base.accessors import BaseIDXAccessor
return BaseIDXAccessor(self.index, freq=self._freq)
@property
def ns_index(self) -> tp.Array1d:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.to_ns`."""
return self.index_acc.to_ns()
def get_period_ns_index(self, *args, **kwargs) -> tp.Array1d:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.to_period_ns`."""
return self.index_acc.to_period_ns(*args, **kwargs)
@property
def columns(self) -> tp.Index:
"""Columns."""
return self._columns
def get_columns(self, group_by: tp.GroupByLike = None) -> tp.Index:
"""Get group-aware `ArrayWrapper.columns`."""
return self.resolve(group_by=group_by).columns
@property
def name(self) -> tp.Any:
"""Name."""
if self.ndim == 1:
if self.columns[0] == 0:
return None
return self.columns[0]
return None
def get_name(self, group_by: tp.GroupByLike = None) -> tp.Any:
"""Get group-aware `ArrayWrapper.name`."""
return self.resolve(group_by=group_by).name
@property
def ndim(self) -> int:
"""Number of dimensions."""
return self._ndim
def get_ndim(self, group_by: tp.GroupByLike = None) -> int:
"""Get group-aware `ArrayWrapper.ndim`."""
return self.resolve(group_by=group_by).ndim
@property
def shape(self) -> tp.Shape:
"""Shape."""
if self.ndim == 1:
return (len(self.index),)
return len(self.index), len(self.columns)
def get_shape(self, group_by: tp.GroupByLike = None) -> tp.Shape:
"""Get group-aware `ArrayWrapper.shape`."""
return self.resolve(group_by=group_by).shape
@property
def shape_2d(self) -> tp.Shape:
"""Shape as if the instance was two-dimensional."""
if self.ndim == 1:
return self.shape[0], 1
return self.shape
def get_shape_2d(self, group_by: tp.GroupByLike = None) -> tp.Shape:
"""Get group-aware `ArrayWrapper.shape_2d`."""
return self.resolve(group_by=group_by).shape_2d
def get_freq(self, *args, **kwargs) -> tp.Union[None, float, tp.PandasFrequency]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_freq`."""
return self.index_acc.get_freq(*args, **kwargs)
@property
def freq(self) -> tp.Optional[tp.PandasFrequency]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.freq`."""
return self.index_acc.freq
@property
def ns_freq(self) -> tp.Optional[int]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.ns_freq`."""
return self.index_acc.ns_freq
@property
def any_freq(self) -> tp.Union[None, float, tp.PandasFrequency]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.any_freq`."""
return self.index_acc.any_freq
@property
def periods(self) -> int:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.periods`."""
return self.index_acc.periods
@property
def dt_periods(self) -> float:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.dt_periods`."""
return self.index_acc.dt_periods
def arr_to_timedelta(self, *args, **kwargs) -> tp.Union[pd.Index, tp.MaybeArray]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.arr_to_timedelta`."""
return self.index_acc.arr_to_timedelta(*args, **kwargs)
@property
def parse_index(self) -> tp.Optional[bool]:
"""Whether to try to convert the index into a datetime index.
Applied during the initialization and passed to `vectorbtpro.utils.datetime_.prepare_dt_index`."""
return self._parse_index
@property
def column_only_select(self) -> bool:
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
column_only_select = self._column_only_select
if column_only_select is None:
column_only_select = wrapping_cfg["column_only_select"]
return column_only_select
@property
def range_only_select(self) -> bool:
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
range_only_select = self._range_only_select
if range_only_select is None:
range_only_select = wrapping_cfg["range_only_select"]
return range_only_select
@property
def group_select(self) -> bool:
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
group_select = self._group_select
if group_select is None:
group_select = wrapping_cfg["group_select"]
return group_select
@property
def grouper(self) -> Grouper:
"""Column grouper."""
return self._grouper
@property
def grouped_ndim(self) -> int:
"""Number of dimensions under column grouping."""
if self._grouped_ndim is None:
if self.grouper.is_grouped():
return 2 if self.grouper.get_group_count() > 1 else 1
return self.ndim
return self._grouped_ndim
@cached_method(whitelist=True)
def regroup(self: ArrayWrapperT, group_by: tp.GroupByLike, **kwargs) -> ArrayWrapperT:
"""Regroup this instance.
Only creates a new instance if grouping has changed, otherwise returns itself."""
if self.grouper.is_grouping_changed(group_by=group_by):
self.grouper.check_group_by(group_by=group_by)
grouped_ndim = None
if self.grouper.is_grouped(group_by=group_by):
if not self.grouper.is_group_count_changed(group_by=group_by):
grouped_ndim = self.grouped_ndim
return self.replace(grouped_ndim=grouped_ndim, group_by=group_by, **kwargs)
if len(kwargs) > 0:
return self.replace(**kwargs)
return self # important for keeping cache
def flip(self: ArrayWrapperT, **kwargs) -> ArrayWrapperT:
"""Flip index and columns."""
if "grouper" not in kwargs:
kwargs["grouper"] = None
return self.replace(index=self.columns, columns=self.index, **kwargs)
@cached_method(whitelist=True)
def resolve(self: ArrayWrapperT, group_by: tp.GroupByLike = None, **kwargs) -> ArrayWrapperT:
"""Resolve this instance.
Replaces columns and other metadata with groups."""
_self = self.regroup(group_by=group_by, **kwargs)
if _self.grouper.is_grouped():
return _self.replace(
columns=_self.grouper.get_index(),
ndim=_self.grouped_ndim,
grouped_ndim=None,
group_by=None,
)
return _self # important for keeping cache
def get_index_grouper(self, *args, **kwargs) -> Grouper:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`."""
return self.index_acc.get_grouper(*args, **kwargs)
def wrap(
self,
arr: tp.ArrayLike,
group_by: tp.GroupByLike = None,
index: tp.Optional[tp.IndexLike] = None,
columns: tp.Optional[tp.IndexLike] = None,
zero_to_none: tp.Optional[bool] = None,
force_2d: bool = False,
fillna: tp.Optional[tp.Scalar] = None,
dtype: tp.Optional[tp.PandasDTypeLike] = None,
min_precision: tp.Union[None, int, str] = None,
max_precision: tp.Union[None, int, str] = None,
prec_float_only: tp.Optional[bool] = None,
prec_check_bounds: tp.Optional[bool] = None,
prec_strict: tp.Optional[bool] = None,
to_timedelta: bool = False,
to_index: bool = False,
silence_warnings: tp.Optional[bool] = None,
) -> tp.SeriesFrame:
"""Wrap a NumPy array using the stored metadata.
Runs the following pipeline:
1) Converts to NumPy array
2) Fills NaN (optional)
3) Wraps using index, columns, and dtype (optional)
4) Converts to index (optional)
5) Converts to timedelta using `ArrayWrapper.arr_to_timedelta` (optional)"""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if zero_to_none is None:
zero_to_none = wrapping_cfg["zero_to_none"]
if min_precision is None:
min_precision = wrapping_cfg["min_precision"]
if max_precision is None:
max_precision = wrapping_cfg["max_precision"]
if prec_float_only is None:
prec_float_only = wrapping_cfg["prec_float_only"]
if prec_check_bounds is None:
prec_check_bounds = wrapping_cfg["prec_check_bounds"]
if prec_strict is None:
prec_strict = wrapping_cfg["prec_strict"]
if silence_warnings is None:
silence_warnings = wrapping_cfg["silence_warnings"]
_self = self.resolve(group_by=group_by)
if index is None:
index = _self.index
if not isinstance(index, pd.Index):
index = pd.Index(index)
if columns is None:
columns = _self.columns
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
if len(columns) == 1:
name = columns[0]
if zero_to_none and name == 0: # was a Series before
name = None
else:
name = None
def _apply_dtype(obj):
if dtype is None:
return obj
return obj.astype(dtype, errors="ignore")
def _wrap(arr):
orig_arr = arr
arr = np.asarray(arr)
if fillna is not None:
arr[pd.isnull(arr)] = fillna
shape_2d = (arr.shape[0] if arr.ndim > 0 else 1, arr.shape[1] if arr.ndim > 1 else 1)
target_shape_2d = (len(index), len(columns))
if shape_2d != target_shape_2d:
if isinstance(orig_arr, (pd.Series, pd.DataFrame)):
arr = reshaping.align_pd_arrays(orig_arr, to_index=index, to_columns=columns).values
arr = reshaping.broadcast_array_to(arr, target_shape_2d)
arr = reshaping.soft_to_ndim(arr, self.ndim)
if min_precision is not None:
arr = cast_to_min_precision(
arr,
min_precision,
float_only=prec_float_only,
)
if max_precision is not None:
arr = cast_to_max_precision(
arr,
max_precision,
float_only=prec_float_only,
check_bounds=prec_check_bounds,
strict=prec_strict,
)
if arr.ndim == 1:
if force_2d:
return _apply_dtype(pd.DataFrame(arr[:, None], index=index, columns=columns))
return _apply_dtype(pd.Series(arr, index=index, name=name))
if arr.ndim == 2:
if not force_2d and arr.shape[1] == 1 and _self.ndim == 1:
return _apply_dtype(pd.Series(arr[:, 0], index=index, name=name))
return _apply_dtype(pd.DataFrame(arr, index=index, columns=columns))
raise ValueError(f"{arr.ndim}-d input is not supported")
out = _wrap(arr)
if to_index:
# Convert to index
if checks.is_series(out):
out = out.map(lambda x: self.index[x] if x != -1 else np.nan)
else:
out = out.applymap(lambda x: self.index[x] if x != -1 else np.nan)
if to_timedelta:
# Convert to timedelta
out = self.arr_to_timedelta(out, silence_warnings=silence_warnings)
return out
def wrap_reduced(
self,
arr: tp.ArrayLike,
group_by: tp.GroupByLike = None,
name_or_index: tp.NameIndex = None,
columns: tp.Optional[tp.IndexLike] = None,
force_1d: bool = False,
fillna: tp.Optional[tp.Scalar] = None,
dtype: tp.Optional[tp.PandasDTypeLike] = None,
to_timedelta: bool = False,
to_index: bool = False,
silence_warnings: tp.Optional[bool] = None,
) -> tp.MaybeSeriesFrame:
"""Wrap result of reduction.
`name_or_index` can be the name of the resulting series if reducing to a scalar per column,
or the index of the resulting series/dataframe if reducing to an array per column.
`columns` can be set to override object's default columns.
See `ArrayWrapper.wrap` for the pipeline."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if silence_warnings is None:
silence_warnings = wrapping_cfg["silence_warnings"]
_self = self.resolve(group_by=group_by)
if columns is None:
columns = _self.columns
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
if to_index:
if dtype is None:
dtype = int_
if fillna is None:
fillna = -1
def _apply_dtype(obj):
if dtype is None:
return obj
return obj.astype(dtype, errors="ignore")
def _wrap_reduced(arr):
nonlocal name_or_index
if isinstance(arr, dict):
arr = reshaping.to_pd_array(arr)
if isinstance(arr, pd.Series):
if not checks.is_index_equal(arr.index, columns):
arr = arr.iloc[indexes.align_indexes(arr.index, columns)[0]]
arr = np.asarray(arr)
if force_1d and arr.ndim == 0:
arr = arr[None]
if fillna is not None:
if arr.ndim == 0:
if pd.isnull(arr):
arr = fillna
else:
arr[pd.isnull(arr)] = fillna
if arr.ndim == 0:
# Scalar per Series/DataFrame
return _apply_dtype(pd.Series(arr[None]))[0]
if arr.ndim == 1:
if not force_1d and _self.ndim == 1:
if arr.shape[0] == 1:
# Scalar per Series/DataFrame with one column
return _apply_dtype(pd.Series(arr))[0]
# Array per Series
sr_name = columns[0]
if sr_name == 0:
sr_name = None
if isinstance(name_or_index, str):
name_or_index = None
return _apply_dtype(pd.Series(arr, index=name_or_index, name=sr_name))
# Scalar per column in DataFrame
if arr.shape[0] == 1 and len(columns) > 1:
arr = reshaping.broadcast_array_to(arr, len(columns))
return _apply_dtype(pd.Series(arr, index=columns, name=name_or_index))
if arr.ndim == 2:
if arr.shape[1] == 1 and _self.ndim == 1:
arr = reshaping.soft_to_ndim(arr, 1)
# Array per Series
sr_name = columns[0]
if sr_name == 0:
sr_name = None
if isinstance(name_or_index, str):
name_or_index = None
return _apply_dtype(pd.Series(arr, index=name_or_index, name=sr_name))
# Array per column in DataFrame
if isinstance(name_or_index, str):
name_or_index = None
if arr.shape[0] == 1 and len(columns) > 1:
arr = reshaping.broadcast_array_to(arr, (arr.shape[0], len(columns)))
return _apply_dtype(pd.DataFrame(arr, index=name_or_index, columns=columns))
raise ValueError(f"{arr.ndim}-d input is not supported")
out = _wrap_reduced(arr)
if to_index:
# Convert to index
if checks.is_series(out):
out = out.map(lambda x: self.index[x] if x != -1 else np.nan)
elif checks.is_frame(out):
out = out.applymap(lambda x: self.index[x] if x != -1 else np.nan)
else:
out = self.index[out] if out != -1 else np.nan
if to_timedelta:
# Convert to timedelta
out = self.arr_to_timedelta(out, silence_warnings=silence_warnings)
return out
def concat_arrs(
self,
*objs: tp.ArrayLike,
group_by: tp.GroupByLike = None,
wrap: bool = True,
**kwargs,
) -> tp.AnyArray1d:
"""Stack reduced objects along columns and wrap the final object."""
from vectorbtpro.base.merging import concat_arrays
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
new_objs = []
for obj in objs:
new_objs.append(reshaping.to_1d_array(obj))
stacked_obj = concat_arrays(new_objs)
if wrap:
_self = self.resolve(group_by=group_by)
return _self.wrap_reduced(stacked_obj, **kwargs)
return stacked_obj
def row_stack_arrs(
self,
*objs: tp.ArrayLike,
group_by: tp.GroupByLike = None,
wrap: bool = True,
**kwargs,
) -> tp.AnyArray:
"""Stack objects along rows and wrap the final object."""
from vectorbtpro.base.merging import row_stack_arrays
_self = self.resolve(group_by=group_by)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
new_objs = []
for obj in objs:
obj = reshaping.to_2d_array(obj)
if obj.shape[1] != _self.shape_2d[1]:
if obj.shape[1] != 1:
raise ValueError(f"Cannot broadcast {obj.shape[1]} to {_self.shape_2d[1]} columns")
obj = np.repeat(obj, _self.shape_2d[1], axis=1)
new_objs.append(obj)
stacked_obj = row_stack_arrays(new_objs)
if wrap:
return _self.wrap(stacked_obj, **kwargs)
return stacked_obj
def column_stack_arrs(
self,
*objs: tp.ArrayLike,
reindex_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
wrap: bool = True,
**kwargs,
) -> tp.AnyArray2d:
"""Stack objects along columns and wrap the final object.
`reindex_kwargs` will be passed to
[pandas.DataFrame.reindex](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.reindex.html)."""
from vectorbtpro.base.merging import column_stack_arrays
_self = self.resolve(group_by=group_by)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
new_objs = []
for obj in objs:
if not checks.is_index_equal(obj.index, _self.index, check_names=False):
was_bool = (isinstance(obj, pd.Series) and obj.dtype == "bool") or (
isinstance(obj, pd.DataFrame) and (obj.dtypes == "bool").all()
)
obj = obj.reindex(_self.index, **resolve_dict(reindex_kwargs))
is_object = (isinstance(obj, pd.Series) and obj.dtype == "object") or (
isinstance(obj, pd.DataFrame) and (obj.dtypes == "object").all()
)
if was_bool and is_object:
obj = obj.astype(None)
new_objs.append(reshaping.to_2d_array(obj))
stacked_obj = column_stack_arrays(new_objs)
if wrap:
return _self.wrap(stacked_obj, **kwargs)
return stacked_obj
def dummy(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame:
"""Create a dummy Series/DataFrame."""
_self = self.resolve(group_by=group_by)
return _self.wrap(np.empty(_self.shape), **kwargs)
def fill(self, fill_value: tp.Scalar = np.nan, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame:
"""Fill a Series/DataFrame."""
_self = self.resolve(group_by=group_by)
return _self.wrap(np.full(_self.shape_2d, fill_value), **kwargs)
def fill_reduced(self, fill_value: tp.Scalar = np.nan, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame:
"""Fill a reduced Series/DataFrame."""
_self = self.resolve(group_by=group_by)
return _self.wrap_reduced(np.full(_self.shape_2d[1], fill_value), **kwargs)
def apply_to_index(
self: ArrayWrapperT,
apply_func: tp.Callable,
*args,
axis: tp.Optional[int] = None,
**kwargs,
) -> ArrayWrapperT:
if axis is None:
axis = 0 if self.ndim == 1 else 1
if self.ndim == 1 and axis == 1:
raise TypeError("Axis 1 is not supported for one dimension")
checks.assert_in(axis, (0, 1))
if axis == 1:
return self.replace(columns=apply_func(self.columns, *args, **kwargs))
return self.replace(index=apply_func(self.index, *args, **kwargs))
def get_index_points(self, *args, **kwargs) -> tp.Array1d:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_points`."""
return self.index_acc.get_points(*args, **kwargs)
def get_index_ranges(self, *args, **kwargs) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""See `vectorbtpro.base.accessors.BaseIDXAccessor.get_ranges`."""
return self.index_acc.get_ranges(*args, **kwargs)
def fill_and_set(
self,
idx_setter: tp.Union[index_dict, IdxSetter, IdxSetterFactory],
keep_flex: bool = False,
fill_value: tp.Scalar = np.nan,
**kwargs,
) -> tp.AnyArray:
"""Fill a new array using an index object such as `vectorbtpro.base.indexing.index_dict`.
Will be wrapped with `vectorbtpro.base.indexing.IdxSetter` if not already.
Will call `vectorbtpro.base.indexing.IdxSetter.fill_and_set`.
Usage:
* Set a single row:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", periods=5)
>>> columns = pd.Index(["a", "b", "c"])
>>> wrapper = vbt.ArrayWrapper(index, columns)
>>> wrapper.fill_and_set(vbt.index_dict({
... 1: 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 NaN NaN NaN
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... "2020-01-02": 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 NaN NaN NaN
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... "2020-01-02": [1, 2, 3]
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 1.0 2.0 3.0
2020-01-03 NaN NaN NaN
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
```
* Set multiple rows:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... (1, 3): [2, 3]
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 NaN NaN NaN
2020-01-04 3.0 3.0 3.0
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... ("2020-01-02", "2020-01-04"): [[1, 2, 3], [4, 5, 6]]
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 1.0 2.0 3.0
2020-01-03 NaN NaN NaN
2020-01-04 4.0 5.0 6.0
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... ("2020-01-02", "2020-01-04"): [[1, 2, 3]]
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 1.0 2.0 3.0
2020-01-03 NaN NaN NaN
2020-01-04 1.0 2.0 3.0
2020-01-05 NaN NaN NaN
```
* Set rows using slices:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.hslice(1, 3): 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 2.0 2.0 2.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.hslice("2020-01-02", "2020-01-04"): 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 2.0 2.0 2.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... ((0, 2), (3, 5)): [[1], [2]]
... }))
a b c
2020-01-01 1.0 1.0 1.0
2020-01-02 1.0 1.0 1.0
2020-01-03 NaN NaN NaN
2020-01-04 2.0 2.0 2.0
2020-01-05 2.0 2.0 2.0
>>> wrapper.fill_and_set(vbt.index_dict({
... ((0, 2), (3, 5)): [[1, 2, 3], [4, 5, 6]]
... }))
a b c
2020-01-01 1.0 2.0 3.0
2020-01-02 1.0 2.0 3.0
2020-01-03 NaN NaN NaN
2020-01-04 4.0 5.0 6.0
2020-01-05 4.0 5.0 6.0
```
* Set rows using index points:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.pointidx(every="2D"): 2
... }))
a b c
2020-01-01 2.0 2.0 2.0
2020-01-02 NaN NaN NaN
2020-01-03 2.0 2.0 2.0
2020-01-04 NaN NaN NaN
2020-01-05 2.0 2.0 2.0
```
* Set rows using index ranges:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.rangeidx(
... start=("2020-01-01", "2020-01-03"),
... end=("2020-01-02", "2020-01-05")
... ): 2
... }))
a b c
2020-01-01 2.0 2.0 2.0
2020-01-02 NaN NaN NaN
2020-01-03 2.0 2.0 2.0
2020-01-04 2.0 2.0 2.0
2020-01-05 NaN NaN NaN
```
* Set column indices:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.colidx("a"): 2
... }))
a b c
2020-01-01 2.0 NaN NaN
2020-01-02 2.0 NaN NaN
2020-01-03 2.0 NaN NaN
2020-01-04 2.0 NaN NaN
2020-01-05 2.0 NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.colidx(("a", "b")): [1, 2]
... }))
a b c
2020-01-01 1.0 2.0 NaN
2020-01-02 1.0 2.0 NaN
2020-01-03 1.0 2.0 NaN
2020-01-04 1.0 2.0 NaN
2020-01-05 1.0 2.0 NaN
>>> multi_columns = pd.MultiIndex.from_arrays(
... [["a", "a", "b", "b"], [1, 2, 1, 2]],
... names=["c1", "c2"]
... )
>>> multi_wrapper = vbt.ArrayWrapper(index, multi_columns)
>>> multi_wrapper.fill_and_set(vbt.index_dict({
... vbt.colidx(("a", 2)): 2
... }))
c1 a b
c2 1 2 1 2
2020-01-01 NaN 2.0 NaN NaN
2020-01-02 NaN 2.0 NaN NaN
2020-01-03 NaN 2.0 NaN NaN
2020-01-04 NaN 2.0 NaN NaN
2020-01-05 NaN 2.0 NaN NaN
>>> multi_wrapper.fill_and_set(vbt.index_dict({
... vbt.colidx("b", level="c1"): [3, 4]
... }))
c1 a b
c2 1 2 1 2
2020-01-01 NaN NaN 3.0 4.0
2020-01-02 NaN NaN 3.0 4.0
2020-01-03 NaN NaN 3.0 4.0
2020-01-04 NaN NaN 3.0 4.0
2020-01-05 NaN NaN 3.0 4.0
```
* Set row and column indices:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(2, 2): 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 NaN NaN NaN
2020-01-03 NaN NaN 2.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(("2020-01-01", "2020-01-03"), 2): [1, 2]
... }))
a b c
2020-01-01 NaN NaN 1.0
2020-01-02 NaN NaN NaN
2020-01-03 NaN NaN 2.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(("2020-01-01", "2020-01-03"), (0, 2)): [[1, 2], [3, 4]]
... }))
a b c
2020-01-01 1.0 NaN 2.0
2020-01-02 NaN NaN NaN
2020-01-03 3.0 NaN 4.0
2020-01-04 NaN NaN NaN
2020-01-05 NaN NaN NaN
>>> multi_wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(
... vbt.pointidx(every="2d"),
... vbt.colidx(1, level="c2")
... ): [[1, 2]]
... }))
c1 a b
c2 1 2 1 2
2020-01-01 1.0 NaN 2.0 NaN
2020-01-02 NaN NaN NaN NaN
2020-01-03 1.0 NaN 2.0 NaN
2020-01-04 NaN NaN NaN NaN
2020-01-05 1.0 NaN 2.0 NaN
>>> multi_wrapper.fill_and_set(vbt.index_dict({
... vbt.idx(
... vbt.pointidx(every="2d"),
... vbt.colidx(1, level="c2")
... ): [[1], [2], [3]]
... }))
c1 a b
c2 1 2 1 2
2020-01-01 1.0 NaN 1.0 NaN
2020-01-02 NaN NaN NaN NaN
2020-01-03 2.0 NaN 2.0 NaN
2020-01-04 NaN NaN NaN NaN
2020-01-05 3.0 NaN 3.0 NaN
```
* Set rows using a template:
```pycon
>>> wrapper.fill_and_set(vbt.index_dict({
... vbt.RepEval("index.day % 2 == 0"): 2
... }))
a b c
2020-01-01 NaN NaN NaN
2020-01-02 2.0 2.0 2.0
2020-01-03 NaN NaN NaN
2020-01-04 2.0 2.0 2.0
2020-01-05 NaN NaN NaN
```
"""
if isinstance(idx_setter, index_dict):
idx_setter = IdxDict(idx_setter)
if isinstance(idx_setter, IdxSetterFactory):
idx_setter = idx_setter.get()
if not isinstance(idx_setter, IdxSetter):
raise ValueError("Index setter factory must return exactly one index setter")
checks.assert_instance_of(idx_setter, IdxSetter)
arr = idx_setter.fill_and_set(
self.shape,
keep_flex=keep_flex,
fill_value=fill_value,
index=self.index,
columns=self.columns,
freq=self.freq,
**kwargs,
)
if not keep_flex:
return self.wrap(arr, group_by=False)
return arr
WrappingT = tp.TypeVar("WrappingT", bound="Wrapping")
class Wrapping(Configured, HasWrapper, IndexApplier, AttrResolverMixin):
"""Class that uses `ArrayWrapper` globally."""
@classmethod
def resolve_row_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `Wrapping` after stacking along rows."""
return kwargs
@classmethod
def resolve_column_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `Wrapping` after stacking along columns."""
return kwargs
@classmethod
def resolve_stack_kwargs(cls, *wrappings: tp.MaybeTuple[WrappingT], **kwargs) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `Wrapping` after stacking.
Should be called after `Wrapping.resolve_row_stack_kwargs` or `Wrapping.resolve_column_stack_kwargs`."""
return cls.resolve_merge_kwargs(*[wrapping.config for wrapping in wrappings], **kwargs)
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[WrappingT],
*objs: tp.MaybeTuple[WrappingT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> WrappingT:
"""Stack multiple `Wrapping` instances along rows.
Should use `ArrayWrapper.row_stack`."""
raise NotImplementedError
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[WrappingT],
*objs: tp.MaybeTuple[WrappingT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> WrappingT:
"""Stack multiple `Wrapping` instances along columns.
Should use `ArrayWrapper.column_stack`."""
raise NotImplementedError
def __init__(self, wrapper: ArrayWrapper, **kwargs) -> None:
checks.assert_instance_of(wrapper, ArrayWrapper)
self._wrapper = wrapper
Configured.__init__(self, wrapper=wrapper, **kwargs)
HasWrapper.__init__(self)
AttrResolverMixin.__init__(self)
def indexing_func(self: WrappingT, *args, **kwargs) -> WrappingT:
"""Perform indexing on `Wrapping`."""
new_wrapper = self.wrapper.indexing_func(
*args,
column_only_select=self.column_only_select,
range_only_select=self.range_only_select,
group_select=self.group_select,
**kwargs,
)
return self.replace(wrapper=new_wrapper)
def resample(self: WrappingT, *args, **kwargs) -> WrappingT:
"""Perform resampling on `Wrapping`.
When overriding, make sure to create a resampler by passing `*args` and `**kwargs`
to `ArrayWrapper.get_resampler`."""
raise NotImplementedError
@property
def wrapper(self) -> ArrayWrapper:
return self._wrapper
def apply_to_index(
self: ArrayWrapperT,
apply_func: tp.Callable,
*args,
axis: tp.Optional[int] = None,
**kwargs,
) -> ArrayWrapperT:
if axis is None:
axis = 0 if self.wrapper.ndim == 1 else 1
if self.wrapper.ndim == 1 and axis == 1:
raise TypeError("Axis 1 is not supported for one dimension")
checks.assert_in(axis, (0, 1))
if axis == 1:
new_wrapper = self.wrapper.replace(columns=apply_func(self.wrapper.columns, *args, **kwargs))
else:
new_wrapper = self.wrapper.replace(index=apply_func(self.wrapper.index, *args, **kwargs))
return self.replace(wrapper=new_wrapper)
@property
def column_only_select(self) -> bool:
column_only_select = getattr(self, "_column_only_select", None)
if column_only_select is None:
return self.wrapper.column_only_select
return column_only_select
@property
def range_only_select(self) -> bool:
range_only_select = getattr(self, "_range_only_select", None)
if range_only_select is None:
return self.wrapper.range_only_select
return range_only_select
@property
def group_select(self) -> bool:
group_select = getattr(self, "_group_select", None)
if group_select is None:
return self.wrapper.group_select
return group_select
def regroup(self: WrappingT, group_by: tp.GroupByLike, **kwargs) -> WrappingT:
"""Regroup this instance.
Only creates a new instance if grouping has changed, otherwise returns itself.
`**kwargs` will be passed to `ArrayWrapper.regroup`."""
if self.wrapper.grouper.is_grouping_changed(group_by=group_by):
self.wrapper.grouper.check_group_by(group_by=group_by)
return self.replace(wrapper=self.wrapper.regroup(group_by, **kwargs))
return self # important for keeping cache
def resolve_self(
self: AttrResolverMixinT,
cond_kwargs: tp.KwargsLike = None,
custom_arg_names: tp.Optional[tp.Set[str]] = None,
impacts_caching: bool = True,
silence_warnings: tp.Optional[bool] = None,
) -> AttrResolverMixinT:
"""Resolve self.
Creates a copy of this instance if a different `freq` can be found in `cond_kwargs`."""
from vectorbtpro._settings import settings
wrapping_cfg = settings["wrapping"]
if cond_kwargs is None:
cond_kwargs = {}
if custom_arg_names is None:
custom_arg_names = set()
if silence_warnings is None:
silence_warnings = wrapping_cfg["silence_warnings"]
if "freq" in cond_kwargs:
wrapper_copy = self.wrapper.replace(freq=cond_kwargs["freq"])
if wrapper_copy.freq != self.wrapper.freq:
if not silence_warnings:
warn(
f"Changing the frequency will create a copy of this instance. "
f"Consider setting it upon instantiation to re-use existing cache."
)
self_copy = self.replace(wrapper=wrapper_copy)
for alias in self.self_aliases:
if alias not in custom_arg_names:
cond_kwargs[alias] = self_copy
cond_kwargs["freq"] = self_copy.wrapper.freq
if impacts_caching:
cond_kwargs["use_caching"] = False
return self_copy
return self
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with custom data classes."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.data.custom.alpaca import *
from vectorbtpro.data.custom.av import *
from vectorbtpro.data.custom.bento import *
from vectorbtpro.data.custom.binance import *
from vectorbtpro.data.custom.ccxt import *
from vectorbtpro.data.custom.csv import *
from vectorbtpro.data.custom.custom import *
from vectorbtpro.data.custom.db import *
from vectorbtpro.data.custom.duckdb import *
from vectorbtpro.data.custom.feather import *
from vectorbtpro.data.custom.file import *
from vectorbtpro.data.custom.finpy import *
from vectorbtpro.data.custom.gbm import *
from vectorbtpro.data.custom.gbm_ohlc import *
from vectorbtpro.data.custom.hdf import *
from vectorbtpro.data.custom.local import *
from vectorbtpro.data.custom.ndl import *
from vectorbtpro.data.custom.parquet import *
from vectorbtpro.data.custom.polygon import *
from vectorbtpro.data.custom.random import *
from vectorbtpro.data.custom.random_ohlc import *
from vectorbtpro.data.custom.remote import *
from vectorbtpro.data.custom.sql import *
from vectorbtpro.data.custom.synthetic import *
from vectorbtpro.data.custom.tv import *
from vectorbtpro.data.custom.yf import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `AlpacaData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.parsing import get_func_arg_names
try:
if not tp.TYPE_CHECKING:
raise ImportError
from alpaca.common.rest import RESTClient as AlpacaClientT
except ImportError:
AlpacaClientT = "AlpacaClient"
__all__ = [
"AlpacaData",
]
AlpacaDataT = tp.TypeVar("AlpacaDataT", bound="AlpacaData")
class AlpacaData(RemoteData):
"""Data class for fetching from Alpaca.
See https://github.com/alpacahq/alpaca-py for API.
See `AlpacaData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional for crypto):
```pycon
>>> from vectorbtpro import *
>>> vbt.AlpacaData.set_custom_settings(
... client_config=dict(
... api_key="YOUR_KEY",
... secret_key="YOUR_SECRET"
... )
... )
```
* Pull stock data:
```pycon
>>> data = vbt.AlpacaData.pull(
... "AAPL",
... start="2021-01-01",
... end="2022-01-01",
... timeframe="1 day"
... )
```
* Pull crypto data:
```pycon
>>> data = vbt.AlpacaData.pull(
... "BTC/USD",
... start="2021-01-01",
... end="2022-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.alpaca")
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
status: tp.Optional[str] = None,
asset_class: tp.Optional[str] = None,
exchange: tp.Optional[str] = None,
trading_client: tp.Optional[AlpacaClientT] = None,
client_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all symbols.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.
Arguments `status`, `asset_class`, and `exchange` can be strings, such as `asset_class="crypto"`.
For possible values, take a look into `alpaca.trading.enums`.
!!! note
If you get an authorization error, make sure that you either enable or disable
the `paper` flag in `client_config` depending upon the account whose credentials you used.
By default, the credentials are assumed to be of a live trading account (`paper=False`)."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("alpaca")
from alpaca.trading.client import TradingClient
from alpaca.trading.requests import GetAssetsRequest
from alpaca.trading.enums import AssetStatus, AssetClass, AssetExchange
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if trading_client is None:
arg_names = get_func_arg_names(TradingClient.__init__)
client_config = {k: v for k, v in client_config.items() if k in arg_names}
trading_client = TradingClient(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
if status is not None:
if isinstance(status, str):
status = getattr(AssetStatus, status.upper())
if asset_class is not None:
if isinstance(asset_class, str):
asset_class = getattr(AssetClass, asset_class.upper())
if exchange is not None:
if isinstance(exchange, str):
exchange = getattr(AssetExchange, exchange.upper())
search_params = GetAssetsRequest(status=status, asset_class=asset_class, exchange=exchange)
assets = trading_client.get_all_assets(search_params)
all_symbols = []
for asset in assets:
symbol = asset.symbol
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def resolve_client(
cls,
client: tp.Optional[AlpacaClientT] = None,
client_type: tp.Optional[str] = None,
**client_config,
) -> AlpacaClientT:
"""Resolve the client.
If provided, must be of the type `alpaca.data.historical.CryptoHistoricalDataClient`
for `client_type="crypto"` and `alpaca.data.historical.StockHistoricalDataClient` for
`client_type="stocks"`. Otherwise, will be created using `client_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("alpaca")
from alpaca.data.historical import CryptoHistoricalDataClient, StockHistoricalDataClient
client = cls.resolve_custom_setting(client, "client")
client_type = cls.resolve_custom_setting(client_type, "client_type")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
if client_type == "crypto":
arg_names = get_func_arg_names(CryptoHistoricalDataClient.__init__)
client_config = {k: v for k, v in client_config.items() if k in arg_names}
client = CryptoHistoricalDataClient(**client_config)
elif client_type == "stocks":
arg_names = get_func_arg_names(StockHistoricalDataClient.__init__)
client_config = {k: v for k, v in client_config.items() if k in arg_names}
client = StockHistoricalDataClient(**client_config)
else:
raise ValueError(f"Invalid client type: '{client_type}'")
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[AlpacaClientT] = None,
client_type: tp.Optional[str] = None,
client_config: tp.KwargsLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
adjustment: tp.Optional[str] = None,
feed: tp.Optional[str] = None,
limit: tp.Optional[int] = None,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Alpaca.
Args:
symbol (str): Symbol.
client (alpaca.common.rest.RESTClient): Client.
See `AlpacaData.resolve_client`.
client_type (str): Client type.
See `AlpacaData.resolve_client`.
Determined automatically based on the symbol. Crypto symbols contain "/".
client_config (dict): Client config.
See `AlpacaData.resolve_client`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
adjustment (str): Specifies the corporate action adjustment for the returned bars.
Options are: "raw", "split", "dividend" or "all". Default is "raw".
feed (str): The feed to pull market data from.
This is either "iex", "otc", or "sip". Feeds "sip" and "otc" are only available to
those with a subscription. Default is "iex" for free plans and "sip" for paid.
limit (int): The maximum number of returned items.
For defaults, see `custom.alpaca` in `vectorbtpro._settings.data`.
Global settings can be provided per exchange id using the `exchanges` dictionary.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("alpaca")
from alpaca.data.historical import CryptoHistoricalDataClient, StockHistoricalDataClient
from alpaca.data.requests import CryptoBarsRequest, StockBarsRequest
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit
if client_type is None:
client_type = "crypto" if "/" in symbol else "stocks"
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, client_type=client_type, **client_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
adjustment = cls.resolve_custom_setting(adjustment, "adjustment")
feed = cls.resolve_custom_setting(feed, "feed")
limit = cls.resolve_custom_setting(limit, "limit")
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "m":
unit = TimeFrameUnit.Minute
elif unit == "h":
unit = TimeFrameUnit.Hour
elif unit == "D":
unit = TimeFrameUnit.Day
elif unit == "W":
unit = TimeFrameUnit.Week
elif unit == "M":
unit = TimeFrameUnit.Month
else:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
timeframe = TimeFrame(multiplier, unit)
if start is not None:
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
start_str = start.replace(tzinfo=None).isoformat("T")
else:
start_str = None
if end is not None:
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
end_str = end.replace(tzinfo=None).isoformat("T")
else:
end_str = None
if isinstance(client, CryptoHistoricalDataClient):
request = CryptoBarsRequest(
symbol_or_symbols=symbol,
timeframe=timeframe,
start=start_str,
end=end_str,
limit=limit,
)
df = client.get_crypto_bars(request).df
elif isinstance(client, StockHistoricalDataClient):
request = StockBarsRequest(
symbol_or_symbols=symbol,
timeframe=timeframe,
start=start_str,
end=end_str,
limit=limit,
adjustment=adjustment,
feed=feed,
)
df = client.get_stock_bars(request).df
else:
raise TypeError(f"Invalid client of type {type(client)}")
df = df.droplevel("symbol", axis=0)
df.index = df.index.rename("Open time")
df.rename(
columns={
"open": "Open",
"high": "High",
"low": "Low",
"close": "Close",
"volume": "Volume",
"trade_count": "Trade count",
"vwap": "VWAP",
},
inplace=True,
)
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize("utc")
if "Open" in df.columns:
df["Open"] = df["Open"].astype(float)
if "High" in df.columns:
df["High"] = df["High"].astype(float)
if "Low" in df.columns:
df["Low"] = df["Low"].astype(float)
if "Close" in df.columns:
df["Close"] = df["Close"].astype(float)
if "Volume" in df.columns:
df["Volume"] = df["Volume"].astype(float)
if "Trade count" in df.columns:
df["Trade count"] = df["Trade count"].astype(int, errors="ignore")
if "VWAP" in df.columns:
df["VWAP"] = df["VWAP"].astype(float)
if not df.empty:
if start is not None:
start = dt.to_timestamp(start, tz=df.index.tz)
if df.index[0] < start:
df = df[df.index >= start]
if end is not None:
end = dt.to_timestamp(end, tz=df.index.tz)
if df.index[-1] >= end:
df = df[df.index < end]
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `AVData`."""
import re
import urllib.parse
from functools import lru_cache
import numpy as np
import pandas as pd
import requests
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.module_ import check_installed
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from alpha_vantage.alphavantage import AlphaVantage as AlphaVantageT
except ImportError:
AlphaVantageT = "AlphaVantage"
__all__ = [
"AVData",
]
__pdoc__ = {}
AVDataT = tp.TypeVar("AVDataT", bound="AVData")
class AVData(RemoteData):
"""Data class for fetching from Alpha Vantage.
See https://www.alphavantage.co/documentation/ for API.
Apart of using https://github.com/RomelTorres/alpha_vantage package, this class can also
parse the API documentation with `AVData.parse_api_meta` using `BeautifulSoup4` and build
the API query based on this metadata (pass `use_parser=True`).
This approach is the most flexible we can get since we can instantly react to Alpha Vantage's changes
in the API. If the data provider changes its API documentation, you can always adapt the parsing
procedure by overriding `AVData.parse_api_meta`.
If parser still fails, you can disable parsing entirely and specify all information manually
by setting `function` and disabling `match_params`
See `AVData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.AVData.set_custom_settings(
... apikey="YOUR_KEY"
... )
```
* Pull data:
```pycon
>>> data = vbt.AVData.pull(
... "GOOGL",
... timeframe="1 day",
... )
>>> data = vbt.AVData.pull(
... "BTC_USD",
... timeframe="30 minutes", # premium?
... category="digital-currency",
... outputsize="full"
... )
>>> data = vbt.AVData.pull(
... "REAL_GDP",
... category="economic-indicators"
... )
>>> data = vbt.AVData.pull(
... "IBM",
... category="technical-indicators",
... function="STOCHRSI",
... params=dict(fastkperiod=14)
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.av")
@classmethod
def list_symbols(cls, keywords: str, apikey: tp.Optional[str] = None, sort: bool = True) -> tp.List[str]:
"""List all symbols."""
apikey = cls.resolve_custom_setting(apikey, "apikey")
query = dict()
query["function"] = "SYMBOL_SEARCH"
query["keywords"] = keywords
query["datatype"] = "csv"
query["apikey"] = apikey
url = "https://www.alphavantage.co/query?" + urllib.parse.urlencode(query)
df = pd.read_csv(url)
if sort:
return sorted(dict.fromkeys(df["symbol"].tolist()))
return list(dict.fromkeys(df["symbol"].tolist()))
@classmethod
@lru_cache()
def parse_api_meta(cls) -> dict:
"""Parse API metadata from the documentation at https://www.alphavantage.co/documentation
Cached class method. To avoid re-parsing the same metadata in different runtimes, save it manually."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("bs4")
from bs4 import BeautifulSoup
page = requests.get("https://www.alphavantage.co/documentation")
soup = BeautifulSoup(page.content, "html.parser")
api_meta = {}
for section in soup.select("article section"):
category = {}
function = None
function_args = dict(req_args=set(), opt_args=set())
for tag in section.find_all(True):
if tag.name == "h6":
if function is not None and tag.select("b")[0].getText().strip() == "API Parameters":
category[function] = function_args
function = None
function_args = dict(req_args=set(), opt_args=set())
if tag.name == "b":
b_text = tag.getText().strip()
if b_text.startswith("❚ Required"):
arg = tag.select("code")[0].getText().strip()
function_args["req_args"].add(arg)
if tag.name == "p":
p_text = tag.getText().strip()
if p_text.startswith("❚ Optional"):
arg = tag.select("code")[0].getText().strip()
function_args["opt_args"].add(arg)
if tag.name == "code":
code_text = tag.getText().strip()
if code_text.startswith("function="):
function = code_text.replace("function=", "")
if function is not None:
category[function] = function_args
api_meta[section.select("h2")[0]["id"]] = category
return api_meta
@classmethod
def fetch_symbol(
cls,
symbol: str,
use_parser: tp.Optional[bool] = None,
apikey: tp.Optional[str] = None,
api_meta: tp.Optional[dict] = None,
category: tp.Union[None, str, AlphaVantageT, tp.Type[AlphaVantageT]] = None,
function: tp.Union[None, str, tp.Callable] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
adjusted: tp.Optional[bool] = None,
extended: tp.Optional[bool] = None,
slice: tp.Optional[str] = None,
series_type: tp.Optional[str] = None,
time_period: tp.Optional[int] = None,
outputsize: tp.Optional[str] = None,
match_params: tp.Optional[bool] = None,
params: tp.KwargsLike = None,
read_csv_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.SymbolData:
"""Fetch a symbol from Alpha Vantage.
If `use_parser` is False, or None and `alpha_vantage` is installed, uses the package.
Otherwise, parses the API documentation and pulls data directly.
See https://www.alphavantage.co/documentation/ for API endpoints and their parameters.
!!! note
Supports the CSV format only.
Args:
symbol (str): Symbol.
May combine symbol/from_currency and market/to_currency using an underscore.
use_parser (bool): Whether to use the parser instead of the `alpha_vantage` package.
apikey (str): API key.
api_meta (dict): API meta.
If None, will use `AVData.parse_api_meta` if `function` is not provided
or `match_params` is True.
category (str or AlphaVantage): API category of your choice.
Used if `function` is not provided or `match_params` is True.
Supported are:
* `alpha_vantage.alphavantage.AlphaVantage` instance, class, or class name
* "time-series-data" or "time-series"
* "fundamental-data" or "fundamentals"
* "foreign-exchange", "forex", or "fx"
* "digital-currency", "cryptocurrencies", "cryptocurrency", or "crypto"
* "commodities"
* "economic-indicators"
* "technical-indicators" or "indicators"
function (str or callable): API function of your choice.
If None, will try to resolve it based on other arguments, such as `timeframe`,
`adjusted`, and `extended`. Required for technical indicators, economic indicators,
and fundamental data.
See the keys in sub-dictionaries returned by `AVData.parse_api_meta`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
For time series, forex, and crypto, looks for interval type in the function's name.
Defaults to "60min" if extended, otherwise to "daily".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
adjusted (bool): Whether to return time series adjusted by historical split and dividend events.
extended (bool): Whether to return historical intraday time series for the trailing 2 years.
slice (str): Slice of the trailing 2 years.
series_type (str): The desired price type in the time series.
time_period (int): Number of data points used to calculate each window value.
outputsize (str): Output size.
Supported are
* "compact" that returns only the latest 100 data points
* "full" that returns the full-length time series
match_params (bool): Whether to match parameters with the ones required by the endpoint.
Otherwise, uses only (resolved) `function`, `apikey`, `datatype="csv"`, and `params`.
params: Additional keyword arguments passed as key/value pairs in the URL.
read_csv_kwargs (dict): Keyword arguments passed to `pd.read_csv`.
silence_warnings (bool): Whether to silence all warnings.
For defaults, see `custom.av` in `vectorbtpro._settings.data`.
"""
use_parser = cls.resolve_custom_setting(use_parser, "use_parser")
apikey = cls.resolve_custom_setting(apikey, "apikey")
api_meta = cls.resolve_custom_setting(api_meta, "api_meta")
category = cls.resolve_custom_setting(category, "category")
function = cls.resolve_custom_setting(function, "function")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
adjusted = cls.resolve_custom_setting(adjusted, "adjusted")
extended = cls.resolve_custom_setting(extended, "extended")
slice = cls.resolve_custom_setting(slice, "slice")
series_type = cls.resolve_custom_setting(series_type, "series_type")
time_period = cls.resolve_custom_setting(time_period, "time_period")
outputsize = cls.resolve_custom_setting(outputsize, "outputsize")
read_csv_kwargs = cls.resolve_custom_setting(read_csv_kwargs, "read_csv_kwargs", merge=True)
match_params = cls.resolve_custom_setting(match_params, "match_params")
params = cls.resolve_custom_setting(params, "params", merge=True)
silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings")
if use_parser is None:
if api_meta is None and check_installed("alpha_vantage"):
use_parser = False
else:
use_parser = True
if not use_parser:
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("alpha_vantage")
if use_parser and api_meta is None and (function is None or match_params):
if not silence_warnings and cls.parse_api_meta.cache_info().misses == 0:
warn("Parsing API documentation...")
try:
api_meta = cls.parse_api_meta()
except Exception as e:
raise ValueError("Can't fetch/parse the API documentation. Specify function and disable match_params.")
freq = timeframe
interval = None
interval_type = None
if timeframe is not None:
if not isinstance(timeframe, str):
raise ValueError(f"Invalid timeframe: '{timeframe}'")
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "m":
interval = str(multiplier) + "min"
interval_type = "intraday"
elif unit == "h":
interval = str(60 * multiplier) + "min"
interval_type = "intraday"
elif unit == "D":
interval = "daily"
interval_type = "daily"
elif unit == "W":
interval = "weekly"
interval_type = "weekly"
elif unit == "M":
interval = "monthly"
interval_type = "monthly"
elif unit == "Q":
interval = "quarterly"
interval_type = "quarterly"
elif unit == "Y":
interval = "annual"
interval_type = "annual"
if interval is None and multiplier > 1:
raise ValueError("Multipliers are supported only for intraday timeframes")
else:
if extended:
interval_type = "intraday"
interval = "60min"
else:
interval_type = "daily"
interval = "daily"
if category is not None:
if isinstance(category, str):
if category.lower() in ("time-series-data", "time-series", "timeseries"):
if use_parser:
category = "time-series-data"
else:
from alpha_vantage.timeseries import TimeSeries
category = TimeSeries
elif category.lower() in ("fundamentals", "fundamental-data", "fundamentaldata"):
if use_parser:
category = "fundamentals"
else:
from alpha_vantage.fundamentaldata import FundamentalData
category = FundamentalData
elif category.lower() in ("fx", "forex", "foreign-exchange", "foreignexchange"):
if use_parser:
category = "fx"
else:
from alpha_vantage.foreignexchange import ForeignExchange
category = ForeignExchange
elif category.lower() in ("digital-currency", "cryptocurrencies", "cryptocurrency", "crypto"):
if use_parser:
category = "digital-currency"
else:
from alpha_vantage.cryptocurrencies import CryptoCurrencies
category = CryptoCurrencies
elif category.lower() in ("commodities",):
if use_parser:
category = "commodities"
else:
raise NotImplementedError(f"Category '{category}' not supported by alpha_vantage. Use parser.")
elif category.lower() in ("economic-indicators",):
if use_parser:
category = "economic-indicators"
else:
raise NotImplementedError(f"Category '{category}' not supported by alpha_vantage. Use parser.")
elif category.lower() in ("technical-indicators", "techindicators", "indicators"):
if use_parser:
category = "technical-indicators"
else:
from alpha_vantage.techindicators import TechIndicators
category = TechIndicators
else:
raise ValueError(f"Invalid category: '{category}'")
else:
if use_parser:
raise TypeError("Category must be a string")
else:
from alpha_vantage.alphavantage import AlphaVantage
if isinstance(category, type):
if not issubclass(category, AlphaVantage):
raise TypeError("Category must be a subclass of AlphaVantage")
elif not isinstance(category, AlphaVantage):
raise TypeError("Category must be an instance of AlphaVantage")
if use_parser:
if function is None:
if category is not None:
if category in ("commodities", "economic-indicators"):
function = symbol
if function is None:
if category is None:
category = "time-series-data"
if category in ("fundamentals", "technical-indicators"):
raise ValueError("Function is required")
adjusted_in_functions = False
extended_in_functions = False
matched_functions = []
for k in api_meta[category]:
if interval_type is None or interval_type.upper() in k:
if "ADJUSTED" in k:
adjusted_in_functions = True
if "EXTENDED" in k:
extended_in_functions = True
matched_functions.append(k)
if adjusted_in_functions:
matched_functions = [
k
for k in matched_functions
if (adjusted and "ADJUSTED" in k) or (not adjusted and "ADJUSTED" not in k)
]
if extended_in_functions:
matched_functions = [
k
for k in matched_functions
if (extended and "EXTENDED" in k) or (not extended and "EXTENDED" not in k)
]
if len(matched_functions) == 0:
raise ValueError("No functions satisfy the requirements")
if len(matched_functions) > 1:
raise ValueError("More than one function satisfies the requirements")
function = matched_functions[0]
if match_params:
if function is not None and category is None:
category = None
for k, v in api_meta.items():
if function in v:
category = k
break
if category is None:
raise ValueError("Category is required")
req_args = api_meta[category][function]["req_args"]
opt_args = api_meta[category][function]["opt_args"]
args = set(req_args) | set(opt_args)
matched_params = dict()
matched_params["function"] = function
matched_params["datatype"] = "csv"
matched_params["apikey"] = apikey
if "symbol" in args and "market" in args:
matched_params["symbol"] = symbol.split("_")[0]
matched_params["market"] = symbol.split("_")[1]
elif "from_" in args and "to_currency" in args:
matched_params["from_currency"] = symbol.split("_")[0]
matched_params["to_currency"] = symbol.split("_")[1]
elif "from_currency" in args and "to_currency" in args:
matched_params["from_currency"] = symbol.split("_")[0]
matched_params["to_currency"] = symbol.split("_")[1]
elif "symbol" in args:
matched_params["symbol"] = symbol
if "interval" in args:
matched_params["interval"] = interval
if "adjusted" in args:
matched_params["adjusted"] = adjusted
if "extended" in args:
matched_params["extended"] = extended
if "extended_hours" in args:
matched_params["extended_hours"] = extended
if "slice" in args:
matched_params["slice"] = slice
if "series_type" in args:
matched_params["series_type"] = series_type
if "time_period" in args:
matched_params["time_period"] = time_period
if "outputsize" in args:
matched_params["outputsize"] = outputsize
for k, v in params.items():
if k in args:
matched_params[k] = v
else:
raise ValueError(f"Function '{function}' does not expect parameter '{k}'")
for arg in req_args:
if arg not in matched_params:
raise ValueError(f"Function '{function}' requires parameter '{arg}'")
else:
matched_params = dict(params)
matched_params["function"] = function
matched_params["apikey"] = apikey
matched_params["datatype"] = "csv"
url = "https://www.alphavantage.co/query?" + urllib.parse.urlencode(matched_params)
df = pd.read_csv(url, **read_csv_kwargs)
else:
from alpha_vantage.alphavantage import AlphaVantage
from alpha_vantage.timeseries import TimeSeries
from alpha_vantage.fundamentaldata import FundamentalData
from alpha_vantage.foreignexchange import ForeignExchange
from alpha_vantage.cryptocurrencies import CryptoCurrencies
from alpha_vantage.techindicators import TechIndicators
if isinstance(category, type) and issubclass(category, AlphaVantage):
category = category(key=apikey, output_format="pandas")
if function is None:
if category is None:
category = TimeSeries(key=apikey, output_format="pandas")
if isinstance(category, (TechIndicators, FundamentalData)):
raise ValueError("Function is required")
adjusted_in_methods = False
extended_in_methods = False
matched_methods = []
for k in dir(category):
if interval_type is None or interval_type in k:
if "adjusted" in k:
adjusted_in_methods = True
if "extended" in k:
extended_in_methods = True
matched_methods.append(k)
if adjusted_in_methods:
matched_methods = [
k
for k in matched_methods
if (adjusted and "adjusted" in k) or (not adjusted and "adjusted" not in k)
]
if extended_in_methods:
matched_methods = [
k
for k in matched_methods
if (extended and "extended" in k) or (not extended and "extended" not in k)
]
if len(matched_methods) == 0:
raise ValueError("No methods satisfy the requirements")
if len(matched_methods) > 1:
raise ValueError("More than one method satisfies the requirements")
function = matched_methods[0]
if isinstance(function, str):
function = function.lower()
if not function.startswith("get_"):
function = "get_" + function
if category is not None:
function = getattr(category, function)
else:
categories = [
TimeSeries,
FundamentalData,
ForeignExchange,
CryptoCurrencies,
TechIndicators,
]
matched_methods = []
for category in categories:
if function in dir(category):
matched_methods.append(getattr(category, function))
if len(matched_methods) == 0:
raise ValueError("No methods satisfy the requirements")
if len(matched_methods) > 1:
raise ValueError("More than one method satisfies the requirements")
function = matched_methods[0]
if match_params:
args = set(get_func_arg_names(function))
matched_params = dict()
if "symbol" in args and "market" in args:
matched_params["symbol"] = symbol.split("_")[0]
matched_params["market"] = symbol.split("_")[1]
elif "from_" in args and "to_currency" in args:
matched_params["from_currency"] = symbol.split("_")[0]
matched_params["to_currency"] = symbol.split("_")[1]
elif "from_currency" in args and "to_currency" in args:
matched_params["from_currency"] = symbol.split("_")[0]
matched_params["to_currency"] = symbol.split("_")[1]
elif "symbol" in args:
matched_params["symbol"] = symbol
if "interval" in args:
matched_params["interval"] = interval
if "adjusted" in args:
matched_params["adjusted"] = adjusted
if "extended" in args:
matched_params["extended"] = extended
if "extended_hours" in args:
matched_params["extended_hours"] = extended
if "slice" in args:
matched_params["slice"] = slice
if "series_type" in args:
matched_params["series_type"] = series_type
if "time_period" in args:
matched_params["time_period"] = time_period
if "outputsize" in args:
matched_params["outputsize"] = outputsize
else:
matched_params = dict(params)
df, df_metadata = function(**matched_params)
for k, v in df_metadata.items():
if "Time Zone" in k:
if tz is None:
if v.endswith(" Time"):
v = v[: -len(" Time")]
tz = v
df.index.name = None
new_columns = []
for c in df.columns:
new_c = re.sub(r"^\d+\w*\.\s*", "", c)
new_c = new_c[0].title() + new_c[1:]
if new_c.endswith(" (USD)"):
new_c = new_c[: -len(" (USD)")]
new_columns.append(new_c)
df = df.rename(columns=dict(zip(df.columns, new_columns)))
df = df.loc[:, ~df.columns.duplicated()]
for c in df.columns:
if df[c].dtype == "O":
df[c] = df[c].replace({".": np.nan})
df = df.apply(pd.to_numeric, errors="ignore")
if not df.empty and df.index[0] > df.index[1]:
df = df.iloc[::-1]
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None and tz is not None:
df = df.tz_localize(tz)
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `BentoData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.parsing import get_func_arg_names
try:
if not tp.TYPE_CHECKING:
raise ImportError
from databento import Historical as HistoricalT
except ImportError:
HistoricalT = "Historical"
__all__ = [
"BentoData",
]
class BentoData(RemoteData):
"""Data class for fetching from Databento.
See https://github.com/databento/databento-python for API.
See `BentoData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.BentoData.set_custom_settings(
... client_config=dict(
... key="YOUR_KEY"
... )
... )
```
* Pull data:
```pycon
>>> data = vbt.BentoData.pull(
... "AAPL",
... dataset="XNAS.ITCH"
... )
```
```pycon
>>> data = vbt.BentoData.pull(
... "AAPL",
... dataset="XNAS.ITCH",
... timeframe="hourly",
... start="one week ago"
... )
```
```pycon
>>> data = vbt.BentoData.pull(
... "ES.FUT",
... dataset="GLBX.MDP3",
... stype_in="parent",
... schema="mbo",
... start="2022-06-10T14:30",
... end="2022-06-11",
... limit=1000
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.bento")
@classmethod
def resolve_client(cls, client: tp.Optional[HistoricalT] = None, **client_config) -> HistoricalT:
"""Resolve the client.
If provided, must be of the type `databento.historical.client.Historical`.
Otherwise, will be created using `client_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("databento")
from databento import Historical
client = cls.resolve_custom_setting(client, "client")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
client = Historical(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def get_cost(cls, symbols: tp.MaybeSymbols, **kwargs) -> float:
"""Get the cost of calling `BentoData.fetch_symbol` on one or more symbols."""
if isinstance(symbols, str):
symbols = [symbols]
costs = []
for symbol in symbols:
client, params = cls.fetch_symbol(symbol, **kwargs, return_params=True)
cost_arg_names = get_func_arg_names(client.metadata.get_cost)
for k in list(params.keys()):
if k not in cost_arg_names:
del params[k]
costs.append(client.metadata.get_cost(**params, mode="historical"))
return sum(costs)
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[HistoricalT] = None,
client_config: tp.KwargsLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
resolve_dates: tp.Optional[bool] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
dataset: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
return_params: bool = False,
df_kwargs: tp.KwargsLike = None,
**params,
) -> tp.Union[float, tp.SymbolData]:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Databento.
Args:
symbol (str): Symbol.
Symbol can be in the `DATASET:SYMBOL` format if `dataset` is None.
client (binance.client.Client): Client.
See `BentoData.resolve_client`.
client_config (dict): Client config.
See `BentoData.resolve_client`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
resolve_dates (bool): Whether to resolve `start` and `end`, or pass them as they are.
timeframe (str): Timeframe to create `schema` from.
Allows human-readable strings such as "1 minute".
If `timeframe` and `schema` are both not None, will raise an error.
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
dataset (str): See `databento.historical.client.Historical.get_range`.
schema (str): See `databento.historical.client.Historical.get_range`.
return_params (bool): Whether to return the client and (final) parameters instead of data.
Used by `BentoData.get_cost`.
df_kwargs (dict): Keyword arguments passed to `databento.common.dbnstore.DBNStore.to_df`.
**params: Keyword arguments passed to `databento.historical.client.Historical.get_range`.
For defaults, see `custom.bento` in `vectorbtpro._settings.data`.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("databento")
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
resolve_dates = cls.resolve_custom_setting(resolve_dates, "resolve_dates")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
dataset = cls.resolve_custom_setting(dataset, "dataset")
schema = cls.resolve_custom_setting(schema, "schema")
params = cls.resolve_custom_setting(params, "params", merge=True)
df_kwargs = cls.resolve_custom_setting(df_kwargs, "df_kwargs", merge=True)
if dataset is None:
if ":" in symbol:
dataset, symbol = symbol.split(":")
if timeframe is None and schema is None:
schema = "ohlcv-1d"
freq = "1d"
elif timeframe is not None:
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is not None:
multiplier, unit = split
timeframe = str(multiplier) + unit
if schema is None or schema.lower() == "ohlcv":
schema = f"ohlcv-{timeframe}"
else:
raise ValueError("Timeframe cannot be used together with schema")
else:
if schema.startswith("ohlcv-"):
freq = schema[len("ohlcv-") :]
else:
freq = None
if resolve_dates:
dataset_range = client.metadata.get_dataset_range(dataset)
if "start_date" in dataset_range:
start_date = dt.to_tzaware_timestamp(dataset_range["start_date"], naive_tz="utc", tz="utc")
else:
start_date = dt.to_tzaware_timestamp(dataset_range["start"], naive_tz="utc", tz="utc")
if "end_date" in dataset_range:
end_date = dt.to_tzaware_timestamp(dataset_range["end_date"], naive_tz="utc", tz="utc")
else:
end_date = dt.to_tzaware_timestamp(dataset_range["end"], naive_tz="utc", tz="utc")
if start is not None:
start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz="utc")
if start < start_date:
start = start_date
else:
start = start_date
if end is not None:
end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz="utc")
if end > end_date:
end = end_date
else:
end = end_date
if start.floor("d") == start:
start = start.date().isoformat()
else:
start = start.isoformat()
if end.floor("d") == end:
end = end.date().isoformat()
else:
end = end.isoformat()
params = merge_dicts(
dict(
dataset=dataset,
start=start,
end=end,
symbols=symbol,
schema=schema,
),
params,
)
if return_params:
return client, params
df = client.timeseries.get_range(**params).to_df(**df_kwargs)
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `BinanceData`."""
import time
import traceback
from functools import partial
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from binance.client import Client as BinanceClientT
except ImportError:
BinanceClientT = "BinanceClient"
__all__ = [
"BinanceData",
]
__pdoc__ = {}
BinanceDataT = tp.TypeVar("BinanceDataT", bound="BinanceData")
class BinanceData(RemoteData):
"""Data class for fetching from Binance.
See https://github.com/sammchardy/python-binance for API.
See `BinanceData.fetch_symbol` for arguments.
!!! note
If you are using an exchange from the US, Japan or other TLD then make sure pass `tld="us"`
in `client_config` when creating the client.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.BinanceData.set_custom_settings(
... client_config=dict(
... api_key="YOUR_KEY",
... api_secret="YOUR_SECRET"
... )
... )
```
* Pull data:
```pycon
>>> data = vbt.BinanceData.pull(
... "BTCUSDT",
... start="2020-01-01",
... end="2021-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.binance")
_feature_config: tp.ClassVar[Config] = HybridConfig(
{
"Quote volume": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
"Taker base volume": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
"Taker quote volume": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
}
)
@property
def feature_config(self) -> Config:
return self._feature_config
@classmethod
def resolve_client(cls, client: tp.Optional[BinanceClientT] = None, **client_config) -> BinanceClientT:
"""Resolve the client.
If provided, must be of the type `binance.client.Client`.
Otherwise, will be created using `client_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("binance")
from binance.client import Client
client = cls.resolve_custom_setting(client, "client")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
client = Client(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
client: tp.Optional[BinanceClientT] = None,
client_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all symbols.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
all_symbols = []
for dct in client.get_exchange_info()["symbols"]:
symbol = dct["symbol"]
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[BinanceClientT] = None,
client_config: tp.KwargsLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
klines_type: tp.Union[None, int, str] = None,
limit: tp.Optional[int] = None,
delay: tp.Optional[float] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
**get_klines_kwargs,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Binance.
Args:
symbol (str): Symbol.
client (binance.client.Client): Client.
See `BinanceData.resolve_client`.
client_config (dict): Client config.
See `BinanceData.resolve_client`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
klines_type (int or str): Kline type.
See `binance.enums.HistoricalKlinesType`. Supports strings.
limit (int): The maximum number of returned items.
delay (float): Time to sleep after each request (in seconds).
show_progress (bool): Whether to show the progress bar.
pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`.
silence_warnings (bool): Whether to silence all warnings.
**get_klines_kwargs: Keyword arguments passed to `binance.client.Client.get_klines`.
For defaults, see `custom.binance` in `vectorbtpro._settings.data`.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("binance")
from binance.enums import HistoricalKlinesType
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
klines_type = cls.resolve_custom_setting(klines_type, "klines_type")
if isinstance(klines_type, str):
klines_type = map_enum_fields(klines_type, HistoricalKlinesType)
if isinstance(klines_type, int):
klines_type = {i.value: i for i in HistoricalKlinesType}[klines_type]
limit = cls.resolve_custom_setting(limit, "limit")
delay = cls.resolve_custom_setting(delay, "delay")
show_progress = cls.resolve_custom_setting(show_progress, "show_progress")
pbar_kwargs = cls.resolve_custom_setting(pbar_kwargs, "pbar_kwargs", merge=True)
if "bar_id" not in pbar_kwargs:
pbar_kwargs["bar_id"] = "binance"
silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings")
get_klines_kwargs = cls.resolve_custom_setting(get_klines_kwargs, "get_klines_kwargs", merge=True)
# Prepare parameters
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is not None:
multiplier, unit = split
if unit == "D":
unit = "d"
elif unit == "W":
unit = "w"
timeframe = str(multiplier) + unit
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
first_valid_ts = client._get_earliest_valid_timestamp(symbol, timeframe, klines_type)
start_ts = max(start_ts, first_valid_ts)
else:
start_ts = None
prev_end_ts = None
if end is not None:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc"))
else:
end_ts = None
def _ts_to_str(ts: tp.Optional[int]) -> str:
if ts is None:
return "?"
return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe)
def _filter_func(d: tp.Sequence, _prev_end_ts: tp.Optional[int] = None) -> bool:
if start_ts is not None:
if d[0] < start_ts:
return False
if _prev_end_ts is not None:
if d[0] <= _prev_end_ts:
return False
if end_ts is not None:
if d[0] >= end_ts:
return False
return True
# Iteratively collect the data
data = []
try:
with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar:
pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts)))
while True:
# Fetch the klines for the next timeframe
next_data = client._klines(
symbol=symbol,
interval=timeframe,
limit=limit,
startTime=start_ts if prev_end_ts is None else prev_end_ts,
endTime=end_ts,
klines_type=klines_type,
**get_klines_kwargs,
)
next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data))
# Update the timestamps and the progress bar
if not len(next_data):
break
data += next_data
if start_ts is None:
start_ts = next_data[0][0]
pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1][0])))
pbar.update()
prev_end_ts = next_data[-1][0]
if end_ts is not None and prev_end_ts >= end_ts:
break
if delay is not None:
time.sleep(delay) # be kind to api
except Exception as e:
if not silence_warnings:
warn(traceback.format_exc())
warn(
f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. "
"Use update() method to fetch missing data."
)
# Convert data to a DataFrame
df = pd.DataFrame(
data,
columns=[
"Open time",
"Open",
"High",
"Low",
"Close",
"Volume",
"Close time",
"Quote volume",
"Trade count",
"Taker base volume",
"Taker quote volume",
"Ignore",
],
)
df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True)
df["Open"] = df["Open"].astype(float)
df["High"] = df["High"].astype(float)
df["Low"] = df["Low"].astype(float)
df["Close"] = df["Close"].astype(float)
df["Volume"] = df["Volume"].astype(float)
df["Quote volume"] = df["Quote volume"].astype(float)
df["Trade count"] = df["Trade count"].astype(int, errors="ignore")
df["Taker base volume"] = df["Taker base volume"].astype(float)
df["Taker quote volume"] = df["Taker quote volume"].astype(float)
del df["Open time"]
del df["Close time"]
del df["Ignore"]
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
BinanceData.override_feature_config_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `CCXTData`."""
import time
import traceback
from functools import wraps, partial
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from ccxt.base.exchange import Exchange as CCXTExchangeT
except ImportError:
CCXTExchangeT = "CCXTExchange"
__all__ = [
"CCXTData",
]
__pdoc__ = {}
class CCXTData(RemoteData):
"""Data class for fetching using CCXT.
See https://github.com/ccxt/ccxt for API.
See `CCXTData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.CCXTData.set_exchange_settings(
... exchange_name="binance",
... populate_=True,
... exchange_config=dict(
... apiKey="YOUR_KEY",
... secret="YOUR_SECRET"
... )
... )
```
* Pull data:
```pycon
>>> data = vbt.CCXTData.pull(
... "BTCUSDT",
... exchange="binance",
... start="2020-01-01",
... end="2021-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.ccxt")
@classmethod
def get_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> dict:
"""`CCXTData.get_custom_settings` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.get_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def has_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> bool:
"""`CCXTData.has_custom_settings` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.has_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def get_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> tp.Any:
"""`CCXTData.get_custom_setting` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.get_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def has_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> bool:
"""`CCXTData.has_custom_setting` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.has_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def resolve_exchange_setting(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> tp.Any:
"""`CCXTData.resolve_custom_setting` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
return cls.resolve_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def set_exchange_settings(cls, *args, exchange_name: tp.Optional[str] = None, **kwargs) -> None:
"""`CCXTData.set_custom_settings` with `sub_path=exchange_name`."""
if exchange_name is not None:
sub_path = "exchanges." + exchange_name
else:
sub_path = None
cls.set_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
exchange: tp.Union[None, str, CCXTExchangeT] = None,
exchange_config: tp.Optional[tp.KwargsLike] = None,
) -> tp.List[str]:
"""List all symbols.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if exchange_config is None:
exchange_config = {}
exchange = cls.resolve_exchange(exchange=exchange, **exchange_config)
all_symbols = []
for symbol in exchange.load_markets():
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def resolve_exchange(
cls,
exchange: tp.Union[None, str, CCXTExchangeT] = None,
**exchange_config,
) -> CCXTExchangeT:
"""Resolve the exchange.
If provided, must be of the type `ccxt.base.exchange.Exchange`.
Otherwise, will be created using `exchange_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("ccxt")
import ccxt
exchange = cls.resolve_exchange_setting(exchange, "exchange")
if exchange is None:
exchange = "binance"
if isinstance(exchange, str):
exchange = exchange.lower()
exchange_name = exchange
elif isinstance(exchange, ccxt.Exchange):
exchange_name = type(exchange).__name__
else:
raise ValueError(f"Unknown exchange of type {type(exchange)}")
if exchange_config is None:
exchange_config = {}
has_exchange_config = len(exchange_config) > 0
exchange_config = cls.resolve_exchange_setting(
exchange_config, "exchange_config", merge=True, exchange_name=exchange_name
)
if isinstance(exchange, str):
if not hasattr(ccxt, exchange):
raise ValueError(f"Exchange '{exchange}' not found in CCXT")
exchange = getattr(ccxt, exchange)(exchange_config)
else:
if has_exchange_config:
raise ValueError("Cannot apply config after instantiation of the exchange")
return exchange
@staticmethod
def _find_earliest_date(
fetch_func: tp.Callable,
start: tp.DatetimeLike = 0,
end: tp.DatetimeLike = "now",
tz: tp.TimezoneLike = None,
for_internal_use: bool = False,
) -> tp.Optional[pd.Timestamp]:
"""Find the earliest date using binary search."""
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
fetched_data = fetch_func(start_ts, 1)
if for_internal_use and len(fetched_data) > 0:
return pd.Timestamp(start_ts, unit="ms", tz="utc")
else:
fetched_data = []
if len(fetched_data) == 0 and start != 0:
fetched_data = fetch_func(0, 1)
if for_internal_use and len(fetched_data) > 0:
return pd.Timestamp(0, unit="ms", tz="utc")
if len(fetched_data) == 0:
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
else:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(0, naive_tz=tz, tz="utc"))
start_ts = start_ts - start_ts % 86400000
if end is not None:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc"))
else:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime("now", naive_tz=tz, tz="utc"))
end_ts = end_ts - end_ts % 86400000 + 86400000
start_time = start_ts
end_time = end_ts
while True:
mid_time = (start_time + end_time) // 2
mid_time = mid_time - mid_time % 86400000
if mid_time == start_time:
break
_fetched_data = fetch_func(mid_time, 1)
if len(_fetched_data) == 0:
start_time = mid_time
else:
end_time = mid_time
fetched_data = _fetched_data
if len(fetched_data) > 0:
return pd.Timestamp(fetched_data[0][0], unit="ms", tz="utc")
return None
@classmethod
def find_earliest_date(cls, symbol: str, for_internal_use: bool = False, **kwargs) -> tp.Optional[pd.Timestamp]:
"""Find the earliest date using binary search.
See `CCXTData.fetch_symbol` for arguments."""
return cls._find_earliest_date(
**cls.fetch_symbol(symbol, return_fetch_method=True, **kwargs),
for_internal_use=for_internal_use,
)
@classmethod
def fetch_symbol(
cls,
symbol: str,
exchange: tp.Union[None, str, CCXTExchangeT] = None,
exchange_config: tp.Optional[tp.KwargsLike] = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
find_earliest_date: tp.Optional[bool] = None,
limit: tp.Optional[int] = None,
delay: tp.Optional[float] = None,
retries: tp.Optional[int] = None,
fetch_params: tp.Optional[tp.KwargsLike] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
return_fetch_method: bool = False,
) -> tp.Union[dict, tp.SymbolData]:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from CCXT.
Args:
symbol (str): Symbol.
Symbol can be in the `EXCHANGE:SYMBOL` format, in this case `exchange` argument will be ignored.
exchange (str or object): Exchange identifier or an exchange object.
See `CCXTData.resolve_exchange`.
exchange_config (dict): Exchange config.
See `CCXTData.resolve_exchange`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
find_earliest_date (bool): Whether to find the earliest date using `CCXTData.find_earliest_date`.
limit (int): The maximum number of returned items.
delay (float): Time to sleep after each request (in seconds).
!!! note
Use only if `enableRateLimit` is not set.
retries (int): The number of retries on failure to fetch data.
fetch_params (dict): Exchange-specific keyword arguments passed to `fetch_ohlcv`.
show_progress (bool): Whether to show the progress bar.
pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`.
silence_warnings (bool): Whether to silence all warnings.
return_fetch_method (bool): Required by `CCXTData.find_earliest_date`.
For defaults, see `custom.ccxt` in `vectorbtpro._settings.data`.
Global settings can be provided per exchange id using the `exchanges` dictionary.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("ccxt")
import ccxt
exchange = cls.resolve_custom_setting(exchange, "exchange")
if exchange is None and ":" in symbol:
exchange, symbol = symbol.split(":")
if exchange_config is None:
exchange_config = {}
exchange = cls.resolve_exchange(exchange=exchange, **exchange_config)
exchange_name = type(exchange).__name__
start = cls.resolve_exchange_setting(start, "start", exchange_name=exchange_name)
end = cls.resolve_exchange_setting(end, "end", exchange_name=exchange_name)
timeframe = cls.resolve_exchange_setting(timeframe, "timeframe", exchange_name=exchange_name)
tz = cls.resolve_exchange_setting(tz, "tz", exchange_name=exchange_name)
find_earliest_date = cls.resolve_exchange_setting(
find_earliest_date, "find_earliest_date", exchange_name=exchange_name
)
limit = cls.resolve_exchange_setting(limit, "limit", exchange_name=exchange_name)
delay = cls.resolve_exchange_setting(delay, "delay", exchange_name=exchange_name)
retries = cls.resolve_exchange_setting(retries, "retries", exchange_name=exchange_name)
fetch_params = cls.resolve_exchange_setting(
fetch_params, "fetch_params", merge=True, exchange_name=exchange_name
)
show_progress = cls.resolve_exchange_setting(show_progress, "show_progress", exchange_name=exchange_name)
pbar_kwargs = cls.resolve_exchange_setting(pbar_kwargs, "pbar_kwargs", merge=True, exchange_name=exchange_name)
if "bar_id" not in pbar_kwargs:
pbar_kwargs["bar_id"] = "ccxt"
silence_warnings = cls.resolve_exchange_setting(
silence_warnings, "silence_warnings", exchange_name=exchange_name
)
if not exchange.has["fetchOHLCV"]:
raise ValueError(f"Exchange {exchange} does not support OHLCV")
if exchange.has["fetchOHLCV"] == "emulated":
if not silence_warnings:
warn("Using emulated OHLCV candles")
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is not None:
multiplier, unit = split
if unit == "D":
unit = "d"
elif unit == "W":
unit = "w"
elif unit == "Y":
unit = "y"
timeframe = str(multiplier) + unit
if timeframe not in exchange.timeframes:
raise ValueError(f"Exchange {exchange} does not support {timeframe} timeframe")
def _retry(method):
@wraps(method)
def retry_method(*args, **kwargs):
for i in range(retries):
try:
return method(*args, **kwargs)
except ccxt.NetworkError as e:
if i == retries - 1:
raise e
if not silence_warnings:
warn(traceback.format_exc())
if delay is not None:
time.sleep(delay)
return retry_method
@_retry
def _fetch(_since, _limit):
return exchange.fetch_ohlcv(
symbol,
timeframe=timeframe,
since=_since,
limit=_limit,
params=fetch_params,
)
if return_fetch_method:
return dict(fetch_func=_fetch, start=start, end=end, tz=tz)
# Establish the timestamps
if find_earliest_date and start is not None:
start = cls._find_earliest_date(_fetch, start=start, end=end, tz=tz, for_internal_use=True)
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
else:
start_ts = None
if end is not None:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="UTC"))
else:
end_ts = None
prev_end_ts = None
def _ts_to_str(ts: tp.Optional[int]) -> str:
if ts is None:
return "?"
return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe)
def _filter_func(d: tp.Sequence, _prev_end_ts: tp.Optional[int] = None) -> bool:
if start_ts is not None:
if d[0] < start_ts:
return False
if _prev_end_ts is not None:
if d[0] <= _prev_end_ts:
return False
if end_ts is not None:
if d[0] >= end_ts:
return False
return True
# Iteratively collect the data
data = []
try:
with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar:
pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts)))
while True:
# Fetch the klines for the next timeframe
next_data = _fetch(start_ts if prev_end_ts is None else prev_end_ts, limit)
next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data))
# Update the timestamps and the progress bar
if not len(next_data):
break
data += next_data
if start_ts is None:
start_ts = next_data[0][0]
pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1][0])))
pbar.update()
prev_end_ts = next_data[-1][0]
if end_ts is not None and prev_end_ts >= end_ts:
break
if delay is not None:
time.sleep(delay) # be kind to api
except Exception as e:
if not silence_warnings:
warn(traceback.format_exc())
warn(
f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. "
"Use update() method to fetch missing data."
)
# Convert data to a DataFrame
df = pd.DataFrame(data, columns=["Open time", "Open", "High", "Low", "Close", "Volume"])
df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True)
del df["Open time"]
if "Open" in df.columns:
df["Open"] = df["Open"].astype(float)
if "High" in df.columns:
df["High"] = df["High"].astype(float)
if "Low" in df.columns:
df["Low"] = df["Low"].astype(float)
if "Close" in df.columns:
df["Close"] = df["Close"].astype(float)
if "Volume" in df.columns:
df["Volume"] = df["Volume"].astype(float)
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `CSVData`."""
from pathlib import Path
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"CSVData",
]
__pdoc__ = {}
CSVDataT = tp.TypeVar("CSVDataT", bound="CSVData")
class CSVData(FileData):
"""Data class for fetching CSV data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.csv")
@classmethod
def is_csv_file(cls, path: tp.PathLike) -> bool:
"""Return whether the path is a CSV/TSV file."""
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_file() and ".csv" in path.suffixes:
return True
if path.exists() and path.is_file() and ".tsv" in path.suffixes:
return True
return False
@classmethod
def is_file_match(cls, path: tp.PathLike) -> bool:
return cls.is_csv_file(path)
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
keys_meta = FileData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None and paths is None:
keys_meta["keys"] = cls.list_paths()
return keys_meta
@classmethod
def fetch_key(
cls,
key: tp.Key,
path: tp.Any = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
tz: tp.TimezoneLike = None,
start_row: tp.Optional[int] = None,
end_row: tp.Optional[int] = None,
header: tp.Optional[tp.MaybeSequence[int]] = None,
index_col: tp.Optional[int] = None,
parse_dates: tp.Optional[bool] = None,
chunk_func: tp.Optional[tp.Callable] = None,
squeeze: tp.Optional[bool] = None,
**read_kwargs,
) -> tp.KeyData:
"""Fetch the CSV file of a feature or symbol.
Args:
key (hashable): Feature or symbol.
path (str): Path.
If `path` is None, uses `key` as the path to the CSV file.
start (any): Start datetime.
Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`.
end (any): End datetime.
Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`.
tz (any): Target timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
start_row (int): Start row (inclusive).
Must exclude header rows.
end_row (int): End row (exclusive).
Must exclude header rows.
header (int or sequence of int): See `pd.read_csv`.
index_col (int): See `pd.read_csv`.
If False, will pass None.
parse_dates (bool): See `pd.read_csv`.
chunk_func (callable): Function to select and concatenate chunks from `TextFileReader`.
Gets called only if `iterator` or `chunksize` are set.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
**read_kwargs: Other keyword arguments passed to `pd.read_csv`.
`skiprows` and `nrows` will be automatically calculated based on `start_row` and `end_row`.
When either `start` or `end` is provided, will fetch the entire data first and filter it thereafter.
See https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html for other arguments.
For defaults, see `custom.csv` in `vectorbtpro._settings.data`."""
from pandas.io.parsers import TextFileReader
from pandas.api.types import is_object_dtype
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
tz = cls.resolve_custom_setting(tz, "tz")
start_row = cls.resolve_custom_setting(start_row, "start_row")
if start_row is None:
start_row = 0
end_row = cls.resolve_custom_setting(end_row, "end_row")
header = cls.resolve_custom_setting(header, "header")
index_col = cls.resolve_custom_setting(index_col, "index_col")
if index_col is False:
index_col = None
parse_dates = cls.resolve_custom_setting(parse_dates, "parse_dates")
chunk_func = cls.resolve_custom_setting(chunk_func, "chunk_func")
squeeze = cls.resolve_custom_setting(squeeze, "squeeze")
read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True)
if path is None:
path = key
if isinstance(header, int):
header = [header]
header_rows = header[-1] + 1
start_row += header_rows
if end_row is not None:
end_row += header_rows
skiprows = range(header_rows, start_row)
if end_row is not None:
nrows = end_row - start_row
else:
nrows = None
sep = read_kwargs.pop("sep", None)
if isinstance(path, (str, Path)):
try:
_path = Path(path)
if _path.suffix.lower() == ".csv":
if sep is None:
sep = ","
if _path.suffix.lower() == ".tsv":
if sep is None:
sep = "\t"
except Exception as e:
pass
if sep is None:
sep = ","
obj = pd.read_csv(
path,
sep=sep,
header=header,
index_col=index_col,
parse_dates=parse_dates,
skiprows=skiprows,
nrows=nrows,
**read_kwargs,
)
if isinstance(obj, TextFileReader):
if chunk_func is None:
obj = pd.concat(list(obj), axis=0)
else:
obj = chunk_func(obj)
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
if index_col is not None and parse_dates and is_object_dtype(obj.index.dtype):
obj.index = pd.to_datetime(obj.index, utc=True)
if tz is not None:
obj.index = obj.index.tz_convert(tz)
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if start is not None or end is not None:
if not isinstance(obj.index, pd.DatetimeIndex):
raise TypeError("Cannot filter index that is not DatetimeIndex")
if obj.index.tz is not None:
if start is not None:
start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz=obj.index.tz)
if end is not None:
end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz=obj.index.tz)
else:
if start is not None:
start = dt.to_naive_timestamp(start, tz=tz)
if end is not None:
end = dt.to_naive_timestamp(end, tz=tz)
mask = True
if start is not None:
mask &= obj.index >= start
if end is not None:
mask &= obj.index < end
mask_indices = np.flatnonzero(mask)
if len(mask_indices) == 0:
return None
obj = obj.iloc[mask_indices[0] : mask_indices[-1] + 1]
start_row += mask_indices[0]
return obj, dict(last_row=start_row - header_rows + len(obj.index) - 1, tz=tz)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Fetch the CSV file of a feature.
Uses `CSVData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Fetch the CSV file of a symbol.
Uses `CSVData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
returned_kwargs = self.select_returned_kwargs(key)
fetch_kwargs["start_row"] = returned_kwargs["last_row"]
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `CSVData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `CSVData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `CustomData`."""
import fnmatch
import re
from vectorbtpro import _typing as tp
from vectorbtpro.data.base import Data
__all__ = [
"CustomData",
]
__pdoc__ = {}
class CustomData(Data):
"""Data class for fetching custom data."""
_settings_path: tp.SettingsPath = dict(custom=None)
@classmethod
def get_custom_settings(cls, *args, **kwargs) -> dict:
"""`CustomData.get_settings` with `path_id="custom"`."""
return cls.get_settings(*args, path_id="custom", **kwargs)
@classmethod
def has_custom_settings(cls, *args, **kwargs) -> bool:
"""`CustomData.has_settings` with `path_id="custom"`."""
return cls.has_settings(*args, path_id="custom", **kwargs)
@classmethod
def get_custom_setting(cls, *args, **kwargs) -> tp.Any:
"""`CustomData.get_setting` with `path_id="custom"`."""
return cls.get_setting(*args, path_id="custom", **kwargs)
@classmethod
def has_custom_setting(cls, *args, **kwargs) -> bool:
"""`CustomData.has_setting` with `path_id="custom"`."""
return cls.has_setting(*args, path_id="custom", **kwargs)
@classmethod
def resolve_custom_setting(cls, *args, **kwargs) -> tp.Any:
"""`CustomData.resolve_setting` with `path_id="custom"`."""
return cls.resolve_setting(*args, path_id="custom", **kwargs)
@classmethod
def set_custom_settings(cls, *args, **kwargs) -> None:
"""`CustomData.set_settings` with `path_id="custom"`."""
cls.set_settings(*args, path_id="custom", **kwargs)
@staticmethod
def key_match(key: str, pattern: str, use_regex: bool = False):
"""Return whether key matches pattern.
If `use_regex` is True, checks against a regular expression.
Otherwise, checks against a glob-style pattern."""
if use_regex:
return re.match(pattern, key)
return re.match(fnmatch.translate(pattern), key)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `DBData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.local import LocalData
__all__ = [
"DBData",
]
__pdoc__ = {}
class DBData(LocalData):
"""Data class for fetching database data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.db")
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `DuckDBData`."""
from pathlib import Path
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.base import key_dict
from vectorbtpro.data.custom.db import DBData
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
try:
if not tp.TYPE_CHECKING:
raise ImportError
from duckdb import DuckDBPyConnection as DuckDBPyConnectionT, DuckDBPyRelation as DuckDBPyRelationT
except ImportError:
DuckDBPyConnectionT = "DuckDBPyConnection"
DuckDBPyRelationT = "DuckDBPyRelation"
__all__ = [
"DuckDBData",
]
__pdoc__ = {}
DuckDBDataT = tp.TypeVar("DuckDBDataT", bound="DuckDBData")
class DuckDBData(DBData):
"""Data class for fetching data using DuckDB.
See `DuckDBData.pull` and `DuckDBData.fetch_key` for arguments.
Usage:
* Set up the connection settings globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.DuckDBData.set_custom_settings(connection="database.duckdb")
```
* Pull tables:
```pycon
>>> data = vbt.DuckDBData.pull(["TABLE1", "TABLE2"])
```
* Rename tables:
```pycon
>>> data = vbt.DuckDBData.pull(
... ["SYMBOL1", "SYMBOL2"],
... table=vbt.key_dict({
... "SYMBOL1": "TABLE1",
... "SYMBOL2": "TABLE2"
... })
... )
```
* Pull queries:
```pycon
>>> data = vbt.DuckDBData.pull(
... ["SYMBOL1", "SYMBOL2"],
... query=vbt.key_dict({
... "SYMBOL1": "SELECT * FROM TABLE1",
... "SYMBOL2": "SELECT * FROM TABLE2"
... })
... )
```
* Pull Parquet files:
```pycon
>>> data = vbt.DuckDBData.pull(
... ["SYMBOL1", "SYMBOL2"],
... read_path=vbt.key_dict({
... "SYMBOL1": "s1.parquet",
... "SYMBOL2": "s2.parquet"
... })
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.duckdb")
@classmethod
def resolve_connection(
cls,
connection: tp.Union[None, str, tp.PathLike, DuckDBPyConnectionT] = None,
read_only: bool = True,
return_meta: bool = False,
**connection_config,
) -> tp.Union[DuckDBPyConnectionT, dict]:
"""Resolve the connection."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("duckdb")
from duckdb import connect, default_connection
connection_meta = {}
connection = cls.resolve_custom_setting(connection, "connection")
if connection_config is None:
connection_config = {}
has_connection_config = len(connection_config) > 0
connection_config["read_only"] = read_only
connection_config = cls.resolve_custom_setting(connection_config, "connection_config", merge=True)
read_only = connection_config.pop("read_only", read_only)
should_close = False
if connection is None:
if len(connection_config) == 0:
connection = default_connection
else:
database = connection_config.pop("database", None)
if "config" in connection_config or len(connection_config) == 0:
connection = connect(database, read_only=read_only, **connection_config)
else:
connection = connect(database, read_only=read_only, config=connection_config)
should_close = True
elif isinstance(connection, (str, Path)):
if "config" in connection_config or len(connection_config) == 0:
connection = connect(str(connection), read_only=read_only, **connection_config)
else:
connection = connect(str(connection), read_only=read_only, config=connection_config)
should_close = True
elif has_connection_config:
raise ValueError("Cannot apply connection_config to already initialized connection")
if return_meta:
connection_meta["connection"] = connection
connection_meta["should_close"] = should_close
return connection_meta
return connection
@classmethod
def list_catalogs(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
incl_system: bool = False,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all catalogs.
Catalogs "system" and "temp" are skipped if `incl_system` is False.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
schemata_df = connection.sql("SELECT * FROM information_schema.schemata").df()
catalogs = []
for catalog in schemata_df["catalog_name"].tolist():
if pattern is not None:
if not cls.key_match(catalog, pattern, use_regex=use_regex):
continue
if not incl_system and catalog == "system":
continue
if not incl_system and catalog == "temp":
continue
catalogs.append(catalog)
if should_close:
connection.close()
if sort:
return sorted(dict.fromkeys(catalogs))
return list(dict.fromkeys(catalogs))
@classmethod
def list_schemas(
cls,
catalog_pattern: tp.Optional[str] = None,
schema_pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
catalog: tp.Optional[str] = None,
incl_system: bool = False,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all schemas.
If `catalog` is None, searches for all catalog names in the database and prefixes each schema
with the respective catalog name. If `catalog` is provided, returns the schemas corresponding
to this catalog without a prefix. Schemas "information_schema" and "pg_catalog" are skipped
if `incl_system` is False.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
if catalog is None:
catalogs = cls.list_catalogs(
pattern=catalog_pattern,
use_regex=use_regex,
sort=sort,
incl_system=incl_system,
connection=connection,
connection_config=connection_config,
)
if len(catalogs) == 1:
prefix_catalog = False
else:
prefix_catalog = True
else:
catalogs = [catalog]
prefix_catalog = False
schemata_df = connection.sql("SELECT * FROM information_schema.schemata").df()
schemas = []
for catalog in catalogs:
all_schemas = schemata_df[schemata_df["catalog_name"] == catalog]["schema_name"].tolist()
for schema in all_schemas:
if schema_pattern is not None:
if not cls.key_match(schema, schema_pattern, use_regex=use_regex):
continue
if not incl_system and schema == "information_schema":
continue
if not incl_system and schema == "pg_catalog":
continue
if prefix_catalog:
schema = catalog + ":" + schema
schemas.append(schema)
if should_close:
connection.close()
if sort:
return sorted(dict.fromkeys(schemas))
return list(dict.fromkeys(schemas))
@classmethod
def get_current_schema(
cls,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> str:
"""Get the current schema."""
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
current_schema = connection.sql("SELECT current_schema()").fetchall()[0][0]
if should_close:
connection.close()
return current_schema
@classmethod
def list_tables(
cls,
*,
catalog_pattern: tp.Optional[str] = None,
schema_pattern: tp.Optional[str] = None,
table_pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
catalog: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
incl_system: bool = False,
incl_temporary: bool = False,
incl_views: bool = True,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> tp.List[str]:
"""List all tables and views.
If `schema` is None, searches for all schema names in the database and prefixes each table
with the respective catalog and schema name (unless there's only one schema which is the current
schema or `schema` is `current_schema`). If `schema` is provided, returns the tables corresponding
to this schema without a prefix.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each schema against
`schema_pattern` and each table against `table_pattern`."""
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
if catalog is None:
catalogs = cls.list_catalogs(
pattern=catalog_pattern,
use_regex=use_regex,
sort=sort,
incl_system=incl_system,
connection=connection,
connection_config=connection_config,
)
if catalog_pattern is None and len(catalogs) == 1:
prefix_catalog = False
else:
prefix_catalog = True
else:
catalogs = [catalog]
prefix_catalog = False
current_schema = cls.get_current_schema(
connection=connection,
connection_config=connection_config,
)
if schema is None:
catalogs_schemas = []
for catalog in catalogs:
catalog_schemas = cls.list_schemas(
schema_pattern=schema_pattern,
use_regex=use_regex,
sort=sort,
catalog=catalog,
incl_system=incl_system,
connection=connection,
connection_config=connection_config,
)
for schema in catalog_schemas:
catalogs_schemas.append((catalog, schema))
if len(catalogs_schemas) == 1 and catalogs_schemas[0][1] == current_schema:
prefix_schema = False
else:
prefix_schema = True
else:
if schema == "current_schema":
schema = current_schema
catalogs_schemas = []
for catalog in catalogs:
catalogs_schemas.append((catalog, schema))
prefix_schema = prefix_catalog
tables_df = connection.sql("SELECT * FROM information_schema.tables").df()
tables = []
for catalog, schema in catalogs_schemas:
all_tables = []
all_tables.extend(
tables_df[
(tables_df["table_catalog"] == catalog)
& (tables_df["table_schema"] == schema)
& (tables_df["table_type"] == "BASE TABLE")
]["table_name"].tolist()
)
if incl_temporary:
all_tables.extend(
tables_df[
(tables_df["table_catalog"] == catalog)
& (tables_df["table_schema"] == schema)
& (tables_df["table_type"] == "LOCAL TEMPORARY")
]["table_name"].tolist()
)
if incl_views:
all_tables.extend(
tables_df[
(tables_df["table_catalog"] == catalog)
& (tables_df["table_schema"] == schema)
& (tables_df["table_type"] == "VIEW")
]["table_name"].tolist()
)
for table in all_tables:
if table_pattern is not None:
if not cls.key_match(table, table_pattern, use_regex=use_regex):
continue
if not prefix_catalog and prefix_schema:
table = schema + ":" + table
elif prefix_catalog or prefix_schema:
table = catalog + ":" + schema + ":" + table
tables.append(table)
if should_close:
connection.close()
if sort:
return sorted(dict.fromkeys(tables))
return list(dict.fromkeys(tables))
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
catalog: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
list_tables_kwargs: tp.KwargsLike = None,
read_path: tp.Optional[tp.PathLike] = None,
read_format: tp.Optional[str] = None,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
) -> tp.Kwargs:
keys_meta = DBData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None:
if cls.has_key_dict(catalog):
raise ValueError("Cannot populate keys if catalog is defined per key")
if cls.has_key_dict(schema):
raise ValueError("Cannot populate keys if schema is defined per key")
if cls.has_key_dict(list_tables_kwargs):
raise ValueError("Cannot populate keys if list_tables_kwargs is defined per key")
if cls.has_key_dict(connection):
raise ValueError("Cannot populate keys if connection is defined per key")
if cls.has_key_dict(connection_config):
raise ValueError("Cannot populate keys if connection_config is defined per key")
if cls.has_key_dict(read_path):
raise ValueError("Cannot populate keys if read_path is defined per key")
if cls.has_key_dict(read_format):
raise ValueError("Cannot populate keys if read_format is defined per key")
if read_path is not None or read_format is not None:
if read_path is None:
read_path = "."
if read_format is not None:
read_format = read_format.lower()
checks.assert_in(read_format, ["csv", "parquet", "json"], arg_name="read_format")
keys_meta["keys"] = FileData.list_paths(read_path, extension=read_format)
else:
if list_tables_kwargs is None:
list_tables_kwargs = {}
keys_meta["keys"] = cls.list_tables(
catalog=catalog,
schema=schema,
connection=connection,
connection_config=connection_config,
**list_tables_kwargs,
)
return keys_meta
@classmethod
def pull(
cls: tp.Type[DuckDBDataT],
keys: tp.Union[tp.MaybeKeys] = None,
*,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[tp.MaybeFeatures] = None,
symbols: tp.Union[tp.MaybeSymbols] = None,
catalog: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
list_tables_kwargs: tp.KwargsLike = None,
read_path: tp.Optional[tp.PathLike] = None,
read_format: tp.Optional[str] = None,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
share_connection: tp.Optional[bool] = None,
**kwargs,
) -> DuckDBDataT:
"""Override `vectorbtpro.data.base.Data.pull` to resolve and share the connection among the keys
and use the table names available in the database in case no keys were provided."""
if share_connection is None:
if not cls.has_key_dict(connection) and not cls.has_key_dict(connection_config):
share_connection = True
else:
share_connection = False
if share_connection:
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
else:
should_close = False
keys_meta = cls.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
catalog=catalog,
schema=schema,
list_tables_kwargs=list_tables_kwargs,
read_path=read_path,
read_format=read_format,
connection=connection,
connection_config=connection_config,
)
keys = keys_meta["keys"]
if isinstance(read_path, key_dict):
new_read_path = read_path.copy()
else:
new_read_path = key_dict()
if isinstance(keys, dict):
new_keys = {}
for k, v in keys.items():
if isinstance(k, Path):
new_k = FileData.path_to_key(k)
new_read_path[new_k] = k
k = new_k
new_keys[k] = v
keys = new_keys
elif cls.has_multiple_keys(keys):
new_keys = []
for k in keys:
if isinstance(k, Path):
new_k = FileData.path_to_key(k)
new_read_path[new_k] = k
k = new_k
new_keys.append(k)
keys = new_keys
else:
if isinstance(keys, Path):
new_keys = FileData.path_to_key(keys)
new_read_path[new_keys] = keys
keys = new_keys
if len(new_read_path) > 0:
read_path = new_read_path
keys_are_features = keys_meta["keys_are_features"]
outputs = super(DBData, cls).pull(
keys,
keys_are_features=keys_are_features,
catalog=catalog,
schema=schema,
read_path=read_path,
read_format=read_format,
connection=connection,
connection_config=connection_config,
**kwargs,
)
if should_close:
connection.close()
return outputs
@classmethod
def format_write_option(cls, option: tp.Any) -> str:
"""Format a write option."""
if isinstance(option, str):
return f"'{option}'"
if isinstance(option, (tuple, list)):
return "(" + ", ".join(map(str, option)) + ")"
if isinstance(option, dict):
return "{" + ", ".join(map(lambda y: f"{y[0]}: {cls.format_write_option(y[1])}", option.items())) + "}"
return f"{option}"
@classmethod
def format_write_options(cls, options: tp.Union[str, dict]) -> str:
"""Format write options."""
if isinstance(options, str):
return options
new_options = []
for k, v in options.items():
new_options.append(f"{k.upper()} {cls.format_write_option(v)}")
return ", ".join(new_options)
@classmethod
def format_read_option(cls, option: tp.Any) -> str:
"""Format a read option."""
if isinstance(option, str):
return f"'{option}'"
if isinstance(option, (tuple, list)):
return "[" + ", ".join(map(cls.format_read_option, option)) + "]"
if isinstance(option, dict):
return "{" + ", ".join(map(lambda y: f"'{y[0]}': {cls.format_read_option(y[1])}", option.items())) + "}"
return f"{option}"
@classmethod
def format_read_options(cls, options: tp.Union[str, dict]) -> str:
"""Format read options."""
if isinstance(options, str):
return options
new_options = []
for k, v in options.items():
new_options.append(f"{k.lower()}={cls.format_read_option(v)}")
return ", ".join(new_options)
@classmethod
def fetch_key(
cls,
key: str,
table: tp.Optional[str] = None,
schema: tp.Optional[str] = None,
catalog: tp.Optional[str] = None,
read_path: tp.Optional[tp.PathLike] = None,
read_format: tp.Optional[str] = None,
read_options: tp.Union[None, str, dict] = None,
query: tp.Union[None, str, DuckDBPyRelationT] = None,
connection: tp.Union[None, str, DuckDBPyConnectionT] = None,
connection_config: tp.KwargsLike = None,
start: tp.Optional[tp.Any] = None,
end: tp.Optional[tp.Any] = None,
align_dates: tp.Optional[bool] = None,
parse_dates: tp.Union[None, bool, tp.Sequence[str]] = None,
to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = None,
tz: tp.TimezoneLike = None,
index_col: tp.Optional[tp.MaybeSequence[tp.IntStr]] = None,
squeeze: tp.Optional[bool] = None,
df_kwargs: tp.KwargsLike = None,
**sql_kwargs,
) -> tp.KeyData:
"""Fetch a feature or symbol from a DuckDB database.
Can use a table name (which defaults to the key) or a custom query.
Args:
key (str): Feature or symbol.
If `table` and `query` are both None, becomes the table name.
Key can be in the `SCHEMA:TABLE` format, in this case `schema` argument will be ignored.
table (str): Table name.
Cannot be used together with `file` or `query`.
schema (str): Schema name.
Cannot be used together with `file` or `query`.
catalog (str): Catalog name.
Cannot be used together with ``file` or query`.
read_path (path_like): Path to a file to read.
Cannot be used together with `table`, `schema`, `catalog`, or `query`.
read_format (str): Format of the file to read.
Allowed values are "csv", "parquet", and "json".
Requires `read_path` to be set.
read_options (str or dict): Options used to read the file.
Requires `read_path` and `read_format` to be set.
Uses `DuckDBData.format_read_options` to transform a dictionary to a string.
query (str or DuckDBPyRelation): Custom query.
Cannot be used together with `catalog`, `schema`, and `table`.
connection (str or object): See `DuckDBData.resolve_connection`.
connection_config (dict): See `DuckDBData.resolve_connection`.
start (any): Start datetime (if datetime index) or any other start value.
Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True
and the index is a datetime index. Otherwise, you must ensure the correct type is provided.
Cannot be used together with `query`. Include the condition into the query.
end (any): End datetime (if datetime index) or any other end value.
Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True
and the index is a datetime index. Otherwise, you must ensure the correct type is provided.
Cannot be used together with `query`. Include the condition into the query.
align_dates (bool): Whether to align `start` and `end` to the timezone of the index.
Will pull one row (using `LIMIT 1`) and use `SQLData.prepare_dt` to get the index.
parse_dates (bool or sequence of str): See `DuckDBData.prepare_dt`.
to_utc (bool, str, or sequence of str): See `DuckDBData.prepare_dt`.
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
index_col (int, str, or list): One or more columns that should become the index.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
df_kwargs (dict): Keyword arguments passed to `relation.df` to convert a relation to a DataFrame.
**sql_kwargs: Other keyword arguments passed to `connection.execute` to run a SQL query.
For defaults, see `custom.duckdb` in `vectorbtpro._settings.data`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("duckdb")
from duckdb import DuckDBPyRelation
if connection_config is None:
connection_config = {}
connection_meta = cls.resolve_connection(connection, return_meta=True, **connection_config)
connection = connection_meta["connection"]
should_close = connection_meta["should_close"]
if catalog is not None and query is not None:
raise ValueError("Cannot use catalog and query together")
if schema is not None and query is not None:
raise ValueError("Cannot use schema and query together")
if table is not None and query is not None:
raise ValueError("Cannot use table and query together")
if read_path is not None and query is not None:
raise ValueError("Cannot use read_path and query together")
if read_path is not None and (catalog is not None or schema is not None or table is not None):
raise ValueError("Cannot use read_path and catalog/schema/table together")
if table is None and read_path is None and read_format is None and query is None:
if ":" in key:
key_parts = key.split(":")
if len(key_parts) == 2:
schema, table = key_parts
else:
catalog, schema, table = key_parts
else:
table = key
if read_format is not None:
read_format = read_format.lower()
checks.assert_in(read_format, ["csv", "parquet", "json"], arg_name="read_format")
if read_path is None:
read_path = (Path(".") / key).with_suffix("." + read_format)
else:
if read_path is not None:
if isinstance(read_path, str):
read_path = Path(read_path)
if read_path.suffix[1:] in ["csv", "parquet", "json"]:
read_format = read_path.suffix[1:]
if read_path is not None:
if isinstance(read_path, Path):
read_path = str(read_path)
read_path = cls.format_read_option(read_path)
if read_options is not None:
if read_format is None:
raise ValueError("Must provide read_format for read_options")
read_options = cls.format_read_options(read_options)
catalog = cls.resolve_custom_setting(catalog, "catalog")
schema = cls.resolve_custom_setting(schema, "schema")
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
align_dates = cls.resolve_custom_setting(align_dates, "align_dates")
parse_dates = cls.resolve_custom_setting(parse_dates, "parse_dates")
to_utc = cls.resolve_custom_setting(to_utc, "to_utc")
tz = cls.resolve_custom_setting(tz, "tz")
index_col = cls.resolve_custom_setting(index_col, "index_col")
squeeze = cls.resolve_custom_setting(squeeze, "squeeze")
df_kwargs = cls.resolve_custom_setting(df_kwargs, "df_kwargs", merge=True)
sql_kwargs = cls.resolve_custom_setting(sql_kwargs, "sql_kwargs", merge=True)
if query is None:
if read_path is not None:
if read_options is not None:
query = f"SELECT * FROM read_{read_format}({read_path}, {read_options})"
elif read_format is not None:
query = f"SELECT * FROM read_{read_format}({read_path})"
else:
query = f"SELECT * FROM {read_path}"
else:
if catalog is not None:
if schema is None:
schema = cls.get_current_schema(
connection=connection,
connection_config=connection_config,
)
query = f'SELECT * FROM "{catalog}"."{schema}"."{table}"'
elif schema is not None:
query = f'SELECT * FROM "{schema}"."{table}"'
else:
query = f'SELECT * FROM "{table}"'
if start is not None or end is not None:
if index_col is None:
raise ValueError("Must provide index column for filtering by start and end")
if not checks.is_int(index_col) and not isinstance(index_col, str):
raise ValueError("Index column must be integer or string for filtering by start and end")
if checks.is_int(index_col) or align_dates:
metadata_df = connection.sql("DESCRIBE " + query + " LIMIT 1").df()
else:
metadata_df = None
if checks.is_int(index_col):
index_name = metadata_df["column_name"].tolist()[0]
else:
index_name = index_col
if parse_dates:
index_column_type = metadata_df[metadata_df["column_name"] == index_name]["column_type"].item()
if index_column_type in (
"TIMESTAMP_NS",
"TIMESTAMP_MS",
"TIMESTAMP_S",
"TIMESTAMP",
"DATETIME",
):
if start is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and index_name in to_utc)
):
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
start = dt.to_naive_datetime(start)
else:
start = dt.to_naive_datetime(start, tz=tz)
if end is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and index_name in to_utc)
):
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
end = dt.to_naive_datetime(end)
else:
end = dt.to_naive_datetime(end, tz=tz)
elif index_column_type in ("TIMESTAMPTZ", "TIMESTAMP WITH TIME ZONE"):
if start is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and index_name in to_utc)
):
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
else:
start = dt.to_tzaware_datetime(start, naive_tz=tz)
if end is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and index_name in to_utc)
):
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
else:
end = dt.to_tzaware_datetime(end, naive_tz=tz)
if start is not None and end is not None:
query += f' WHERE "{index_name}" >= $start AND "{index_name}" < $end'
elif start is not None:
query += f' WHERE "{index_name}" >= $start'
elif end is not None:
query += f' WHERE "{index_name}" < $end'
params = sql_kwargs.get("params", None)
if params is None:
params = {}
else:
params = dict(params)
if not isinstance(params, dict):
raise ValueError("Parameters must be a dictionary for filtering by start and end")
if start is not None:
if "start" in params:
raise ValueError("Start is already in params")
params["start"] = start
if end is not None:
if "end" in params:
raise ValueError("End is already in params")
params["end"] = end
sql_kwargs["params"] = params
else:
if start is not None:
raise ValueError("Start cannot be applied to custom queries")
if end is not None:
raise ValueError("End cannot be applied to custom queries")
if not isinstance(query, DuckDBPyRelation):
relation = connection.sql(query, **sql_kwargs)
else:
relation = query
obj = relation.df(**df_kwargs)
if isinstance(obj, pd.DataFrame) and checks.is_default_index(obj.index):
if index_col is not None:
if checks.is_int(index_col):
keys = obj.columns[index_col]
elif isinstance(index_col, str):
keys = index_col
else:
keys = []
for col in index_col:
if checks.is_int(col):
keys.append(obj.columns[col])
else:
keys.append(col)
obj = obj.set_index(keys)
if not isinstance(obj.index, pd.MultiIndex):
if obj.index.name == "index":
obj.index.name = None
obj = cls.prepare_dt(obj, to_utc=to_utc, parse_dates=parse_dates)
if not isinstance(obj.index, pd.MultiIndex):
if obj.index.name == "index":
obj.index.name = None
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
if should_close:
connection.close()
return obj, dict(tz=tz)
@classmethod
def fetch_feature(cls, feature: str, **kwargs) -> tp.FeatureData:
"""Fetch the table of a feature.
Uses `DuckDBData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: str, **kwargs) -> tp.SymbolData:
"""Fetch the table for a symbol.
Uses `DuckDBData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: str, from_last_index: tp.Optional[bool] = None, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
pre_kwargs = merge_dicts(fetch_kwargs, kwargs)
if from_last_index is None:
if pre_kwargs.get("query", None) is not None:
from_last_index = False
else:
from_last_index = True
if from_last_index:
fetch_kwargs["start"] = self.select_last_index(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if self.feature_oriented:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: str, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `DuckDBData.update_key`."""
return self.update_key(feature, **kwargs)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `DuckDBData.update_key`."""
return self.update_key(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FeatherData`."""
from pathlib import Path
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"FeatherData",
]
__pdoc__ = {}
FeatherDataT = tp.TypeVar("FeatherDataT", bound="FeatherData")
class FeatherData(FileData):
"""Data class for fetching Feather data using PyArrow."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.feather")
@classmethod
def list_paths(cls, path: tp.PathLike = ".", **match_path_kwargs) -> tp.List[Path]:
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_dir():
path = path / "*.feather"
return cls.match_path(path, **match_path_kwargs)
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
keys_meta = FileData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None and paths is None:
keys_meta["keys"] = "*.feather"
return keys_meta
@classmethod
def fetch_key(
cls,
key: tp.Key,
path: tp.Any = None,
tz: tp.TimezoneLike = None,
index_col: tp.Optional[tp.MaybeSequence[tp.IntStr]] = None,
squeeze: tp.Optional[bool] = None,
**read_kwargs,
) -> tp.KeyData:
"""Fetch the Feather file of a feature or symbol.
Args:
key (hashable): Feature or symbol.
path (str): Path.
If `path` is None, uses `key` as the path to the Feather file.
tz (any): Target timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
index_col (int, str, or sequence): Position(s) or name(s) of column(s) that should become the index.
Will only apply if the fetched object has a default index.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
**read_kwargs: Other keyword arguments passed to `pd.read_feather`.
See https://pandas.pydata.org/docs/reference/api/pandas.read_feather.html for other arguments.
For defaults, see `custom.feather` in `vectorbtpro._settings.data`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pyarrow")
tz = cls.resolve_custom_setting(tz, "tz")
index_col = cls.resolve_custom_setting(index_col, "index_col")
if index_col is False:
index_col = None
squeeze = cls.resolve_custom_setting(squeeze, "squeeze")
read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True)
if path is None:
path = key
obj = pd.read_feather(path, **read_kwargs)
if isinstance(obj, pd.DataFrame) and checks.is_default_index(obj.index):
if index_col is not None:
if checks.is_int(index_col):
keys = obj.columns[index_col]
elif isinstance(index_col, str):
keys = index_col
else:
keys = []
for col in index_col:
if checks.is_int(col):
keys.append(obj.columns[col])
else:
keys.append(col)
obj = obj.set_index(keys)
if not isinstance(obj.index, pd.MultiIndex):
if obj.index.name == "index":
obj.index.name = None
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
return obj, dict(tz=tz)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Fetch the Feather file of a feature.
Uses `FeatherData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Fetch the Feather file of a symbol.
Uses `FeatherData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `FeatherData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `FeatherData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FileData`."""
import re
from glob import glob
from pathlib import Path
from vectorbtpro import _typing as tp
from vectorbtpro.data.base import key_dict
from vectorbtpro.data.custom.local import LocalData
from vectorbtpro.utils import checks
__all__ = [
"FileData",
]
__pdoc__ = {}
FileDataT = tp.TypeVar("FileDataT", bound="FileData")
class FileData(LocalData):
"""Data class for fetching file data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.file")
@classmethod
def is_dir_match(cls, path: tp.PathLike) -> bool:
"""Return whether a directory is a valid match."""
return False
@classmethod
def is_file_match(cls, path: tp.PathLike) -> bool:
"""Return whether a file is a valid match."""
return True
@classmethod
def match_path(
cls,
path: tp.PathLike,
match_regex: tp.Optional[str] = None,
sort_paths: bool = True,
recursive: bool = True,
extension: tp.Optional[str] = None,
**kwargs,
) -> tp.List[Path]:
"""Get the list of all paths matching a path.
If `FileData.is_dir_match` returns True for a directory, it gets returned as-is.
Otherwise, iterates through all files in that directory and invokes `FileData.is_file_match`.
If a pattern was provided, these methods aren't invoked."""
if not isinstance(path, Path):
path = Path(path)
if path.exists():
if path.is_dir() and not cls.is_dir_match(path):
sub_paths = []
for p in path.iterdir():
if p.is_dir() and cls.is_dir_match(p):
sub_paths.append(p)
if p.is_file() and cls.is_file_match(p):
if extension is None or p.suffix == "." + extension:
sub_paths.append(p)
else:
sub_paths = [path]
else:
sub_paths = list([Path(p) for p in glob(str(path), recursive=recursive)])
if match_regex is not None:
sub_paths = [p for p in sub_paths if re.match(match_regex, str(p))]
if sort_paths:
sub_paths = sorted(sub_paths)
return sub_paths
@classmethod
def list_paths(cls, path: tp.PathLike = ".", **match_path_kwargs) -> tp.List[Path]:
"""List all features or symbols under a path."""
return cls.match_path(path, **match_path_kwargs)
@classmethod
def path_to_key(cls, path: tp.PathLike, **kwargs) -> str:
"""Convert a path into a feature or symbol."""
return Path(path).stem
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
return LocalData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
@classmethod
def pull(
cls: tp.Type[FileDataT],
keys: tp.Union[tp.MaybeKeys] = None,
*,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[tp.MaybeFeatures] = None,
symbols: tp.Union[tp.MaybeSymbols] = None,
paths: tp.Any = None,
match_paths: tp.Optional[bool] = None,
match_regex: tp.Optional[str] = None,
sort_paths: tp.Optional[bool] = None,
match_path_kwargs: tp.KwargsLike = None,
path_to_key_kwargs: tp.KwargsLike = None,
**kwargs,
) -> FileDataT:
"""Override `vectorbtpro.data.base.Data.pull` to take care of paths.
Use either features, symbols, or `paths` to specify the path to one or multiple files.
Allowed are paths in a string or `pathlib.Path` format, or string expressions accepted by `glob.glob`.
Set `match_paths` to False to not parse paths and behave like a regular
`vectorbtpro.data.base.Data` instance.
For defaults, see `custom.local` in `vectorbtpro._settings.data`.
"""
keys_meta = cls.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
paths=paths,
)
keys = keys_meta["keys"]
keys_are_features = keys_meta["keys_are_features"]
dict_type = keys_meta["dict_type"]
match_paths = cls.resolve_custom_setting(match_paths, "match_paths")
match_regex = cls.resolve_custom_setting(match_regex, "match_regex")
sort_paths = cls.resolve_custom_setting(sort_paths, "sort_paths")
if match_paths:
sync = False
if paths is None:
paths = keys
sync = True
elif keys is None:
sync = True
if paths is None:
if keys_are_features:
raise ValueError("At least features or paths must be set")
else:
raise ValueError("At least symbols or paths must be set")
if match_path_kwargs is None:
match_path_kwargs = {}
if path_to_key_kwargs is None:
path_to_key_kwargs = {}
single_key = False
if isinstance(keys, (str, Path)):
# Single key
keys = [keys]
single_key = True
single_path = False
if isinstance(paths, (str, Path)):
# Single path
paths = [paths]
single_path = True
if sync:
single_key = True
cls.check_dict_type(paths, "paths", dict_type=dict_type)
if isinstance(paths, key_dict):
# Dict of path per key
if sync:
keys = list(paths.keys())
elif len(keys) != len(paths):
if keys_are_features:
raise ValueError("The number of features must be equal to the number of matched paths")
else:
raise ValueError("The number of symbols must be equal to the number of matched paths")
elif checks.is_iterable(paths) or checks.is_sequence(paths):
# Multiple paths
matched_paths = [
p
for sub_path in paths
for p in cls.match_path(
sub_path,
match_regex=match_regex,
sort_paths=sort_paths,
**match_path_kwargs,
)
]
if len(matched_paths) == 0:
raise FileNotFoundError(f"No paths could be matched with {paths}")
if sync:
keys = []
paths = key_dict()
for p in matched_paths:
s = cls.path_to_key(p, **path_to_key_kwargs)
keys.append(s)
paths[s] = p
elif len(keys) != len(matched_paths):
if keys_are_features:
raise ValueError("The number of features must be equal to the number of matched paths")
else:
raise ValueError("The number of symbols must be equal to the number of matched paths")
else:
paths = key_dict({s: matched_paths[i] for i, s in enumerate(keys)})
if len(matched_paths) == 1 and single_path:
paths = matched_paths[0]
else:
raise TypeError(f"Path '{paths}' is not supported")
if len(keys) == 1 and single_key:
keys = keys[0]
return super(FileData, cls).pull(
keys,
keys_are_features=keys_are_features,
path=paths,
**kwargs,
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FinPyData`."""
from itertools import product
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
try:
if not tp.TYPE_CHECKING:
raise ImportError
from findatapy.market import Market as MarketT
from findatapy.util import ConfigManager as ConfigManagerT
except ImportError:
MarketT = "Market"
ConfigManagerT = "ConfigManager"
__all__ = [
"FinPyData",
]
FinPyDataT = tp.TypeVar("FinPyDataT", bound="FinPyData")
class FinPyData(RemoteData):
"""Data class for fetching using findatapy.
See https://github.com/cuemacro/findatapy for API.
See `FinPyData.fetch_symbol` for arguments.
Usage:
* Pull data (keyword argument format):
```pycon
>>> data = vbt.FinPyData.pull(
... "EURUSD",
... start="14 June 2016",
... end="15 June 2016",
... timeframe="tick",
... category="fx",
... fields=["bid", "ask"],
... data_source="dukascopy"
... )
```
* Pull data (string format):
```pycon
>>> data = vbt.FinPyData.pull(
... "fx.dukascopy.tick.NYC.EURUSD.bid,ask",
... start="14 June 2016",
... end="15 June 2016",
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.finpy")
@classmethod
def resolve_market(
cls,
market: tp.Optional[MarketT] = None,
**market_config,
) -> MarketT:
"""Resolve the market.
If provided, must be of the type `findatapy.market.market.Market`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("findatapy")
from findatapy.market import Market, MarketDataGenerator
market = cls.resolve_custom_setting(market, "market")
if market_config is None:
market_config = {}
has_market_config = len(market_config) > 0
market_config = cls.resolve_custom_setting(market_config, "market_config", merge=True)
if "market_data_generator" not in market_config:
market_config["market_data_generator"] = MarketDataGenerator()
if market is None:
market = Market(**market_config)
elif has_market_config:
raise ValueError("Cannot apply market_config to already initialized market")
return market
@classmethod
def resolve_config_manager(
cls,
config_manager: tp.Optional[ConfigManagerT] = None,
**config_manager_config,
) -> MarketT:
"""Resolve the config manager.
If provided, must be of the type `findatapy.util.ConfigManager`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("findatapy")
from findatapy.util import ConfigManager
config_manager = cls.resolve_custom_setting(config_manager, "config_manager")
if config_manager_config is None:
config_manager_config = {}
has_config_manager_config = len(config_manager_config) > 0
config_manager_config = cls.resolve_custom_setting(config_manager_config, "config_manager_config", merge=True)
if config_manager is None:
config_manager = ConfigManager().get_instance(**config_manager_config)
elif has_config_manager_config:
raise ValueError("Cannot apply config_manager_config to already initialized config_manager")
return config_manager
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
config_manager: tp.Optional[ConfigManagerT] = None,
config_manager_config: tp.KwargsLike = None,
category: tp.Optional[tp.MaybeList[str]] = None,
data_source: tp.Optional[tp.MaybeList[str]] = None,
freq: tp.Optional[tp.MaybeList[str]] = None,
cut: tp.Optional[tp.MaybeList[str]] = None,
tickers: tp.Optional[tp.MaybeList[str]] = None,
dict_filter: tp.DictLike = None,
smart_group: bool = False,
return_fields: tp.Optional[tp.MaybeList[str]] = None,
combine_parts: bool = True,
) -> tp.List[str]:
"""List all symbols.
Passes most arguments to `findatapy.util.ConfigManager.free_form_tickers_regex_query`.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`."""
if config_manager_config is None:
config_manager_config = {}
config_manager = cls.resolve_config_manager(config_manager=config_manager, **config_manager_config)
if dict_filter is None:
dict_filter = {}
def_ret_fields = ["category", "data_source", "freq", "cut", "tickers"]
if return_fields is None:
ret_fields = def_ret_fields
elif isinstance(return_fields, str):
if return_fields.lower() == "all":
ret_fields = def_ret_fields + ["fields"]
else:
ret_fields = [return_fields]
else:
ret_fields = return_fields
df = config_manager.free_form_tickers_regex_query(
category=category,
data_source=data_source,
freq=freq,
cut=cut,
tickers=tickers,
dict_filter=dict_filter,
smart_group=smart_group,
ret_fields=ret_fields,
)
all_symbols = []
for _, row in df.iterrows():
parts = []
if "category" in row.index:
parts.append(row.loc["category"])
if "data_source" in row.index:
parts.append(row.loc["data_source"])
if "freq" in row.index:
parts.append(row.loc["freq"])
if "cut" in row.index:
parts.append(row.loc["cut"])
if "tickers" in row.index:
parts.append(row.loc["tickers"])
if "fields" in row.index:
parts.append(row.loc["fields"])
if combine_parts:
split_parts = [part.split(",") for part in parts]
combinations = list(product(*split_parts))
else:
combinations = [parts]
for symbol in [".".join(combination) for combination in combinations]:
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def fetch_symbol(
cls,
symbol: str,
market: tp.Optional[MarketT] = None,
market_config: tp.KwargsLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
**request_kwargs,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from findatapy.
Args:
symbol (str): Symbol.
Also accepts the format such as "fx.bloomberg.daily.NYC.EURUSD.close".
The fields `freq`, `cut`, `tickers`, and `fields` here are optional.
market (findatapy.market.market.Market): Market.
See `FinPyData.resolve_market`.
market_config (dict): Client config.
See `FinPyData.resolve_market`.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
**request_kwargs: Other keyword arguments passed to `findatapy.market.marketdatarequest.MarketDataRequest`.
For defaults, see `custom.finpy` in `vectorbtpro._settings.data`.
Global settings can be provided per exchange id using the `exchanges` dictionary.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("findatapy")
from findatapy.market import MarketDataRequest
if market_config is None:
market_config = {}
market = cls.resolve_market(market=market, **market_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
request_kwargs = cls.resolve_custom_setting(request_kwargs, "request_kwargs", merge=True)
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "s":
unit = "second"
freq = timeframe
elif unit == "m":
unit = "minute"
freq = timeframe
elif unit == "h":
unit = "hourly"
freq = timeframe
elif unit == "D":
unit = "daily"
freq = timeframe
elif unit == "W":
unit = "weekly"
freq = timeframe
elif unit == "M":
unit = "monthly"
freq = timeframe
elif unit == "Q":
unit = "quarterly"
freq = timeframe
elif unit == "Y":
unit = "annually"
freq = timeframe
else:
freq = None
if "resample" in request_kwargs:
freq = request_kwargs["resample"]
if start is not None:
start = dt.to_naive_datetime(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
if end is not None:
end = dt.to_naive_datetime(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc"))
if "md_request" in request_kwargs:
md_request = request_kwargs["md_request"]
elif "md_request_df" in request_kwargs:
md_request = market.create_md_request_from_dataframe(
md_request_df=request_kwargs["md_request_df"],
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
elif "md_request_str" in request_kwargs:
md_request = market.create_md_request_from_str(
md_request_str=request_kwargs["md_request_str"],
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
elif "md_request_dict" in request_kwargs:
md_request = market.create_md_request_from_dict(
md_request_dict=request_kwargs["md_request_dict"],
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
elif symbol.count(".") >= 2:
md_request = market.create_md_request_from_str(
md_request_str=symbol,
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
else:
md_request = MarketDataRequest(
tickers=symbol,
start_date=start,
finish_date=end,
freq_mult=multiplier,
freq=unit,
**request_kwargs,
)
df = market.fetch_market(md_request=md_request)
if df is None:
return None
if isinstance(md_request.tickers, str):
ticker = md_request.tickers
elif len(md_request.tickers) == 1:
ticker = md_request.tickers[0]
else:
ticker = None
if ticker is not None:
df.columns = df.columns.map(lambda x: x.replace(ticker + ".", ""))
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize("utc")
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `GBMOHLCData`."""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import broadcast_array_to
from vectorbtpro.data import nb
from vectorbtpro.data.custom.synthetic import SyntheticData
from vectorbtpro.ohlcv import nb as ohlcv_nb
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.random_ import set_seed
from vectorbtpro.utils.template import substitute_templates
__all__ = [
"GBMOHLCData",
]
__pdoc__ = {}
class GBMOHLCData(SyntheticData):
"""`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_gbm_data_1d_nb`
and then resampled using `vectorbtpro.ohlcv.nb.ohlc_every_1d_nb`."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.gbm_ohlc")
@classmethod
def generate_symbol(
cls,
symbol: tp.Symbol,
index: tp.Index,
n_ticks: tp.Optional[tp.ArrayLike] = None,
start_value: tp.Optional[float] = None,
mean: tp.Optional[float] = None,
std: tp.Optional[float] = None,
dt: tp.Optional[float] = None,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.SymbolData:
"""Generate a symbol.
Args:
symbol (hashable): Symbol.
index (pd.Index): Pandas index.
n_ticks (int or array_like): Number of ticks per bar.
Flexible argument. Can be a template with a context containing `symbol` and `index`.
start_value (float): Value at time 0.
Does not appear as the first value in the output data.
mean (float): Drift, or mean of the percentage change.
std (float): Standard deviation of the percentage change.
dt (float): Time change (one period of time).
seed (int): Seed to make output deterministic.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
template_context (dict): Context used to substitute templates.
For defaults, see `custom.gbm` in `vectorbtpro._settings.data`.
!!! note
When setting a seed, remember to pass a seed per symbol using `vectorbtpro.data.base.symbol_dict`.
"""
n_ticks = cls.resolve_custom_setting(n_ticks, "n_ticks")
template_context = merge_dicts(dict(symbol=symbol, index=index), template_context)
n_ticks = substitute_templates(n_ticks, template_context, eval_id="n_ticks")
n_ticks = broadcast_array_to(n_ticks, len(index))
start_value = cls.resolve_custom_setting(start_value, "start_value")
mean = cls.resolve_custom_setting(mean, "mean")
std = cls.resolve_custom_setting(std, "std")
dt = cls.resolve_custom_setting(dt, "dt")
seed = cls.resolve_custom_setting(seed, "seed")
if seed is not None:
set_seed(seed)
func = jit_reg.resolve_option(nb.generate_gbm_data_1d_nb, jitted)
ticks = func(
np.sum(n_ticks),
start_value=start_value,
mean=mean,
std=std,
dt=dt,
)
func = jit_reg.resolve_option(ohlcv_nb.ohlc_every_1d_nb, jitted)
out = func(ticks, n_ticks)
return pd.DataFrame(out, index=index, columns=["Open", "High", "Low", "Close"])
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
_ = fetch_kwargs.pop("start_value", None)
start_value = self.data[symbol]["Open"].iloc[-1]
fetch_kwargs["seed"] = None
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, start_value=start_value, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `GBMData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_1d_array
from vectorbtpro.data import nb
from vectorbtpro.data.custom.synthetic import SyntheticData
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.random_ import set_seed
__all__ = [
"GBMData",
]
__pdoc__ = {}
class GBMData(SyntheticData):
"""`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_gbm_data_nb`."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.gbm")
@classmethod
def generate_key(
cls,
key: tp.Key,
index: tp.Index,
columns: tp.Union[tp.Hashable, tp.IndexLike] = None,
start_value: tp.Optional[float] = None,
mean: tp.Optional[float] = None,
std: tp.Optional[float] = None,
dt: tp.Optional[float] = None,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
**kwargs,
) -> tp.KeyData:
"""Generate a feature or symbol.
Args:
key (hashable): Feature or symbol.
index (pd.Index): Pandas index.
columns (hashable or index_like): Column names.
Provide a single value (hashable) to make a Series.
start_value (float): Value at time 0.
Does not appear as the first value in the output data.
mean (float): Drift, or mean of the percentage change.
std (float): Standard deviation of the percentage change.
dt (float): Time change (one period of time).
seed (int): Seed to make output deterministic.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
For defaults, see `custom.gbm` in `vectorbtpro._settings.data`.
!!! note
When setting a seed, remember to pass a seed per feature/symbol using
`vectorbtpro.data.base.feature_dict`/`vectorbtpro.data.base.symbol_dict` or generally
`vectorbtpro.data.base.key_dict`.
"""
if checks.is_hashable(columns):
columns = [columns]
make_series = True
else:
make_series = False
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
start_value = cls.resolve_custom_setting(start_value, "start_value")
mean = cls.resolve_custom_setting(mean, "mean")
std = cls.resolve_custom_setting(std, "std")
dt = cls.resolve_custom_setting(dt, "dt")
seed = cls.resolve_custom_setting(seed, "seed")
if seed is not None:
set_seed(seed)
func = jit_reg.resolve_option(nb.generate_gbm_data_nb, jitted)
out = func(
(len(index), len(columns)),
start_value=to_1d_array(start_value),
mean=to_1d_array(mean),
std=to_1d_array(std),
dt=to_1d_array(dt),
)
if make_series:
return pd.Series(out[:, 0], index=index, name=columns[0])
return pd.DataFrame(out, index=index, columns=columns)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
fetch_kwargs = self.select_fetch_kwargs(key)
fetch_kwargs["start"] = self.select_last_index(key)
_ = fetch_kwargs.pop("start_value", None)
start_value = self.data[key].iloc[-2]
fetch_kwargs["seed"] = None
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, start_value=start_value, **kwargs)
return self.fetch_symbol(key, start_value=start_value, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `HDFData`."""
import re
from glob import glob
from pathlib import Path, PurePath
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.parsing import get_func_arg_names
__all__ = [
"HDFData",
]
__pdoc__ = {}
class HDFPathNotFoundError(Exception):
"""Gets raised if the path to an HDF file could not be found."""
pass
class HDFKeyNotFoundError(Exception):
"""Gets raised if the key to an HDF object could not be found."""
pass
HDFDataT = tp.TypeVar("HDFDataT", bound="HDFData")
class HDFData(FileData):
"""Data class for fetching HDF data using PyTables."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.hdf")
@classmethod
def is_hdf_file(cls, path: tp.PathLike) -> bool:
"""Return whether the path is an HDF file."""
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_file() and ".hdf" in path.suffixes:
return True
if path.exists() and path.is_file() and ".hdf5" in path.suffixes:
return True
if path.exists() and path.is_file() and ".h5" in path.suffixes:
return True
return False
@classmethod
def is_file_match(cls, path: tp.PathLike) -> bool:
return cls.is_hdf_file(path)
@classmethod
def split_hdf_path(
cls,
path: tp.PathLike,
key: tp.Optional[str] = None,
_full_path: tp.Optional[Path] = None,
) -> tp.Tuple[Path, tp.Optional[str]]:
"""Split the path to an HDF object into the path to the file and the key."""
path = Path(path)
if _full_path is None:
_full_path = path
if path.exists():
if path.is_dir():
raise HDFPathNotFoundError(f"No HDF files could be matched with {_full_path}")
return path, key
new_path = path.parent
if key is None:
new_key = path.name
else:
new_key = str(Path(path.name) / key)
return cls.split_hdf_path(new_path, new_key, _full_path=_full_path)
@classmethod
def match_path(
cls,
path: tp.PathLike,
match_regex: tp.Optional[str] = None,
sort_paths: bool = True,
recursive: bool = True,
**kwargs,
) -> tp.List[Path]:
"""Override `FileData.match_path` to return a list of HDF paths
(path to file + key) matching a path."""
path = Path(path)
if path.exists():
if path.is_dir() and not cls.is_dir_match(path):
sub_paths = []
for p in path.iterdir():
if p.is_dir() and cls.is_dir_match(p):
sub_paths.append(p)
if p.is_file() and cls.is_file_match(p):
sub_paths.append(p)
key_paths = [p for sub_path in sub_paths for p in cls.match_path(sub_path, sort_paths=False, **kwargs)]
else:
with pd.HDFStore(str(path), mode="r") as store:
keys = [k[1:] for k in store.keys()]
key_paths = [path / k for k in keys]
else:
try:
file_path, key = cls.split_hdf_path(path)
with pd.HDFStore(str(file_path), mode="r") as store:
keys = [k[1:] for k in store.keys()]
if key is None:
key_paths = [file_path / k for k in keys]
elif key in keys:
key_paths = [file_path / key]
else:
matching_keys = []
for k in keys:
if k.startswith(key) or PurePath("/" + str(k)).match("/" + str(key)):
matching_keys.append(k)
if len(matching_keys) == 0:
raise HDFKeyNotFoundError(f"No HDF keys could be matched with {key}")
key_paths = [file_path / k for k in matching_keys]
except HDFPathNotFoundError:
sub_paths = list([Path(p) for p in glob(str(path), recursive=recursive)])
if len(sub_paths) == 0 and re.match(r".+\..+", str(path)):
base_path = None
base_ended = False
key_path = None
for part in path.parts:
part = Path(part)
if base_ended:
if key_path is None:
key_path = part
else:
key_path /= part
else:
if re.match(r".+\..+", str(part)):
base_ended = True
if base_path is None:
base_path = part
else:
base_path /= part
sub_paths = list([Path(p) for p in glob(str(base_path), recursive=recursive)])
if key_path is not None:
sub_paths = [p / key_path for p in sub_paths]
key_paths = [p for sub_path in sub_paths for p in cls.match_path(sub_path, sort_paths=False, **kwargs)]
if match_regex is not None:
key_paths = [p for p in key_paths if re.match(match_regex, str(p))]
if sort_paths:
key_paths = sorted(key_paths)
return key_paths
@classmethod
def path_to_key(cls, path: tp.PathLike, **kwargs) -> str:
return Path(path).name
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
keys_meta = FileData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None and paths is None:
keys_meta["keys"] = cls.list_paths()
return keys_meta
@classmethod
def fetch_key(
cls,
key: tp.Key,
path: tp.Any = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
tz: tp.TimezoneLike = None,
start_row: tp.Optional[int] = None,
end_row: tp.Optional[int] = None,
chunk_func: tp.Optional[tp.Callable] = None,
**read_kwargs,
) -> tp.KeyData:
"""Fetch the HDF object of a feature or symbol.
Args:
key (hashable): Feature or symbol.
path (str): Path.
Will be resolved with `HDFData.split_hdf_path`.
If `path` is None, uses `key` as the path to the HDF file.
start (any): Start datetime.
Will extract the object's index and compare the index to the date.
Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`.
!!! note
Can only be used if the object was saved in the table format!
end (any): End datetime.
Will extract the object's index and compare the index to the date.
Will use the timezone of the object. See `vectorbtpro.utils.datetime_.to_timestamp`.
!!! note
Can only be used if the object was saved in the table format!
tz (any): Target timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
start_row (int): Start row (inclusive).
Will use it when querying index as well.
end_row (int): End row (exclusive).
Will use it when querying index as well.
chunk_func (callable): Function to select and concatenate chunks from `TableIterator`.
Gets called only if `iterator` or `chunksize` are set.
**read_kwargs: Other keyword arguments passed to `pd.read_hdf`.
See https://pandas.pydata.org/docs/reference/api/pandas.read_hdf.html for other arguments.
For defaults, see `custom.hdf` in `vectorbtpro._settings.data`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("tables")
from pandas.io.pytables import TableIterator
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
tz = cls.resolve_custom_setting(tz, "tz")
start_row = cls.resolve_custom_setting(start_row, "start_row")
if start_row is None:
start_row = 0
end_row = cls.resolve_custom_setting(end_row, "end_row")
read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True)
if path is None:
path = key
path = Path(path)
file_path, file_key = cls.split_hdf_path(path)
if file_key is not None:
key = file_key
if start is not None or end is not None:
hdf_store_arg_names = get_func_arg_names(pd.HDFStore.__init__)
hdf_store_kwargs = dict()
for k, v in read_kwargs.items():
if k in hdf_store_arg_names:
hdf_store_kwargs[k] = v
with pd.HDFStore(str(file_path), mode="r", **hdf_store_kwargs) as store:
index = store.select_column(key, "index", start=start_row, stop=end_row)
if not isinstance(index, pd.Index):
index = pd.Index(index)
if not isinstance(index, pd.DatetimeIndex):
raise TypeError("Cannot filter index that is not DatetimeIndex")
if tz is None:
tz = index.tz
if index.tz is not None:
if start is not None:
start = dt.to_tzaware_timestamp(start, naive_tz=tz, tz=index.tz)
if end is not None:
end = dt.to_tzaware_timestamp(end, naive_tz=tz, tz=index.tz)
else:
if start is not None:
start = dt.to_naive_timestamp(start, tz=tz)
if end is not None:
end = dt.to_naive_timestamp(end, tz=tz)
mask = True
if start is not None:
mask &= index >= start
if end is not None:
mask &= index < end
mask_indices = np.flatnonzero(mask)
if len(mask_indices) == 0:
return None
start_row += mask_indices[0]
end_row = start_row + mask_indices[-1] - mask_indices[0] + 1
obj = pd.read_hdf(file_path, key=key, start=start_row, stop=end_row, **read_kwargs)
if isinstance(obj, TableIterator):
if chunk_func is None:
obj = pd.concat(list(obj), axis=0)
else:
obj = chunk_func(obj)
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
return obj, dict(last_row=start_row + len(obj.index) - 1, tz=tz)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Fetch the HDF object of a feature.
Uses `HDFData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Load the HDF object for a symbol.
Uses `HDFData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
returned_kwargs = self.select_returned_kwargs(key)
fetch_kwargs["start_row"] = returned_kwargs["last_row"]
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `HDFData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `HDFData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `LocalData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.custom import CustomData
__all__ = [
"LocalData",
]
__pdoc__ = {}
class LocalData(CustomData):
"""Data class for fetching local data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.local")
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `NDLData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"NDLData",
]
__pdoc__ = {}
NDLDataT = tp.TypeVar("NDLDataT", bound="NDLData")
class NDLData(RemoteData):
"""Data class for fetching from Nasdaq Data Link.
See https://github.com/Nasdaq/data-link-python for API.
See `NDLData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.NDLData.set_custom_settings(
... api_key="YOUR_KEY"
... )
```
* Pull a dataset:
```pycon
>>> data = vbt.NDLData.pull(
... "FRED/GDP",
... start="2001-12-31",
... end="2005-12-31"
... )
```
* Pull a datatable:
```pycon
>>> data = vbt.NDLData.pull(
... "MER/F1",
... data_format="datatable",
... compnumber="39102",
... paginate=True
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.ndl")
@classmethod
def fetch_symbol(
cls,
symbol: str,
api_key: tp.Optional[str] = None,
data_format: tp.Optional[str] = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
tz: tp.TimezoneLike = None,
column_indices: tp.Optional[tp.MaybeIterable[int]] = None,
**params,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Nasdaq Data Link.
Args:
symbol (str): Symbol.
api_key (str): API key.
data_format (str): Data format.
Supported are "dataset" and "datatable".
start (any): Retrieve data rows on and after the specified start date.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): Retrieve data rows up to and including the specified end date.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
column_indices (int or iterable): Request one or more specific columns.
Column 0 is the date column and is always returned. Data begins at column 1.
**params: Keyword arguments sent as field/value params to Nasdaq Data Link with no interference.
For defaults, see `custom.ndl` in `vectorbtpro._settings.data`.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("nasdaqdatalink")
import nasdaqdatalink
api_key = cls.resolve_custom_setting(api_key, "api_key")
data_format = cls.resolve_custom_setting(data_format, "data_format")
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
tz = cls.resolve_custom_setting(tz, "tz")
column_indices = cls.resolve_custom_setting(column_indices, "column_indices")
if column_indices is not None:
if isinstance(column_indices, int):
dataset = symbol + "." + str(column_indices)
else:
dataset = [symbol + "." + str(index) for index in column_indices]
else:
dataset = symbol
params = cls.resolve_custom_setting(params, "params", merge=True)
# Establish the timestamps
if start is not None:
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
start_date = pd.Timestamp(start).isoformat()
if "start_date" not in params:
params["start_date"] = start_date
else:
start_date = None
if end is not None:
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
end_date = pd.Timestamp(end).isoformat()
if "end_date" not in params:
params["end_date"] = end_date
else:
end_date = None
# Collect and format the data
if data_format.lower() == "dataset":
df = nasdaqdatalink.get(
dataset,
api_key=api_key,
**params,
)
else:
df = nasdaqdatalink.get_table(
dataset,
api_key=api_key,
**params,
)
new_columns = []
for c in df.columns:
new_c = c
if isinstance(symbol, str):
new_c = new_c.replace(symbol + " - ", "")
if new_c == "Last":
new_c = "Close"
new_columns.append(new_c)
df = df.rename(columns=dict(zip(df.columns, new_columns)))
if df.index.name == "None":
df.index.name = None
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize("utc")
if isinstance(df.index, pd.DatetimeIndex) and not df.empty:
if start is not None:
start = dt.to_timestamp(start, tz=df.index.tz)
if df.index[0] < start:
df = df[df.index >= start]
if end is not None:
end = dt.to_timestamp(end, tz=df.index.tz)
if df.index[-1] >= end:
df = df[df.index < end]
return df, dict(tz=tz)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `ParquetData`."""
import re
from pathlib import Path
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.file import FileData
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"ParquetData",
]
__pdoc__ = {}
ParquetDataT = tp.TypeVar("ParquetDataT", bound="ParquetData")
class ParquetData(FileData):
"""Data class for fetching Parquet data using PyArrow or FastParquet."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.parquet")
@classmethod
def is_parquet_file(cls, path: tp.PathLike) -> bool:
"""Return whether the path is a Parquet file."""
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_file() and ".parquet" in path.suffixes:
return True
return False
@classmethod
def is_parquet_group_dir(cls, path: tp.PathLike) -> bool:
"""Return whether the path is a directory that is a group of Parquet partitions.
!!! note
Assumes the Hive partitioning scheme."""
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_dir():
partition_regex = r"^(.+)=(.+)"
if re.match(partition_regex, path.name):
for p in path.iterdir():
if cls.is_parquet_group_dir(p) or cls.is_parquet_file(p):
return True
return False
@classmethod
def is_parquet_dir(cls, path: tp.PathLike) -> bool:
"""Return whether the path is a directory that is a group itself or
contains groups of Parquet partitions."""
if cls.is_parquet_group_dir(path):
return True
if not isinstance(path, Path):
path = Path(path)
if path.exists() and path.is_dir():
for p in path.iterdir():
if cls.is_parquet_group_dir(p):
return True
return False
@classmethod
def is_dir_match(cls, path: tp.PathLike) -> bool:
return cls.is_parquet_dir(path)
@classmethod
def is_file_match(cls, path: tp.PathLike) -> bool:
return cls.is_parquet_file(path)
@classmethod
def list_partition_cols(cls, path: tp.PathLike) -> tp.List[str]:
"""List partitioning columns under a path.
!!! note
Assumes the Hive partitioning scheme."""
if not isinstance(path, Path):
path = Path(path)
partition_cols = []
found_last_level = False
while not found_last_level:
found_new_level = False
for p in path.iterdir():
if cls.is_parquet_group_dir(p):
partition_cols.append(p.name.split("=")[0])
path = p
found_new_level = True
break
if not found_new_level:
found_last_level = True
return partition_cols
@classmethod
def is_default_partition_col(cls, level: str) -> bool:
"""Return whether a partitioning column is a default partitioning column."""
return re.match(r"^(\bgroup\b)|(group_\d+)", level) is not None
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
paths: tp.Any = None,
) -> tp.Kwargs:
keys_meta = FileData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None and paths is None:
keys_meta["keys"] = cls.list_paths()
return keys_meta
@classmethod
def fetch_key(
cls,
key: tp.Key,
path: tp.Any = None,
tz: tp.TimezoneLike = None,
squeeze: tp.Optional[bool] = None,
keep_partition_cols: tp.Optional[bool] = None,
engine: tp.Optional[str] = None,
**read_kwargs,
) -> tp.KeyData:
"""Fetch the Parquet file of a feature or symbol.
Args:
key (hashable): Feature or symbol.
path (str): Path.
If `path` is None, uses `key` as the path to the Parquet file.
tz (any): Target timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
keep_partition_cols (bool): Whether to return partitioning columns (if any).
If None, will remove any partitioning column that is "group" or "group_{index}".
Retrieves the list of partitioning columns with `ParquetData.list_partition_cols`.
engine (str): See `pd.read_parquet`.
**read_kwargs: Other keyword arguments passed to `pd.read_parquet`.
See https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html for other arguments.
For defaults, see `custom.parquet` in `vectorbtpro._settings.data`."""
from vectorbtpro.utils.module_ import assert_can_import, assert_can_import_any
tz = cls.resolve_custom_setting(tz, "tz")
squeeze = cls.resolve_custom_setting(squeeze, "squeeze")
keep_partition_cols = cls.resolve_custom_setting(keep_partition_cols, "keep_partition_cols")
engine = cls.resolve_custom_setting(engine, "engine")
read_kwargs = cls.resolve_custom_setting(read_kwargs, "read_kwargs", merge=True)
if engine == "pyarrow":
assert_can_import("pyarrow")
elif engine == "fastparquet":
assert_can_import("fastparquet")
elif engine == "auto":
assert_can_import_any("pyarrow", "fastparquet")
else:
raise ValueError(f"Invalid engine: '{engine}'")
if path is None:
path = key
obj = pd.read_parquet(path, engine=engine, **read_kwargs)
if keep_partition_cols in (None, False):
if cls.is_parquet_dir(path):
drop_columns = []
partition_cols = cls.list_partition_cols(path)
for col in obj.columns:
if col in partition_cols:
if keep_partition_cols is False or cls.is_default_partition_col(col):
drop_columns.append(col)
obj = obj.drop(drop_columns, axis=1)
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
return obj, dict(tz=tz)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Fetch the Parquet file of a feature.
Uses `ParquetData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Fetch the Parquet file of a symbol.
Uses `ParquetData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `ParquetData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `ParquetData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `PolygonData`."""
import time
import traceback
from functools import wraps, partial
import pandas as pd
import requests
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from polygon import RESTClient as PolygonClientT
except ImportError:
PolygonClientT = "PolygonClient"
__all__ = [
"PolygonData",
]
PolygonDataT = tp.TypeVar("PolygonDataT", bound="PolygonData")
class PolygonData(RemoteData):
"""Data class for fetching from Polygon.
See https://github.com/polygon-io/client-python for API.
See `PolygonData.fetch_symbol` for arguments.
Usage:
* Set up the API key globally:
```pycon
>>> from vectorbtpro import *
>>> vbt.PolygonData.set_custom_settings(
... client_config=dict(
... api_key="YOUR_KEY"
... )
... )
```
* Pull stock data:
```pycon
>>> data = vbt.PolygonData.pull(
... "AAPL",
... start="2021-01-01",
... end="2022-01-01",
... timeframe="1 day"
... )
```
* Pull crypto data:
```pycon
>>> data = vbt.PolygonData.pull(
... "X:BTCUSD",
... start="2021-01-01",
... end="2022-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.polygon")
@classmethod
def list_symbols(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
client: tp.Optional[PolygonClientT] = None,
client_config: tp.DictLike = None,
**list_tickers_kwargs,
) -> tp.List[str]:
"""List all symbols.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.
For supported keyword arguments, see `polygon.RESTClient.list_tickers`."""
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
all_symbols = []
for ticker in client.list_tickers(**list_tickers_kwargs):
symbol = ticker.ticker
if pattern is not None:
if not cls.key_match(symbol, pattern, use_regex=use_regex):
continue
all_symbols.append(symbol)
if sort:
return sorted(dict.fromkeys(all_symbols))
return list(dict.fromkeys(all_symbols))
@classmethod
def resolve_client(cls, client: tp.Optional[PolygonClientT] = None, **client_config) -> PolygonClientT:
"""Resolve the client.
If provided, must be of the type `polygon.rest.RESTClient`.
Otherwise, will be created using `client_config`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("polygon")
from polygon import RESTClient
client = cls.resolve_custom_setting(client, "client")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
client = RESTClient(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[PolygonClientT] = None,
client_config: tp.DictLike = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
adjusted: tp.Optional[bool] = None,
limit: tp.Optional[int] = None,
params: tp.KwargsLike = None,
delay: tp.Optional[float] = None,
retries: tp.Optional[int] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Polygon.
Args:
symbol (str): Symbol.
Supports the following APIs:
* Stocks and equities
* Currencies - symbol must have the prefix `C:`
* Crypto - symbol must have the prefix `X:`
client (polygon.rest.RESTClient): Client.
See `PolygonData.resolve_client`.
client_config (dict): Client config.
See `PolygonData.resolve_client`.
start (any): The start of the aggregate time window.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): The end of the aggregate time window.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
adjusted (str): Whether the results are adjusted for splits.
By default, results are adjusted.
Set this to False to get results that are NOT adjusted for splits.
limit (int): Limits the number of base aggregates queried to create the aggregate results.
Max 50000 and Default 5000.
params (dict): Any additional query params.
delay (float): Time to sleep after each request (in seconds).
retries (int): The number of retries on failure to fetch data.
show_progress (bool): Whether to show the progress bar.
pbar_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`.
silence_warnings (bool): Whether to silence all warnings.
For defaults, see `custom.polygon` in `vectorbtpro._settings.data`.
!!! note
If you're using a free plan that has an API rate limit of several requests per minute,
make sure to set `delay` to a higher number, such as 12 (which makes 5 requests per minute).
"""
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
adjusted = cls.resolve_custom_setting(adjusted, "adjusted")
limit = cls.resolve_custom_setting(limit, "limit")
params = cls.resolve_custom_setting(params, "params", merge=True)
delay = cls.resolve_custom_setting(delay, "delay")
retries = cls.resolve_custom_setting(retries, "retries")
show_progress = cls.resolve_custom_setting(show_progress, "show_progress")
pbar_kwargs = cls.resolve_custom_setting(pbar_kwargs, "pbar_kwargs", merge=True)
if "bar_id" not in pbar_kwargs:
pbar_kwargs["bar_id"] = "polygon"
silence_warnings = cls.resolve_custom_setting(silence_warnings, "silence_warnings")
# Resolve the timeframe
if not isinstance(timeframe, str):
raise ValueError(f"Invalid timeframe: '{timeframe}'")
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "m":
unit = "minute"
elif unit == "h":
unit = "hour"
elif unit == "D":
unit = "day"
elif unit == "W":
unit = "week"
elif unit == "M":
unit = "month"
elif unit == "Q":
unit = "quarter"
elif unit == "Y":
unit = "year"
# Establish the timestamps
if start is not None:
start_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc"))
else:
start_ts = None
if end is not None:
end_ts = dt.datetime_to_ms(dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc"))
else:
end_ts = None
prev_end_ts = None
def _retry(method):
@wraps(method)
def retry_method(*args, **kwargs):
for i in range(retries):
try:
return method(*args, **kwargs)
except requests.exceptions.HTTPError as e:
if isinstance(e, requests.exceptions.HTTPError) and e.response.status_code == 429:
if not silence_warnings:
warn(traceback.format_exc())
# Polygon.io API rate limit is per minute
warn("Waiting 1 minute...")
time.sleep(60)
else:
raise e
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
if i == retries - 1:
raise e
if not silence_warnings:
warn(traceback.format_exc())
if delay is not None:
time.sleep(delay)
return retry_method
def _postprocess(agg):
return dict(
o=agg.open,
h=agg.high,
l=agg.low,
c=agg.close,
v=agg.volume,
vw=agg.vwap,
t=agg.timestamp,
n=agg.transactions,
)
@_retry
def _fetch(_start_ts, _limit):
return list(
map(
_postprocess,
client.get_aggs(
ticker=symbol,
multiplier=multiplier,
timespan=unit,
from_=_start_ts,
to=end_ts,
adjusted=adjusted,
sort="asc",
limit=_limit,
params=params,
raw=False,
),
)
)
def _ts_to_str(ts: tp.Optional[int]) -> str:
if ts is None:
return "?"
return dt.readable_datetime(pd.Timestamp(ts, unit="ms", tz="utc"), freq=timeframe)
def _filter_func(d: tp.Dict, _prev_end_ts: tp.Optional[int] = None) -> bool:
if start_ts is not None:
if d["t"] < start_ts:
return False
if _prev_end_ts is not None:
if d["t"] <= _prev_end_ts:
return False
if end_ts is not None:
if d["t"] >= end_ts:
return False
return True
# Iteratively collect the data
data = []
try:
with ProgressBar(show_progress=show_progress, **pbar_kwargs) as pbar:
pbar.set_description("{} → ?".format(_ts_to_str(start_ts if prev_end_ts is None else prev_end_ts)))
while True:
# Fetch the klines for the next timeframe
next_data = _fetch(start_ts if prev_end_ts is None else prev_end_ts, limit)
next_data = list(filter(partial(_filter_func, _prev_end_ts=prev_end_ts), next_data))
# Update the timestamps and the progress bar
if not len(next_data):
break
data += next_data
if start_ts is None:
start_ts = next_data[0]["t"]
pbar.set_description("{} → {}".format(_ts_to_str(start_ts), _ts_to_str(next_data[-1]["t"])))
pbar.update()
prev_end_ts = next_data[-1]["t"]
if end_ts is not None and prev_end_ts >= end_ts:
break
if delay is not None:
time.sleep(delay) # be kind to api
except Exception as e:
if not silence_warnings:
warn(traceback.format_exc())
warn(
f"Symbol '{str(symbol)}' raised an exception. Returning incomplete data. "
"Use update() method to fetch missing data."
)
df = pd.DataFrame(data)
df = df[["t", "o", "h", "l", "c", "v", "n", "vw"]]
df = df.rename(
columns={
"t": "Open time",
"o": "Open",
"h": "High",
"l": "Low",
"c": "Close",
"v": "Volume",
"n": "Trade count",
"vw": "VWAP",
}
)
df.index = pd.to_datetime(df["Open time"], unit="ms", utc=True)
del df["Open time"]
if "Open" in df.columns:
df["Open"] = df["Open"].astype(float)
if "High" in df.columns:
df["High"] = df["High"].astype(float)
if "Low" in df.columns:
df["Low"] = df["Low"].astype(float)
if "Close" in df.columns:
df["Close"] = df["Close"].astype(float)
if "Volume" in df.columns:
df["Volume"] = df["Volume"].astype(float)
if "Trade count" in df.columns:
df["Trade count"] = df["Trade count"].astype(int, errors="ignore")
if "VWAP" in df.columns:
df["VWAP"] = df["VWAP"].astype(float)
return df, dict(tz=tz, freq=timeframe)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RandomOHLCData`."""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import broadcast_array_to
from vectorbtpro.data import nb
from vectorbtpro.data.custom.synthetic import SyntheticData
from vectorbtpro.ohlcv import nb as ohlcv_nb
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.random_ import set_seed
from vectorbtpro.utils.template import substitute_templates
__all__ = [
"RandomOHLCData",
]
__pdoc__ = {}
class RandomOHLCData(SyntheticData):
"""`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_random_data_1d_nb`
and then resampled using `vectorbtpro.ohlcv.nb.ohlc_every_1d_nb`."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.random_ohlc")
@classmethod
def generate_symbol(
cls,
symbol: tp.Symbol,
index: tp.Index,
n_ticks: tp.Optional[tp.ArrayLike] = None,
start_value: tp.Optional[float] = None,
mean: tp.Optional[float] = None,
std: tp.Optional[float] = None,
symmetric: tp.Optional[bool] = None,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.SymbolData:
"""Generate a symbol.
Args:
symbol (hashable): Symbol.
index (pd.Index): Pandas index.
n_ticks (int or array_like): Number of ticks per bar.
Flexible argument. Can be a template with a context containing `symbol` and `index`.
start_value (float): Value at time 0.
Does not appear as the first value in the output data.
mean (float): Drift, or mean of the percentage change.
std (float): Standard deviation of the percentage change.
symmetric (bool): Whether to diminish negative returns and make them symmetric to positive ones.
seed (int): Seed to make output deterministic.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
template_context (dict): Template context.
For defaults, see `custom.random_ohlc` in `vectorbtpro._settings.data`.
!!! note
When setting a seed, remember to pass a seed per symbol using `vectorbtpro.data.base.symbol_dict`.
"""
n_ticks = cls.resolve_custom_setting(n_ticks, "n_ticks")
template_context = merge_dicts(dict(symbol=symbol, index=index), template_context)
n_ticks = substitute_templates(n_ticks, template_context, eval_id="n_ticks")
n_ticks = broadcast_array_to(n_ticks, len(index))
start_value = cls.resolve_custom_setting(start_value, "start_value")
mean = cls.resolve_custom_setting(mean, "mean")
std = cls.resolve_custom_setting(std, "std")
symmetric = cls.resolve_custom_setting(symmetric, "symmetric")
seed = cls.resolve_custom_setting(seed, "seed")
if seed is not None:
set_seed(seed)
func = jit_reg.resolve_option(nb.generate_random_data_1d_nb, jitted)
ticks = func(np.sum(n_ticks), start_value=start_value, mean=mean, std=std, symmetric=symmetric)
func = jit_reg.resolve_option(ohlcv_nb.ohlc_every_1d_nb, jitted)
out = func(ticks, n_ticks)
return pd.DataFrame(out, index=index, columns=["Open", "High", "Low", "Close"])
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
_ = fetch_kwargs.pop("start_value", None)
start_value = self.data[symbol]["Open"].iloc[-1]
fetch_kwargs["seed"] = None
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, start_value=start_value, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RandomData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_1d_array
from vectorbtpro.data import nb
from vectorbtpro.data.custom.synthetic import SyntheticData
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.random_ import set_seed
__all__ = [
"RandomData",
]
__pdoc__ = {}
class RandomData(SyntheticData):
"""`SyntheticData` for data generated using `vectorbtpro.data.nb.generate_random_data_nb`."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.random")
@classmethod
def generate_key(
cls,
key: tp.Key,
index: tp.Index,
columns: tp.Union[tp.Hashable, tp.IndexLike] = None,
start_value: tp.Optional[float] = None,
mean: tp.Optional[float] = None,
std: tp.Optional[float] = None,
symmetric: tp.Optional[bool] = None,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
**kwargs,
) -> tp.KeyData:
"""Generate a feature or symbol.
Args:
key (hashable): Feature or symbol.
index (pd.Index): Pandas index.
columns (hashable or index_like): Column names.
Provide a single value (hashable) to make a Series.
start_value (float): Value at time 0.
Does not appear as the first value in the output data.
mean (float): Drift, or mean of the percentage change.
std (float): Standard deviation of the percentage change.
symmetric (bool): Whether to diminish negative returns and make them symmetric to positive ones.
seed (int): Seed to make output deterministic.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
For defaults, see `custom.random` in `vectorbtpro._settings.data`.
!!! note
When setting a seed, remember to pass a seed per feature/symbol using
`vectorbtpro.data.base.feature_dict`/`vectorbtpro.data.base.symbol_dict` or generally
`vectorbtpro.data.base.key_dict`.
"""
if checks.is_hashable(columns):
columns = [columns]
make_series = True
else:
make_series = False
if not isinstance(columns, pd.Index):
columns = pd.Index(columns)
start_value = cls.resolve_custom_setting(start_value, "start_value")
mean = cls.resolve_custom_setting(mean, "mean")
std = cls.resolve_custom_setting(std, "std")
symmetric = cls.resolve_custom_setting(symmetric, "symmetric")
seed = cls.resolve_custom_setting(seed, "seed")
if seed is not None:
set_seed(seed)
func = jit_reg.resolve_option(nb.generate_random_data_nb, jitted)
out = func(
(len(index), len(columns)),
start_value=to_1d_array(start_value),
mean=to_1d_array(mean),
std=to_1d_array(std),
symmetric=to_1d_array(symmetric),
)
if make_series:
return pd.Series(out[:, 0], index=index, name=columns[0])
return pd.DataFrame(out, index=index, columns=columns)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
fetch_kwargs = self.select_fetch_kwargs(key)
fetch_kwargs["start"] = self.select_last_index(key)
_ = fetch_kwargs.pop("start_value", None)
start_value = self.data[key].iloc[-2]
fetch_kwargs["seed"] = None
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, start_value=start_value, **kwargs)
return self.fetch_symbol(key, start_value=start_value, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RemoteData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.custom import CustomData
__all__ = [
"RemoteData",
]
__pdoc__ = {}
class RemoteData(CustomData):
"""Data class for fetching remote data."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.remote")
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `SQLData`."""
from typing import Iterator
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.db import DBData
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
try:
if not tp.TYPE_CHECKING:
raise ImportError
from sqlalchemy import Engine as EngineT, Selectable as SelectableT, Table as TableT
except ImportError:
EngineT = "Engine"
SelectableT = "Selectable"
TableT = "Table"
__all__ = [
"SQLData",
]
__pdoc__ = {}
SQLDataT = tp.TypeVar("SQLDataT", bound="SQLData")
class SQLData(DBData):
"""Data class for fetching data from a database using SQLAlchemy.
See https://www.sqlalchemy.org/ for the SQLAlchemy's API.
See https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html for the read method.
See `SQLData.pull` and `SQLData.fetch_key` for arguments.
Usage:
* Set up the engine settings globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.SQLData.set_engine_settings(
... engine_name="postgresql",
... populate_=True,
... engine="postgresql+psycopg2://...",
... engine_config=dict(),
... schema="public"
... )
```
* Pull tables:
```pycon
>>> data = vbt.SQLData.pull(
... ["TABLE1", "TABLE2"],
... engine="postgresql",
... start="2020-01-01",
... end="2021-01-01"
... )
```
* Pull queries:
```pycon
>>> data = vbt.SQLData.pull(
... ["SYMBOL1", "SYMBOL2"],
... query=vbt.key_dict({
... "SYMBOL1": "SELECT * FROM TABLE1",
... "SYMBOL2": "SELECT * FROM TABLE2"
... }),
... engine="postgresql"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.sql")
@classmethod
def get_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> dict:
"""`SQLData.get_custom_settings` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.get_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def has_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> bool:
"""`SQLData.has_custom_settings` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.has_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def get_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> tp.Any:
"""`SQLData.get_custom_setting` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.get_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def has_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> bool:
"""`SQLData.has_custom_setting` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.has_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def resolve_engine_setting(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> tp.Any:
"""`SQLData.resolve_custom_setting` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
return cls.resolve_custom_setting(*args, sub_path=sub_path, **kwargs)
@classmethod
def set_engine_settings(cls, *args, engine_name: tp.Optional[str] = None, **kwargs) -> None:
"""`SQLData.set_custom_settings` with `sub_path=engine_name`."""
if engine_name is not None:
sub_path = "engines." + engine_name
else:
sub_path = None
cls.set_custom_settings(*args, sub_path=sub_path, **kwargs)
@classmethod
def resolve_engine(
cls,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
return_meta: bool = False,
**engine_config,
) -> tp.Union[EngineT, dict]:
"""Resolve the engine.
Argument `engine` can be
1) an object of the type `sqlalchemy.engine.base.Engine`,
2) a URL of the engine as a string, which will be used to create an engine with
`sqlalchemy.engine.create.create_engine` and `engine_config` passed as keyword arguments
(you should not include `url` in the `engine_config`), or
3) an engine name, which is the name of a sub-config with engine settings under `custom.sql.engines`
in `vectorbtpro._settings.data`. Such a sub-config can then contain the actual engine as an object or a URL.
Argument `engine_name` can be provided instead of `engine`, or also together with `engine`
to pull other settings from a sub-config. URLs can also be used as engine names, but not the
other way around."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import create_engine
if engine is None and engine_name is None:
engine_name = cls.resolve_engine_setting(engine_name, "engine_name")
if engine_name is not None:
engine = cls.resolve_engine_setting(engine, "engine", engine_name=engine_name)
if engine is None:
raise ValueError("Must provide engine or URL (via engine argument)")
else:
engine = cls.resolve_engine_setting(engine, "engine")
if engine is None:
raise ValueError("Must provide engine or URL (via engine argument)")
if isinstance(engine, str):
engine_name = engine
else:
engine_name = None
if engine_name is not None:
if cls.has_engine_setting("engine", engine_name=engine_name, sub_path_only=True):
engine = cls.get_engine_setting("engine", engine_name=engine_name, sub_path_only=True)
has_engine_config = len(engine_config) > 0
engine_config = cls.resolve_engine_setting(engine_config, "engine_config", merge=True, engine_name=engine_name)
if isinstance(engine, str):
if engine.startswith("duckdb:"):
assert_can_import("duckdb_engine")
engine = create_engine(engine, **engine_config)
should_dispose = True
else:
if has_engine_config:
raise ValueError("Cannot apply engine_config to initialized created engine")
should_dispose = False
if return_meta:
return dict(
engine=engine,
engine_name=engine_name,
should_dispose=should_dispose,
)
return engine
@classmethod
def list_schemas(
cls,
pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
**kwargs,
) -> tp.List[str]:
"""List all schemas.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each symbol against `pattern`.
Keyword arguments `**kwargs` are passed to `inspector.get_schema_names`.
If `dispose_engine` is None, disposes the engine if it wasn't provided."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import inspect
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
dispose_engine = should_dispose
inspector = inspect(engine)
all_schemas = inspector.get_schema_names(**kwargs)
schemas = []
for schema in all_schemas:
if pattern is not None:
if not cls.key_match(schema, pattern, use_regex=use_regex):
continue
if schema == "information_schema":
continue
schemas.append(schema)
if dispose_engine:
engine.dispose()
if sort:
return sorted(dict.fromkeys(schemas))
return list(dict.fromkeys(schemas))
@classmethod
def list_tables(
cls,
*,
schema_pattern: tp.Optional[str] = None,
table_pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
schema: tp.Optional[str] = None,
incl_views: bool = True,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
**kwargs,
) -> tp.List[str]:
"""List all tables and views.
If `schema` is None, searches for all schema names in the database and prefixes each table
with the respective schema name (unless there's only one schema "main"). If `schema` is False,
sets the schema to None. If `schema` is provided, returns the tables corresponding to this
schema without a prefix.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each schema against
`schema_pattern` and each table against `table_pattern`.
Keyword arguments `**kwargs` are passed to `inspector.get_table_names`.
If `dispose_engine` is None, disposes the engine if it wasn't provided."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import inspect
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
dispose_engine = should_dispose
schema = cls.resolve_engine_setting(schema, "schema", engine_name=engine_name)
if schema is None:
schemas = cls.list_schemas(
pattern=schema_pattern,
use_regex=use_regex,
sort=sort,
engine=engine,
engine_name=engine_name,
**kwargs,
)
if len(schemas) == 0:
schemas = [None]
prefix_schema = False
elif len(schemas) == 1 and schemas[0] == "main":
prefix_schema = False
else:
prefix_schema = True
elif schema is False:
schemas = [None]
prefix_schema = False
else:
schemas = [schema]
prefix_schema = False
inspector = inspect(engine)
tables = []
for schema in schemas:
all_tables = inspector.get_table_names(schema, **kwargs)
if incl_views:
try:
all_tables += inspector.get_view_names(schema, **kwargs)
except NotImplementedError as e:
pass
try:
all_tables += inspector.get_materialized_view_names(schema, **kwargs)
except NotImplementedError as e:
pass
for table in all_tables:
if table_pattern is not None:
if not cls.key_match(table, table_pattern, use_regex=use_regex):
continue
if prefix_schema and schema is not None:
table = str(schema) + ":" + table
tables.append(table)
if dispose_engine:
engine.dispose()
if sort:
return sorted(dict.fromkeys(tables))
return list(dict.fromkeys(tables))
@classmethod
def has_schema(
cls,
schema: str,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> bool:
"""Check whether the database has a schema."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import inspect
if engine_config is None:
engine_config = {}
engine = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
**engine_config,
)
return inspect(engine).has_schema(schema)
@classmethod
def create_schema(
cls,
schema: str,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> None:
"""Create a schema if it doesn't exist yet."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy.schema import CreateSchema
if engine_config is None:
engine_config = {}
engine = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
**engine_config,
)
if not cls.has_schema(schema, engine=engine, engine_name=engine_name):
with engine.connect() as connection:
connection.execute(CreateSchema(schema))
connection.commit()
@classmethod
def has_table(
cls,
table: str,
schema: tp.Optional[str] = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> bool:
"""Check whether the database has a table."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import inspect
if engine_config is None:
engine_config = {}
engine = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
**engine_config,
)
return inspect(engine).has_table(table, schema=schema)
@classmethod
def get_table_relation(
cls,
table: str,
schema: tp.Optional[str] = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> TableT:
"""Get table relation."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import MetaData
if engine_config is None:
engine_config = {}
engine = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
**engine_config,
)
schema = cls.resolve_engine_setting(schema, "schema", engine_name=engine_name)
metadata_obj = MetaData()
metadata_obj.reflect(bind=engine, schema=schema, only=[table], views=True)
if schema is not None and schema + "." + table in metadata_obj.tables:
return metadata_obj.tables[schema + "." + table]
return metadata_obj.tables[table]
@classmethod
def get_last_row_number(
cls,
table: str,
schema: tp.Optional[str] = None,
row_number_column: tp.Optional[str] = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> TableT:
"""Get last row number."""
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
row_number_column = cls.resolve_engine_setting(
row_number_column,
"row_number_column",
engine_name=engine_name,
)
table_relation = cls.get_table_relation(table, schema=schema, engine=engine, engine_name=engine_name)
table_column_names = []
for column in table_relation.columns:
table_column_names.append(column.name)
if row_number_column not in table_column_names:
raise ValueError(f"Row number column '{row_number_column}' not found")
query = (
table_relation.select()
.with_only_columns(table_relation.columns.get(row_number_column))
.order_by(table_relation.columns.get(row_number_column).desc())
.limit(1)
)
with engine.connect() as connection:
results = connection.execute(query)
last_row_number = results.first()[0]
connection.commit()
return last_row_number
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
schema: tp.Optional[str] = None,
list_tables_kwargs: tp.KwargsLike = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
) -> tp.Kwargs:
keys_meta = DBData.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
if keys_meta["keys"] is None:
if cls.has_key_dict(schema):
raise ValueError("Cannot populate keys if schema is defined per key")
if cls.has_key_dict(list_tables_kwargs):
raise ValueError("Cannot populate keys if list_tables_kwargs is defined per key")
if cls.has_key_dict(engine):
raise ValueError("Cannot populate keys if engine is defined per key")
if cls.has_key_dict(engine_name):
raise ValueError("Cannot populate keys if engine_name is defined per key")
if cls.has_key_dict(engine_config):
raise ValueError("Cannot populate keys if engine_config is defined per key")
if list_tables_kwargs is None:
list_tables_kwargs = {}
keys_meta["keys"] = cls.list_tables(
schema=schema,
engine=engine,
engine_name=engine_name,
engine_config=engine_config,
**list_tables_kwargs,
)
return keys_meta
@classmethod
def pull(
cls: tp.Type[SQLDataT],
keys: tp.Union[tp.MaybeKeys] = None,
*,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[tp.MaybeFeatures] = None,
symbols: tp.Union[tp.MaybeSymbols] = None,
schema: tp.Optional[str] = None,
list_tables_kwargs: tp.KwargsLike = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
share_engine: tp.Optional[bool] = None,
**kwargs,
) -> SQLDataT:
"""Override `vectorbtpro.data.base.Data.pull` to resolve and share the engine among the keys
and use the table names available in the database in case no keys were provided."""
if share_engine is None:
if (
not cls.has_key_dict(engine)
and not cls.has_key_dict(engine_name)
and not cls.has_key_dict(engine_config)
):
share_engine = True
else:
share_engine = False
if share_engine:
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
dispose_engine = should_dispose
else:
engine_name = None
keys_meta = cls.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
schema=schema,
list_tables_kwargs=list_tables_kwargs,
engine=engine,
engine_name=engine_name,
engine_config=engine_config,
)
keys = keys_meta["keys"]
keys_are_features = keys_meta["keys_are_features"]
outputs = super(DBData, cls).pull(
keys,
keys_are_features=keys_are_features,
schema=schema,
engine=engine,
engine_name=engine_name,
engine_config=engine_config,
dispose_engine=False if share_engine else dispose_engine,
**kwargs,
)
if share_engine and dispose_engine:
engine.dispose()
return outputs
@classmethod
def fetch_key(
cls,
key: str,
table: tp.Union[None, str, TableT] = None,
schema: tp.Optional[str] = None,
query: tp.Union[None, str, SelectableT] = None,
engine: tp.Union[None, str, EngineT] = None,
engine_name: tp.Optional[str] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
start: tp.Optional[tp.Any] = None,
end: tp.Optional[tp.Any] = None,
align_dates: tp.Optional[bool] = None,
parse_dates: tp.Union[None, bool, tp.List[tp.IntStr], tp.Dict[tp.IntStr, tp.Any]] = None,
to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = None,
tz: tp.TimezoneLike = None,
start_row: tp.Optional[int] = None,
end_row: tp.Optional[int] = None,
keep_row_number: tp.Optional[bool] = None,
row_number_column: tp.Optional[str] = None,
index_col: tp.Union[None, bool, tp.MaybeList[tp.IntStr]] = None,
columns: tp.Optional[tp.MaybeList[tp.IntStr]] = None,
dtype: tp.Union[None, tp.DTypeLike, tp.Dict[tp.IntStr, tp.DTypeLike]] = None,
chunksize: tp.Optional[int] = None,
chunk_func: tp.Optional[tp.Callable] = None,
squeeze: tp.Optional[bool] = None,
**read_sql_kwargs,
) -> tp.KeyData:
"""Fetch a feature or symbol from a SQL database.
Can use a table name (which defaults to the key) or a custom query.
Args:
key (str): Feature or symbol.
If `table` and `query` are both None, becomes the table name.
Key can be in the `SCHEMA:TABLE` format, in this case `schema` argument will be ignored.
table (str or Table): Table name or actual object.
Cannot be used together with `query`.
schema (str): Schema.
Cannot be used together with `query`.
query (str or Selectable): Custom query.
Cannot be used together with `table` and `schema`.
engine (str or object): See `SQLData.resolve_engine`.
engine_name (str): See `SQLData.resolve_engine`.
engine_config (dict): See `SQLData.resolve_engine`.
dispose_engine (bool): See `SQLData.resolve_engine`.
start (any): Start datetime (if datetime index) or any other start value.
Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True
and the index is a datetime index. Otherwise, you must ensure the correct type is provided.
If the index is a multi-index, start value must be a tuple.
Cannot be used together with `query`. Include the condition into the query.
end (any): End datetime (if datetime index) or any other end value.
Will parse with `vectorbtpro.utils.datetime_.to_timestamp` if `align_dates` is True
and the index is a datetime index. Otherwise, you must ensure the correct type is provided.
If the index is a multi-index, end value must be a tuple.
Cannot be used together with `query`. Include the condition into the query.
align_dates (bool): Whether to align `start` and `end` to the timezone of the index.
Will pull one row (using `LIMIT 1`) and use `SQLData.prepare_dt` to get the index.
parse_dates (bool, list, or dict): Whether to parse dates and how to do it.
If `query` is not used, will get mapped into column names. Otherwise,
usage of integers is not allowed and column names directly must be used.
If enabled, will also try to parse the datetime columns that couldn't be parsed
by Pandas after the object has been fetched.
For dict format, see `pd.read_sql_query`.
to_utc (bool, str, or sequence of str): See `SQLData.prepare_dt`.
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
start_row (int): Start row.
Table must contain the column defined in `row_number_column`.
Cannot be used together with `query`. Include the condition into the query.
end_row (int): End row.
Table must contain the column defined in `row_number_column`.
Cannot be used together with `query`. Include the condition into the query.
keep_row_number (bool): Whether to return the column defined in `row_number_column`.
row_number_column (str): Name of the column with row numbers.
index_col (int, str, or list): One or more columns that should become the index.
If `query` is not used, will get mapped into column names. Otherwise,
usage of integers is not allowed and column names directly must be used.
columns (int, str, or list): One or more columns to select.
Will get mapped into column names. Cannot be used together with `query`.
dtype (dtype_like or dict): Data type of each column.
If `query` is not used, will get mapped into column names. Otherwise,
usage of integers is not allowed and column names directly must be used.
For dict format, see `pd.read_sql_query`.
chunksize (int): See `pd.read_sql_query`.
chunk_func (callable): Function to select and concatenate chunks from `Iterator`.
Gets called only if `chunksize` is set.
squeeze (int): Whether to squeeze a DataFrame with one column into a Series.
**read_sql_kwargs: Other keyword arguments passed to `pd.read_sql_query`.
See https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html for other arguments.
For defaults, see `custom.sql` in `vectorbtpro._settings.data`.
Global settings can be provided per engine name using the `engines` dictionary.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from sqlalchemy import Selectable, Select, FromClause, and_, text
if engine_config is None:
engine_config = {}
engine_meta = cls.resolve_engine(
engine=engine,
engine_name=engine_name,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
dispose_engine = should_dispose
if table is not None and query is not None:
raise ValueError("Must provide either table name or query, not both")
if schema is not None and query is not None:
raise ValueError("Schema cannot be applied to custom queries")
if table is None and query is None:
if ":" in key:
schema, table = key.split(":")
else:
table = key
start = cls.resolve_engine_setting(start, "start", engine_name=engine_name)
end = cls.resolve_engine_setting(end, "end", engine_name=engine_name)
align_dates = cls.resolve_engine_setting(align_dates, "align_dates", engine_name=engine_name)
parse_dates = cls.resolve_engine_setting(parse_dates, "parse_dates", engine_name=engine_name)
to_utc = cls.resolve_engine_setting(to_utc, "to_utc", engine_name=engine_name)
tz = cls.resolve_engine_setting(tz, "tz", engine_name=engine_name)
start_row = cls.resolve_engine_setting(start_row, "start_row", engine_name=engine_name)
end_row = cls.resolve_engine_setting(end_row, "end_row", engine_name=engine_name)
keep_row_number = cls.resolve_engine_setting(keep_row_number, "keep_row_number", engine_name=engine_name)
row_number_column = cls.resolve_engine_setting(row_number_column, "row_number_column", engine_name=engine_name)
index_col = cls.resolve_engine_setting(index_col, "index_col", engine_name=engine_name)
columns = cls.resolve_engine_setting(columns, "columns", engine_name=engine_name)
dtype = cls.resolve_engine_setting(dtype, "dtype", engine_name=engine_name)
chunksize = cls.resolve_engine_setting(chunksize, "chunksize", engine_name=engine_name)
chunk_func = cls.resolve_engine_setting(chunk_func, "chunk_func", engine_name=engine_name)
squeeze = cls.resolve_engine_setting(squeeze, "squeeze", engine_name=engine_name)
read_sql_kwargs = cls.resolve_engine_setting(
read_sql_kwargs, "read_sql_kwargs", merge=True, engine_name=engine_name
)
if query is None or isinstance(query, (Selectable, FromClause)):
if query is None:
if isinstance(table, str):
table = cls.get_table_relation(table, schema=schema, engine=engine, engine_name=engine_name)
else:
table = query
table_column_names = []
for column in table.columns:
table_column_names.append(column.name)
def _resolve_columns(c):
if checks.is_int(c):
c = table_column_names[int(c)]
elif not isinstance(c, str):
new_c = []
for _c in c:
if checks.is_int(_c):
new_c.append(table_column_names[int(_c)])
else:
if _c not in table_column_names:
for __c in table_column_names:
if _c.lower() == __c.lower():
_c = __c
break
new_c.append(_c)
c = new_c
else:
if c not in table_column_names:
for _c in table_column_names:
if c.lower() == _c.lower():
return _c
return c
if index_col is False:
index_col = None
if index_col is not None:
index_col = _resolve_columns(index_col)
if isinstance(index_col, str):
index_col = [index_col]
if columns is not None:
columns = _resolve_columns(columns)
if isinstance(columns, str):
columns = [columns]
if parse_dates is not None:
if not isinstance(parse_dates, bool):
if isinstance(parse_dates, dict):
parse_dates = dict(zip(_resolve_columns(parse_dates.keys()), parse_dates.values()))
else:
parse_dates = _resolve_columns(parse_dates)
if isinstance(parse_dates, str):
parse_dates = [parse_dates]
if dtype is not None:
if isinstance(dtype, dict):
dtype = dict(zip(_resolve_columns(dtype.keys()), dtype.values()))
if not isinstance(table, Select):
query = table.select()
else:
query = table
if index_col is not None:
for col in index_col:
query = query.order_by(col)
if index_col is not None and columns is not None:
pre_columns = []
for col in index_col:
if col not in columns:
pre_columns.append(col)
columns = pre_columns + columns
if keep_row_number and columns is not None:
if row_number_column in table_column_names and row_number_column not in columns:
columns = [row_number_column] + columns
elif not keep_row_number and columns is None:
if row_number_column in table_column_names:
columns = [col for col in table_column_names if col != row_number_column]
if columns is not None:
query = query.with_only_columns(*[table.columns.get(c) for c in columns])
def _to_native_type(x):
if checks.is_np_scalar(x):
return x.item()
return x
if start_row is not None or end_row is not None:
if start is not None or end is not None:
raise ValueError("Can either filter by row numbers or by index, not both")
_row_number_column = table.columns.get(row_number_column)
if _row_number_column is None:
raise ValueError(f"Row number column '{row_number_column}' not found")
and_list = []
if start_row is not None:
and_list.append(_row_number_column >= _to_native_type(start_row))
if end_row is not None:
and_list.append(_row_number_column < _to_native_type(end_row))
query = query.where(and_(*and_list))
if start is not None or end is not None:
if index_col is None:
raise ValueError("Must provide index column for filtering by start and end")
if align_dates:
first_obj = pd.read_sql_query(
query.limit(1),
engine.connect(),
index_col=index_col,
parse_dates=None if isinstance(parse_dates, bool) else parse_dates, # bool not accepted
dtype=dtype,
chunksize=None,
**read_sql_kwargs,
)
first_obj = cls.prepare_dt(
first_obj,
parse_dates=list(parse_dates) if isinstance(parse_dates, dict) else parse_dates,
to_utc=False,
)
if isinstance(first_obj.index, pd.DatetimeIndex):
if tz is None:
tz = first_obj.index.tz
if first_obj.index.tz is not None:
if start is not None:
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz=first_obj.index.tz)
if end is not None:
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz=first_obj.index.tz)
else:
if start is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and first_obj.index.name in to_utc)
):
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz="utc")
start = dt.to_naive_datetime(start)
else:
start = dt.to_naive_datetime(start, tz=tz)
if end is not None:
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and first_obj.index.name in to_utc)
):
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz="utc")
end = dt.to_naive_datetime(end)
else:
end = dt.to_naive_datetime(end, tz=tz)
and_list = []
if start is not None:
if len(index_col) > 1:
if not isinstance(start, tuple):
raise TypeError("Start must be a tuple if the index is a multi-index")
if len(start) != len(index_col):
raise ValueError("Start tuple must match the number of levels in the multi-index")
for i in range(len(index_col)):
index_column = table.columns.get(index_col[i])
and_list.append(index_column >= _to_native_type(start[i]))
else:
index_column = table.columns.get(index_col[0])
and_list.append(index_column >= _to_native_type(start))
if end is not None:
if len(index_col) > 1:
if not isinstance(end, tuple):
raise TypeError("End must be a tuple if the index is a multi-index")
if len(end) != len(index_col):
raise ValueError("End tuple must match the number of levels in the multi-index")
for i in range(len(index_col)):
index_column = table.columns.get(index_col[i])
and_list.append(index_column < _to_native_type(end[i]))
else:
index_column = table.columns.get(index_col[0])
and_list.append(index_column < _to_native_type(end))
query = query.where(and_(*and_list))
else:
def _check_columns(c, arg_name):
if checks.is_int(c):
raise ValueError(f"Must provide column as a string for '{arg_name}'")
elif not isinstance(c, str):
for _c in c:
if checks.is_int(_c):
raise ValueError(f"Must provide each column as a string for '{arg_name}'")
if start is not None:
raise ValueError("Start cannot be applied to custom queries")
if end is not None:
raise ValueError("End cannot be applied to custom queries")
if start_row is not None:
raise ValueError("Start row cannot be applied to custom queries")
if end_row is not None:
raise ValueError("End row cannot be applied to custom queries")
if index_col is False:
index_col = None
if index_col is not None:
_check_columns(index_col, "index_col")
if isinstance(index_col, str):
index_col = [index_col]
if columns is not None:
raise ValueError("Columns cannot be applied to custom queries")
if parse_dates is not None:
if not isinstance(parse_dates, bool):
if isinstance(parse_dates, dict):
_check_columns(parse_dates.keys(), "parse_dates")
else:
_check_columns(parse_dates, "parse_dates")
if isinstance(parse_dates, str):
parse_dates = [parse_dates]
if dtype is not None:
_check_columns(dtype.keys(), "dtype")
if isinstance(query, str):
query = text(query)
obj = pd.read_sql_query(
query,
engine.connect(),
index_col=index_col,
parse_dates=None if isinstance(parse_dates, bool) else parse_dates, # bool not accepted
dtype=dtype,
chunksize=chunksize,
**read_sql_kwargs,
)
if isinstance(obj, Iterator):
if chunk_func is None:
obj = pd.concat(list(obj), axis=0)
else:
obj = chunk_func(obj)
obj = cls.prepare_dt(
obj,
parse_dates=list(parse_dates) if isinstance(parse_dates, dict) else parse_dates,
to_utc=to_utc,
)
if not isinstance(obj.index, pd.MultiIndex):
if obj.index.name == "index":
obj.index.name = None
if isinstance(obj.index, pd.DatetimeIndex) and tz is None:
tz = obj.index.tz
if isinstance(obj, pd.DataFrame) and squeeze:
obj = obj.squeeze("columns")
if isinstance(obj, pd.Series) and obj.name == "0":
obj.name = None
if dispose_engine:
engine.dispose()
if keep_row_number:
return obj, dict(tz=tz, row_number_column=row_number_column)
return obj, dict(tz=tz)
@classmethod
def fetch_feature(cls, feature: str, **kwargs) -> tp.FeatureData:
"""Fetch the table of a feature.
Uses `SQLData.fetch_key`."""
return cls.fetch_key(feature, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: str, **kwargs) -> tp.SymbolData:
"""Fetch the table for a symbol.
Uses `SQLData.fetch_key`."""
return cls.fetch_key(symbol, **kwargs)
def update_key(
self,
key: str,
from_last_row: tp.Optional[bool] = None,
from_last_index: tp.Optional[bool] = None,
**kwargs,
) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
returned_kwargs = self.select_returned_kwargs(key)
pre_kwargs = merge_dicts(fetch_kwargs, kwargs)
if from_last_row is None:
if pre_kwargs.get("query", None) is not None:
from_last_row = False
elif from_last_index is True:
from_last_row = False
elif pre_kwargs.get("start", None) is not None or pre_kwargs.get("end", None) is not None:
from_last_row = False
elif "row_number_column" not in returned_kwargs:
from_last_row = False
elif returned_kwargs["row_number_column"] not in self.wrapper.columns:
from_last_row = False
else:
from_last_row = True
if from_last_index is None:
if pre_kwargs.get("query", None) is not None:
from_last_index = False
elif from_last_row is True:
from_last_index = False
elif pre_kwargs.get("start_row", None) is not None or pre_kwargs.get("end_row", None) is not None:
from_last_index = False
else:
from_last_index = True
if from_last_row:
if "row_number_column" not in returned_kwargs:
raise ValueError("Argument row_number_column must be in returned_kwargs for from_last_row")
row_number_column = returned_kwargs["row_number_column"]
fetch_kwargs["start_row"] = self.data[key][row_number_column].iloc[-1]
if from_last_index:
fetch_kwargs["start"] = self.select_last_index(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if self.feature_oriented:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: str, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `SQLData.update_key`."""
return self.update_key(feature, **kwargs)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `SQLData.update_key`."""
return self.update_key(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `SyntheticData`."""
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.custom import CustomData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"SyntheticData",
]
__pdoc__ = {}
class SyntheticData(CustomData):
"""Data class for fetching synthetic data.
Exposes an abstract class method `SyntheticData.generate_symbol`.
Everything else is taken care of."""
_settings_path: tp.SettingsPath = dict(custom="data.custom.synthetic")
@classmethod
def generate_key(cls, key: tp.Key, index: tp.Index, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Abstract method to generate data of a feature or symbol."""
raise NotImplementedError
@classmethod
def generate_feature(cls, feature: tp.Feature, index: tp.Index, **kwargs) -> tp.FeatureData:
"""Abstract method to generate data of a feature.
Uses `SyntheticData.generate_key` with `key_is_feature=True`."""
return cls.generate_key(feature, index, key_is_feature=True, **kwargs)
@classmethod
def generate_symbol(cls, symbol: tp.Symbol, index: tp.Index, **kwargs) -> tp.SymbolData:
"""Abstract method to generate data for a symbol.
Uses `SyntheticData.generate_key` with `key_is_feature=False`."""
return cls.generate_key(symbol, index, key_is_feature=False, **kwargs)
@classmethod
def fetch_key(
cls,
key: tp.Symbol,
key_is_feature: bool = False,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
periods: tp.Optional[int] = None,
timeframe: tp.Optional[tp.FrequencyLike] = None,
tz: tp.TimezoneLike = None,
normalize: tp.Optional[bool] = None,
inclusive: tp.Optional[str] = None,
**kwargs,
) -> tp.KeyData:
"""Generate data of a feature or symbol.
Generates datetime index using `vectorbtpro.utils.datetime_.date_range` and passes it to
`SyntheticData.generate_key` to fill the Series/DataFrame with generated data.
For defaults, see `custom.synthetic` in `vectorbtpro._settings.data`."""
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
normalize = cls.resolve_custom_setting(normalize, "normalize")
inclusive = cls.resolve_custom_setting(inclusive, "inclusive")
index = dt.date_range(
start=start,
end=end,
periods=periods,
freq=timeframe,
normalize=normalize,
inclusive=inclusive,
)
if tz is None:
tz = index.tz
if len(index) == 0:
raise ValueError("Date range is empty")
if key_is_feature:
return cls.generate_feature(key, index, **kwargs), dict(tz=tz, freq=timeframe)
return cls.generate_symbol(key, index, **kwargs), dict(tz=tz, freq=timeframe)
@classmethod
def fetch_feature(cls, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Generate data of a feature.
Uses `SyntheticData.fetch_key` with `key_is_feature=True`."""
return cls.fetch_key(feature, key_is_feature=True, **kwargs)
@classmethod
def fetch_symbol(cls, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Generate data for a symbol.
Uses `SyntheticData.fetch_key` with `key_is_feature=False`."""
return cls.fetch_key(symbol, key_is_feature=False, **kwargs)
def update_key(self, key: tp.Key, key_is_feature: bool = False, **kwargs) -> tp.KeyData:
"""Update data of a feature or symbol."""
fetch_kwargs = self.select_fetch_kwargs(key)
fetch_kwargs["start"] = self.select_last_index(key)
kwargs = merge_dicts(fetch_kwargs, kwargs)
if key_is_feature:
return self.fetch_feature(key, **kwargs)
return self.fetch_symbol(key, **kwargs)
def update_feature(self, feature: tp.Feature, **kwargs) -> tp.FeatureData:
"""Update data of a feature.
Uses `SyntheticData.update_key` with `key_is_feature=True`."""
return self.update_key(feature, key_is_feature=True, **kwargs)
def update_symbol(self, symbol: tp.Symbol, **kwargs) -> tp.SymbolData:
"""Update data for a symbol.
Uses `SyntheticData.update_key` with `key_is_feature=False`."""
return self.update_key(symbol, key_is_feature=False, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `TVData`."""
import datetime
import json
import math
import random
import re
import string
import time
import pandas as pd
import requests
from websocket import WebSocket
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts, Configured
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.template import CustomTemplate
__all__ = [
"TVClient",
"TVData",
]
SIGNIN_URL = "https://www.tradingview.com/accounts/signin/"
"""Sign-in URL."""
SEARCH_URL = (
"https://symbol-search.tradingview.com/symbol_search/v3/?"
"text={text}&"
"start={start}&"
"hl=1&"
"exchange={exchange}&"
"lang=en&"
"search_type=undefined&"
"domain=production&"
"sort_by_country=US"
)
"""Symbol search URL."""
SCAN_URL = "https://scanner.tradingview.com/{market}/scan"
"""Market scanner URL."""
ORIGIN_URL = "https://data.tradingview.com"
"""Origin URL."""
REFERER_URL = "https://www.tradingview.com"
"""Referer URL."""
WS_URL = "wss://data.tradingview.com/socket.io/websocket"
"""Websocket URL."""
PRO_WS_URL = "wss://prodata.tradingview.com/socket.io/websocket"
"""Websocket URL (Pro)."""
WS_TIMEOUT = 5
"""Websocket timeout."""
MARKET_LIST = [
"america",
"argentina",
"australia",
"austria",
"bahrain",
"bangladesh",
"belgium",
"brazil",
"canada",
"chile",
"china",
"colombia",
"cyprus",
"czech",
"denmark",
"egypt",
"estonia",
"euronext",
"finland",
"france",
"germany",
"greece",
"hongkong",
"hungary",
"iceland",
"india",
"indonesia",
"israel",
"italy",
"japan",
"kenya",
"korea",
"ksa",
"kuwait",
"latvia",
"lithuania",
"luxembourg",
"malaysia",
"mexico",
"morocco",
"netherlands",
"newzealand",
"nigeria",
"norway",
"pakistan",
"peru",
"philippines",
"poland",
"portugal",
"qatar",
"romania",
"rsa",
"russia",
"serbia",
"singapore",
"slovakia",
"spain",
"srilanka",
"sweden",
"switzerland",
"taiwan",
"thailand",
"tunisia",
"turkey",
"uae",
"uk",
"venezuela",
"vietnam",
]
"""List of markets supported by the market scanner (list may be incomplete)."""
FIELD_LIST = [
"name",
"description",
"logoid",
"update_mode",
"type",
"typespecs",
"close",
"pricescale",
"minmov",
"fractional",
"minmove2",
"currency",
"change",
"change_abs",
"Recommend.All",
"volume",
"Value.Traded",
"market_cap_basic",
"fundamental_currency_code",
"Perf.1Y.MarketCap",
"price_earnings_ttm",
"earnings_per_share_basic_ttm",
"number_of_employees_fy",
"sector",
"market",
]
"""List of fields supported by the market scanner (list may be incomplete)."""
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/125.0.0.0 Safari/537.36"
"""User agent."""
class TVClient(Configured):
"""Client for TradingView."""
def __init__(
self,
username: tp.Optional[str] = None,
password: tp.Optional[str] = None,
auth_token: tp.Optional[str] = None,
**kwargs,
) -> None:
"""Client for TradingView."""
Configured.__init__(
self,
username=username,
password=password,
auth_token=auth_token,
**kwargs,
)
if auth_token is None:
auth_token = self.auth(username, password)
elif username is not None or password is not None:
raise ValueError("Must provide either username and password, or auth_token")
self._auth_token = auth_token
self._ws = None
self._session = self.generate_session()
self._chart_session = self.generate_chart_session()
@property
def auth_token(self) -> str:
"""Authentication token."""
return self._auth_token
@property
def ws(self) -> WebSocket:
"""Instance of `websocket.Websocket`."""
return self._ws
@property
def session(self) -> str:
"""Session."""
return self._session
@property
def chart_session(self) -> str:
"""Chart session."""
return self._chart_session
@classmethod
def auth(
cls,
username: tp.Optional[str] = None,
password: tp.Optional[str] = None,
) -> str:
"""Authenticate."""
if username is not None and password is not None:
data = {"username": username, "password": password, "remember": "on"}
headers = {"Referer": REFERER_URL, "User-Agent": USER_AGENT}
response = requests.post(url=SIGNIN_URL, data=data, headers=headers)
response.raise_for_status()
json = response.json()
if "user" not in json or "auth_token" not in json["user"]:
raise ValueError(json)
return json["user"]["auth_token"]
if username is not None or password is not None:
raise ValueError("Must provide both username and password")
return "unauthorized_user_token"
@classmethod
def generate_session(cls) -> str:
"""Generate session."""
stringLength = 12
letters = string.ascii_lowercase
random_string = "".join(random.choice(letters) for _ in range(stringLength))
return "qs_" + random_string
@classmethod
def generate_chart_session(cls) -> str:
"""Generate chart session."""
stringLength = 12
letters = string.ascii_lowercase
random_string = "".join(random.choice(letters) for _ in range(stringLength))
return "cs_" + random_string
def create_connection(self, pro_data: bool = True) -> None:
"""Create a websocket connection."""
from websocket import create_connection
if pro_data:
self._ws = create_connection(
PRO_WS_URL,
headers=json.dumps({"Origin": ORIGIN_URL}),
timeout=WS_TIMEOUT,
)
else:
self._ws = create_connection(
WS_URL,
headers=json.dumps({"Origin": ORIGIN_URL}),
timeout=WS_TIMEOUT,
)
@classmethod
def filter_raw_message(cls, text) -> tp.Tuple[str, str]:
"""Filter raw message."""
found = re.search('"m":"(.+?)",', text).group(1)
found2 = re.search('"p":(.+?"}"])}', text).group(1)
return found, found2
@classmethod
def prepend_header(cls, st: str) -> str:
"""Prepend a header."""
return "~m~" + str(len(st)) + "~m~" + st
@classmethod
def construct_message(cls, func: str, param_list: tp.List[str]) -> str:
"""Construct a message."""
return json.dumps({"m": func, "p": param_list}, separators=(",", ":"))
def create_message(self, func: str, param_list: tp.List[str]) -> str:
"""Create a message."""
return self.prepend_header(self.construct_message(func, param_list))
def send_message(self, func: str, param_list: tp.List[str]) -> None:
"""Send a message."""
m = self.create_message(func, param_list)
self.ws.send(m)
@classmethod
def convert_raw_data(cls, raw_data: str, symbol: str) -> pd.DataFrame:
"""Process raw data into a DataFrame."""
search_result = re.search(r'"s":\[(.+?)\}\]', raw_data)
if search_result is None:
raise ValueError("Couldn't parse data returned by TradingView: {}".format(raw_data))
out = search_result.group(1)
x = out.split(',{"')
data = list()
volume_data = True
for xi in x:
xi = re.split(r"\[|:|,|\]", xi)
ts = datetime.datetime.utcfromtimestamp(float(xi[4]))
row = [ts]
for i in range(5, 10):
# skip converting volume data if does not exists
if not volume_data and i == 9:
row.append(0.0)
continue
try:
row.append(float(xi[i]))
except ValueError:
volume_data = False
row.append(0.0)
data.append(row)
data = pd.DataFrame(data, columns=["datetime", "open", "high", "low", "close", "volume"])
data = data.set_index("datetime")
data.insert(0, "symbol", value=symbol)
return data
@classmethod
def format_symbol(cls, symbol: str, exchange: str, fut_contract: tp.Optional[int] = None) -> str:
"""Format a symbol."""
if ":" in symbol:
pass
elif fut_contract is None:
symbol = f"{exchange}:{symbol}"
elif isinstance(fut_contract, int):
symbol = f"{exchange}:{symbol}{fut_contract}!"
else:
raise ValueError(f"Invalid fut_contract: '{fut_contract}'")
return symbol
def get_hist(
self,
symbol: str,
exchange: str = "NSE",
interval: str = "1D",
fut_contract: tp.Optional[int] = None,
adjustment: str = "splits",
extended_session: bool = False,
pro_data: bool = True,
limit: int = 20000,
return_raw: bool = False,
) -> tp.Union[str, tp.Frame]:
"""Get historical data."""
symbol = self.format_symbol(symbol=symbol, exchange=exchange, fut_contract=fut_contract)
backadjustment = False
if symbol.endswith("!A"):
backadjustment = True
symbol = symbol.replace("!A", "!")
self.create_connection(pro_data=pro_data)
self.send_message("set_auth_token", [self.auth_token])
self.send_message("chart_create_session", [self.chart_session, ""])
self.send_message("quote_create_session", [self.session])
self.send_message(
"quote_set_fields",
[
self.session,
"ch",
"chp",
"current_session",
"description",
"local_description",
"language",
"exchange",
"fractional",
"is_tradable",
"lp",
"lp_time",
"minmov",
"minmove2",
"original_name",
"pricescale",
"pro_name",
"short_name",
"type",
"update_mode",
"volume",
"currency_code",
"rchp",
"rtc",
],
)
self.send_message("quote_add_symbols", [self.session, symbol, {"flags": ["force_permission"]}])
self.send_message("quote_fast_symbols", [self.session, symbol])
self.send_message(
"resolve_symbol",
[
self.chart_session,
"symbol_1",
'={"symbol":"'
+ symbol
+ '","adjustment":"'
+ adjustment
+ ("" if not backadjustment else '","backadjustment":"default')
+ '","session":'
+ ('"regular"' if not extended_session else '"extended"')
+ "}",
],
)
self.send_message("create_series", [self.chart_session, "s1", "s1", "symbol_1", interval, limit])
self.send_message("switch_timezone", [self.chart_session, "exchange"])
raw_data = ""
while True:
try:
result = self.ws.recv()
raw_data += result + "\n"
except Exception as e:
break
if "series_completed" in result:
break
if return_raw:
return raw_data
return self.convert_raw_data(raw_data, symbol)
@classmethod
def search_symbol(
cls,
text: tp.Optional[str] = None,
exchange: tp.Optional[str] = None,
pages: tp.Optional[int] = None,
delay: tp.Optional[int] = None,
retries: int = 3,
show_progress: bool = True,
pbar_kwargs: tp.KwargsLike = None,
) -> tp.List[dict]:
"""Search for a symbol."""
if text is None:
text = ""
if exchange is None:
exchange = ""
if pbar_kwargs is None:
pbar_kwargs = {}
symbols_list = []
pbar = None
pages_fetched = 0
while True:
for i in range(retries):
try:
url = SEARCH_URL.format(text=text, exchange=exchange.upper(), start=len(symbols_list))
headers = {"Referer": REFERER_URL, "Origin": ORIGIN_URL, "User-Agent": USER_AGENT}
resp = requests.get(url, headers=headers)
symbols_data = json.loads(resp.text.replace("", "").replace("", ""))
break
except json.JSONDecodeError as e:
if i == retries - 1:
raise e
if delay is not None:
time.sleep(delay)
symbols_remaining = symbols_data.get("symbols_remaining", 0)
new_symbols = symbols_data.get("symbols", [])
symbols_list.extend(new_symbols)
if pages is None and symbols_remaining > 0:
show_pbar = True
elif pages is not None and pages > 1:
show_pbar = True
else:
show_pbar = False
if pbar is None and show_pbar:
if pages is not None:
total = pages
else:
total = math.ceil((len(new_symbols) + symbols_remaining) / len(new_symbols))
pbar = ProgressBar(
total=total,
show_progress=show_progress,
**pbar_kwargs,
)
pbar.enter()
if pbar is not None:
max_symbols = len(symbols_list) + symbols_remaining
if pages is not None:
max_symbols = min(max_symbols, pages * len(new_symbols))
pbar.set_description(dict(symbols="%d/%d" % (len(symbols_list), max_symbols)))
pbar.update()
if symbols_remaining == 0:
break
pages_fetched += 1
if pages is not None and pages_fetched >= pages:
break
if delay is not None:
time.sleep(delay)
if pbar is not None:
pbar.exit()
return symbols_list
@classmethod
def scan_symbols(cls, market: tp.Optional[str] = None, **kwargs) -> tp.List[dict]:
"""Scan symbols in a region/market."""
if market is None:
market = "global"
url = SCAN_URL.format(market=market.lower())
headers = {"Referer": REFERER_URL, "Origin": ORIGIN_URL, "User-Agent": USER_AGENT}
resp = requests.post(url, json.dumps(kwargs), headers=headers)
symbols_list = json.loads(resp.text)["data"]
return symbols_list
TVDataT = tp.TypeVar("TVDataT", bound="TVData")
class TVData(RemoteData):
"""Data class for fetching from TradingView.
See `TVData.fetch_symbol` for arguments.
!!! note
If you're getting the error "Please confirm that you are not a robot by clicking the captcha box."
when attempting to authenticate, use `auth_token` instead of `username` and `password`.
To get the authentication token, go to TradingView, log in, visit any chart, open your console's
developer tools, and search for "auth_token".
Usage:
* Set up the credentials globally (optional):
```pycon
>>> from vectorbtpro import *
>>> vbt.TVData.set_custom_settings(
... client_config=dict(
... username="YOUR_USERNAME",
... password="YOUR_PASSWORD",
... auth_token="YOUR_AUTH_TOKEN", # optional, instead of username and password
... )
... )
```
* Pull data:
```pycon
>>> data = vbt.TVData.pull(
... "NASDAQ:AAPL",
... timeframe="1 hour"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.tv")
@classmethod
def list_symbols(
cls,
*,
exchange_pattern: tp.Optional[str] = None,
symbol_pattern: tp.Optional[str] = None,
use_regex: bool = False,
sort: bool = True,
client: tp.Optional[TVClient] = None,
client_config: tp.DictLike = None,
text: tp.Optional[str] = None,
exchange: tp.Optional[str] = None,
pages: tp.Optional[int] = None,
delay: tp.Optional[int] = None,
retries: tp.Optional[int] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
market: tp.Optional[str] = None,
markets: tp.Optional[tp.List[str]] = None,
fields: tp.Optional[tp.MaybeIterable[str]] = None,
filter_by: tp.Union[None, tp.Callable, CustomTemplate] = None,
groups: tp.Optional[tp.MaybeIterable[tp.Dict[str, tp.MaybeIterable[str]]]] = None,
template_context: tp.KwargsLike = None,
return_field_data: bool = False,
**scanner_kwargs,
) -> tp.Union[tp.List[str], tp.List[tp.Kwargs]]:
"""List all symbols.
Uses symbol search when either `text` or `exchange` is provided (returns a subset of symbols).
Otherwise, uses the market scanner (returns all symbols, big payload).
When using the market scanner, use `market` to filter by one or multiple markets. For the list
of available markets, see `MARKET_LIST`.
Use `fields` to make the market scanner return additional information that can be used for
filtering with `filter_by`. Such information is passed to the function as a dictionary where
fields are keys. The function can also be a template that can use the same information provided
as a context, or a list of values that should be matched against the values corresponding to their fields.
For the list of available fields, see `FIELD_LIST`. Argument `fields` can also be "all".
Set `return_field_data` to True to return a list with (filtered) field data.
Use `groups` to provide a single dictionary or a list of dictionaries with groups.
Each dictionary can be provided either in a compressed format, such as `dict(index=index)`,
or in a full format, such as `dict(type="index", values=[index])`.
Keyword arguments `scanner_kwargs` are encoded and passed directly to the market scanner.
Uses `vectorbtpro.data.custom.custom.CustomData.key_match` to check each exchange against
`exchange_pattern` and each symbol against `symbol_pattern`.
Usage:
* List all symbols (market scanner):
```pycon
>>> from vectorbtpro import *
>>> vbt.TVData.list_symbols()
```
* Search for symbols matching a pattern (market scanner, client-side):
```pycon
>>> vbt.TVData.list_symbols(symbol_pattern="BTC*")
```
* Search for exchanges matching a pattern (market scanner, client-side):
```pycon
>>> vbt.TVData.list_symbols(exchange_pattern="NASDAQ")
```
* Search for symbols containing a text (symbol search, server-side):
```pycon
>>> vbt.TVData.list_symbols(text="BTC")
```
* List symbols from an exchange (symbol search):
```pycon
>>> vbt.TVData.list_symbols(exchange="NASDAQ")
```
* List symbols from a market (market scanner):
```pycon
>>> vbt.TVData.list_symbols(market="poland")
```
* List index constituents (market scanner):
```pycon
>>> vbt.TVData.list_symbols(groups=dict(index="NASDAQ:NDX"))
```
* Filter symbols by fields using a function (market scanner):
```pycon
>>> vbt.TVData.list_symbols(
... market="america",
... fields=["sector"],
... filter_by=lambda context: context["sector"] == "Technology Services"
... )
```
* Filter symbols by fields using a template (market scanner):
```pycon
>>> vbt.TVData.list_symbols(
... market="america",
... fields=["sector"],
... filter_by=vbt.RepEval("sector == 'Technology Services'")
... )
```
"""
pages = cls.resolve_custom_setting(pages, "pages", sub_path="search", sub_path_only=True)
delay = cls.resolve_custom_setting(delay, "delay", sub_path="search", sub_path_only=True)
retries = cls.resolve_custom_setting(retries, "retries", sub_path="search", sub_path_only=True)
show_progress = cls.resolve_custom_setting(
show_progress, "show_progress", sub_path="search", sub_path_only=True
)
pbar_kwargs = cls.resolve_custom_setting(
pbar_kwargs, "pbar_kwargs", sub_path="search", sub_path_only=True, merge=True
)
markets = cls.resolve_custom_setting(markets, "markets", sub_path="scanner", sub_path_only=True)
fields = cls.resolve_custom_setting(fields, "fields", sub_path="scanner", sub_path_only=True)
filter_by = cls.resolve_custom_setting(filter_by, "filter_by", sub_path="scanner", sub_path_only=True)
groups = cls.resolve_custom_setting(groups, "groups", sub_path="scanner", sub_path_only=True)
template_context = cls.resolve_custom_setting(
template_context, "template_context", sub_path="scanner", sub_path_only=True, merge=True
)
scanner_kwargs = cls.resolve_custom_setting(
scanner_kwargs, "scanner_kwargs", sub_path="scanner", sub_path_only=True, merge=True
)
if market is None and text is None and exchange is None:
market = "global"
if market is not None and (text is not None or exchange is not None):
raise ValueError("Please provide either market, or text and/or exchange")
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
if market is None:
data = client.search_symbol(
text=text,
exchange=exchange,
pages=pages,
delay=delay,
retries=retries,
show_progress=show_progress,
pbar_kwargs=pbar_kwargs,
)
all_symbols = map(lambda x: x["exchange"] + ":" + x["symbol"], data)
return_field_data = False
else:
if markets is not None:
scanner_kwargs["markets"] = markets
if fields is not None:
if "columns" in scanner_kwargs:
raise ValueError("Use fields instead of columns")
if isinstance(fields, str):
if fields.lower() == "all":
fields = FIELD_LIST
else:
fields = [fields]
scanner_kwargs["columns"] = fields
if groups is not None:
if isinstance(groups, dict):
groups = [groups]
new_groups = []
for group in groups:
if "type" in group:
new_groups.append(group)
else:
for k, v in group.items():
if isinstance(v, str):
v = [v]
new_groups.append(dict(type=k, values=v))
groups = new_groups
if "symbols" in scanner_kwargs:
scanner_kwargs["symbols"] = dict(scanner_kwargs["symbols"])
else:
scanner_kwargs["symbols"] = dict()
scanner_kwargs["symbols"]["groups"] = groups
if filter_by is not None:
if isinstance(filter_by, str):
filter_by = [filter_by]
data = client.scan_symbols(market.lower(), **scanner_kwargs)
if data is None:
raise ValueError("No data returned by TradingView")
all_symbols = []
for item in data:
if fields is not None:
item = {"symbol": item["s"], **dict(zip(fields, item["d"]))}
else:
item = {"symbol": item["s"]}
if filter_by is not None:
if fields is not None:
context = merge_dicts(item, template_context)
else:
raise ValueError("Must provide fields for filter_by")
if isinstance(filter_by, CustomTemplate):
if not filter_by.substitute(context, eval_id="filter_by"):
continue
elif callable(filter_by):
if not filter_by(context):
continue
else:
if len(fields) != len(filter_by):
raise ValueError("Fields and filter_by must have the same number of values")
conditions_met = True
for i in range(len(fields)):
if context[fields[i]] != filter_by[i]:
conditions_met = False
break
if not conditions_met:
continue
if return_field_data:
all_symbols.append(item)
else:
all_symbols.append(item["symbol"])
found_symbols = []
for symbol in all_symbols:
if return_field_data:
item = symbol
symbol = item["symbol"]
else:
item = symbol
if '"symbol"' in symbol:
continue
if exchange_pattern is not None:
if not cls.key_match(symbol.split(":")[0], exchange_pattern, use_regex=use_regex):
continue
if symbol_pattern is not None:
if not cls.key_match(symbol.split(":")[1], symbol_pattern, use_regex=use_regex):
continue
found_symbols.append(item)
if sort:
if return_field_data:
return sorted(found_symbols, key=lambda x: x["symbol"])
return sorted(dict.fromkeys(found_symbols))
if return_field_data:
return found_symbols
return list(dict.fromkeys(found_symbols))
@classmethod
def resolve_client(cls, client: tp.Optional[TVClient] = None, **client_config) -> TVClient:
"""Resolve the client.
If provided, must be of the type `TVClient`. Otherwise, will be created using `client_config`."""
client = cls.resolve_custom_setting(client, "client")
if client_config is None:
client_config = {}
has_client_config = len(client_config) > 0
client_config = cls.resolve_custom_setting(client_config, "client_config", merge=True)
if client is None:
client = TVClient(**client_config)
elif has_client_config:
raise ValueError("Cannot apply client_config to already initialized client")
return client
@classmethod
def fetch_symbol(
cls,
symbol: str,
client: tp.Optional[TVClient] = None,
client_config: tp.KwargsLike = None,
exchange: tp.Optional[str] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
fut_contract: tp.Optional[int] = None,
adjustment: tp.Optional[str] = None,
extended_session: tp.Optional[bool] = None,
pro_data: tp.Optional[bool] = None,
limit: tp.Optional[int] = None,
delay: tp.Optional[int] = None,
retries: tp.Optional[int] = None,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from TradingView.
Args:
symbol (str): Symbol.
Symbol must be in the `EXCHANGE:SYMBOL` format if `exchange` is None.
client (TVClient): Client.
See `TVData.resolve_client`.
client_config (dict): Client config.
See `TVData.resolve_client`.
exchange (str): Exchange.
Can be omitted if already provided via `symbol`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
fut_contract (int): None for cash, 1 for continuous current contract in front,
2 for continuous next contract in front.
adjustment (str): Adjustment.
Either "splits" (default) or "dividends".
extended_session (bool): Regular session if False, extended session if True.
pro_data (bool): Whether to use pro data.
limit (int): The maximum number of returned items.
delay (float): Time to sleep after each request (in seconds).
retries (int): The number of retries on failure to fetch data.
For defaults, see `custom.tv` in `vectorbtpro._settings.data`.
"""
if client_config is None:
client_config = {}
client = cls.resolve_client(client=client, **client_config)
exchange = cls.resolve_custom_setting(exchange, "exchange")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
fut_contract = cls.resolve_custom_setting(fut_contract, "fut_contract")
adjustment = cls.resolve_custom_setting(adjustment, "adjustment")
extended_session = cls.resolve_custom_setting(extended_session, "extended_session")
pro_data = cls.resolve_custom_setting(pro_data, "pro_data")
limit = cls.resolve_custom_setting(limit, "limit")
delay = cls.resolve_custom_setting(delay, "delay")
retries = cls.resolve_custom_setting(retries, "retries")
freq = timeframe
if not isinstance(timeframe, str):
raise ValueError(f"Invalid timeframe: '{timeframe}'")
split = dt.split_freq_str(timeframe)
if split is None:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
multiplier, unit = split
if unit == "s":
interval = f"{str(multiplier)}S"
elif unit == "m":
interval = str(multiplier)
elif unit == "h":
interval = f"{str(multiplier)}H"
elif unit == "D":
interval = f"{str(multiplier)}D"
elif unit == "W":
interval = f"{str(multiplier)}W"
elif unit == "M":
interval = f"{str(multiplier)}M"
else:
raise ValueError(f"Invalid timeframe: '{timeframe}'")
for i in range(retries):
try:
df = client.get_hist(
symbol=symbol,
exchange=exchange,
interval=interval,
fut_contract=fut_contract,
adjustment=adjustment,
extended_session=extended_session,
pro_data=pro_data,
limit=limit,
)
break
except Exception as e:
if i == retries - 1:
raise e
if delay is not None:
time.sleep(delay)
df.rename(
columns={
"symbol": "Symbol",
"open": "Open",
"high": "High",
"low": "Low",
"close": "Close",
"volume": "Volume",
},
inplace=True,
)
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize("utc")
if "Symbol" in df:
del df["Symbol"]
if "Open" in df.columns:
df["Open"] = df["Open"].astype(float)
if "High" in df.columns:
df["High"] = df["High"].astype(float)
if "Low" in df.columns:
df["Low"] = df["Low"].astype(float)
if "Close" in df.columns:
df["Close"] = df["Close"].astype(float)
if "Volume" in df.columns:
df["Volume"] = df["Volume"].astype(float)
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `YFData`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.data.custom.remote import RemoteData
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.utils import datetime_ as dt
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig
from vectorbtpro.utils.parsing import get_func_kwargs
__all__ = [
"YFData",
]
__pdoc__ = {}
class YFData(RemoteData):
"""Data class for fetching from Yahoo Finance.
See https://github.com/ranaroussi/yfinance for API.
See `YFData.fetch_symbol` for arguments.
Usage:
```pycon
>>> from vectorbtpro import *
>>> data = vbt.YFData.pull(
... "BTC-USD",
... start="2020-01-01",
... end="2021-01-01",
... timeframe="1 day"
... )
```
"""
_settings_path: tp.SettingsPath = dict(custom="data.custom.yf")
_feature_config: tp.ClassVar[Config] = HybridConfig(
{
"Dividends": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
"Stock Splits": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.nonzero_prod_reduce_nb,
)
),
"Capital Gains": dict(
resample_func=lambda self, obj, resampler: obj.vbt.resample_apply(
resampler,
generic_nb.sum_reduce_nb,
)
),
}
)
@property
def feature_config(self) -> Config:
return self._feature_config
@classmethod
def fetch_symbol(
cls,
symbol: str,
period: tp.Optional[str] = None,
start: tp.Optional[tp.DatetimeLike] = None,
end: tp.Optional[tp.DatetimeLike] = None,
timeframe: tp.Optional[str] = None,
tz: tp.TimezoneLike = None,
**history_kwargs,
) -> tp.SymbolData:
"""Override `vectorbtpro.data.base.Data.fetch_symbol` to fetch a symbol from Yahoo Finance.
Args:
symbol (str): Symbol.
period (str): Period.
start (any): Start datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
end (any): End datetime.
See `vectorbtpro.utils.datetime_.to_tzaware_datetime`.
timeframe (str): Timeframe.
Allows human-readable strings such as "15 minutes".
tz (any): Timezone.
See `vectorbtpro.utils.datetime_.to_timezone`.
**history_kwargs: Keyword arguments passed to `yfinance.base.TickerBase.history`.
For defaults, see `custom.yf` in `vectorbtpro._settings.data`.
!!! warning
Data coming from Yahoo is not the most stable data out there. Yahoo may manipulate data
how they want, add noise, return missing data points (see volume in the example below), etc.
It's only used in vectorbt for demonstration purposes.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("yfinance")
import yfinance as yf
period = cls.resolve_custom_setting(period, "period")
start = cls.resolve_custom_setting(start, "start")
end = cls.resolve_custom_setting(end, "end")
timeframe = cls.resolve_custom_setting(timeframe, "timeframe")
tz = cls.resolve_custom_setting(tz, "tz")
history_kwargs = cls.resolve_custom_setting(history_kwargs, "history_kwargs", merge=True)
ticker = yf.Ticker(symbol)
def_history_kwargs = get_func_kwargs(yf.Tickers.history)
ticker_tz = ticker._get_ticker_tz(
history_kwargs.get("proxy", def_history_kwargs["proxy"]),
history_kwargs.get("timeout", def_history_kwargs["timeout"]),
)
if tz is None:
tz = ticker_tz
if start is not None:
start = dt.to_tzaware_datetime(start, naive_tz=tz, tz=ticker_tz)
if end is not None:
end = dt.to_tzaware_datetime(end, naive_tz=tz, tz=ticker_tz)
freq = timeframe
split = dt.split_freq_str(timeframe)
if split is not None:
multiplier, unit = split
if unit == "D":
unit = "d"
elif unit == "W":
unit = "wk"
elif unit == "M":
unit = "mo"
timeframe = str(multiplier) + unit
df = ticker.history(period=period, start=start, end=end, interval=timeframe, **history_kwargs)
if isinstance(df.index, pd.DatetimeIndex) and df.index.tz is None:
df = df.tz_localize(ticker_tz)
if not df.empty:
if start is not None:
if df.index[0] < start:
df = df[df.index >= start]
if end is not None:
if df.index[-1] >= end:
df = df[df.index < end]
return df, dict(tz=tz, freq=freq)
def update_symbol(self, symbol: str, **kwargs) -> tp.SymbolData:
fetch_kwargs = self.select_fetch_kwargs(symbol)
fetch_kwargs["start"] = self.select_last_index(symbol)
kwargs = merge_dicts(fetch_kwargs, kwargs)
return self.fetch_symbol(symbol, **kwargs)
YFData.override_feature_config_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for working with data sources."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.data.base import *
from vectorbtpro.data.custom import *
from vectorbtpro.data.decorators import *
from vectorbtpro.data.nb import *
from vectorbtpro.data.saver import *
from vectorbtpro.data.updater import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with data sources."""
import inspect
import string
import traceback
from pathlib import Path
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexes import stack_indexes
from vectorbtpro.base.merging import column_stack_arrays, is_merge_func_from_config
from vectorbtpro.base.reshaping import to_any_array, to_pd_array, to_1d_array, to_2d_array, broadcast_to
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.data.decorators import attach_symbol_dict_methods
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.generic.drawdowns import Drawdowns
from vectorbtpro.ohlcv.nb import mirror_ohlc_nb
from vectorbtpro.ohlcv.enums import PriceFeature
from vectorbtpro.returns.accessors import ReturnsAccessor
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.attr_ import get_dict_attr
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig, copy_dict
from vectorbtpro.utils.decorators import cached_property, hybrid_method
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.execution import Task, NoResult, NoResultsException, filter_out_no_results, execute
from vectorbtpro.utils.merging import MergeFunc
from vectorbtpro.utils.parsing import get_func_arg_names, extend_args
from vectorbtpro.utils.path_ import check_mkdir
from vectorbtpro.utils.pickling import pdict, RecState
from vectorbtpro.utils.template import Rep, RepEval, CustomTemplate, substitute_templates
from vectorbtpro.utils.warnings_ import warn
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
try:
if not tp.TYPE_CHECKING:
raise ImportError
from sqlalchemy import Engine as EngineT
except ImportError:
EngineT = "Engine"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from duckdb import DuckDBPyConnection as DuckDBPyConnectionT
except ImportError:
DuckDBPyConnectionT = "DuckDBPyConnection"
__all__ = [
"key_dict",
"feature_dict",
"symbol_dict",
"run_func_dict",
"run_arg_dict",
"Data",
]
__pdoc__ = {}
class key_dict(pdict):
"""Dict that contains features or symbols as keys."""
pass
class feature_dict(key_dict):
"""Dict that contains features as keys."""
pass
class symbol_dict(key_dict):
"""Dict that contains symbols as keys."""
pass
class run_func_dict(pdict):
"""Dict that contains function names as keys for `Data.run`."""
pass
class run_arg_dict(pdict):
"""Dict that contains argument names as keys for `Data.run`."""
pass
BaseDataMixinT = tp.TypeVar("BaseDataMixinT", bound="BaseDataMixin")
class BaseDataMixin(Base):
"""Base mixin class for working with data."""
@property
def feature_wrapper(self) -> ArrayWrapper:
"""Column wrapper."""
raise NotImplementedError
@property
def symbol_wrapper(self) -> ArrayWrapper:
"""Symbol wrapper."""
raise NotImplementedError
@property
def features(self) -> tp.List[tp.Feature]:
"""List of features."""
return self.feature_wrapper.columns.tolist()
@property
def symbols(self) -> tp.List[tp.Symbol]:
"""List of symbols."""
return self.symbol_wrapper.columns.tolist()
@classmethod
def has_multiple_keys(cls, keys: tp.MaybeKeys) -> bool:
"""Check whether there are one or multiple keys."""
if checks.is_hashable(keys):
return False
elif checks.is_sequence(keys):
return True
raise TypeError("Keys must be either a hashable or a sequence of hashable")
@classmethod
def prepare_key(cls, key: tp.Key) -> tp.Key:
"""Prepare a key."""
if isinstance(key, tuple):
return tuple([cls.prepare_key(k) for k in key])
if isinstance(key, str):
return key.lower().strip().replace(" ", "_")
return key
def get_feature_idx(self, feature: tp.Feature, raise_error: bool = False) -> int:
"""Return the index of a feature."""
# shortcut
columns = self.feature_wrapper.columns
if not columns.has_duplicates:
if feature in columns:
return columns.get_loc(feature)
feature = self.prepare_key(feature)
found_indices = []
for i, c in enumerate(self.features):
c = self.prepare_key(c)
if feature == c:
found_indices.append(i)
if len(found_indices) == 0:
if raise_error:
raise ValueError(f"No features match the feature '{str(feature)}'")
return -1
if len(found_indices) == 1:
return found_indices[0]
raise ValueError(f"Multiple features match the feature '{str(feature)}'")
def get_symbol_idx(self, symbol: tp.Symbol, raise_error: bool = False) -> int:
"""Return the index of a symbol."""
# shortcut
columns = self.symbol_wrapper.columns
if not columns.has_duplicates:
if symbol in columns:
return columns.get_loc(symbol)
symbol = self.prepare_key(symbol)
found_indices = []
for i, c in enumerate(self.symbols):
c = self.prepare_key(c)
if symbol == c:
found_indices.append(i)
if len(found_indices) == 0:
if raise_error:
raise ValueError(f"No symbols match the symbol '{str(symbol)}'")
return -1
if len(found_indices) == 1:
return found_indices[0]
raise ValueError(f"Multiple symbols match the symbol '{str(symbol)}'")
def select_feature_idxs(self: BaseDataMixinT, idxs: tp.MaybeSequence[int], **kwargs) -> BaseDataMixinT:
"""Select one or more features by index.
Returns a new instance."""
raise NotImplementedError
def select_symbol_idxs(self: BaseDataMixinT, idxs: tp.MaybeSequence[int], **kwargs) -> BaseDataMixinT:
"""Select one or more symbols by index.
Returns a new instance."""
raise NotImplementedError
def select_features(self: BaseDataMixinT, features: tp.MaybeFeatures, **kwargs) -> BaseDataMixinT:
"""Select one or more features.
Returns a new instance."""
if self.has_multiple_keys(features):
feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features]
else:
feature_idxs = self.get_feature_idx(features, raise_error=True)
return self.select_feature_idxs(feature_idxs, **kwargs)
def select_symbols(self: BaseDataMixinT, symbols: tp.MaybeSymbols, **kwargs) -> BaseDataMixinT:
"""Select one or more symbols.
Returns a new instance."""
if self.has_multiple_keys(symbols):
symbol_idxs = [self.get_symbol_idx(k, raise_error=True) for k in symbols]
else:
symbol_idxs = self.get_symbol_idx(symbols, raise_error=True)
return self.select_symbol_idxs(symbol_idxs, **kwargs)
def get(
self,
features: tp.Optional[tp.MaybeFeatures] = None,
symbols: tp.Optional[tp.MaybeSymbols] = None,
feature: tp.Optional[tp.Feature] = None,
symbol: tp.Optional[tp.Symbol] = None,
**kwargs,
) -> tp.MaybeTuple[tp.SeriesFrame]:
"""Get one or more features of one or more symbols of data."""
raise NotImplementedError
def has_feature(self, feature: tp.Feature) -> bool:
"""Whether feature exists."""
feature_idx = self.get_feature_idx(feature, raise_error=False)
return feature_idx != -1
def has_symbol(self, symbol: tp.Symbol) -> bool:
"""Whether symbol exists."""
symbol_idx = self.get_symbol_idx(symbol, raise_error=False)
return symbol_idx != -1
def assert_has_feature(self, feature: tp.Feature) -> None:
"""Assert that feature exists."""
self.get_feature_idx(feature, raise_error=True)
def assert_has_symbol(self, symbol: tp.Symbol) -> None:
"""Assert that symbol exists."""
self.get_symbol_idx(symbol, raise_error=True)
def get_feature(
self,
feature: tp.Union[int, tp.Feature],
raise_error: bool = False,
) -> tp.Optional[tp.SeriesFrame]:
"""Get feature that match a feature index or label."""
if checks.is_int(feature):
return self.get(features=self.features[feature])
feature_idx = self.get_feature_idx(feature, raise_error=raise_error)
if feature_idx == -1:
return None
return self.get(features=self.features[feature_idx])
def get_symbol(
self,
symbol: tp.Union[int, tp.Symbol],
raise_error: bool = False,
) -> tp.Optional[tp.SeriesFrame]:
"""Get symbol that match a symbol index or label."""
if checks.is_int(symbol):
return self.get(symbol=self.symbols[symbol])
symbol_idx = self.get_symbol_idx(symbol, raise_error=raise_error)
if symbol_idx == -1:
return None
return self.get(symbol=self.symbols[symbol_idx])
OHLCDataMixinT = tp.TypeVar("OHLCDataMixinT", bound="OHLCDataMixin")
class OHLCDataMixin(BaseDataMixin):
"""Mixin class for working with OHLC data."""
@property
def open(self) -> tp.Optional[tp.SeriesFrame]:
"""Open."""
return self.get_feature("Open")
@property
def high(self) -> tp.Optional[tp.SeriesFrame]:
"""High."""
return self.get_feature("High")
@property
def low(self) -> tp.Optional[tp.SeriesFrame]:
"""Low."""
return self.get_feature("Low")
@property
def close(self) -> tp.Optional[tp.SeriesFrame]:
"""Close."""
return self.get_feature("Close")
@property
def volume(self) -> tp.Optional[tp.SeriesFrame]:
"""Volume."""
return self.get_feature("Volume")
@property
def trade_count(self) -> tp.Optional[tp.SeriesFrame]:
"""Trade count."""
return self.get_feature("Trade count")
@property
def vwap(self) -> tp.Optional[tp.SeriesFrame]:
"""VWAP."""
return self.get_feature("VWAP")
@property
def hlc3(self) -> tp.SeriesFrame:
"""HLC/3."""
high = self.get_feature("High", raise_error=True)
low = self.get_feature("Low", raise_error=True)
close = self.get_feature("Close", raise_error=True)
return (high + low + close) / 3
@property
def ohlc4(self) -> tp.SeriesFrame:
"""OHLC/4."""
open = self.get_feature("Open", raise_error=True)
high = self.get_feature("High", raise_error=True)
low = self.get_feature("Low", raise_error=True)
close = self.get_feature("Close", raise_error=True)
return (open + high + low + close) / 4
@property
def has_any_ohlc(self) -> bool:
"""Whether the instance has any of the OHLC features."""
return (
self.has_feature("Open") or self.has_feature("High") or self.has_feature("Low") or self.has_feature("Close")
)
@property
def has_ohlc(self) -> bool:
"""Whether the instance has all the OHLC features."""
return (
self.has_feature("Open")
and self.has_feature("High")
and self.has_feature("Low")
and self.has_feature("Close")
)
@property
def has_any_ohlcv(self) -> bool:
"""Whether the instance has any of the OHLCV features."""
return self.has_any_ohlc or self.has_feature("Volume")
@property
def has_ohlcv(self) -> bool:
"""Whether the instance has all the OHLCV features."""
return self.has_ohlc and self.has_feature("Volume")
@property
def ohlc(self: OHLCDataMixinT) -> OHLCDataMixinT:
"""Return a `OHLCDataMixin` instance with the OHLC features only."""
open_idx = self.get_feature_idx("Open", raise_error=True)
high_idx = self.get_feature_idx("High", raise_error=True)
low_idx = self.get_feature_idx("Low", raise_error=True)
close_idx = self.get_feature_idx("Close", raise_error=True)
return self.select_feature_idxs([open_idx, high_idx, low_idx, close_idx])
@property
def ohlcv(self: OHLCDataMixinT) -> OHLCDataMixinT:
"""Return a `OHLCDataMixin` instance with the OHLCV features only."""
open_idx = self.get_feature_idx("Open", raise_error=True)
high_idx = self.get_feature_idx("High", raise_error=True)
low_idx = self.get_feature_idx("Low", raise_error=True)
close_idx = self.get_feature_idx("Close", raise_error=True)
volume_idx = self.get_feature_idx("Volume", raise_error=True)
return self.select_feature_idxs([open_idx, high_idx, low_idx, close_idx, volume_idx])
def get_returns_acc(self, **kwargs) -> ReturnsAccessor:
"""Return accessor of type `vectorbtpro.returns.accessors.ReturnsAccessor`."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=False,
**kwargs,
)
@property
def returns_acc(self) -> ReturnsAccessor:
"""`OHLCDataMixin.get_returns_acc` with default arguments."""
return self.get_returns_acc()
def get_returns(self, **kwargs) -> tp.SeriesFrame:
"""Returns."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=True,
**kwargs,
)
@property
def returns(self) -> tp.SeriesFrame:
"""`OHLCDataMixin.get_returns` with default arguments."""
return self.get_returns()
def get_log_returns(self, **kwargs) -> tp.SeriesFrame:
"""Log returns."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=True,
log_returns=True,
**kwargs,
)
@property
def log_returns(self) -> tp.SeriesFrame:
"""`OHLCDataMixin.get_log_returns` with default arguments."""
return self.get_log_returns()
def get_daily_returns(self, **kwargs) -> tp.SeriesFrame:
"""Daily returns."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=False,
**kwargs,
).daily()
@property
def daily_returns(self) -> tp.SeriesFrame:
"""`OHLCDataMixin.get_daily_returns` with default arguments."""
return self.get_daily_returns()
def get_daily_log_returns(self, **kwargs) -> tp.SeriesFrame:
"""Daily log returns."""
return ReturnsAccessor.from_value(
self.get_feature("Close", raise_error=True),
wrapper=self.symbol_wrapper,
return_values=False,
log_returns=True,
**kwargs,
).daily()
@property
def daily_log_returns(self) -> tp.SeriesFrame:
"""`OHLCDataMixin.get_daily_log_returns` with default arguments."""
return self.get_daily_log_returns()
def get_drawdowns(self, **kwargs) -> Drawdowns:
"""Generate drawdown records.
See `vectorbtpro.generic.drawdowns.Drawdowns`."""
return Drawdowns.from_price(
open=self.get_feature("Open", raise_error=True),
high=self.get_feature("High", raise_error=True),
low=self.get_feature("Low", raise_error=True),
close=self.get_feature("Close", raise_error=True),
**kwargs,
)
@property
def drawdowns(self) -> Drawdowns:
"""`OHLCDataMixin.get_drawdowns` with default arguments."""
return self.get_drawdowns()
DataT = tp.TypeVar("DataT", bound="Data")
class MetaData(type(Analyzable)):
"""Metaclass for `Data`."""
@property
def feature_config(cls) -> Config:
"""Feature config."""
return cls._feature_config
@attach_symbol_dict_methods
class Data(Analyzable, OHLCDataMixin, metaclass=MetaData):
"""Class that downloads, updates, and manages data coming from a data source."""
_settings_path: tp.SettingsPath = dict(base="data")
_writeable_attrs: tp.WriteableAttrs = {"_feature_config"}
_feature_config: tp.ClassVar[Config] = HybridConfig()
_key_dict_attrs = [
"fetch_kwargs",
"returned_kwargs",
"last_index",
"delisted",
"classes",
]
"""Attributes that subclass either `feature_dict` or `symbol_dict`."""
_data_dict_type_attrs = [
"classes",
]
"""Attributes that subclass the data dict type."""
_updatable_attrs = [
"fetch_kwargs",
"returned_kwargs",
"classes",
]
"""Attributes that have a method for updating."""
@property
def feature_config(self) -> Config:
"""Column config of `${cls_name}`.
```python
${feature_config}
```
Returns `${cls_name}._feature_config`, which gets (hybrid-) copied upon creation of each instance.
Thus, changing this config won't affect the class.
To change fields, you can either change the config in-place, override this property,
or overwrite the instance variable `${cls_name}._feature_config`.
"""
return self._feature_config
def use_feature_config_of(self, cls: tp.Type[DataT]) -> None:
"""Copy feature config from another `Data` class."""
self._feature_config = cls.feature_config.copy()
@classmethod
def modify_state(cls, rec_state: RecState) -> RecState:
# Ensure backward compatibility
if "_column_config" in rec_state.attr_dct and "_feature_config" not in rec_state.attr_dct:
new_attr_dct = dict(rec_state.attr_dct)
new_attr_dct["_feature_config"] = new_attr_dct.pop("_column_config")
rec_state = RecState(
init_args=rec_state.init_args,
init_kwargs=rec_state.init_kwargs,
attr_dct=new_attr_dct,
)
if "single_symbol" in rec_state.init_kwargs and "single_key" not in rec_state.init_kwargs:
new_init_kwargs = dict(rec_state.init_kwargs)
new_init_kwargs["single_key"] = new_init_kwargs.pop("single_symbol")
rec_state = RecState(
init_args=rec_state.init_args,
init_kwargs=new_init_kwargs,
attr_dct=rec_state.attr_dct,
)
if "symbol_classes" in rec_state.init_kwargs and "classes" not in rec_state.init_kwargs:
new_init_kwargs = dict(rec_state.init_kwargs)
new_init_kwargs["classes"] = new_init_kwargs.pop("symbol_classes")
rec_state = RecState(
init_args=rec_state.init_args,
init_kwargs=new_init_kwargs,
attr_dct=rec_state.attr_dct,
)
return rec_state
@classmethod
def fix_data_dict_type(cls, data: dict) -> tp.Union[feature_dict, symbol_dict]:
"""Fix dict type for data."""
checks.assert_instance_of(data, dict, arg_name="data")
if not isinstance(data, key_dict):
data = symbol_dict(data)
return data
@classmethod
def fix_dict_types_in_kwargs(
cls,
data_type: tp.Type[tp.Union[feature_dict, symbol_dict]],
**kwargs: tp.Kwargs,
) -> tp.Kwargs:
"""Fix dict types in keyword arguments."""
for attr in cls._key_dict_attrs:
if attr in kwargs:
attr_value = kwargs[attr]
if attr_value is None:
attr_value = {}
checks.assert_instance_of(attr_value, dict, arg_name=attr)
if not isinstance(attr_value, key_dict):
attr_value = data_type(attr_value)
if attr in cls._data_dict_type_attrs:
checks.assert_instance_of(attr_value, data_type, arg_name=attr)
kwargs[attr] = attr_value
return kwargs
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[DataT],
*objs: tp.MaybeTuple[DataT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> DataT:
"""Stack multiple `Data` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Data):
raise TypeError("Each object to be merged must be an instance of Data")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
keys = set()
for obj in objs:
keys = keys.union(set(obj.data.keys()))
data_type = None
for obj in objs:
if len(keys.difference(set(obj.data.keys()))) > 0:
if isinstance(obj.data, feature_dict):
raise ValueError("Objects to be merged must have the same features")
else:
raise ValueError("Objects to be merged must have the same symbols")
if data_type is None:
data_type = type(obj.data)
elif not isinstance(obj.data, data_type):
raise TypeError("Objects to be merged must have the same dict type for data")
if "data" not in kwargs:
new_data = data_type()
for k in objs[0].data.keys():
new_data[k] = kwargs["wrapper"].row_stack_arrs(*[obj.data[k] for obj in objs], group_by=False)
kwargs["data"] = new_data
kwargs["data"] = cls.fix_data_dict_type(kwargs["data"])
for attr in cls._key_dict_attrs:
if attr not in kwargs:
attr_data_type = None
for obj in objs:
v = getattr(obj, attr)
if attr_data_type is None:
attr_data_type = type(v)
elif not isinstance(v, attr_data_type):
raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'")
kwargs[attr] = getattr(objs[-1], attr)
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
kwargs = cls.fix_dict_types_in_kwargs(type(kwargs["data"]), **kwargs)
return cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[DataT],
*objs: tp.MaybeTuple[DataT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> DataT:
"""Stack multiple `Data` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Data):
raise TypeError("Each object to be merged must be an instance of Data")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in objs],
**wrapper_kwargs,
)
keys = set()
for obj in objs:
keys = keys.union(set(obj.data.keys()))
data_type = None
for obj in objs:
if len(keys.difference(set(obj.data.keys()))) > 0:
if isinstance(obj.data, feature_dict):
raise ValueError("Objects to be merged must have the same features")
else:
raise ValueError("Objects to be merged must have the same symbols")
if data_type is None:
data_type = type(obj.data)
elif not isinstance(obj.data, data_type):
raise TypeError("Objects to be merged must have the same dict type for data")
if "data" not in kwargs:
new_data = data_type()
for k in objs[0].data.keys():
new_data[k] = kwargs["wrapper"].column_stack_arrs(*[obj.data[k] for obj in objs], group_by=False)
kwargs["data"] = new_data
kwargs["data"] = cls.fix_data_dict_type(kwargs["data"])
for attr in cls._key_dict_attrs:
if attr not in kwargs:
attr_data_type = None
for obj in objs:
v = getattr(obj, attr)
if attr_data_type is None:
attr_data_type = type(v)
elif not isinstance(v, attr_data_type):
raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'")
if (issubclass(data_type, feature_dict) and issubclass(attr_data_type, symbol_dict)) or (
issubclass(data_type, symbol_dict) and issubclass(attr_data_type, feature_dict)
):
kwargs[attr] = attr_data_type()
for obj in objs:
kwargs[attr].update(**getattr(obj, attr))
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
kwargs = cls.fix_dict_types_in_kwargs(type(kwargs["data"]), **kwargs)
return cls(**kwargs)
def __init__(
self,
wrapper: ArrayWrapper,
data: tp.Union[feature_dict, symbol_dict],
single_key: bool = True,
classes: tp.Union[None, feature_dict, symbol_dict] = None,
level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None,
fetch_kwargs: tp.Union[None, feature_dict, symbol_dict] = None,
returned_kwargs: tp.Union[None, feature_dict, symbol_dict] = None,
last_index: tp.Union[None, feature_dict, symbol_dict] = None,
delisted: tp.Union[None, feature_dict, symbol_dict] = None,
tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None,
tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None,
missing_index: tp.Optional[str] = None,
missing_columns: tp.Optional[str] = None,
**kwargs,
) -> None:
Analyzable.__init__(
self,
wrapper,
data=data,
single_key=single_key,
classes=classes,
level_name=level_name,
fetch_kwargs=fetch_kwargs,
returned_kwargs=returned_kwargs,
last_index=last_index,
delisted=delisted,
tz_localize=tz_localize,
tz_convert=tz_convert,
missing_index=missing_index,
missing_columns=missing_columns,
**kwargs,
)
if len(set(map(self.prepare_key, data.keys()))) < len(list(map(self.prepare_key, data.keys()))):
raise ValueError("Found duplicate keys in data dictionary")
data = self.fix_data_dict_type(data)
for obj in data.values():
checks.assert_meta_equal(obj, data[list(data.keys())[0]])
if len(data) > 1:
single_key = False
self._data = data
self._single_key = single_key
self._level_name = level_name
self._tz_localize = tz_localize
self._tz_convert = tz_convert
self._missing_index = missing_index
self._missing_columns = missing_columns
attr_kwargs = dict()
for attr in self._key_dict_attrs:
attr_value = locals()[attr]
attr_kwargs[attr] = attr_value
attr_kwargs = self.fix_dict_types_in_kwargs(type(data), **attr_kwargs)
for k, v in attr_kwargs.items():
setattr(self, "_" + k, v)
# Copy writeable attrs
self._feature_config = type(self)._feature_config.copy()
def replace(self: DataT, **kwargs) -> DataT:
"""See `vectorbtpro.utils.config.Configured.replace`.
Replaces the data's index and/or columns if they were changed in the wrapper."""
if "wrapper" in kwargs and "data" not in kwargs:
wrapper = kwargs["wrapper"]
if isinstance(wrapper, dict):
new_index = wrapper.get("index", self.wrapper.index)
new_columns = wrapper.get("columns", self.wrapper.columns)
else:
new_index = wrapper.index
new_columns = wrapper.columns
data = self.config["data"]
new_data = {}
index_changed = False
columns_changed = False
for k, v in data.items():
if isinstance(v, (pd.Series, pd.DataFrame)):
if not checks.is_index_equal(v.index, new_index):
v = v.copy(deep=False)
v.index = new_index
index_changed = True
if isinstance(v, pd.DataFrame):
if not checks.is_index_equal(v.columns, new_columns):
v = v.copy(deep=False)
v.columns = new_columns
columns_changed = True
new_data[k] = v
if index_changed or columns_changed:
kwargs["data"] = self.fix_data_dict_type(new_data)
if columns_changed:
rename = dict(zip(self.keys, new_columns))
for attr in self._key_dict_attrs:
if attr not in kwargs:
attr_value = getattr(self, attr)
if (self.feature_oriented and isinstance(attr_value, symbol_dict)) or (
self.symbol_oriented and isinstance(attr_value, feature_dict)
):
kwargs[attr] = self.rename_in_dict(getattr(self, attr), rename)
kwargs = self.fix_dict_types_in_kwargs(type(kwargs.get("data", self.data)), **kwargs)
return Analyzable.replace(self, **kwargs)
def indexing_func(self: DataT, *args, replace_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Perform indexing on `Data`."""
if replace_kwargs is None:
replace_kwargs = {}
wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs)
new_wrapper = wrapper_meta["new_wrapper"]
new_data = self.dict_type()
for k, v in self._data.items():
if wrapper_meta["rows_changed"]:
v = v.iloc[wrapper_meta["row_idxs"]]
if wrapper_meta["columns_changed"]:
v = v.iloc[:, wrapper_meta["col_idxs"]]
new_data[k] = v
attr_dicts = dict()
attr_dicts["last_index"] = type(self.last_index)()
for k in self.last_index:
attr_dicts["last_index"][k] = min([self.last_index[k], new_wrapper.index[-1]])
if wrapper_meta["columns_changed"]:
new_symbols = new_wrapper.columns
for attr in self._key_dict_attrs:
attr_value = getattr(self, attr)
if (self.feature_oriented and isinstance(attr_value, symbol_dict)) or (
self.symbol_oriented and isinstance(attr_value, feature_dict)
):
if attr in attr_dicts:
attr_dicts[attr] = self.select_from_dict(attr_dicts[attr], new_symbols)
else:
attr_dicts[attr] = self.select_from_dict(attr_value, new_symbols)
return self.replace(wrapper=new_wrapper, data=new_data, **attr_dicts, **replace_kwargs)
@property
def data(self) -> tp.Union[feature_dict, symbol_dict]:
"""Data dictionary.
Has the type `feature_dict` for feature-oriented data or `symbol_dict` for symbol-oriented data."""
return self._data
@property
def dict_type(self) -> tp.Type[tp.Union[feature_dict, symbol_dict]]:
"""Return the dict type."""
return type(self.data)
@property
def column_type(self) -> tp.Type[tp.Union[feature_dict, symbol_dict]]:
"""Return the column type."""
if isinstance(self.data, feature_dict):
return symbol_dict
return feature_dict
@property
def feature_oriented(self) -> bool:
"""Whether data has features as keys."""
return issubclass(self.dict_type, feature_dict)
@property
def symbol_oriented(self) -> bool:
"""Whether data has symbols as keys."""
return issubclass(self.dict_type, symbol_dict)
def get_keys(self, dict_type: tp.Type[tp.Union[feature_dict, symbol_dict]]) -> tp.List[tp.Key]:
"""Get keys depending on the provided dict type."""
checks.assert_subclass_of(dict_type, (feature_dict, symbol_dict), arg_name="dict_type")
if issubclass(dict_type, feature_dict):
return self.features
return self.symbols
@property
def keys(self) -> tp.List[tp.Union[tp.Feature, tp.Symbol]]:
"""Keys in data.
Features if `feature_dict` and symbols if `symbol_dict`."""
return list(self.data.keys())
@property
def single_key(self) -> bool:
"""Whether there is only one key in `Data.data`."""
return self._single_key
@property
def single_feature(self) -> bool:
"""Whether there is only one feature in `Data.data`."""
if self.feature_oriented:
return self.single_key
return self.wrapper.ndim == 1
@property
def single_symbol(self) -> bool:
"""Whether there is only one symbol in `Data.data`."""
if self.symbol_oriented:
return self.single_key
return self.wrapper.ndim == 1
@property
def classes(self) -> tp.Union[feature_dict, symbol_dict]:
"""Key classes."""
return self._classes
@property
def feature_classes(self) -> tp.Optional[feature_dict]:
"""Feature classes."""
if self.feature_oriented:
return self.classes
return None
@property
def symbol_classes(self) -> tp.Optional[symbol_dict]:
"""Symbol classes."""
if self.symbol_oriented:
return self.classes
return None
@hybrid_method
def get_level_name(
cls_or_self,
keys: tp.Optional[tp.Keys] = None,
level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None,
feature_oriented: tp.Optional[bool] = None,
) -> tp.Optional[tp.MaybeIterable[tp.Hashable]]:
"""Get level name(s) for keys."""
if isinstance(cls_or_self, type):
checks.assert_not_none(keys, arg_name="keys")
checks.assert_not_none(feature_oriented, arg_name="feature_oriented")
else:
if keys is None:
keys = cls_or_self.keys
if level_name is None:
level_name = cls_or_self._level_name
if feature_oriented is None:
feature_oriented = cls_or_self.feature_oriented
first_key = keys[0]
if isinstance(level_name, bool):
if level_name:
level_name = None
else:
return None
if feature_oriented:
key_prefix = "feature"
else:
key_prefix = "symbol"
if isinstance(first_key, tuple):
if level_name is None:
level_name = ["%s_%d" % (key_prefix, i) for i in range(len(first_key))]
if not checks.is_iterable(level_name) or isinstance(level_name, str):
raise TypeError("Level name should be list-like for a MultiIndex")
return tuple(level_name)
if level_name is None:
level_name = key_prefix
return level_name
@property
def level_name(self) -> tp.Optional[tp.MaybeIterable[tp.Hashable]]:
"""Level name(s) for keys.
Keys are symbols or features depending on the data dict type.
Must be a sequence if keys are tuples, otherwise a hashable.
If False, no level names will be used."""
return self.get_level_name()
@hybrid_method
def get_key_index(
cls_or_self,
keys: tp.Optional[tp.Keys] = None,
level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None,
feature_oriented: tp.Optional[bool] = None,
) -> tp.Index:
"""Get key index."""
if isinstance(cls_or_self, type):
checks.assert_not_none(keys, arg_name="keys")
else:
if keys is None:
keys = cls_or_self.keys
level_name = cls_or_self.get_level_name(keys=keys, level_name=level_name, feature_oriented=feature_oriented)
if isinstance(level_name, tuple):
return pd.MultiIndex.from_tuples(keys, names=level_name)
return pd.Index(keys, name=level_name)
@property
def key_index(self) -> tp.Index:
"""Key index."""
return self.get_key_index()
@property
def fetch_kwargs(self) -> tp.Union[feature_dict, symbol_dict]:
"""Keyword arguments of type `symbol_dict` initially passed to `Data.fetch_symbol`."""
return self._fetch_kwargs
@property
def returned_kwargs(self) -> tp.Union[feature_dict, symbol_dict]:
"""Keyword arguments of type `symbol_dict` returned by `Data.fetch_symbol`."""
return self._returned_kwargs
@property
def last_index(self) -> tp.Union[feature_dict, symbol_dict]:
"""Last fetched index per symbol of type `symbol_dict`."""
return self._last_index
@property
def delisted(self) -> tp.Union[feature_dict, symbol_dict]:
"""Delisted flag per symbol of type `symbol_dict`."""
return self._delisted
@property
def tz_localize(self) -> tp.Union[None, bool, tp.TimezoneLike]:
"""Timezone to localize a datetime-naive index to, which is initially passed to `Data.pull`."""
return self._tz_localize
@property
def tz_convert(self) -> tp.Union[None, bool, tp.TimezoneLike]:
"""Timezone to convert a datetime-aware to, which is initially passed to `Data.pull`."""
return self._tz_convert
@property
def missing_index(self) -> tp.Optional[str]:
"""Argument `missing` passed to `Data.align_index`."""
return self._missing_index
@property
def missing_columns(self) -> tp.Optional[str]:
"""Argument `missing` passed to `Data.align_columns`."""
return self._missing_columns
# ############# Settings ############# #
@classmethod
def get_base_settings(cls, *args, **kwargs) -> dict:
"""`CustomData.get_settings` with `path_id="base"`."""
return cls.get_settings(*args, path_id="base", **kwargs)
@classmethod
def has_base_settings(cls, *args, **kwargs) -> bool:
"""`CustomData.has_settings` with `path_id="base"`."""
return cls.has_settings(*args, path_id="base", **kwargs)
@classmethod
def get_base_setting(cls, *args, **kwargs) -> tp.Any:
"""`CustomData.get_setting` with `path_id="base"`."""
return cls.get_setting(*args, path_id="base", **kwargs)
@classmethod
def has_base_setting(cls, *args, **kwargs) -> bool:
"""`CustomData.has_setting` with `path_id="base"`."""
return cls.has_setting(*args, path_id="base", **kwargs)
@classmethod
def resolve_base_setting(cls, *args, **kwargs) -> tp.Any:
"""`CustomData.resolve_setting` with `path_id="base"`."""
return cls.resolve_setting(*args, path_id="base", **kwargs)
@classmethod
def set_base_settings(cls, *args, **kwargs) -> None:
"""`CustomData.set_settings` with `path_id="base"`."""
cls.set_settings(*args, path_id="base", **kwargs)
# ############# Iteration ############# #
def items(
self,
over: str = "symbols",
group_by: tp.GroupByLike = None,
apply_group_by: bool = False,
keep_2d: bool = False,
key_as_index: bool = False,
) -> tp.Items:
"""Iterate over columns (or groups if grouped and `Wrapping.group_select` is True), keys,
features, or symbols. The respective mode can be selected with `over`.
See `vectorbtpro.base.wrapping.Wrapping.items` for iteration over columns.
Iteration over keys supports `group_by` but doesn't support `apply_group_by`."""
if (
over.lower() == "columns"
or (over.lower() == "symbols" and self.feature_oriented)
or (over.lower() == "features" and self.symbol_oriented)
):
for k, v in Analyzable.items(
self,
group_by=group_by,
apply_group_by=apply_group_by,
keep_2d=keep_2d,
key_as_index=key_as_index,
):
yield k, v
elif (
over.lower() == "keys"
or (over.lower() == "features" and self.feature_oriented)
or (over.lower() == "symbols" and self.symbol_oriented)
):
if apply_group_by:
raise ValueError("Cannot apply grouping to keys")
if group_by is not None:
key_wrapper = self.get_key_wrapper(group_by=group_by)
if key_wrapper.get_ndim() == 1:
if key_as_index:
yield key_wrapper.get_columns(), self
else:
yield key_wrapper.get_columns()[0], self
else:
for group, group_idxs in key_wrapper.grouper.iter_groups(key_as_index=key_as_index):
if keep_2d or len(group_idxs) > 1:
yield group, self.select_keys([self.keys[i] for i in group_idxs])
else:
yield group, self.select_keys(self.keys[group_idxs[0]])
else:
key_wrapper = self.get_key_wrapper(attach_classes=False)
if key_wrapper.ndim == 1:
if key_as_index:
yield key_wrapper.columns, self
else:
yield key_wrapper.columns[0], self
else:
for i in range(len(key_wrapper.columns)):
if key_as_index:
key = key_wrapper.columns[[i]]
else:
key = key_wrapper.columns[i]
if keep_2d:
yield key, self.select_keys([key])
else:
yield key, self.select_keys(key)
else:
raise ValueError(f"Invalid over: '{over}'")
# ############# Getting ############# #
def get_key_wrapper(
self,
keys: tp.Optional[tp.MaybeKeys] = None,
attach_classes: bool = True,
clean_index_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> ArrayWrapper:
"""Get wrapper with keys as columns.
If `attach_classes` is True, attaches `Data.classes` by stacking them over
the keys using `vectorbtpro.base.indexes.stack_indexes`.
Other keyword arguments are passed to the constructor of the wrapper."""
if clean_index_kwargs is None:
clean_index_kwargs = {}
if keys is None:
keys = self.keys
ndim = 1 if self.single_key else 2
else:
if self.has_multiple_keys(keys):
ndim = 2
else:
keys = [keys]
ndim = 1
for key in keys:
if self.feature_oriented:
self.assert_has_feature(key)
else:
self.assert_has_symbol(key)
new_columns = self.get_key_index(keys=keys)
wrapper = self.wrapper.replace(
columns=new_columns,
ndim=ndim,
grouper=None,
**kwargs,
)
if attach_classes:
classes = []
all_have_classes = True
for key in wrapper.columns:
if key in self.classes:
key_classes = self.classes[key]
if len(key_classes) > 0:
classes.append(key_classes)
else:
all_have_classes = False
else:
all_have_classes = False
if len(classes) > 0 and not all_have_classes:
if self.feature_oriented:
raise ValueError("Some features have classes while others not")
else:
raise ValueError("Some symbols have classes while others not")
if len(classes) > 0:
classes_frame = pd.DataFrame(classes)
if len(classes_frame.columns) == 1:
classes_columns = pd.Index(classes_frame.iloc[:, 0])
else:
classes_columns = pd.MultiIndex.from_frame(classes_frame)
new_columns = stack_indexes((classes_columns, wrapper.columns), **clean_index_kwargs)
wrapper = wrapper.replace(columns=new_columns)
if group_by is not None:
wrapper = wrapper.replace(group_by=group_by)
return wrapper
@cached_property
def key_wrapper(self) -> ArrayWrapper:
"""Key wrapper."""
return self.get_key_wrapper()
def get_feature_wrapper(self, features: tp.Optional[tp.MaybeFeatures] = None, **kwargs) -> ArrayWrapper:
"""Get wrapper with features as columns."""
if self.feature_oriented:
return self.get_key_wrapper(keys=features, **kwargs)
wrapper = self.wrapper
if features is not None:
wrapper = wrapper[features]
return wrapper
@cached_property
def feature_wrapper(self) -> ArrayWrapper:
return self.get_feature_wrapper()
def get_symbol_wrapper(self, symbols: tp.Optional[tp.MaybeSymbols] = None, **kwargs) -> ArrayWrapper:
"""Get wrapper with symbols as columns."""
if self.symbol_oriented:
return self.get_key_wrapper(keys=symbols, **kwargs)
wrapper = self.wrapper
if symbols is not None:
wrapper = wrapper[symbols]
return wrapper
@cached_property
def symbol_wrapper(self) -> ArrayWrapper:
return self.get_symbol_wrapper()
@property
def ndim(self) -> int:
"""Number of dimensions.
Based on the default symbol wrapper."""
return self.symbol_wrapper.ndim
@property
def shape(self) -> tp.Shape:
"""Shape.
Based on the default symbol wrapper."""
return self.symbol_wrapper.shape
@property
def shape_2d(self) -> tp.Shape:
"""Shape as if the object was two-dimensional.
Based on the default symbol wrapper."""
return self.symbol_wrapper.shape_2d
@property
def columns(self) -> tp.Index:
"""Columns.
Based on the default symbol wrapper."""
return self.symbol_wrapper.columns
@property
def index(self) -> tp.Index:
"""Index.
Based on the default symbol wrapper."""
return self.symbol_wrapper.index
@property
def freq(self) -> tp.Optional[tp.PandasFrequency]:
"""Frequency.
Based on the default symbol wrapper."""
return self.symbol_wrapper.freq
@property
def features(self) -> tp.List[tp.Feature]:
if self.feature_oriented:
return self.keys
return self.wrapper.columns.tolist()
@property
def symbols(self) -> tp.List[tp.Symbol]:
if self.feature_oriented:
return self.wrapper.columns.tolist()
return self.keys
def resolve_feature(self, feature: tp.Feature, raise_error: bool = False) -> tp.Optional[tp.Feature]:
"""Return the feature of this instance that matches the provided feature."""
feature_idx = self.get_feature_idx(feature, raise_error=raise_error)
if feature_idx == -1:
return None
return self.features[feature_idx]
def resolve_symbol(self, symbol: tp.Feature, raise_error: bool = False) -> tp.Optional[tp.Feature]:
"""Return the symbol of this instance that matches the provided symbol."""
symbol_idx = self.get_symbol_idx(symbol, raise_error=raise_error)
if symbol_idx == -1:
return None
return self.symbols[symbol_idx]
def resolve_key(self, key: tp.Key, raise_error: bool = False) -> tp.Optional[tp.Key]:
"""Return the key of this instance that matches the provided key."""
if self.feature_oriented:
return self.resolve_feature(key, raise_error=raise_error)
return self.resolve_symbol(key, raise_error=raise_error)
def resolve_column(self, column: tp.Column, raise_error: bool = False) -> tp.Optional[tp.Column]:
"""Return the column of this instance that matches the provided column."""
if self.feature_oriented:
return self.resolve_symbol(column, raise_error=raise_error)
return self.resolve_feature(column, raise_error=raise_error)
def resolve_features(self, features: tp.MaybeFeatures, raise_error: bool = True) -> tp.MaybeFeatures:
"""Return the features of this instance that match the provided features."""
if not self.has_multiple_keys(features):
features = [features]
single_feature = True
else:
single_feature = False
new_features = []
for feature in features:
new_features.append(self.resolve_feature(feature, raise_error=raise_error))
if single_feature:
return new_features[0]
return new_features
def resolve_symbols(self, symbols: tp.MaybeSymbols, raise_error: bool = True) -> tp.MaybeSymbols:
"""Return the symbols of this instance that match the provided symbols."""
if not self.has_multiple_keys(symbols):
symbols = [symbols]
single_symbol = True
else:
single_symbol = False
new_symbols = []
for symbol in symbols:
new_symbols.append(self.resolve_symbol(symbol, raise_error=raise_error))
if single_symbol:
return new_symbols[0]
return new_symbols
def resolve_keys(self, keys: tp.MaybeKeys, raise_error: bool = True) -> tp.MaybeKeys:
"""Return the keys of this instance that match the provided keys."""
if self.feature_oriented:
return self.resolve_features(keys, raise_error=raise_error)
return self.resolve_symbols(keys, raise_error=raise_error)
def resolve_columns(self, columns: tp.MaybeColumns, raise_error: bool = True) -> tp.MaybeColumns:
"""Return the columns of this instance that match the provided columns."""
if self.feature_oriented:
return self.resolve_symbols(columns, raise_error=raise_error)
return self.resolve_features(columns, raise_error=raise_error)
def concat(
self,
keys: tp.Optional[tp.Symbols] = None,
attach_classes: bool = True,
clean_index_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Union[feature_dict, symbol_dict]:
"""Concatenate keys along columns."""
key_wrapper = self.get_key_wrapper(
keys=keys,
attach_classes=attach_classes,
clean_index_kwargs=clean_index_kwargs,
**kwargs,
)
if keys is None:
keys = self.keys
new_data = self.column_type()
first_data = self.data[keys[0]]
if key_wrapper.ndim == 1:
if isinstance(first_data, pd.Series):
new_data[first_data.name] = key_wrapper.wrap(first_data.values, zero_to_none=False)
else:
for c in first_data.columns:
new_data[c] = key_wrapper.wrap(first_data[c].values, zero_to_none=False)
else:
if isinstance(first_data, pd.Series):
columns = pd.Index([first_data.name])
else:
columns = first_data.columns
for c in columns:
col_data = []
for k in keys:
if isinstance(self.data[k], pd.Series):
col_data.append(self.data[k].values)
else:
col_data.append(self.data[k][c].values)
new_data[c] = key_wrapper.wrap(column_stack_arrays(col_data), zero_to_none=False)
return new_data
def get(
self,
features: tp.Optional[tp.MaybeFeatures] = None,
symbols: tp.Optional[tp.MaybeSymbols] = None,
feature: tp.Optional[tp.Feature] = None,
symbol: tp.Optional[tp.Symbol] = None,
squeeze_features: bool = False,
squeeze_symbols: bool = False,
per: str = "feature",
as_dict: bool = False,
**kwargs,
) -> tp.Union[tp.MaybeTuple[tp.SeriesFrame], dict]:
"""Get one or more features of one or more symbols of data."""
if features is not None and feature is not None:
raise ValueError("Must provide either features or feature, not both")
if symbols is not None and symbol is not None:
raise ValueError("Must provide either symbols or symbol, not both")
if feature is not None:
features = feature
single_feature = True
else:
if features is None:
features = self.features
single_feature = self.single_feature
if single_feature:
features = features[0]
else:
single_feature = not self.has_multiple_keys(features)
if not single_feature and squeeze_features and len(features) == 1:
features = features[0]
single_feature = True
if symbol is not None:
symbols = symbol
single_symbol = True
else:
if symbols is None:
symbols = self.symbols
single_symbol = self.single_symbol
if single_symbol:
symbols = symbols[0]
else:
single_symbol = not self.has_multiple_keys(symbols)
if not single_symbol and squeeze_symbols and len(symbols) == 1:
symbols = symbols[0]
single_symbol = True
if not single_feature:
feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features]
features = [self.features[i] for i in feature_idxs]
else:
feature_idxs = self.get_feature_idx(features, raise_error=True)
features = self.features[feature_idxs]
if not single_symbol:
symbol_idxs = [self.get_symbol_idx(k, raise_error=True) for k in symbols]
symbols = [self.symbols[i] for i in symbol_idxs]
else:
symbol_idxs = self.get_symbol_idx(symbols, raise_error=True)
symbols = self.symbols[symbol_idxs]
def _get_objs():
if self.feature_oriented:
if single_feature:
if self.single_symbol:
return list(self.data.values())[feature_idxs], features
return list(self.data.values())[feature_idxs].iloc[:, symbol_idxs], features
if single_symbol:
concat_data = self.concat(keys=features, **kwargs)
return list(concat_data.values())[symbol_idxs], symbols
if per.lower() in ("symbol", "column"):
concat_data = self.concat(keys=features, **kwargs)
return tuple([list(concat_data.values())[i] for i in symbol_idxs]), symbols
if per.lower() in ("feature", "key"):
if self.single_feature:
if self.single_symbol:
return list(self.data.values())[feature_idxs], features
return list(self.data.values())[feature_idxs].iloc[:, symbol_idxs], features
if self.single_symbol:
return tuple([list(self.data.values())[i] for i in feature_idxs]), features
return tuple([list(self.data.values())[i].iloc[:, symbol_idxs] for i in feature_idxs]), features
raise ValueError(f"Invalid per: '{per}'")
else:
if single_symbol:
if self.single_feature:
return self.data[self.symbols[symbol_idxs]], symbols
return self.data[self.symbols[symbol_idxs]].iloc[:, feature_idxs], symbols
if single_feature:
concat_data = self.concat(keys=symbols, **kwargs)
return list(concat_data.values())[feature_idxs], features
if per.lower() in ("feature", "column"):
concat_data = self.concat(keys=symbols, **kwargs)
return tuple([list(concat_data.values())[i] for i in feature_idxs]), features
if per.lower() in ("symbol", "key"):
if self.single_symbol:
if self.single_feature:
return list(self.data.values())[symbol_idxs], symbols
return list(self.data.values())[symbol_idxs].iloc[:, feature_idxs], symbols
if self.single_feature:
return tuple([list(self.data.values())[i] for i in symbol_idxs]), symbols
return tuple([list(self.data.values())[i].iloc[:, feature_idxs] for i in symbol_idxs]), symbols
raise ValueError(f"Invalid per: '{per}'")
objs, keys = _get_objs()
if as_dict:
if isinstance(objs, tuple):
return dict(zip(keys, objs))
return {keys: objs}
return objs
# ############# Pre- and post-processing ############# #
@classmethod
def prepare_dt_index(
cls,
index: tp.Index,
parse_dates: bool = False,
tz_localize: tp.TimezoneLike = None,
tz_convert: tp.TimezoneLike = None,
force_tz_convert: bool = False,
remove_tz: bool = False,
) -> tp.SeriesFrame:
"""Prepare datetime index.
If `parse_dates` is True, will try to convert the index with an object data type
into a datetime format using `vectorbtpro.utils.datetime_.prepare_dt_index`.
If `tz_localize` is not None, will localize a datetime-naive index into this timezone.
If `tz_convert` is not None, will convert a datetime-aware index into this timezone.
If `force_tz_convert` is True, will convert regardless of whether the index is datetime-aware."""
if parse_dates:
if not isinstance(index, (pd.DatetimeIndex, pd.MultiIndex)) and index.dtype == object:
index = dt.prepare_dt_index(index)
if isinstance(index, pd.DatetimeIndex):
if index.tz is None and tz_localize is not None:
index = index.tz_localize(dt.to_timezone(tz_localize))
if tz_convert is not None:
if index.tz is not None or force_tz_convert:
index = index.tz_convert(dt.to_timezone(tz_convert))
if remove_tz and index.tz is not None:
index = index.tz_localize(None)
return index
@classmethod
def prepare_dt_column(
cls,
sr: tp.Series,
parse_dates: bool = False,
tz_localize: tp.TimezoneLike = None,
tz_convert: tp.TimezoneLike = None,
force_tz_convert: bool = False,
remove_tz: bool = False,
) -> tp.Series:
"""Prepare datetime column.
See `Data.prepare_dt_index` for arguments."""
index = cls.prepare_dt_index(
pd.Index(sr),
parse_dates=parse_dates,
tz_localize=tz_localize,
tz_convert=tz_convert,
force_tz_convert=force_tz_convert,
remove_tz=remove_tz,
)
if isinstance(index, pd.DatetimeIndex):
return pd.Series(index, index=sr.index, name=sr.name)
return sr
@classmethod
def prepare_dt(
cls,
obj: tp.SeriesFrame,
parse_dates: tp.Union[None, bool, tp.Sequence[str]] = True,
to_utc: tp.Union[None, bool, str, tp.Sequence[str]] = True,
remove_utc_tz: bool = False,
) -> tp.Frame:
"""Prepare datetime index and columns.
If `parse_dates` is True, will try to convert any index and column with object data type
into a datetime format using `vectorbtpro.utils.datetime_.prepare_dt_index`.
If `parse_dates` is a list or dict, will first check whether the name of the column
is among the names that are in `parse_dates`.
If `to_utc` is True or `to_utc` is "index" or `to_utc` is a sequence and index name is in this
sequence, will localize/convert any datetime index to the UTC timezone. If `to_utc` is True or
`to_utc` is "columns" or `to_utc` is a sequence and column name is in this sequence, will
localize/convert any datetime column to the UTC timezone."""
obj = obj.copy(deep=False)
made_frame = False
if isinstance(obj, pd.Series):
obj = obj.to_frame()
made_frame = True
index_parse_dates = False
if not isinstance(obj.index, pd.MultiIndex) and obj.index.dtype == object:
if parse_dates is True:
index_parse_dates = True
elif checks.is_sequence(parse_dates) and obj.index.name in parse_dates:
index_parse_dates = True
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "index")
or (checks.is_sequence(to_utc) and obj.index.name in to_utc)
):
index_tz_localize = "utc"
index_tz_convert = "utc"
index_remove_tz = remove_utc_tz
else:
index_tz_localize = None
index_tz_convert = None
index_remove_tz = False
obj.index = cls.prepare_dt_index(
obj.index,
parse_dates=index_parse_dates,
tz_localize=index_tz_localize,
tz_convert=index_tz_convert,
remove_tz=index_remove_tz,
)
for column_name in obj.columns:
column_parse_dates = False
if obj[column_name].dtype == object:
if parse_dates is True:
column_parse_dates = True
elif checks.is_sequence(parse_dates) and column_name in parse_dates:
column_parse_dates = True
elif not hasattr(obj[column_name], "dt"):
continue
if (
to_utc is True
or (isinstance(to_utc, str) and to_utc.lower() == "columns")
or (checks.is_sequence(to_utc) and column_name in to_utc)
):
column_tz_localize = "utc"
column_tz_convert = "utc"
column_remove_tz = remove_utc_tz
else:
column_tz_localize = None
column_tz_convert = None
column_remove_tz = False
obj[column_name] = cls.prepare_dt_column(
obj[column_name],
parse_dates=column_parse_dates,
tz_localize=column_tz_localize,
tz_convert=column_tz_convert,
remove_tz=column_remove_tz,
)
if made_frame:
obj = obj.iloc[:, 0]
return obj
@classmethod
def prepare_tzaware_index(
cls,
obj: tp.SeriesFrame,
tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None,
tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None,
) -> tp.SeriesFrame:
"""Prepare a timezone-aware index of a Pandas object.
Uses `Data.prepare_dt_index` with `parse_dates=True` and `force_tz_convert=True`.
For defaults, see `vectorbtpro._settings.data`."""
obj = obj.copy(deep=False)
tz_localize = cls.resolve_base_setting(tz_localize, "tz_localize")
if isinstance(tz_localize, bool):
if tz_localize:
raise ValueError("tz_localize cannot be True")
else:
tz_localize = None
tz_convert = cls.resolve_base_setting(tz_convert, "tz_convert")
if isinstance(tz_convert, bool):
if tz_convert:
raise ValueError("tz_convert cannot be True")
else:
tz_convert = None
obj.index = cls.prepare_dt_index(
obj.index,
parse_dates=True,
tz_localize=tz_localize,
tz_convert=tz_convert,
force_tz_convert=True,
)
return obj
@classmethod
def align_index(
cls,
data: dict,
missing: tp.Optional[str] = None,
silence_warnings: tp.Optional[bool] = None,
) -> dict:
"""Align data to have the same index.
The argument `missing` accepts the following values:
* 'nan': set missing data points to NaN
* 'drop': remove missing data points
* 'raise': raise an error
For defaults, see `vectorbtpro._settings.data`."""
missing = cls.resolve_base_setting(missing, "missing_index")
silence_warnings = cls.resolve_base_setting(silence_warnings, "silence_warnings")
index = None
index_changed = False
for k, obj in data.items():
if index is None:
index = obj.index
else:
if not checks.is_index_equal(index, obj.index, check_names=False):
if missing == "nan":
if not silence_warnings:
warn("Symbols have mismatching index. Setting missing data points to NaN.")
index = index.union(obj.index)
index_changed = True
elif missing == "drop":
if not silence_warnings:
warn("Symbols have mismatching index. Dropping missing data points.")
index = index.intersection(obj.index)
index_changed = True
elif missing == "raise":
raise ValueError("Symbols have mismatching index")
else:
raise ValueError(f"Invalid missing: '{missing}'")
if not index_changed:
return data
new_data = {k: obj.reindex(index=index) for k, obj in data.items()}
return type(data)(new_data)
@classmethod
def align_columns(
cls,
data: dict,
missing: tp.Optional[str] = None,
silence_warnings: tp.Optional[bool] = None,
) -> dict:
"""Align data to have the same columns.
See `Data.align_index` for `missing`."""
if len(data) == 1:
return data
missing = cls.resolve_base_setting(missing, "missing_columns")
silence_warnings = cls.resolve_base_setting(silence_warnings, "silence_warnings")
columns = None
multiple_columns = False
name_is_none = False
columns_changed = False
for k, obj in data.items():
if isinstance(obj, pd.Series):
if obj.name is None:
name_is_none = True
obj = obj.to_frame()
else:
multiple_columns = True
obj_columns = obj.columns
if columns is None:
columns = obj_columns
else:
if not checks.is_index_equal(columns, obj_columns, check_names=False):
if missing == "nan":
if not silence_warnings:
warn("Symbols have mismatching columns. Setting missing data points to NaN.")
columns = columns.union(obj_columns)
columns_changed = True
elif missing == "drop":
if not silence_warnings:
warn("Symbols have mismatching columns. Dropping missing data points.")
columns = columns.intersection(obj_columns)
columns_changed = True
elif missing == "raise":
raise ValueError("Symbols have mismatching columns")
else:
raise ValueError(f"Invalid missing: '{missing}'")
if not columns_changed:
return data
new_data = {}
for k, obj in data.items():
if isinstance(obj, pd.Series):
obj = obj.to_frame()
obj = obj.reindex(columns=columns)
if not multiple_columns:
obj = obj[columns[0]]
if name_is_none:
obj = obj.rename(None)
new_data[k] = obj
return type(data)(new_data)
def switch_class(
self,
new_cls: tp.Type[DataT],
clear_fetch_kwargs: bool = False,
clear_returned_kwargs: bool = False,
**kwargs,
) -> DataT:
"""Switch the class of the data instance."""
if clear_fetch_kwargs:
new_fetch_kwargs = type(self.fetch_kwargs)({k: {} for k in self.symbols})
else:
new_fetch_kwargs = copy_dict(self.fetch_kwargs)
if clear_returned_kwargs:
new_returned_kwargs = type(self.returned_kwargs)({k: {} for k in self.symbols})
else:
new_returned_kwargs = copy_dict(self.returned_kwargs)
return self.replace(
cls_=new_cls,
fetch_kwargs=new_fetch_kwargs,
returned_kwargs=new_returned_kwargs,
**kwargs,
)
@classmethod
def invert_data(cls, dct: tp.Dict[tp.Key, tp.SeriesFrame]) -> tp.Dict[tp.Key, tp.SeriesFrame]:
"""Invert data by swapping keys and columns."""
if len(dct) == 0:
return dct
new_dct = dict()
for k, v in dct.items():
if isinstance(v, pd.Series):
if v.name not in new_dct:
new_dct[v.name] = []
new_dct[v.name].append(v.rename(k))
else:
for c in v.columns:
if c not in new_dct:
new_dct[c] = []
new_dct[c].append(v[c].rename(k))
new_dct2 = {}
for k, v in new_dct.items():
if len(v) == 1:
new_dct2[k] = v[0]
else:
new_dct2[k] = pd.concat(v, axis=1)
if isinstance(dct, symbol_dict):
return feature_dict(new_dct2)
if isinstance(dct, feature_dict):
return symbol_dict(new_dct2)
return new_dct2
@hybrid_method
def align_data(
cls_or_self,
data: dict,
last_index: tp.Union[None, feature_dict, symbol_dict] = None,
delisted: tp.Union[None, feature_dict, symbol_dict] = None,
tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None,
tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None,
missing_index: tp.Optional[str] = None,
missing_columns: tp.Optional[str] = None,
silence_warnings: tp.Optional[bool] = None,
) -> dict:
"""Align data.
Removes any index duplicates, prepares the datetime index, and aligns the index and columns."""
if last_index is None:
last_index = {}
if delisted is None:
delisted = {}
if tz_localize is None and not isinstance(cls_or_self, type):
tz_localize = cls_or_self.tz_localize
if tz_convert is None and not isinstance(cls_or_self, type):
tz_convert = cls_or_self.tz_convert
if missing_index is None and not isinstance(cls_or_self, type):
missing_index = cls_or_self.missing_index
if missing_columns is None and not isinstance(cls_or_self, type):
missing_columns = cls_or_self.missing_columns
data = type(data)(data)
for k, obj in data.items():
obj = to_pd_array(obj)
obj = cls_or_self.prepare_tzaware_index(obj, tz_localize=tz_localize, tz_convert=tz_convert)
if obj.index.is_monotonic_decreasing:
obj = obj.iloc[::-1]
elif not obj.index.is_monotonic_increasing:
obj = obj.sort_index()
if obj.index.has_duplicates:
obj = obj[~obj.index.duplicated(keep="last")]
data[k] = obj
if (isinstance(data, symbol_dict) and isinstance(last_index, symbol_dict)) or (
isinstance(data, feature_dict) and isinstance(last_index, feature_dict)
):
if k not in last_index:
last_index[k] = obj.index[-1]
if (isinstance(data, symbol_dict) and isinstance(delisted, symbol_dict)) or (
isinstance(data, feature_dict) and isinstance(delisted, feature_dict)
):
if k not in delisted:
delisted[k] = False
data = cls_or_self.align_index(data, missing=missing_index, silence_warnings=silence_warnings)
data = cls_or_self.align_columns(data, missing=missing_columns, silence_warnings=silence_warnings)
first_data = data[list(data.keys())[0]]
if isinstance(first_data, pd.Series):
columns = [first_data.name]
else:
columns = first_data.columns
for k in columns:
if (isinstance(data, symbol_dict) and isinstance(last_index, feature_dict)) or (
isinstance(data, feature_dict) and isinstance(last_index, symbol_dict)
):
if k not in last_index:
last_index[k] = first_data.index[-1]
if (isinstance(data, symbol_dict) and isinstance(delisted, feature_dict)) or (
isinstance(data, feature_dict) and isinstance(delisted, symbol_dict)
):
if k not in delisted:
delisted[k] = False
for obj in data.values():
if isinstance(obj.index, pd.DatetimeIndex):
obj.index.freq = obj.index.inferred_freq
return data
@classmethod
def from_data(
cls: tp.Type[DataT],
data: tp.Union[dict, tp.SeriesFrame],
columns_are_symbols: bool = False,
invert_data: bool = False,
single_key: bool = True,
classes: tp.Optional[dict] = None,
level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None,
tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None,
tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None,
missing_index: tp.Optional[str] = None,
missing_columns: tp.Optional[str] = None,
wrapper_kwargs: tp.KwargsLike = None,
fetch_kwargs: tp.Optional[dict] = None,
returned_kwargs: tp.Optional[dict] = None,
last_index: tp.Optional[dict] = None,
delisted: tp.Optional[dict] = None,
silence_warnings: tp.Optional[bool] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance from data.
Args:
data (dict): Dictionary of array-like objects keyed by symbol.
columns_are_symbols (bool): Whether columns in each DataFrame are symbols.
invert_data (bool): Whether to invert the data dictionary with `Data.invert_data`.
single_key (bool): See `Data.single_key`.
classes (feature_dict or symbol_dict): See `Data.classes`.
level_name (bool, hashable or iterable of hashable): See `Data.level_name`.
tz_localize (timezone_like): See `Data.prepare_tzaware_index`.
tz_convert (timezone_like): See `Data.prepare_tzaware_index`.
missing_index (str): See `Data.align_index`.
missing_columns (str): See `Data.align_columns`.
wrapper_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper`.
fetch_kwargs (feature_dict or symbol_dict): Keyword arguments initially passed to `Data.fetch_symbol`.
returned_kwargs (feature_dict or symbol_dict): Keyword arguments returned by `Data.fetch_symbol`.
last_index (feature_dict or symbol_dict): Last fetched index per symbol.
delisted (feature_dict or symbol_dict): Whether symbol has been delisted.
silence_warnings (bool): Whether to silence all warnings.
**kwargs: Keyword arguments passed to the `__init__` method.
For defaults, see `vectorbtpro._settings.data`."""
if wrapper_kwargs is None:
wrapper_kwargs = {}
if classes is None:
classes = {}
if fetch_kwargs is None:
fetch_kwargs = {}
if returned_kwargs is None:
returned_kwargs = {}
if last_index is None:
last_index = {}
if delisted is None:
delisted = {}
if columns_are_symbols and isinstance(data, symbol_dict):
raise TypeError("Data cannot have the type symbol_dict when columns_are_symbols=True")
if isinstance(data, (pd.Series, pd.DataFrame)):
if columns_are_symbols:
data = feature_dict(feature=data)
else:
data = symbol_dict(symbol=data)
checks.assert_instance_of(data, dict, arg_name="data")
if not isinstance(data, key_dict):
if columns_are_symbols:
data = feature_dict(data)
else:
data = symbol_dict(data)
if invert_data:
data = cls.invert_data(data)
if len(data) > 1:
single_key = False
checks.assert_instance_of(last_index, dict, arg_name="last_index")
if not isinstance(last_index, key_dict):
last_index = type(data)(last_index)
checks.assert_instance_of(delisted, dict, arg_name="delisted")
if not isinstance(delisted, key_dict):
delisted = type(data)(delisted)
data = cls.align_data(
data,
last_index=last_index,
delisted=delisted,
tz_localize=tz_localize,
tz_convert=tz_convert,
missing_index=missing_index,
missing_columns=missing_columns,
silence_warnings=silence_warnings,
)
first_data = data[list(data.keys())[0]]
wrapper = ArrayWrapper.from_obj(first_data, **wrapper_kwargs)
attr_dicts = cls.fix_dict_types_in_kwargs(
type(data),
classes=classes,
fetch_kwargs=fetch_kwargs,
returned_kwargs=returned_kwargs,
last_index=last_index,
delisted=delisted,
)
return cls(
wrapper,
data,
single_key=single_key,
level_name=level_name,
tz_localize=tz_localize,
tz_convert=tz_convert,
missing_index=missing_index,
missing_columns=missing_columns,
**attr_dicts,
**kwargs,
)
def invert(self: DataT, key_wrapper_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Invert data and return a new instance."""
if key_wrapper_kwargs is None:
key_wrapper_kwargs = {}
new_data = self.concat(attach_classes=False)
if "wrapper" not in kwargs:
kwargs["wrapper"] = self.get_key_wrapper(**key_wrapper_kwargs)
if "classes" not in kwargs:
kwargs["classes"] = self.column_type()
if "single_key" not in kwargs:
kwargs["single_key"] = self.wrapper.ndim == 1
if "level_name" not in kwargs:
if isinstance(self.wrapper.columns, pd.MultiIndex):
if self.wrapper.columns.names == [None] * self.wrapper.columns.nlevels:
kwargs["level_name"] = False
else:
kwargs["level_name"] = self.wrapper.columns.names
else:
if self.wrapper.columns.name is None:
kwargs["level_name"] = False
else:
kwargs["level_name"] = self.wrapper.columns.name
return self.replace(data=new_data, **kwargs)
def to_feature_oriented(self: DataT, **kwargs) -> DataT:
"""Convert this instance to the feature-oriented format.
Returns self if the data is already properly formatted."""
if self.feature_oriented:
if len(kwargs) > 0:
return self.replace(**kwargs)
return self
return self.invert(**kwargs)
def to_symbol_oriented(self: DataT, **kwargs) -> DataT:
"""Convert this instance to the symbol-oriented format.
Returns self if the data is already properly formatted."""
if self.symbol_oriented:
if len(kwargs) > 0:
return self.replace(**kwargs)
return self
return self.invert(**kwargs)
@classmethod
def has_key_dict(
cls,
arg: tp.Any,
dict_type: tp.Optional[tp.Type[tp.Union[feature_dict, symbol_dict]]] = None,
) -> bool:
"""Check whether the argument contains any data dictionary."""
if dict_type is None:
dict_type = key_dict
if isinstance(arg, dict_type):
return True
if isinstance(arg, dict):
for k, v in arg.items():
if isinstance(v, dict_type):
return True
return False
@hybrid_method
def check_dict_type(
cls_or_self,
arg: tp.Any,
arg_name: tp.Optional[str] = None,
dict_type: tp.Optional[tp.Type[tp.Union[feature_dict, symbol_dict]]] = None,
) -> None:
"""Check whether the argument conforms to a data dictionary."""
if isinstance(cls_or_self, type):
checks.assert_not_none(dict_type, arg_name="dict_type")
if dict_type is None:
dict_type = cls_or_self.dict_type
if issubclass(dict_type, feature_dict):
checks.assert_not_instance_of(arg, symbol_dict, arg_name=arg_name)
if issubclass(dict_type, symbol_dict):
checks.assert_not_instance_of(arg, feature_dict, arg_name=arg_name)
@hybrid_method
def select_key_kwargs(
cls_or_self,
key: tp.Key,
kwargs: tp.KwargsLike,
kwargs_name: str = "kwargs",
dict_type: tp.Optional[tp.Type[tp.Union[feature_dict, symbol_dict]]] = None,
check_dict_type: bool = True,
) -> tp.Kwargs:
"""Select the keyword arguments belonging to a feature or symbol."""
if isinstance(cls_or_self, type):
checks.assert_not_none(dict_type, arg_name="dict_type")
if dict_type is None:
dict_type = cls_or_self.dict_type
if kwargs is None:
return {}
if check_dict_type:
cls_or_self.check_dict_type(kwargs, arg_name=kwargs_name, dict_type=dict_type)
if type(kwargs) is key_dict or isinstance(kwargs, dict_type):
if key not in kwargs:
return {}
kwargs = dict(kwargs[key])
_kwargs = {}
for k, v in kwargs.items():
if check_dict_type:
cls_or_self.check_dict_type(v, arg_name=f"{kwargs_name}[{k}]", dict_type=dict_type)
if type(v) is key_dict or isinstance(v, dict_type):
if key in v:
_kwargs[k] = v[key]
else:
_kwargs[k] = v
return _kwargs
@classmethod
def select_feature_kwargs(cls, feature: tp.Feature, kwargs: tp.KwargsLike, **kwargs_) -> tp.Kwargs:
"""Select the keyword arguments belonging to a feature."""
return cls.select_key_kwargs(feature, kwargs, dict_type=feature_dict, **kwargs_)
@classmethod
def select_symbol_kwargs(cls, symbol: tp.Symbol, kwargs: tp.KwargsLike, **kwargs_) -> tp.Kwargs:
"""Select the keyword arguments belonging to a symbol."""
return cls.select_key_kwargs(symbol, kwargs, dict_type=symbol_dict, **kwargs_)
@hybrid_method
def select_key_from_dict(
cls_or_self,
key: tp.Key,
dct: key_dict,
dct_name: str = "dct",
dict_type: tp.Optional[tp.Type[tp.Union[feature_dict, symbol_dict]]] = None,
check_dict_type: bool = True,
) -> tp.Any:
"""Select the dictionary value belonging to a feature or symbol."""
if isinstance(cls_or_self, type):
checks.assert_not_none(dict_type, arg_name="dict_type")
if dict_type is None:
dict_type = cls_or_self.dict_type
if check_dict_type:
cls_or_self.check_dict_type(dct, arg_name=dct_name, dict_type=dict_type)
return dct[key]
@classmethod
def select_feature_from_dict(cls, feature: tp.Feature, dct: feature_dict, **kwargs) -> tp.Any:
"""Select the dictionary value belonging to a feature."""
return cls.select_key_kwargs(feature, dct, dict_type=feature_dict, **kwargs)
@classmethod
def select_symbol_from_dict(cls, symbol: tp.Symbol, dct: symbol_dict, **kwargs) -> tp.Any:
"""Select the dictionary value belonging to a symbol."""
return cls.select_key_kwargs(symbol, dct, dict_type=symbol_dict, **kwargs)
@classmethod
def select_from_dict(cls, dct: dict, keys: tp.Keys, raise_error: bool = False) -> dict:
"""Select keys from a dict."""
if raise_error:
return type(dct)({k: dct[k] for k in keys})
return type(dct)({k: dct[k] for k in keys if k in dct})
@classmethod
def get_intersection_dict(cls, dct: dict) -> dict:
"""Get sub-keys and corresponding sub-values that are the same for all keys."""
dct_values = list(dct.values())
overlapping_keys = set(dct_values[0].keys())
for d in dct_values[1:]:
overlapping_keys.intersection_update(d.keys())
new_dct = dict()
for i, k in enumerate(dct.keys()):
for k2 in overlapping_keys:
v2 = dct[k][k2]
if i == 0 and k2 not in new_dct:
new_dct[k2] = v2
elif k2 in new_dct and new_dct[k2] is not v2:
del new_dct[k2]
return new_dct
def select_keys(self: DataT, keys: tp.MaybeKeys, **kwargs) -> DataT:
"""Create a new `Data` instance with one or more keys selected from this instance."""
keys = self.resolve_keys(keys)
if self.has_multiple_keys(keys):
single_key = False
else:
single_key = True
keys = [keys]
attr_dicts = dict()
for attr in self._key_dict_attrs:
attr_value = getattr(self, attr)
if isinstance(attr_value, self.dict_type):
attr_dicts[attr] = self.select_from_dict(attr_value, keys)
return self.replace(
data=self.select_from_dict(self.data, keys, raise_error=True),
single_key=single_key,
**attr_dicts,
**kwargs,
)
def select_columns(self: DataT, columns: tp.MaybeColumns, **kwargs) -> DataT:
"""Create a new `Data` instance with one or more columns selected from this instance."""
columns = self.resolve_columns(columns)
def _pd_indexing_func(obj):
return obj[columns]
return self.indexing_func(_pd_indexing_func, replace_kwargs=kwargs)
def select_feature_idxs(self: DataT, idxs: tp.MaybeSequence[int], **kwargs) -> DataT:
if checks.is_int(idxs):
features = self.features[idxs]
else:
features = [self.features[i] for i in idxs]
if self.feature_oriented:
return self.select_keys(features, **kwargs)
return self.select_columns(features, **kwargs)
def select_symbol_idxs(self: DataT, idxs: tp.MaybeSequence[int], **kwargs) -> DataT:
if checks.is_int(idxs):
symbols = self.symbols[idxs]
else:
symbols = [self.symbols[i] for i in idxs]
if self.feature_oriented:
return self.select_columns(symbols, **kwargs)
return self.select_keys(symbols, **kwargs)
def select(self: DataT, keys: tp.MaybeKeys, **kwargs) -> DataT:
"""Create a new `Data` instance with one or more features or symbols selected from this instance.
Will try to determine the orientation automatically."""
if not self.has_multiple_keys(keys):
keys = [keys]
single_key = True
else:
single_key = False
feature_keys = set(self.resolve_features(keys, raise_error=False))
symbol_keys = set(self.resolve_symbols(keys, raise_error=False))
features_and_keys = set(self.features).intersection(feature_keys)
symbols_and_keys = set(self.symbols).intersection(symbol_keys)
if features_and_keys and not symbols_and_keys:
if single_key:
return self.select_features(keys[0], **kwargs)
return self.select_features(keys, **kwargs)
if symbols_and_keys and not features_and_keys:
if single_key:
return self.select_symbols(keys[0], **kwargs)
return self.select_symbols(keys, **kwargs)
raise ValueError("Cannot determine orientation. Use select_features or select_symbols.")
def add_feature(
self: DataT,
feature: tp.Feature,
data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None,
pull_feature: bool = False,
pull_kwargs: tp.KwargsLike = None,
reuse_fetch_kwargs: bool = True,
run_kwargs: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
merge_kwargs: tp.KwargsLike = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with a new feature added to this instance.
If `data` is None, uses `Data.run`. If in addition `pull_feature` is True, uses `Data.pull` instead."""
if run_kwargs is None:
run_kwargs = {}
if wrap_kwargs is None:
wrap_kwargs = {}
if data is None:
if pull_feature:
if isinstance(self.fetch_kwargs, feature_dict) and reuse_fetch_kwargs:
pull_kwargs = merge_dicts(self.get_intersection_dict(self.fetch_kwargs), pull_kwargs)
data = type(self).pull(features=feature, **pull_kwargs).get(feature=feature)
else:
data = self.run(feature, **run_kwargs, unpack=True)
data = self.symbol_wrapper.wrap(data, **wrap_kwargs)
if isinstance(data, CustomTemplate):
data = data.substitute(dict(data=self), eval_id="data")
if isinstance(data, pd.Series) and self.symbol_wrapper.ndim == 1:
data = data.copy(deep=False)
data.name = self.symbols[0]
for attr in self._key_dict_attrs:
if attr in kwargs:
checks.assert_not_instance_of(kwargs[attr], key_dict, arg_name=attr)
kwargs[attr] = feature_dict({feature: kwargs[attr]})
data = type(self).from_data(
feature_dict({feature: data}),
invert_data=not self.feature_oriented,
**kwargs,
)
on_merge_conflict = {k: "error" for k in kwargs if k not in self._key_dict_attrs}
on_merge_conflict["_def"] = "first"
if merge_kwargs is None:
merge_kwargs = {}
return self.merge(data, on_merge_conflict=on_merge_conflict, **merge_kwargs)
def add_symbol(
self: DataT,
symbol: tp.Symbol,
data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None,
pull_kwargs: tp.KwargsLike = None,
reuse_fetch_kwargs: bool = True,
merge_kwargs: tp.KwargsLike = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with a new symbol added to this instance.
If `data` is None, uses `Data.pull`."""
if pull_kwargs is None:
pull_kwargs = {}
if data is None:
if isinstance(self.fetch_kwargs, symbol_dict) and reuse_fetch_kwargs:
pull_kwargs = merge_dicts(self.get_intersection_dict(self.fetch_kwargs), pull_kwargs)
data = type(self).pull(symbols=symbol, **pull_kwargs).get(symbol=symbol)
if isinstance(data, CustomTemplate):
data = data.substitute(dict(data=self), eval_id="data")
if isinstance(data, pd.Series) and self.feature_wrapper.ndim == 1:
data = data.copy(deep=False)
data.name = self.features[0]
for attr in self._key_dict_attrs:
if attr in kwargs:
checks.assert_not_instance_of(kwargs[attr], key_dict, arg_name=attr)
kwargs[attr] = symbol_dict({symbol: kwargs[attr]})
data = type(self).from_data(
symbol_dict({symbol: data}),
invert_data=not self.symbol_oriented,
**kwargs,
)
on_merge_conflict = {k: "error" for k in kwargs if k not in self._key_dict_attrs}
on_merge_conflict["_def"] = "first"
if merge_kwargs is None:
merge_kwargs = {}
return self.merge(data, on_merge_conflict=on_merge_conflict, **merge_kwargs)
def add_key(
self: DataT,
key: tp.Key,
data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with a new key added to this instance."""
if self.feature_oriented:
return self.add_feature(key, data=data, **kwargs)
return self.add_symbol(key, data=data, **kwargs)
def add_column(
self: DataT,
column: tp.Column,
data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with a new column added to this instance."""
if self.feature_oriented:
return self.add_symbol(column, data=data, **kwargs)
return self.add_feature(column, data=data, **kwargs)
def add(
self: DataT,
key: tp.Key,
data: tp.Union[None, tp.SeriesFrame, CustomTemplate] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with a new feature or symbol added to this instance.
Will try to determine the orientation automatically."""
if data is not None:
if isinstance(data, CustomTemplate):
data = data.substitute(dict(data=self), eval_id="data")
if isinstance(data, pd.Series):
columns = [data.name]
else:
columns = data.columns
feature_columns = set(self.resolve_features(columns, raise_error=False))
symbol_columns = set(self.resolve_symbols(columns, raise_error=False))
features_and_columns = set(self.features).intersection(feature_columns)
symbols_and_columns = set(self.symbols).intersection(symbol_columns)
if features_and_columns and not symbols_and_columns:
return self.add_symbol(key, data=data, **kwargs)
if symbols_and_columns and not features_and_columns:
return self.add_feature(key, data=data, **kwargs)
raise ValueError("Cannot determine orientation. Use add_feature or add_symbol.")
@classmethod
def rename_in_dict(cls, dct: dict, rename: tp.Dict[tp.Key, tp.Key]) -> dict:
"""Rename keys in a dict."""
return type(dct)({rename.get(k, k): v for k, v in dct.items()})
def rename_keys(
self: DataT,
rename: tp.Union[tp.MaybeKeys, tp.Dict[tp.Key, tp.Key]],
to: tp.Optional[tp.MaybeKeys] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with keys renamed."""
if to is not None:
if self.has_multiple_keys(to):
rename = dict(zip(rename, to))
else:
rename = {rename: to}
rename = dict(zip(self.resolve_keys(list(rename.keys())), rename.values()))
attr_dicts = dict()
for attr in self._key_dict_attrs:
attr_value = getattr(self, attr)
if isinstance(attr_value, self.dict_type):
attr_dicts[attr] = self.rename_in_dict(attr_value, rename)
return self.replace(data=self.rename_in_dict(self.data, rename), **attr_dicts, **kwargs)
def rename_columns(
self: DataT,
rename: tp.Union[tp.MaybeColumns, tp.Dict[tp.Column, tp.Column]],
to: tp.Optional[tp.MaybeColumns] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with columns renamed."""
if to is not None:
if self.has_multiple_keys(to):
rename = dict(zip(rename, to))
else:
rename = {rename: to}
rename = dict(zip(self.resolve_columns(list(rename.keys())), rename.values()))
attr_dicts = dict()
for attr in self._key_dict_attrs:
attr_value = getattr(self, attr)
if isinstance(attr_value, self.column_type):
attr_dicts[attr] = self.rename_in_dict(attr_value, rename)
new_wrapper = self.wrapper.replace(columns=self.wrapper.columns.map(lambda x: rename.get(x, x)))
return self.replace(wrapper=new_wrapper, **attr_dicts, **kwargs)
def rename_features(
self: DataT,
rename: tp.Union[tp.MaybeFeatures, tp.Dict[tp.Feature, tp.Feature]],
to: tp.Optional[tp.MaybeFeatures] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with features renamed."""
if self.feature_oriented:
return self.rename_keys(rename, to=to, **kwargs)
return self.rename_columns(rename, to=to, **kwargs)
def rename_symbols(
self: DataT,
rename: tp.Union[tp.MaybeSymbols, tp.Dict[tp.Symbol, tp.Symbol]],
to: tp.Optional[tp.MaybeSymbols] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with symbols renamed."""
if self.feature_oriented:
return self.rename_columns(rename, to=to, **kwargs)
return self.rename_keys(rename, to=to, **kwargs)
def rename(
self: DataT,
rename: tp.Union[tp.MaybeKeys, tp.Dict[tp.Key, tp.Key]],
to: tp.Optional[tp.MaybeKeys] = None,
**kwargs,
) -> DataT:
"""Create a new `Data` instance with features or symbols renamed.
Will try to determine the orientation automatically."""
if to is not None:
if self.has_multiple_keys(to):
rename = dict(zip(rename, to))
else:
rename = {rename: to}
feature_keys = set(self.resolve_features(list(rename.keys()), raise_error=False))
symbol_keys = set(self.resolve_symbols(list(rename.keys()), raise_error=False))
features_and_keys = set(self.features).intersection(feature_keys)
symbols_and_keys = set(self.symbols).intersection(symbol_keys)
if features_and_keys and not symbols_and_keys:
return self.rename_features(rename, **kwargs)
if symbols_and_keys and not features_and_keys:
return self.rename_symbols(rename, **kwargs)
raise ValueError("Cannot determine orientation. Use rename_features or rename_symbols.")
def remove_features(self: DataT, features: tp.MaybeFeatures, **kwargs) -> DataT:
"""Create a new `Data` instance with one or more features removed from this instance."""
if self.has_multiple_keys(features):
remove_feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features]
else:
remove_feature_idxs = [self.get_feature_idx(features, raise_error=True)]
keep_feature_idxs = [i for i in range(len(self.features)) if i not in remove_feature_idxs]
if len(keep_feature_idxs) == 0:
raise ValueError("No features will be left after this operation")
return self.select_feature_idxs(keep_feature_idxs, **kwargs)
def remove_symbols(self: DataT, symbols: tp.MaybeFeatures, **kwargs) -> DataT:
"""Create a new `Data` instance with one or more symbols removed from this instance."""
if self.has_multiple_keys(symbols):
remove_symbol_idxs = [self.get_symbol_idx(k, raise_error=True) for k in symbols]
else:
remove_symbol_idxs = [self.get_symbol_idx(symbols, raise_error=True)]
keep_symbol_idxs = [i for i in range(len(self.symbols)) if i not in remove_symbol_idxs]
if len(keep_symbol_idxs) == 0:
raise ValueError("No symbols will be left after this operation")
return self.select_symbol_idxs(keep_symbol_idxs, **kwargs)
def remove_keys(self: DataT, keys: tp.MaybeKeys, **kwargs) -> DataT:
"""Create a new `Data` instance with one or more keys removed from this instance."""
if self.feature_oriented:
return self.remove_features(keys, **kwargs)
return self.remove_symbols(keys, **kwargs)
def remove_columns(self: DataT, columns: tp.MaybeColumns, **kwargs) -> DataT:
"""Create a new `Data` instance with one or more columns removed from this instance."""
if self.feature_oriented:
return self.remove_symbols(columns, **kwargs)
return self.remove_features(columns, **kwargs)
def remove(self: DataT, keys: tp.MaybeKeys, **kwargs) -> DataT:
"""Create a new `Data` instance with one or more features or symbols removed from this instance.
Will try to determine the orientation automatically."""
if not self.has_multiple_keys(keys):
keys = [keys]
feature_keys = set(self.resolve_features(keys, raise_error=False))
symbol_keys = set(self.resolve_symbols(keys, raise_error=False))
features_and_keys = set(self.features).intersection(feature_keys)
symbols_and_keys = set(self.symbols).intersection(symbol_keys)
if features_and_keys and not symbols_and_keys:
return self.remove_features(keys, **kwargs)
if symbols_and_keys and not features_and_keys:
return self.remove_symbols(keys, **kwargs)
raise ValueError("Cannot determine orientation. Use remove_features or remove_symbols.")
@hybrid_method
def merge(
cls_or_self: tp.MaybeType[DataT],
*datas: DataT,
rename: tp.Optional[tp.Dict[tp.Key, tp.Key]] = None,
**kwargs,
) -> DataT:
"""Merge multiple `Data` instances.
Can merge both symbols and features. Data is overridden in the order as provided in `datas`."""
if len(datas) == 1 and not isinstance(datas[0], Data):
datas = datas[0]
datas = list(datas)
if not isinstance(cls_or_self, type):
datas = (cls_or_self, *datas)
data_type = None
data = {}
single_key = True
attr_dicts = dict()
for instance in datas:
if data_type is None:
data_type = type(instance.data)
elif not isinstance(instance.data, data_type):
raise TypeError("Objects to be merged must have the same dict type for data")
if not instance.single_key:
single_key = False
for k in instance.keys:
if rename is None:
new_k = k
else:
new_k = rename[k]
if new_k in data:
obj1 = instance.data[k]
obj2 = data[new_k]
both_were_series = True
if isinstance(obj1, pd.Series):
obj1 = obj1.to_frame()
else:
both_were_series = False
if isinstance(obj2, pd.Series):
obj2 = obj2.to_frame()
else:
both_were_series = False
new_obj = obj1.combine_first(obj2)
new_columns = []
for c in obj2.columns:
new_columns.append(c)
for c in obj1.columns:
if c not in new_columns:
new_columns.append(c)
new_obj = new_obj[new_columns]
if new_obj.shape[1] == 1 and both_were_series:
new_obj = new_obj.iloc[:, 0]
data[new_k] = new_obj
else:
data[new_k] = instance.data[k]
for attr in cls_or_self._key_dict_attrs:
attr_value = getattr(instance, attr)
if (issubclass(data_type, symbol_dict) and isinstance(attr_value, symbol_dict)) or (
issubclass(data_type, feature_dict) and isinstance(attr_value, feature_dict)
):
if k in attr_value:
if attr not in attr_dicts:
attr_dicts[attr] = type(attr_value)()
elif not isinstance(attr_value, type(attr_dicts[attr])):
raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'")
attr_dicts[attr][new_k] = attr_value[k]
for attr in cls_or_self._key_dict_attrs:
attr_value = getattr(instance, attr)
if (issubclass(data_type, symbol_dict) and isinstance(attr_value, feature_dict)) or (
issubclass(data_type, feature_dict) and isinstance(attr_value, symbol_dict)
):
if attr not in attr_dicts:
attr_dicts[attr] = type(attr_value)()
elif not isinstance(attr_value, type(attr_dicts[attr])):
raise TypeError(f"Objects to be merged must have the same dict type for '{attr}'")
attr_dicts[attr].update(**attr_value)
if "missing_index" not in kwargs:
kwargs["missing_index"] = "nan"
if "missing_columns" not in kwargs:
kwargs["missing_columns"] = "nan"
kwargs = cls_or_self.resolve_merge_kwargs(
*[instance.config for instance in datas],
wrapper=None,
data=data_type(data),
single_key=single_key,
**attr_dicts,
**kwargs,
)
kwargs.pop("wrapper", None)
return cls_or_self.from_data(**kwargs)
# ############# Fetching ############# #
@classmethod
def fetch_feature(
cls,
feature: tp.Feature,
**kwargs,
) -> tp.FeatureData:
"""Fetch a feature.
Can also return a dictionary that will be accessible in `Data.returned_kwargs`.
If there are keyword arguments `tz_localize`, `tz_convert`, or `freq` in this dict,
will pop them and use them to override global settings.
This is an abstract method - override it to define custom logic."""
raise NotImplementedError
@classmethod
def try_fetch_feature(
cls,
feature: tp.Feature,
skip_on_error: bool = False,
silence_warnings: bool = False,
fetch_kwargs: tp.KwargsLike = None,
) -> tp.FeatureData:
"""Try to fetch a feature using `Data.fetch_feature`."""
if fetch_kwargs is None:
fetch_kwargs = {}
try:
out = cls.fetch_feature(feature, **fetch_kwargs)
if out is None:
if not silence_warnings:
warn(f"Feature '{str(feature)}' returned None. Skipping.")
return out
except Exception as e:
if not skip_on_error:
raise e
if not silence_warnings:
warn(traceback.format_exc())
warn(f"Feature '{str(feature)}' raised an exception. Skipping.")
return None
@classmethod
def fetch_symbol(
cls,
symbol: tp.Symbol,
**kwargs,
) -> tp.SymbolData:
"""Fetch a symbol.
Can also return a dictionary that will be accessible in `Data.returned_kwargs`.
If there are keyword arguments `tz_localize`, `tz_convert`, or `freq` in this dict,
will pop them and use them to override global settings.
This is an abstract method - override it to define custom logic."""
raise NotImplementedError
@classmethod
def try_fetch_symbol(
cls,
symbol: tp.Symbol,
skip_on_error: bool = False,
silence_warnings: bool = False,
fetch_kwargs: tp.KwargsLike = None,
) -> tp.SymbolData:
"""Try to fetch a symbol using `Data.fetch_symbol`."""
if fetch_kwargs is None:
fetch_kwargs = {}
try:
out = cls.fetch_symbol(symbol, **fetch_kwargs)
if out is None:
if not silence_warnings:
warn(f"Symbol '{str(symbol)}' returned None. Skipping.")
return out
except Exception as e:
if not skip_on_error:
raise e
if not silence_warnings:
warn(traceback.format_exc())
warn(f"Symbol '{str(symbol)}' raised an exception. Skipping.")
return None
@classmethod
def resolve_keys_meta(
cls,
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
) -> tp.Kwargs:
"""Resolve metadata for keys."""
if keys is not None and features is not None:
raise ValueError("Must provide either keys or features, not both")
if keys is not None and symbols is not None:
raise ValueError("Must provide either keys or symbols, not both")
if features is not None and symbols is not None:
raise ValueError("Must provide either features or symbols, not both")
if keys is None:
if features is not None:
if isinstance(features, dict):
cls.check_dict_type(features, "features", dict_type=feature_dict)
keys = features
keys_are_features = True
dict_type = feature_dict
elif symbols is not None:
if isinstance(symbols, dict):
cls.check_dict_type(symbols, "symbols", dict_type=symbol_dict)
keys = symbols
keys_are_features = False
dict_type = symbol_dict
else:
keys = symbols
keys_are_features = False
dict_type = symbol_dict
else:
if isinstance(keys, feature_dict):
if keys_are_features is not None and not keys_are_features:
raise TypeError("Keys are of type feature_dict but keys_are_features is False")
keys_are_features = True
elif isinstance(keys, symbol_dict):
if keys_are_features is not None and keys_are_features:
raise TypeError("Keys are of type symbol_dict but keys_are_features is True")
keys_are_features = False
keys_are_features = cls.resolve_base_setting(keys_are_features, "keys_are_features")
if keys_are_features:
dict_type = feature_dict
else:
dict_type = symbol_dict
return dict(
keys=keys,
keys_are_features=keys_are_features,
dict_type=dict_type,
)
@classmethod
def pull(
cls: tp.Type[DataT],
keys: tp.Union[None, dict, tp.MaybeKeys] = None,
*,
keys_are_features: tp.Optional[bool] = None,
features: tp.Union[None, dict, tp.MaybeFeatures] = None,
symbols: tp.Union[None, dict, tp.MaybeSymbols] = None,
classes: tp.Optional[tp.MaybeSequence[tp.Union[tp.Hashable, dict]]] = None,
level_name: tp.Union[None, bool, tp.MaybeIterable[tp.Hashable]] = None,
tz_localize: tp.Union[None, bool, tp.TimezoneLike] = None,
tz_convert: tp.Union[None, bool, tp.TimezoneLike] = None,
missing_index: tp.Optional[str] = None,
missing_columns: tp.Optional[str] = None,
wrapper_kwargs: tp.KwargsLike = None,
skip_on_error: tp.Optional[bool] = None,
silence_warnings: tp.Optional[bool] = None,
execute_kwargs: tp.KwargsLike = None,
return_raw: bool = False,
**kwargs,
) -> tp.Union[DataT, tp.List[tp.Any]]:
"""Pull data.
Fetches each feature/symbol with `Data.fetch_feature`/`Data.fetch_symbol` and prepares it with `Data.from_data`.
Iteration over features/symbols is done using `vectorbtpro.utils.execution.execute`.
That is, it can be distributed and parallelized when needed.
Args:
keys (hashable, sequence of hashable, or dict): One or multiple keys.
Depending on `keys_are_features` will be set to `features` or `symbols`.
keys_are_features (bool): Whether `keys` are considered features.
features (hashable, sequence of hashable, or dict): One or multiple features.
If provided as a dictionary, will use keys as features and values as keyword arguments.
!!! note
Tuple is considered as a single feature (tuple is a hashable).
symbols (hashable, sequence of hashable, or dict): One or multiple symbols.
If provided as a dictionary, will use keys as symbols and values as keyword arguments.
!!! note
Tuple is considered as a single symbol (tuple is a hashable).
classes (feature_dict or symbol_dict): See `Data.classes`.
Can be a hashable (single value), a dictionary (class names as keys and
class values as values), or a sequence of such.
!!! note
Tuple is considered as a single class (tuple is a hashable).
level_name (bool, hashable or iterable of hashable): See `Data.level_name`.
tz_localize (any): See `Data.from_data`.
tz_convert (any): See `Data.from_data`.
missing_index (str): See `Data.from_data`.
missing_columns (str): See `Data.from_data`.
wrapper_kwargs (dict): See `Data.from_data`.
skip_on_error (bool): Whether to skip the feature/symbol when an exception is raised.
silence_warnings (bool): Whether to silence all warnings.
Will also forward this argument to `Data.fetch_feature`/`Data.fetch_symbol` if in the signature.
execute_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.execution.execute`.
return_raw (bool): Whether to return the raw outputs.
**kwargs: Passed to `Data.fetch_feature`/`Data.fetch_symbol`.
If two features/symbols require different keyword arguments, pass
`key_dict` or `feature_dict`/`symbol_dict` for each argument.
For defaults, see `vectorbtpro._settings.data`.
"""
keys_meta = cls.resolve_keys_meta(
keys=keys,
keys_are_features=keys_are_features,
features=features,
symbols=symbols,
)
keys = keys_meta["keys"]
keys_are_features = keys_meta["keys_are_features"]
dict_type = keys_meta["dict_type"]
fetch_kwargs = dict_type()
if isinstance(keys, dict):
new_keys = []
for k, key_fetch_kwargs in keys.items():
new_keys.append(k)
fetch_kwargs[k] = key_fetch_kwargs
keys = new_keys
single_key = False
elif cls.has_multiple_keys(keys):
keys = list(keys)
if len(set(keys)) < len(keys):
raise ValueError("Duplicate keys provided")
single_key = False
else:
single_key = True
keys = [keys]
if classes is not None:
cls.check_dict_type(classes, arg_name="classes", dict_type=dict_type)
if not isinstance(classes, key_dict):
new_classes = {}
single_class = checks.is_hashable(classes) or isinstance(classes, dict)
if single_class:
for k in keys:
if isinstance(classes, dict):
new_classes[k] = classes
else:
if keys_are_features:
new_classes[k] = {"feature_class": classes}
else:
new_classes[k] = {"symbol_class": classes}
else:
for i, k in enumerate(keys):
_classes = classes[i]
if not isinstance(_classes, dict):
if keys_are_features:
_classes = {"feature_class": _classes}
else:
_classes = {"symbol_class": _classes}
new_classes[k] = _classes
classes = new_classes
wrapper_kwargs = cls.resolve_base_setting(wrapper_kwargs, "wrapper_kwargs", merge=True)
skip_on_error = cls.resolve_base_setting(skip_on_error, "skip_on_error")
silence_warnings = cls.resolve_base_setting(silence_warnings, "silence_warnings")
execute_kwargs = cls.resolve_base_setting(execute_kwargs, "execute_kwargs", merge=True)
execute_kwargs = merge_dicts(dict(show_progress=not single_key), execute_kwargs)
tasks = []
if keys_are_features:
func_arg_names = get_func_arg_names(cls.fetch_feature)
else:
func_arg_names = get_func_arg_names(cls.fetch_symbol)
for k in keys:
if keys_are_features:
key_fetch_func = cls.try_fetch_feature
key_fetch_kwargs = cls.select_feature_kwargs(k, kwargs)
else:
key_fetch_func = cls.try_fetch_symbol
key_fetch_kwargs = cls.select_symbol_kwargs(k, kwargs)
if "silence_warnings" in func_arg_names:
key_fetch_kwargs["silence_warnings"] = silence_warnings
if k in fetch_kwargs:
key_fetch_kwargs = merge_dicts(key_fetch_kwargs, fetch_kwargs[k])
tasks.append(
Task(
key_fetch_func,
k,
skip_on_error=skip_on_error,
silence_warnings=silence_warnings,
fetch_kwargs=key_fetch_kwargs,
)
)
fetch_kwargs[k] = key_fetch_kwargs
key_index = cls.get_key_index(keys=keys, level_name=level_name, feature_oriented=keys_are_features)
outputs = execute(tasks, size=len(keys), keys=key_index, **execute_kwargs)
if return_raw:
return outputs
data = dict_type()
returned_kwargs = dict_type()
common_tz_localize = None
common_tz_convert = None
common_freq = None
for i, out in enumerate(outputs):
k = keys[i]
if out is not None:
if isinstance(out, tuple):
_data = out[0]
_returned_kwargs = out[1]
else:
_data = out
_returned_kwargs = {}
_data = to_any_array(_data)
_tz = _returned_kwargs.pop("tz", None)
_tz_localize = _returned_kwargs.pop("tz_localize", None)
_tz_convert = _returned_kwargs.pop("tz_convert", None)
_freq = _returned_kwargs.pop("freq", None)
if _tz is not None:
if _tz_localize is None:
_tz_localize = _tz
if _tz_convert is None:
_tz_convert = _tz
if _tz_localize is not None:
if common_tz_localize is None:
common_tz_localize = _tz_localize
elif common_tz_localize != _tz_localize:
raise ValueError("Returned objects have different timezones (tz_localize)")
if _tz_convert is not None:
if common_tz_convert is None:
common_tz_convert = _tz_convert
elif common_tz_convert != _tz_convert:
if not silence_warnings:
warn(f"Returned objects have different timezones (tz_convert). Setting to UTC.")
common_tz_convert = "utc"
if _freq is not None:
if common_freq is None:
common_freq = _freq
elif common_freq != _freq:
raise ValueError("Returned objects have different frequencies (freq)")
if _data.size == 0:
if not silence_warnings:
if keys_are_features:
warn(f"Feature '{str(k)}' returned an empty array. Skipping.")
else:
warn(f"Symbol '{str(k)}' returned an empty array. Skipping.")
else:
data[k] = _data
returned_kwargs[k] = _returned_kwargs
if tz_localize is None and common_tz_localize is not None:
tz_localize = common_tz_localize
if tz_convert is None and common_tz_convert is not None:
tz_convert = common_tz_convert
if wrapper_kwargs.get("freq", None) is None and common_freq is not None:
wrapper_kwargs["freq"] = common_freq
if len(data) == 0:
if keys_are_features:
raise ValueError("No features could be fetched")
else:
raise ValueError("No symbols could be fetched")
return cls.from_data(
data,
single_key=single_key,
classes=classes,
level_name=level_name,
tz_localize=tz_localize,
tz_convert=tz_convert,
missing_index=missing_index,
missing_columns=missing_columns,
wrapper_kwargs=wrapper_kwargs,
fetch_kwargs=fetch_kwargs,
returned_kwargs=returned_kwargs,
silence_warnings=silence_warnings,
)
@classmethod
def download(cls: tp.Type[DataT], *args, **kwargs) -> tp.Union[DataT, tp.List[tp.Any]]:
"""Exists for backward compatibility. Use `Data.pull` instead."""
return cls.pull(*args, **kwargs)
@classmethod
def fetch(cls: tp.Type[DataT], *args, **kwargs) -> tp.Union[DataT, tp.List[tp.Any]]:
"""Exists for backward compatibility. Use `Data.pull` instead."""
return cls.pull(*args, **kwargs)
@classmethod
def from_data_str(cls: tp.Type[DataT], data_str: str) -> DataT:
"""Parse a `Data` instance from a string.
For example: `YFData:BTC-USD` or just `BTC-USD` where the data class is
`vectorbtpro.data.custom.yf.YFData` by default."""
from vectorbtpro.data import custom
if ":" in data_str:
cls_name, symbol = data_str.split(":")
cls_name = cls_name.strip()
symbol = symbol.strip()
return getattr(custom, cls_name).pull(symbol)
return custom.YFData.pull(data_str.strip())
# ############# Updating ############# #
def update_feature(
self,
feature: tp.Feature,
**kwargs,
) -> tp.FeatureData:
"""Update a feature.
Can also return a dictionary that will be accessible in `Data.returned_kwargs`.
This is an abstract method - override it to define custom logic."""
raise NotImplementedError
def try_update_feature(
self,
feature: tp.Feature,
skip_on_error: bool = False,
silence_warnings: bool = False,
update_kwargs: tp.KwargsLike = None,
) -> tp.FeatureData:
"""Try to update a feature using `Data.update_feature`."""
if update_kwargs is None:
update_kwargs = {}
try:
out = self.update_feature(feature, **update_kwargs)
if out is None:
if not silence_warnings:
warn(f"Feature '{str(feature)}' returned None. Skipping.")
return out
except Exception as e:
if not skip_on_error:
raise e
if not silence_warnings:
warn(traceback.format_exc())
warn(f"Feature '{str(feature)}' raised an exception. Skipping.")
return None
def update_symbol(
self,
symbol: tp.Symbol,
**kwargs,
) -> tp.SymbolData:
"""Update a symbol.
Can also return a dictionary that will be accessible in `Data.returned_kwargs`.
This is an abstract method - override it to define custom logic."""
raise NotImplementedError
def try_update_symbol(
self,
symbol: tp.Symbol,
skip_on_error: bool = False,
silence_warnings: bool = False,
update_kwargs: tp.KwargsLike = None,
) -> tp.SymbolData:
"""Try to update a symbol using `Data.update_symbol`."""
if update_kwargs is None:
update_kwargs = {}
try:
out = self.update_symbol(symbol, **update_kwargs)
if out is None:
if not silence_warnings:
warn(f"Symbol '{str(symbol)}' returned None. Skipping.")
return out
except Exception as e:
if not skip_on_error:
raise e
if not silence_warnings:
warn(traceback.format_exc())
warn(f"Symbol '{str(symbol)}' raised an exception. Skipping.")
return None
def update(
self: DataT,
*,
concat: bool = True,
skip_on_error: tp.Optional[bool] = None,
silence_warnings: tp.Optional[bool] = None,
execute_kwargs: tp.KwargsLike = None,
return_raw: bool = False,
**kwargs,
) -> tp.Union[DataT, tp.List[tp.Any]]:
"""Update data.
Fetches new data for each feature/symbol using `Data.update_feature`/`Data.update_symbol`.
Args:
concat (bool): Whether to concatenate existing and updated/new data.
skip_on_error (bool): Whether to skip the feature/symbol when an exception is raised.
silence_warnings (bool): Whether to silence all warnings.
Will also forward this argument to `Data.update_feature`/`Data.update_symbol`
if accepted by `Data.fetch_feature`/`Data.fetch_symbol`.
execute_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.execution.execute`.
return_raw (bool): Whether to return the raw outputs.
**kwargs: Passed to `Data.update_feature`/`Data.update_symbol`.
If two features/symbols require different keyword arguments,
pass `key_dict` or `feature_dict`/`symbol_dict` for each argument.
!!! note
Returns a new `Data` instance instead of changing the data in place.
"""
skip_on_error = self.resolve_base_setting(skip_on_error, "skip_on_error")
silence_warnings = self.resolve_base_setting(silence_warnings, "silence_warnings")
execute_kwargs = self.resolve_base_setting(execute_kwargs, "execute_kwargs", merge=True)
execute_kwargs = merge_dicts(dict(show_progress=False), execute_kwargs)
if self.feature_oriented:
func_arg_names = get_func_arg_names(self.fetch_feature)
else:
func_arg_names = get_func_arg_names(self.fetch_symbol)
if "show_progress" in func_arg_names and "show_progress" not in kwargs:
kwargs["show_progress"] = False
checks.assert_instance_of(self.last_index, self.dict_type, "last_index")
checks.assert_instance_of(self.delisted, self.dict_type, "delisted")
tasks = []
key_indices = []
for i, k in enumerate(self.keys):
if not self.delisted.get(k, False):
if self.feature_oriented:
key_update_func = self.try_update_feature
key_update_kwargs = self.select_feature_kwargs(k, kwargs)
else:
key_update_func = self.try_update_symbol
key_update_kwargs = self.select_symbol_kwargs(k, kwargs)
if "silence_warnings" in func_arg_names:
key_update_kwargs["silence_warnings"] = silence_warnings
tasks.append(
Task(
key_update_func,
k,
skip_on_error=skip_on_error,
silence_warnings=silence_warnings,
update_kwargs=key_update_kwargs,
)
)
key_indices.append(i)
outputs = execute(tasks, size=len(self.keys), keys=self.key_index, **execute_kwargs)
if return_raw:
return outputs
new_data = {}
new_last_index = {}
new_returned_kwargs = {}
i = 0
for k, obj in self.data.items():
if self.delisted.get(k, False):
out = None
else:
out = outputs[i]
i += 1
skip_key = False
if out is not None:
if isinstance(out, tuple):
new_obj = out[0]
new_returned_kwargs[k] = out[1]
else:
new_obj = out
new_returned_kwargs[k] = {}
new_obj = to_any_array(new_obj)
if new_obj.size == 0:
if not silence_warnings:
if self.feature_oriented:
warn(f"Feature '{str(k)}' returned an empty array. Skipping.")
else:
warn(f"Symbol '{str(k)}' returned an empty array. Skipping.")
skip_key = True
else:
if not isinstance(new_obj, (pd.Series, pd.DataFrame)):
new_obj = to_pd_array(new_obj)
new_obj.index = pd.RangeIndex(
start=obj.index[-1],
stop=obj.index[-1] + new_obj.shape[0],
step=1,
)
new_obj = self.prepare_tzaware_index(
new_obj,
tz_localize=self.tz_localize,
tz_convert=self.tz_convert,
)
if new_obj.index.is_monotonic_decreasing:
new_obj = new_obj.iloc[::-1]
elif not new_obj.index.is_monotonic_increasing:
new_obj = new_obj.sort_index()
if new_obj.index.has_duplicates:
new_obj = new_obj[~new_obj.index.duplicated(keep="last")]
new_data[k] = new_obj
if len(new_obj.index) > 0:
new_last_index[k] = new_obj.index[-1]
else:
new_last_index[k] = self.last_index[k]
else:
skip_key = True
if skip_key:
new_data[k] = obj.iloc[0:0]
new_last_index[k] = self.last_index[k]
# Get the last index in the old data from where the new data should begin
from_index = None
for k, new_obj in new_data.items():
if len(new_obj.index) > 0:
index = new_obj.index[0]
else:
continue
if from_index is None or index < from_index:
from_index = index
if from_index is None:
if not silence_warnings:
if self.feature_oriented:
warn(f"None of the features were updated")
else:
warn(f"None of the symbols were updated")
return self.copy()
# Concatenate the updated old data and the new data
for k, new_obj in new_data.items():
if len(new_obj.index) > 0:
to_index = new_obj.index[0]
else:
to_index = None
obj = self.data[k]
if isinstance(obj, pd.DataFrame) and isinstance(new_obj, pd.DataFrame):
shared_columns = obj.columns.intersection(new_obj.columns)
obj = obj[shared_columns]
new_obj = new_obj[shared_columns]
elif isinstance(new_obj, pd.DataFrame):
if obj.name is not None:
new_obj = new_obj[obj.name]
else:
new_obj = new_obj[0]
elif isinstance(obj, pd.DataFrame):
if new_obj.name is not None:
obj = obj[new_obj.name]
else:
obj = obj[0]
obj = obj.loc[from_index:to_index]
new_obj = pd.concat((obj, new_obj), axis=0)
if new_obj.index.has_duplicates:
new_obj = new_obj[~new_obj.index.duplicated(keep="last")]
new_data[k] = new_obj
# Align the index and columns in the new data
new_data = self.align_index(new_data, missing=self.missing_index, silence_warnings=silence_warnings)
new_data = self.align_columns(new_data, missing=self.missing_columns, silence_warnings=silence_warnings)
# Align the columns and data type in the old and new data
for k, new_obj in new_data.items():
obj = self.data[k]
if isinstance(obj, pd.DataFrame) and isinstance(new_obj, pd.DataFrame):
new_obj = new_obj[obj.columns]
elif isinstance(new_obj, pd.DataFrame):
if obj.name is not None:
new_obj = new_obj[obj.name]
else:
new_obj = new_obj[0]
if isinstance(obj, pd.DataFrame):
new_obj = new_obj.astype(obj.dtypes, errors="ignore")
else:
new_obj = new_obj.astype(obj.dtype, errors="ignore")
new_data[k] = new_obj
if not concat:
# Do not concatenate with the old data
for k, new_obj in new_data.items():
if isinstance(new_obj.index, pd.DatetimeIndex):
new_obj.index.freq = new_obj.index.inferred_freq
new_index = new_data[self.keys[0]].index
return self.replace(
wrapper=self.wrapper.replace(index=new_index),
data=self.dict_type(new_data),
returned_kwargs=self.dict_type(new_returned_kwargs),
last_index=self.dict_type(new_last_index),
)
# Append the new data to the old data
for k, new_obj in new_data.items():
obj = self.data[k]
obj = obj.loc[:from_index]
if obj.index[-1] == from_index:
obj = obj.iloc[:-1]
new_obj = pd.concat((obj, new_obj), axis=0)
if isinstance(new_obj.index, pd.DatetimeIndex):
new_obj.index.freq = new_obj.index.inferred_freq
new_data[k] = new_obj
new_index = new_data[self.keys[0]].index
return self.replace(
wrapper=self.wrapper.replace(index=new_index),
data=self.dict_type(new_data),
returned_kwargs=self.dict_type(new_returned_kwargs),
last_index=self.dict_type(new_last_index),
)
# ############# Transforming ############# #
def transform(
self: DataT,
transform_func: tp.Callable,
*args,
per_feature: bool = False,
per_symbol: bool = False,
pass_frame: bool = False,
key_wrapper_kwargs: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> DataT:
"""Transform data.
If one key (i.e., feature or symbol), passes the entire Series/DataFrame. If `per_feature` is True,
passes the Series/DataFrame of each feature. If `per_symbol` is True, passes the Series/DataFrame
of each symbol. If both are True, passes each feature and symbol combination as a Series
if `pass_frame` is False or as a DataFrame with one column if `pass_frame` is True.
If both are False, concatenates all features and symbols into a single DataFrame
and calls `transform_func` on it. Then, splits the data by key and builds a new `Data` instance.
Keyword arguments `key_wrapper_kwargs` are passed to `Data.get_key_wrapper` to control,
for example, attachment of classes.
After the transformation, the new data is aligned using `Data.align_data`. If the data is not
a Pandas object, it's broadcasted to the original data with `broadcast_kwargs`.
!!! note
The returned object must have the same type and dimensionality as the input object.
Number of columns (i.e., features and symbols) and their names must stay the same.
To remove columns, use either indexing or `Data.select` (depending on the data orientation).
To add new columns, use either column stacking or `Data.merge`.
Index, on the other hand, can be changed freely."""
if key_wrapper_kwargs is None:
key_wrapper_kwargs = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
def _transform(data, _template_context=None):
_transform_func = substitute_templates(transform_func, _template_context, eval_id="transform_func")
_args = substitute_templates(args, _template_context, eval_id="args")
_kwargs = substitute_templates(kwargs, _template_context, eval_id="kwargs")
out = _transform_func(data, *_args, **_kwargs)
if not isinstance(out, (pd.Series, pd.DataFrame)):
out = broadcast_to(out, data, **broadcast_kwargs)
return out
if (self.feature_oriented and (per_feature and not per_symbol)) or (
self.symbol_oriented and (per_symbol and not per_feature)
):
new_data = self.dict_type()
for k in self.keys:
if self.feature_oriented:
_template_context = merge_dicts(dict(key=k, feature=k), template_context)
else:
_template_context = merge_dicts(dict(key=k, symbol=k), template_context)
new_data[k] = _transform(self.data[k], _template_context)
checks.assert_meta_equal(new_data[k], self.data[k], axis=1)
elif (self.feature_oriented and (per_symbol and not per_feature)) or (
self.symbol_oriented and (per_feature and not per_symbol)
):
first_data = self.data[list(self.data.keys())[0]]
if isinstance(first_data, pd.Series):
concat_data = pd.concat(self.data.values(), axis=1)
key_wrapper = self.get_key_wrapper(**key_wrapper_kwargs)
concat_data.columns = key_wrapper.columns
if self.feature_oriented:
_template_context = merge_dicts(
dict(column=self.wrapper.columns[0], symbol=self.wrapper.columns[0]),
template_context,
)
else:
_template_context = merge_dicts(
dict(column=self.wrapper.columns[0], feature=self.wrapper.columns[0]),
template_context,
)
new_concat_data = _transform(concat_data, _template_context)
checks.assert_meta_equal(new_concat_data, concat_data, axis=1)
new_data = self.dict_type()
for i, k in enumerate(self.keys):
new_data[k] = new_concat_data.iloc[:, i]
new_data[k].name = first_data.name
else:
all_concat_data = []
for i in range(len(first_data.columns)):
concat_data = pd.concat([self.data[k].iloc[:, [i]] for k in self.keys], axis=1)
key_wrapper = self.get_key_wrapper(**key_wrapper_kwargs)
concat_data.columns = key_wrapper.columns
if self.feature_oriented:
_template_context = merge_dicts(
dict(column=self.wrapper.columns[i], symbol=self.wrapper.columns[i]),
template_context,
)
else:
_template_context = merge_dicts(
dict(column=self.wrapper.columns[i], feature=self.wrapper.columns[i]),
template_context,
)
new_concat_data = _transform(concat_data, _template_context)
checks.assert_meta_equal(new_concat_data, concat_data, axis=1)
all_concat_data.append(new_concat_data)
new_data = self.dict_type()
for i, k in enumerate(self.keys):
new_objs = []
for c in range(len(first_data.columns)):
new_objs.append(all_concat_data[c].iloc[:, [i]])
new_data[k] = pd.concat(new_objs, axis=1)
new_data[k].columns = first_data.columns
else:
key_wrapper = self.get_key_wrapper(**key_wrapper_kwargs)
concat_data = pd.concat(self.data.values(), axis=1, keys=key_wrapper.columns)
if (self.feature_oriented and (per_feature and per_symbol)) or (
self.symbol_oriented and (per_symbol and per_feature)
):
new_concat_data = []
for i in range(len(concat_data.columns)):
if self.feature_oriented:
_template_context = merge_dicts(
dict(
key=self.keys[i // len(self.wrapper.columns)],
column=self.wrapper.columns[i % len(self.wrapper.columns)],
feature=self.keys[i // len(self.wrapper.columns)],
symbol=self.wrapper.columns[i % len(self.wrapper.columns)],
),
template_context,
)
else:
_template_context = merge_dicts(
dict(
key=self.wrapper.columns[i % len(self.wrapper.columns)],
column=self.keys[i // len(self.wrapper.columns)],
feature=self.wrapper.columns[i % len(self.wrapper.columns)],
symbol=self.keys[i // len(self.wrapper.columns)],
),
template_context,
)
if pass_frame:
new_obj = _transform(concat_data.iloc[:, [i]], _template_context)
checks.assert_meta_equal(new_obj, concat_data.iloc[:, [i]], axis=1)
else:
new_obj = _transform(concat_data.iloc[:, i], _template_context)
checks.assert_meta_equal(new_obj, concat_data.iloc[:, i], axis=1)
new_concat_data.append(new_obj)
new_concat_data = pd.concat(new_concat_data, axis=1)
else:
new_concat_data = _transform(concat_data)
checks.assert_meta_equal(new_concat_data, concat_data, axis=1)
native_concat_data = pd.concat(self.data.values(), axis=1, keys=None)
new_concat_data.columns = native_concat_data.columns
new_data = self.dict_type()
first_data = self.data[list(self.data.keys())[0]]
for i, k in enumerate(self.keys):
if isinstance(first_data, pd.Series):
new_data[k] = new_concat_data.iloc[:, i]
new_data[k].name = first_data.name
else:
start_i = first_data.shape[1] * i
stop_i = first_data.shape[1] * (1 + i)
new_data[k] = new_concat_data.iloc[:, start_i:stop_i]
new_data[k].columns = first_data.columns
new_data = self.align_data(new_data)
first_data = new_data[list(new_data.keys())[0]]
new_wrapper = self.wrapper.replace(index=first_data.index)
return self.replace(
wrapper=new_wrapper,
data=new_data,
)
def dropna(self: DataT, **kwargs) -> DataT:
"""Drop missing values.
Keyword arguments are passed to `Data.transform` and then to `pd.Series.dropna`
or `pd.DataFrame.dropna`."""
def _dropna(df, **_kwargs):
return df.dropna(**_kwargs)
return self.transform(_dropna, **kwargs)
def mirror_ohlc(
self: DataT,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
start_value: tp.ArrayLike = np.nan,
ref_feature: tp.ArrayLike = -1,
) -> DataT:
"""Mirror OHLC features."""
if isinstance(ref_feature, str):
ref_feature = map_enum_fields(ref_feature, PriceFeature)
open_name = self.resolve_feature("Open")
high_name = self.resolve_feature("High")
low_name = self.resolve_feature("Low")
close_name = self.resolve_feature("Close")
func = jit_reg.resolve_option(mirror_ohlc_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
new_open, new_high, new_low, new_close = func(
self.symbol_wrapper.shape_2d,
open=to_2d_array(self.get_feature(open_name)) if open_name is not None else None,
high=to_2d_array(self.get_feature(high_name)) if high_name is not None else None,
low=to_2d_array(self.get_feature(low_name)) if low_name is not None else None,
close=to_2d_array(self.get_feature(close_name)) if close_name is not None else None,
start_value=to_1d_array(start_value),
ref_feature=to_1d_array(ref_feature),
)
def _mirror_ohlc(df, feature, **_kwargs):
if open_name is not None and feature == open_name:
return new_open
if high_name is not None and feature == high_name:
return new_high
if low_name is not None and feature == low_name:
return new_low
if close_name is not None and feature == close_name:
return new_close
return df
return self.transform(_mirror_ohlc, Rep("feature"), per_feature=True)
def resample(self: DataT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> DataT:
"""Perform resampling on `Data`.
Features "open", "high", "low", "close", "volume", "trade count", and "vwap" (case-insensitive)
are recognized and resampled automatically.
Looks for `resample_func` of each feature in `Data.feature_config`. The function must
accept the `Data` instance, object, and resampler."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.resample_meta(*args, **kwargs)
def _resample_feature(obj, feature, symbol=None):
resample_func = self.feature_config.get(feature, {}).get("resample_func", None)
if resample_func is not None:
if isinstance(resample_func, str):
return obj.vbt.resample_apply(wrapper_meta["resampler"], resample_func)
return resample_func(self, obj, wrapper_meta["resampler"])
if isinstance(feature, str) and feature.lower() == "open":
return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.first_reduce_nb)
if isinstance(feature, str) and feature.lower() == "high":
return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.max_reduce_nb)
if isinstance(feature, str) and feature.lower() == "low":
return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.min_reduce_nb)
if isinstance(feature, str) and feature.lower() == "close":
return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.last_reduce_nb)
if isinstance(feature, str) and feature.lower() == "volume":
return obj.vbt.resample_apply(wrapper_meta["resampler"], generic_nb.sum_reduce_nb)
if isinstance(feature, str) and feature.lower() == "trade count":
return obj.vbt.resample_apply(
wrapper_meta["resampler"],
generic_nb.sum_reduce_nb,
wrap_kwargs=dict(dtype=int),
)
if isinstance(feature, str) and feature.lower() == "vwap":
volume_obj = None
for feature2 in self.features:
if isinstance(feature2, str) and feature2.lower() == "volume":
if self.feature_oriented:
volume_obj = self.data[feature2]
else:
volume_obj = self.data[symbol][feature2]
if volume_obj is None:
raise ValueError("Volume is required to resample VWAP")
return pd.DataFrame.vbt.resample_apply(
wrapper_meta["resampler"],
generic_nb.wmean_range_reduce_meta_nb,
to_2d_array(obj),
to_2d_array(volume_obj),
wrapper=self.wrapper[feature],
)
raise ValueError(f"Cannot resample feature '{feature}'. Specify resample_func in feature_config.")
new_data = self.dict_type()
if self.feature_oriented:
for feature in self.features:
new_data[feature] = _resample_feature(self.data[feature], feature)
else:
for symbol, obj in self.data.items():
_new_obj = []
for feature in self.features:
if self.single_feature:
_new_obj.append(_resample_feature(obj, feature, symbol=symbol))
else:
_new_obj.append(_resample_feature(obj[[feature]], feature, symbol=symbol))
if self.single_feature:
new_obj = _new_obj[0]
else:
new_obj = pd.concat(_new_obj, axis=1)
new_data[symbol] = new_obj
return self.replace(
wrapper=wrapper_meta["new_wrapper"],
data=new_data,
)
def realign(
self: DataT,
rule: tp.Optional[tp.AnyRuleLike] = None,
*args,
wrapper_meta: tp.DictLike = None,
ffill: bool = True,
**kwargs,
) -> DataT:
"""Perform realigning on `Data`.
Looks for `realign_func` of each feature in `Data.feature_config`. If no function provided,
resamples feature "open" with `vectorbtpro.generic.accessors.GenericAccessor.realign_opening`
and other features with `vectorbtpro.generic.accessors.GenericAccessor.realign_closing`."""
if rule is None:
rule = self.wrapper.freq
if wrapper_meta is None:
wrapper_meta = self.wrapper.resample_meta(rule, *args, **kwargs)
def _realign_feature(obj, feature, symbol=None):
realign_func = self.feature_config.get(feature, {}).get("realign_func", None)
if realign_func is not None:
if isinstance(realign_func, str):
return getattr(obj.vbt, realign_func)(wrapper_meta["resampler"], ffill=ffill)
return realign_func(self, obj, wrapper_meta["resampler"], ffill=ffill)
if isinstance(feature, str) and feature.lower() == "open":
return obj.vbt.realign_opening(wrapper_meta["resampler"], ffill=ffill)
return obj.vbt.realign_closing(wrapper_meta["resampler"], ffill=ffill)
new_data = self.dict_type()
if self.feature_oriented:
for feature in self.features:
new_data[feature] = _realign_feature(self.data[feature], feature)
else:
for symbol, obj in self.data.items():
_new_obj = []
for feature in self.features:
if self.single_feature:
_new_obj.append(_realign_feature(obj, feature, symbol=symbol))
else:
_new_obj.append(_realign_feature(obj[[feature]], feature, symbol=symbol))
if self.single_feature:
new_obj = _new_obj[0]
else:
new_obj = pd.concat(_new_obj, axis=1)
new_data[symbol] = new_obj
return self.replace(
wrapper=wrapper_meta["new_wrapper"],
data=new_data,
)
# ############# Running ############# #
@classmethod
def try_run(
cls,
data: "Data",
func_name: str,
*args,
raise_errors: bool = False,
silence_warnings: bool = False,
**kwargs,
) -> tp.Any:
"""Try to run a function on data."""
try:
return data.run(*args, **kwargs)
except Exception as e:
if raise_errors:
raise e
if not silence_warnings:
warn(func_name + ": " + str(e))
return NoResult
@classmethod
def select_run_func_args(cls, i: int, func_name: str, args: tp.Args) -> tuple:
"""Select positional arguments that correspond to a runnable function index or name."""
_args = ()
for v in args:
if isinstance(v, run_func_dict):
if func_name in v:
_args += (v[func_name],)
elif i in v:
_args += (v[i],)
elif "_def" in v:
_args += (v["_def"],)
else:
_args += (v,)
return _args
@classmethod
def select_run_func_kwargs(cls, i: int, func_name: str, kwargs: tp.Kwargs) -> dict:
"""Select keyword arguments that correspond to a runnable function index or name."""
_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, run_func_dict):
if func_name in v:
_kwargs[k] = v[func_name]
elif i in v:
_kwargs[k] = v[i]
elif "_def" in v:
_kwargs[k] = v["_def"]
elif isinstance(v, run_arg_dict):
if func_name == k or i == k:
_kwargs.update(v)
else:
_kwargs[k] = v
return _kwargs
def run(
self,
func: tp.MaybeIterable[tp.Union[tp.Hashable, tp.Callable]],
*args,
on_features: tp.Optional[tp.MaybeFeatures] = None,
on_symbols: tp.Optional[tp.MaybeSymbols] = None,
func_args: tp.ArgsLike = None,
func_kwargs: tp.KwargsLike = None,
magnet_kwargs: tp.KwargsLike = None,
ignore_args: tp.Optional[tp.Sequence[str]] = None,
rename_args: tp.DictLike = None,
location: tp.Optional[str] = None,
prepend_location: tp.Optional[bool] = None,
unpack: tp.Union[bool, str] = False,
concat: bool = True,
data_kwargs: tp.KwargsLike = None,
silence_warnings: bool = False,
raise_errors: bool = False,
execute_kwargs: tp.KwargsLike = None,
filter_results: bool = True,
raise_no_results: bool = True,
merge_func: tp.MergeFuncLike = None,
merge_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
return_keys: bool = False,
_func_name: tp.Optional[str] = None,
**kwargs,
) -> tp.Any:
"""Run a function on data.
Looks into the signature of the function and searches for arguments with the name `data` or
those found among features or attributes.
For example, the argument `open` will be substituted by `Data.open`.
`func` can be one of the following:
* Location to compute all indicators from. See `vectorbtpro.indicators.factory.IndicatorFactory.list_locations`.
* Indicator name. See `vectorbtpro.indicators.factory.IndicatorFactory.get_indicator`.
* Simulation method. See `vectorbtpro.portfolio.base.Portfolio`.
* Any callable object
* Iterable with any of the above. Will be stacked as columns into a DataFrame.
Use `magnet_kwargs` to provide keyword arguments that will be passed only if found
in the signature of the function.
Use `rename_args` to rename arguments. For example, in `vectorbtpro.portfolio.base.Portfolio`,
data can be passed instead of `close`.
Set `unpack` to True, "dict", or "frame" to use
`vectorbtpro.indicators.factory.IndicatorBase.unpack`,
`vectorbtpro.indicators.factory.IndicatorBase.to_dict`, and
`vectorbtpro.indicators.factory.IndicatorBase.to_frame` respectively.
Any argument in `*args` and `**kwargs` can be wrapped with `run_func_dict`/`run_arg_dict`
to specify the value per function/argument name or index when `func` is iterable.
Multiple function calls are executed with `vectorbtpro.utils.execution.execute`."""
from vectorbtpro.indicators.factory import IndicatorBase, IndicatorFactory
from vectorbtpro.indicators.talib_ import talib_func
from vectorbtpro.portfolio.base import Portfolio
if magnet_kwargs is None:
magnet_kwargs = {}
if data_kwargs is None:
data_kwargs = {}
if execute_kwargs is None:
execute_kwargs = {}
if merge_kwargs is None:
merge_kwargs = {}
_self = self
if on_features is not None:
_self = _self.select_features(on_features)
if on_symbols is not None:
_self = _self.select_symbols(on_symbols)
if checks.is_complex_iterable(func):
tasks = []
keys = []
for i, f in enumerate(func):
_location = location
if callable(f):
func_name = f.__name__
elif isinstance(f, str):
if _location is not None:
func_name = f.lower().strip()
if func_name == "*":
func_name = "all"
if prepend_location is True:
func_name = _location + "_" + func_name
else:
_location, f = IndicatorFactory.split_indicator_name(f)
if f is None:
raise ValueError("Sequence of locations is not supported")
func_name = f.lower().strip()
if func_name == "*":
func_name = "all"
if _location is not None:
if prepend_location in (None, True):
func_name = _location + "_" + func_name
else:
func_name = f
new_args = _self.select_run_func_args(i, func_name, args)
new_args = (_self, func_name, f, *new_args)
new_kwargs = _self.select_run_func_kwargs(i, func_name, kwargs)
if concat and _location == "talib_func":
new_kwargs["unpack_to"] = "frame"
new_kwargs = {
**dict(
func_args=func_args,
func_kwargs=func_kwargs,
magnet_kwargs=magnet_kwargs,
ignore_args=ignore_args,
rename_args=rename_args,
location=_location,
prepend_location=prepend_location,
unpack="frame" if concat else unpack,
concat=concat,
data_kwargs=data_kwargs,
silence_warnings=silence_warnings,
raise_errors=raise_errors,
execute_kwargs=execute_kwargs,
merge_func=merge_func,
merge_kwargs=merge_kwargs,
template_context=template_context,
return_keys=return_keys,
_func_name=func_name,
),
**new_kwargs,
}
tasks.append(Task(self.try_run, *new_args, **new_kwargs))
keys.append(str(func_name))
keys = pd.Index(keys, name="run_func")
results = execute(tasks, size=len(keys), keys=keys, **execute_kwargs)
if filter_results:
try:
results, keys = filter_out_no_results(results, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
no_results_filtered = True
else:
no_results_filtered = False
if merge_func is None and concat:
merge_func = "column_stack"
if merge_func is not None:
if is_merge_func_from_config(merge_func):
merge_kwargs = merge_dicts(
dict(
keys=keys,
filter_results=not no_results_filtered,
raise_no_results=raise_no_results,
),
merge_kwargs,
)
if isinstance(merge_func, MergeFunc):
merge_func = merge_func.replace(merge_kwargs=merge_kwargs, context=template_context)
else:
merge_func = MergeFunc(merge_func, merge_kwargs=merge_kwargs, context=template_context)
if return_keys:
return merge_func(results), keys
return merge_func(results)
if return_keys:
return results, keys
return results
if isinstance(func, str):
func_name = func.lower().strip()
if func_name.startswith("from_") and getattr(Portfolio, func_name):
func = getattr(Portfolio, func_name)
if func_args is None:
func_args = ()
if func_kwargs is None:
func_kwargs = {}
pf = func(_self, *args, *func_args, **kwargs, **func_kwargs)
if isinstance(pf, Portfolio) and unpack:
raise ValueError("Portfolio cannot be unpacked")
return pf
if location is None:
location, func_name = IndicatorFactory.split_indicator_name(func_name)
if location is not None and (func_name is None or func_name == "all" or func_name == "*"):
matched_location = IndicatorFactory.match_location(location)
if matched_location is not None:
location = matched_location
if func_name == "all" or func_name == "*":
if prepend_location is None:
prepend_location = True
else:
if prepend_location is None:
prepend_location = False
if location == "talib_func":
indicators = IndicatorFactory.list_indicators("talib", prepend_location=False)
else:
indicators = IndicatorFactory.list_indicators(location, prepend_location=False)
return _self.run(
indicators,
*args,
func_args=func_args,
func_kwargs=func_kwargs,
magnet_kwargs=magnet_kwargs,
ignore_args=ignore_args,
rename_args=rename_args,
location=location,
prepend_location=prepend_location,
unpack=unpack,
concat=concat,
data_kwargs=data_kwargs,
silence_warnings=silence_warnings,
raise_errors=raise_errors,
execute_kwargs=execute_kwargs,
merge_func=merge_func,
merge_kwargs=merge_kwargs,
template_context=template_context,
return_keys=return_keys,
**kwargs,
)
if location is not None:
matched_location = IndicatorFactory.match_location(location)
if matched_location is not None:
location = matched_location
if location == "talib_func":
func = talib_func(func_name)
else:
func = IndicatorFactory.get_indicator(func_name, location=location)
else:
func = IndicatorFactory.get_indicator(func_name)
if isinstance(func, type) and issubclass(func, IndicatorBase):
func = func.run
with_kwargs = {}
func_arg_names = get_func_arg_names(func)
for arg_name in func_arg_names:
real_arg_name = arg_name
if ignore_args is not None:
if arg_name in ignore_args:
continue
if rename_args is not None:
if arg_name in rename_args:
arg_name = rename_args[arg_name]
if real_arg_name not in kwargs:
if arg_name == "data":
with_kwargs[real_arg_name] = _self
elif arg_name == "wrapper":
with_kwargs[real_arg_name] = _self.symbol_wrapper
elif arg_name in ("input_shape", "shape"):
with_kwargs[real_arg_name] = _self.shape
elif arg_name in ("target_shape", "shape_2d"):
with_kwargs[real_arg_name] = _self.shape_2d
elif arg_name in ("input_index", "index"):
with_kwargs[real_arg_name] = _self.index
elif arg_name in ("input_columns", "columns"):
with_kwargs[real_arg_name] = _self.columns
elif arg_name == "freq":
with_kwargs[real_arg_name] = _self.freq
elif arg_name == "hlc3":
with_kwargs[real_arg_name] = _self.hlc3
elif arg_name == "ohlc4":
with_kwargs[real_arg_name] = _self.ohlc4
elif arg_name == "returns":
with_kwargs[real_arg_name] = _self.returns
else:
feature_idx = _self.get_feature_idx(arg_name)
if feature_idx != -1:
with_kwargs[real_arg_name] = _self.get_feature(feature_idx)
kwargs = dict(kwargs)
for k, v in magnet_kwargs.items():
if k in func_arg_names:
kwargs[k] = v
new_args, new_kwargs = extend_args(func, args, kwargs, **with_kwargs)
if func_args is None:
func_args = ()
if func_kwargs is None:
func_kwargs = {}
out = func(*new_args, *func_args, **new_kwargs, **func_kwargs)
if isinstance(unpack, bool):
if unpack:
if isinstance(out, IndicatorBase):
out = out.unpack()
elif isinstance(unpack, str) and unpack.lower() == "dict":
if isinstance(out, IndicatorBase):
out = out.to_dict()
else:
if _func_name is None:
feature_name = func.__name__
else:
feature_name = _func_name
out = {feature_name: out}
elif isinstance(unpack, str) and unpack.lower() == "frame":
if isinstance(out, IndicatorBase):
out = out.to_frame()
elif isinstance(out, pd.Series):
out = out.to_frame()
elif isinstance(unpack, str) and unpack.lower() == "data":
if isinstance(out, IndicatorBase):
out = feature_dict(out.to_dict())
else:
if _func_name is None:
feature_name = func.__name__
else:
feature_name = _func_name
out = feature_dict({feature_name: out})
out = Data.from_data(out, **data_kwargs)
else:
raise ValueError(f"Invalid unpack: '{unpack}'")
return out
# ############# Persisting ############# #
def resolve_key_arg(
self,
arg: tp.Any,
k: tp.Key,
arg_name: str,
check_dict_type: bool = True,
template_context: tp.KwargsLike = None,
is_kwargs: bool = False,
) -> tp.Any:
"""Resolve argument."""
if check_dict_type:
self.check_dict_type(arg, arg_name=arg_name)
if isinstance(arg, key_dict):
_arg = arg[k]
else:
if is_kwargs:
_arg = self.select_key_kwargs(k, arg, check_dict_type=check_dict_type)
else:
_arg = arg
if isinstance(_arg, CustomTemplate):
_arg = _arg.substitute(template_context, eval_id=arg_name)
elif is_kwargs:
_arg = substitute_templates(_arg, template_context, eval_id=arg_name)
return _arg
def to_csv(
self,
path_or_buf: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".",
ext: tp.Union[str, feature_dict, symbol_dict, CustomTemplate] = "csv",
mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None,
check_dict_type: bool = True,
template_context: tp.KwargsLike = None,
return_meta: bool = False,
**kwargs,
) -> tp.Union[None, feature_dict, symbol_dict]:
"""Save data to CSV file(s).
Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_csv.html
Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`,
depending on the format of the data dictionary.
If `path_or_buf` is a path to a directory, will save each feature/symbol to a separate file.
If there's only one file, you can specify the file path via `path_or_buf`. If there are
multiple files, use the same argument but wrap the multiple paths with `key_dict`."""
meta = self.dict_type()
for k, v in self.data.items():
if self.feature_oriented:
_template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context)
else:
_template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context)
_kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type)
sep = _kwargs.pop("sep", None)
_path_or_buf = self.resolve_key_arg(
path_or_buf,
k,
"path_or_buf",
check_dict_type=check_dict_type,
template_context=_template_context,
)
if isinstance(_path_or_buf, str):
_path_or_buf = Path(_path_or_buf)
if isinstance(_path_or_buf, Path):
if (_path_or_buf.exists() and _path_or_buf.is_dir()) or _path_or_buf.suffix == "":
_ext = self.resolve_key_arg(
ext,
k,
"ext",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_path_or_buf /= f"{k}.{_ext}"
_mkdir_kwargs = self.resolve_key_arg(
mkdir_kwargs,
k,
"mkdir_kwargs",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=True,
)
check_mkdir(_path_or_buf.parent, **_mkdir_kwargs)
if _path_or_buf.suffix.lower() == ".csv":
if sep is None:
sep = ","
if _path_or_buf.suffix.lower() == ".tsv":
if sep is None:
sep = "\t"
_path_or_buf = str(_path_or_buf)
if sep is None:
sep = ","
meta[k] = {"path_or_buf": _path_or_buf, "sep": sep, **_kwargs}
v.to_csv(**meta[k])
if return_meta:
return meta
return None
@classmethod
def from_csv(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Use `vectorbtpro.data.custom.csv.CSVData` to load data from CSV and switch the class back to this class.
Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching."""
from vectorbtpro.data.custom.csv import CSVData
if fetch_kwargs is None:
fetch_kwargs = {}
data = CSVData.pull(*args, **kwargs)
data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True)
data = data.update_fetch_kwargs(**fetch_kwargs)
return data
def to_hdf(
self,
path_or_buf: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".",
key: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None,
format: str = "table",
check_dict_type: bool = True,
template_context: tp.KwargsLike = None,
return_meta: bool = False,
**kwargs,
) -> tp.Union[None, feature_dict, symbol_dict]:
"""Save data to an HDF file using PyTables.
Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_hdf.html
Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`,
depending on the format of the data dictionary.
If `path_or_buf` exists and it's a directory, will create inside it a file named after this class."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("tables")
meta = self.dict_type()
for k, v in self.data.items():
if self.feature_oriented:
_template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context)
else:
_template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context)
_path_or_buf = self.resolve_key_arg(
path_or_buf,
k,
"path_or_buf",
check_dict_type=check_dict_type,
template_context=_template_context,
)
if isinstance(_path_or_buf, str):
_path_or_buf = Path(_path_or_buf)
if isinstance(_path_or_buf, Path):
if (_path_or_buf.exists() and _path_or_buf.is_dir()) or _path_or_buf.suffix == "":
_path_or_buf /= type(self).__name__ + ".h5"
_mkdir_kwargs = self.resolve_key_arg(
mkdir_kwargs,
k,
"mkdir_kwargs",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=True,
)
check_mkdir(_path_or_buf.parent, **_mkdir_kwargs)
_path_or_buf = str(_path_or_buf)
if key is None:
_key = str(k)
else:
_key = self.resolve_key_arg(
key,
k,
"key",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type)
meta[k] = {"path_or_buf": _path_or_buf, "key": _key, "format": format, **_kwargs}
v.to_hdf(**meta[k])
if return_meta:
return meta
return None
@classmethod
def from_hdf(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Use `vectorbtpro.data.custom.hdf.HDFData` to load data from HDF and switch the class back to this class.
Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching."""
from vectorbtpro.data.custom.hdf import HDFData
if fetch_kwargs is None:
fetch_kwargs = {}
data = HDFData.pull(*args, **kwargs)
data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True)
data = data.update_fetch_kwargs(**fetch_kwargs)
return data
def to_feather(
self,
path_or_buf: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".",
mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None,
check_dict_type: bool = True,
template_context: tp.KwargsLike = None,
return_meta: bool = False,
**kwargs,
) -> tp.Union[None, feature_dict, symbol_dict]:
"""Save data to Feather file(s) using PyArrow.
Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_feather.html
Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`,
depending on the format of the data dictionary.
If `path_or_buf` is a path to a directory, will save each feature/symbol to a separate file.
If there's only one file, you can specify the file path via `path_or_buf`. If there are
multiple files, use the same argument but wrap the multiple paths with `key_dict`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pyarrow")
meta = self.dict_type()
for k, v in self.data.items():
if self.feature_oriented:
_template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context)
else:
_template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context)
_path_or_buf = self.resolve_key_arg(
path_or_buf,
k,
"path_or_buf",
check_dict_type=check_dict_type,
template_context=_template_context,
)
if isinstance(_path_or_buf, str):
_path_or_buf = Path(_path_or_buf)
if isinstance(_path_or_buf, Path):
if (_path_or_buf.exists() and _path_or_buf.is_dir()) or _path_or_buf.suffix == "":
_path_or_buf /= f"{k}.feather"
_mkdir_kwargs = self.resolve_key_arg(
mkdir_kwargs,
k,
"mkdir_kwargs",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=True,
)
check_mkdir(_path_or_buf.parent, **_mkdir_kwargs)
_path_or_buf = str(_path_or_buf)
_kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type)
meta[k] = {"path": _path_or_buf, **_kwargs}
if isinstance(v, pd.Series):
v = v.to_frame()
try:
v.to_feather(**meta[k])
except Exception as e:
if isinstance(e, ValueError) and "you can .reset_index()" in str(e):
v = v.reset_index()
v.to_feather(**meta[k])
else:
raise e
if return_meta:
return meta
return None
@classmethod
def from_feather(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Use `vectorbtpro.data.custom.feather.FeatherData` to load data from Feather and
switch the class back to this class.
Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching."""
from vectorbtpro.data.custom.feather import FeatherData
if fetch_kwargs is None:
fetch_kwargs = {}
data = FeatherData.pull(*args, **kwargs)
data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True)
data = data.update_fetch_kwargs(**fetch_kwargs)
return data
def to_parquet(
self,
path_or_buf: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".",
mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None,
partition_cols: tp.Union[None, tp.List[str], feature_dict, symbol_dict, CustomTemplate] = None,
partition_by: tp.Union[None, tp.AnyGroupByLike, feature_dict, symbol_dict, CustomTemplate] = None,
period_index_to: tp.Union[str, tp.AnyGroupByLike, feature_dict, symbol_dict, CustomTemplate] = "str",
groupby_kwargs: tp.Union[None, tp.AnyGroupByLike, feature_dict, symbol_dict, CustomTemplate] = None,
keep_groupby_names: tp.Union[bool, feature_dict, symbol_dict, CustomTemplate] = False,
engine: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
check_dict_type: bool = True,
template_context: tp.KwargsLike = None,
return_meta: bool = False,
**kwargs,
) -> tp.Union[None, feature_dict, symbol_dict]:
"""Save data to Parquet file(s) using PyArrow or FastParquet.
Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_parquet.html
Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`,
depending on the format of the data dictionary.
If `path_or_buf` is a path to a directory, will save each feature/symbol to a separate file.
If there's only one file, you can specify the file path via `path_or_buf`. If there are
multiple files, use the same argument but wrap the multiple paths with `key_dict`.
If `partition_cols` and `partition_by` are None, `path_or_buf` must be a file, otherwise
it must be a directory. If `partition_by` is not None, will group the index by using
`vectorbtpro.base.wrapping.ArrayWrapper.get_index_grouper` with `**groupby_kwargs` and
put it inside `partition_cols`. In this case, `partition_cols` must be None."""
from vectorbtpro.utils.module_ import assert_can_import, assert_can_import_any
from vectorbtpro.data.custom.parquet import ParquetData
meta = self.dict_type()
for k, v in self.data.items():
if self.feature_oriented:
_template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context)
else:
_template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context)
_partition_cols = self.resolve_key_arg(
partition_cols,
k,
"partition_cols",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_partition_by = self.resolve_key_arg(
partition_by,
k,
"partition_by",
check_dict_type=check_dict_type,
template_context=_template_context,
)
if _partition_cols is not None and _partition_by is not None:
raise ValueError("Must use either partition_cols or partition_by, not both")
_path_or_buf = self.resolve_key_arg(
path_or_buf,
k,
"path_or_buf",
check_dict_type=check_dict_type,
template_context=_template_context,
)
if isinstance(_path_or_buf, str):
_path_or_buf = Path(_path_or_buf)
if isinstance(_path_or_buf, Path):
if (_path_or_buf.exists() and _path_or_buf.is_dir()) or _path_or_buf.suffix == "":
if _partition_cols is not None or _partition_by is not None:
_path_or_buf /= f"{k}"
else:
_path_or_buf /= f"{k}.parquet"
_mkdir_kwargs = self.resolve_key_arg(
mkdir_kwargs,
k,
"mkdir_kwargs",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=True,
)
check_mkdir(_path_or_buf.parent, **_mkdir_kwargs)
_path_or_buf = str(_path_or_buf)
_engine = self.resolve_key_arg(
ParquetData.resolve_custom_setting(engine, "engine"),
k,
"engine",
check_dict_type=check_dict_type,
template_context=_template_context,
)
if _engine == "pyarrow":
assert_can_import("pyarrow")
elif _engine == "fastparquet":
assert_can_import("fastparquet")
elif _engine == "auto":
assert_can_import_any("pyarrow", "fastparquet")
else:
raise ValueError(f"Invalid engine: '{_engine}'")
if isinstance(v, pd.Series):
v = v.to_frame()
if _partition_by is not None:
_period_index_to = self.resolve_key_arg(
period_index_to,
k,
"period_index_to",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_groupby_kwargs = self.resolve_key_arg(
groupby_kwargs,
k,
"groupby_kwargs",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=True,
)
_keep_groupby_names = self.resolve_key_arg(
keep_groupby_names,
k,
"keep_groupby_names",
check_dict_type=check_dict_type,
template_context=_template_context,
)
v = v.copy(deep=False)
grouper = self.wrapper.get_index_grouper(_partition_by, **_groupby_kwargs)
partition_index = grouper.get_stretched_index()
_partition_cols = []
def _convert_period_index(index):
if _period_index_to == "str":
return index.map(str)
return index.to_timestamp(how=_period_index_to)
if isinstance(partition_index, pd.MultiIndex):
for i in range(partition_index.nlevels):
partition_level = partition_index.get_level_values(i)
if _keep_groupby_names:
new_column_name = partition_level.name
else:
new_column_name = f"group_{i}"
if isinstance(partition_level, pd.PeriodIndex):
partition_level = _convert_period_index(partition_level)
v[new_column_name] = partition_level
_partition_cols.append(new_column_name)
else:
if _keep_groupby_names:
new_column_name = partition_index.name
else:
new_column_name = "group"
if isinstance(partition_index, pd.PeriodIndex):
partition_index = _convert_period_index(partition_index)
v[new_column_name] = partition_index
_partition_cols.append(new_column_name)
_kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type)
meta[k] = {"path": _path_or_buf, "partition_cols": _partition_cols, "engine": _engine, **_kwargs}
v.to_parquet(**meta[k])
if return_meta:
return meta
return None
@classmethod
def from_parquet(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Use `vectorbtpro.data.custom.parquet.ParquetData` to load data from Parquet and
switch the class back to this class.
Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching."""
from vectorbtpro.data.custom.parquet import ParquetData
if fetch_kwargs is None:
fetch_kwargs = {}
data = ParquetData.pull(*args, **kwargs)
data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True)
data = data.update_fetch_kwargs(**fetch_kwargs)
return data
def to_sql(
self,
engine: tp.Union[None, str, EngineT, feature_dict, symbol_dict, CustomTemplate] = None,
table: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
schema: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
to_utc: tp.Union[None, bool, str, tp.Sequence[str], feature_dict, symbol_dict, CustomTemplate] = None,
remove_utc_tz: tp.Union[bool, feature_dict, symbol_dict, CustomTemplate] = True,
attach_row_number: tp.Union[bool, feature_dict, symbol_dict, CustomTemplate] = False,
from_row_number: tp.Union[None, int, feature_dict, symbol_dict, CustomTemplate] = None,
row_number_column: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
engine_config: tp.KwargsLike = None,
dispose_engine: tp.Optional[bool] = None,
check_dict_type: bool = True,
template_context: tp.KwargsLike = None,
return_meta: bool = False,
return_engine: bool = False,
**kwargs,
) -> tp.Union[None, feature_dict, symbol_dict, EngineT]:
"""Save data to a SQL database using SQLAlchemy.
Uses https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html
Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`,
depending on the format of the data dictionary.
Each feature/symbol gets saved to a separate table.
If `engine` is None or a string, will resolve an engine with
`vectorbtpro.data.custom.sql.SQLData.resolve_engine` and dispose it afterward if `dispose_engine`
is None or True. It can additionally return the engine if `return_engine` is True or entire
metadata (all passed arguments as `feature_dict` or `symbol_dict`). In this case, the engine
won't be disposed by default.
If `schema` is not None and it doesn't exist, will create a new schema.
For `to_utc` and `remove_utc_tz`, see `Data.prepare_dt`. If `to_utc` is None, uses the
corresponding setting of `vectorbtpro.data.custom.sql.SQLData`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("sqlalchemy")
from vectorbtpro.data.custom.sql import SQLData
if engine_config is None:
engine_config = {}
if (engine is None or isinstance(engine, str)) and not self.has_key_dict(engine_config):
engine_meta = SQLData.resolve_engine(
engine=engine,
return_meta=True,
**engine_config,
)
engine = engine_meta["engine"]
engine_name = engine_meta["engine_name"]
should_dispose = engine_meta["should_dispose"]
if dispose_engine is None:
if return_meta or return_engine:
dispose_engine = False
else:
dispose_engine = should_dispose
else:
engine_name = None
if return_engine:
raise ValueError("Engine can be returned only if URL was provided")
meta = self.dict_type()
for k, v in self.data.items():
if self.feature_oriented:
_template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context)
else:
_template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context)
_engine = self.resolve_key_arg(
engine,
k,
"engine",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_engine_config = self.resolve_key_arg(
engine_config,
k,
"engine_config",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=True,
)
if _engine is None or isinstance(_engine, str):
_engine_meta = SQLData.resolve_engine(
engine=_engine,
return_meta=True,
**_engine_config,
)
_engine = _engine_meta["engine"]
_engine_name = _engine_meta["engine_name"]
_should_dispose = _engine_meta["should_dispose"]
if dispose_engine is None:
if return_meta or return_engine:
_dispose_engine = False
else:
_dispose_engine = _should_dispose
else:
_dispose_engine = dispose_engine
else:
_engine_name = engine_name
if dispose_engine is None:
_dispose_engine = False
else:
_dispose_engine = dispose_engine
if table is None:
_table = k
else:
_table = self.resolve_key_arg(
table,
k,
"table",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_schema = self.resolve_key_arg(
SQLData.resolve_engine_setting(schema, "schema", engine_name=_engine_name),
k,
"schema",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_to_utc = self.resolve_key_arg(
SQLData.resolve_engine_setting(to_utc, "to_utc", engine_name=_engine_name),
k,
"to_utc",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_remove_utc_tz = self.resolve_key_arg(
remove_utc_tz,
k,
"remove_utc_tz",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_attach_row_number = self.resolve_key_arg(
attach_row_number,
k,
"attach_row_number",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_from_row_number = self.resolve_key_arg(
from_row_number,
k,
"from_row_number",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_row_number_column = self.resolve_key_arg(
SQLData.resolve_engine_setting(row_number_column, "row_number_column", engine_name=_engine_name),
k,
"row_number_column",
check_dict_type=check_dict_type,
template_context=_template_context,
)
v = SQLData.prepare_dt(v, to_utc=_to_utc, remove_utc_tz=_remove_utc_tz, parse_dates=False)
_kwargs = self.select_key_kwargs(k, kwargs, check_dict_type=check_dict_type)
if _attach_row_number:
v = v.copy(deep=False)
if isinstance(v, pd.Series):
v = v.to_frame()
if _from_row_number is None:
if not SQLData.has_table(_table, schema=_schema, engine=_engine):
_from_row_number = 0
elif _kwargs.get("if_exists", "fail") != "append":
_from_row_number = 0
else:
last_row_number = SQLData.get_last_row_number(
_table,
schema=_schema,
row_number_column=_row_number_column,
engine=_engine,
)
_from_row_number = last_row_number + 1
v[_row_number_column] = np.arange(_from_row_number, _from_row_number + len(v.index))
if _schema is not None:
SQLData.create_schema(_schema, engine=_engine)
meta[k] = {"name": _table, "con": _engine, "schema": _schema, **_kwargs}
v.to_sql(**meta[k])
if _dispose_engine:
_engine.dispose()
if return_meta:
return meta
if return_engine:
return engine
return None
@classmethod
def from_sql(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Use `vectorbtpro.data.custom.sql.SQLData` to load data from a SQL database and switch the class
back to this class.
Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching."""
from vectorbtpro.data.custom.sql import SQLData
if fetch_kwargs is None:
fetch_kwargs = {}
data = SQLData.pull(*args, **kwargs)
data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True)
data = data.update_fetch_kwargs(**fetch_kwargs)
return data
def to_duckdb(
self,
connection: tp.Union[None, str, DuckDBPyConnectionT, feature_dict, symbol_dict, CustomTemplate] = None,
table: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
schema: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
catalog: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
write_format: tp.Union[None, str, feature_dict, symbol_dict, CustomTemplate] = None,
write_path: tp.Union[tp.PathLike, feature_dict, symbol_dict, CustomTemplate] = ".",
write_options: tp.Union[None, str, dict, feature_dict, symbol_dict, CustomTemplate] = None,
mkdir_kwargs: tp.Union[tp.KwargsLike, feature_dict, symbol_dict, CustomTemplate] = None,
to_utc: tp.Union[None, bool, str, tp.Sequence[str], feature_dict, symbol_dict, CustomTemplate] = None,
remove_utc_tz: tp.Union[bool, feature_dict, symbol_dict, CustomTemplate] = True,
if_exists: tp.Union[str, feature_dict, symbol_dict, CustomTemplate] = "fail",
connection_config: tp.KwargsLike = None,
check_dict_type: bool = True,
template_context: tp.KwargsLike = None,
return_meta: bool = False,
return_connection: bool = False,
) -> tp.Union[None, feature_dict, symbol_dict, DuckDBPyConnectionT]:
"""Save data to a DuckDB database.
Any argument can be provided per feature using `feature_dict` or per symbol using `symbol_dict`,
depending on the format of the data dictionary.
If `connection` is None or a string, will resolve a connection with
`vectorbtpro.data.custom.duckdb.DuckDBData.resolve_connection`. It can additionally return the
connection if `return_connection` is True or entire metadata (all passed arguments as `feature_dict`
or `symbol_dict`). In this case, the engine won't be disposed by default.
If `write_format` is None and `write_path` is a directory (default), will persist each feature/symbol
to a table (see https://duckdb.org/docs/guides/python/import_pandas).
If `catalog` is not None, will make it default for this connection. If `schema` is not None,
and it doesn't exist, will create a new schema in the current catalog and make it default
for this connection. Any new table will be automatically created under this schema.
If `if_exists` is "fail", will raise an error if a table with the same name already exists.
If `if_exists` is "replace", will drop the existing table first. If `if_exists` is "append",
will append the new table to the existing one.
If `write_format` is not None, it must be either "csv", "parquet", or "json". If `write_path` is
a directory or has no suffix (meaning it's not a file), each feature/symbol will be saved to a
separate file under that path and with the provided `write_format` as extension. The data will be
saved using a `COPY` mechanism (see https://duckdb.org/docs/sql/statements/copy.html).
To provide options to the write operation, pass them as a dictionary or an already formatted
string (without brackets). For example, `dict(compression="gzip")` is same as "COMPRESSION 'gzip'".
For `to_utc` and `remove_utc_tz`, see `Data.prepare_dt`. If `to_utc` is None, uses the
corresponding setting of `vectorbtpro.data.custom.duckdb.DuckDBData`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("duckdb")
from vectorbtpro.data.custom.duckdb import DuckDBData
if connection_config is None:
connection_config = {}
if (connection is None or isinstance(connection, (str, Path))) and not self.has_key_dict(connection_config):
connection_meta = DuckDBData.resolve_connection(
connection=connection,
read_only=False,
return_meta=True,
**connection_config,
)
connection = connection_meta["connection"]
if return_meta or return_connection:
should_close = False
else:
should_close = connection_meta["should_close"]
elif return_connection:
raise ValueError("Connection can be returned only if URL was provided")
else:
should_close = False
meta = self.dict_type()
for k, v in self.data.items():
if self.feature_oriented:
_template_context = merge_dicts(dict(data=v, key=k, feature=k), template_context)
else:
_template_context = merge_dicts(dict(data=v, key=k, symbol=k), template_context)
_connection = self.resolve_key_arg(
connection,
k,
"connection",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_connection_config = self.resolve_key_arg(
connection_config,
k,
"connection_config",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=True,
)
if _connection is None or isinstance(_connection, (str, Path)):
_connection_meta = DuckDBData.resolve_connection(
connection=_connection,
read_only=False,
return_meta=True,
**_connection_config,
)
_connection = _connection_meta["connection"]
_should_close = _connection_meta["should_close"]
else:
_should_close = False
if table is None:
_table = k
else:
_table = self.resolve_key_arg(
table,
k,
"table",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_schema = self.resolve_key_arg(
DuckDBData.resolve_custom_setting(schema, "schema"),
k,
"schema",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_catalog = self.resolve_key_arg(
DuckDBData.resolve_custom_setting(catalog, "catalog"),
k,
"catalog",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_write_format = self.resolve_key_arg(
write_format,
k,
"write_format",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_write_path = self.resolve_key_arg(
write_path,
k,
"write_path",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_write_path = Path(_write_path)
is_not_file = (_write_path.exists() and _write_path.is_dir()) or _write_path.suffix == ""
if _write_format is not None and is_not_file:
if _write_format.upper() == "CSV":
_write_path /= f"{k}.csv"
elif _write_format.upper() == "PARQUET":
_write_path /= f"{k}.parquet"
elif _write_format.upper() == "JSON":
_write_path /= f"{k}.json"
else:
raise ValueError(f"Invalid write format: '{_write_format}'")
if _write_path.suffix != "":
_mkdir_kwargs = self.resolve_key_arg(
mkdir_kwargs,
k,
"mkdir_kwargs",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=True,
)
check_mkdir(_write_path.parent, **_mkdir_kwargs)
_write_path = str(_write_path)
use_write = True
else:
use_write = False
_to_utc = self.resolve_key_arg(
DuckDBData.resolve_custom_setting(to_utc, "to_utc"),
k,
"to_utc",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_remove_utc_tz = self.resolve_key_arg(
remove_utc_tz,
k,
"remove_utc_tz",
check_dict_type=check_dict_type,
template_context=_template_context,
)
_if_exists = self.resolve_key_arg(
if_exists,
k,
"if_exists",
check_dict_type=check_dict_type,
template_context=_template_context,
)
v = DuckDBData.prepare_dt(v, to_utc=_to_utc, remove_utc_tz=_remove_utc_tz, parse_dates=False)
v = v.reset_index()
if use_write:
_write_options = self.resolve_key_arg(
write_options,
k,
"write_options",
check_dict_type=check_dict_type,
template_context=_template_context,
is_kwargs=isinstance(write_options, dict),
)
if _write_options is not None:
_write_options = DuckDBData.format_write_options(_write_options)
if _write_format is not None and _write_options is not None and "FORMAT" not in _write_options:
_write_options = f"FORMAT {_write_format.upper()}, " + _write_options
elif _write_format is not None and _write_options is None:
_write_options = f"FORMAT {_write_format.upper()}"
_connection.register("_" + k, v)
if _write_options is not None:
_connection.sql(f"COPY (SELECT * FROM \"_{k}\") TO '{_write_path}' ({_write_options})")
else:
_connection.sql(f"COPY (SELECT * FROM \"_{k}\") TO '{_write_path}'")
meta[k] = {"write_path": _write_path, "write_options": _write_options}
else:
if _catalog is not None:
_connection.sql(f"USE {_catalog}")
elif _schema is not None:
catalogs = DuckDBData.list_catalogs(connection=_connection)
if len(catalogs) > 1:
raise ValueError("Please select a catalog")
_catalog = catalogs[0]
_connection.sql(f"USE {_catalog}")
if _schema is not None:
_connection.sql(f'CREATE SCHEMA IF NOT EXISTS "{_schema}"')
_connection.sql(f"USE {_catalog}.{_schema}")
append = False
if _table in DuckDBData.list_tables(catalog=_catalog, schema=_schema, connection=_connection):
if _if_exists.lower() == "fail":
raise ValueError(f"Table '{_table}' already exists")
elif _if_exists.lower() == "replace":
_connection.sql(f'DROP TABLE "{_table}"')
elif _if_exists.lower() == "append":
append = True
_connection.register("_" + k, v)
if append:
_connection.sql(f'INSERT INTO "{_table}" SELECT * FROM "_{k}"')
else:
_connection.sql(f'CREATE TABLE "{_table}" AS SELECT * FROM "_{k}"')
meta[k] = {"table": _table, "schema": _schema, "catalog": _catalog}
if _should_close:
_connection.close()
if should_close:
connection.close()
if return_meta:
return meta
if return_connection:
return connection
return None
@classmethod
def from_duckdb(cls: tp.Type[DataT], *args, fetch_kwargs: tp.KwargsLike = None, **kwargs) -> DataT:
"""Use `vectorbtpro.data.custom.duckdb.DuckDBData` to load data from a DuckDB database and
switch the class back to this class.
Use `fetch_kwargs` to provide keyword arguments that were originally used in fetching."""
from vectorbtpro.data.custom.duckdb import DuckDBData
if fetch_kwargs is None:
fetch_kwargs = {}
data = DuckDBData.pull(*args, **kwargs)
data = data.switch_class(cls, clear_fetch_kwargs=True, clear_returned_kwargs=True)
data = data.update_fetch_kwargs(**fetch_kwargs)
return data
# ############# Querying ############# #
def sql(
self,
query: str,
dbcon: tp.Optional[DuckDBPyConnectionT] = None,
database: str = ":memory:",
db_config: tp.KwargsLike = None,
alias: str = "",
params: tp.KwargsLike = None,
other_objs: tp.Optional[dict] = None,
date_as_object: bool = False,
align_dtypes: bool = True,
squeeze: bool = True,
**kwargs,
) -> tp.SeriesFrame:
"""Run a SQL query on this instance using DuckDB.
First, connection gets established. Then, `Data.get` gets invoked with `**kwargs` passed as
keyword arguments and `as_dict=True`. Then, each returned object gets registered within the
database. Finally, the query gets executed with `duckdb.sql` and the relation as a DataFrame
gets returned. If `squeeze` is True, a DataFrame with one column will be converted into a Series."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("duckdb")
from duckdb import connect
if db_config is None:
db_config = {}
if dbcon is None:
dbcon = connect(database=database, read_only=False, config=db_config)
if params is None:
params = {}
dtypes = {}
objs = self.get(**kwargs, as_dict=True)
for k, v in objs.items():
if not checks.is_default_index(v.index):
v = v.reset_index()
if isinstance(v, pd.Series):
v = v.to_frame()
for c in v.columns:
dtypes[c] = v[c].dtype
dbcon.register(k, v)
if other_objs is not None:
checks.assert_instance_of(other_objs, dict, arg_name="other_objs")
for k, v in other_objs.items():
if not checks.is_default_index(v.index):
v = v.reset_index()
if isinstance(v, pd.Series):
v = v.to_frame()
for c in v.columns:
dtypes[c] = v[c].dtype
dbcon.register(k, v)
df = dbcon.sql(query, alias=alias, params=params).df(date_as_object=date_as_object)
if align_dtypes:
for c in df.columns:
if c in dtypes:
df[c] = df[c].astype(dtypes[c])
if isinstance(self.index, pd.MultiIndex):
if set(self.index.names) <= set(df.columns):
df = df.set_index(self.index.names)
else:
if self.index.name is not None and self.index.name in df.columns:
df = df.set_index(self.index.name)
elif "index" in df.columns:
df = df.set_index("index")
df.index.name = None
if squeeze and len(df.columns) == 1:
df = df.iloc[:, 0]
return df
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Data.stats`.
Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
`stats` from `vectorbtpro._settings.data`."""
return merge_dicts(Analyzable.stats_defaults.__get__(self), self.get_base_settings()["stats"])
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
total_features=dict(
title="Total Features",
check_is_feature_oriented=True,
calc_func=lambda self: len(self.features),
agg_func=None,
tags="data",
),
total_symbols=dict(
title="Total Symbols",
check_is_symbol_oriented=True,
calc_func=lambda self: len(self.symbols),
tags="data",
),
null_counts=dict(
title="Null Counts",
calc_func=lambda self, group_by: {
k: v.isnull().vbt(wrapper=self.wrapper).sum(group_by=group_by) for k, v in self.data.items()
},
agg_func=lambda x: x.sum(),
tags="data",
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot(
self,
column: tp.Optional[tp.Hashable] = None,
feature: tp.Optional[tp.Feature] = None,
symbol: tp.Optional[tp.Symbol] = None,
feature_map: tp.KwargsLike = None,
plot_volume: tp.Optional[bool] = None,
base: tp.Optional[float] = None,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Plot either one feature of multiple symbols, or OHLC(V) of one symbol.
Args:
column (hashable): Name of the feature or symbol to plot.
Depends on the data orientation.
feature (hashable): Name of the feature to plot.
symbol (hashable): Name of the symbol to plot.
feature_map (sequence of str): Dictionary mapping the feature names to OHLCV.
Applied only if OHLC(V) is plotted.
plot_volume (bool): Whether to plot volume beneath.
Applied only if OHLC(V) is plotted.
base (float): Rebase all series of a feature to a given initial base.
!!! note
The feature must contain prices.
Applied only if lines are plotted.
kwargs (dict): Keyword arguments passed to `vectorbtpro.generic.accessors.GenericAccessor.plot`
for lines and to `vectorbtpro.ohlcv.accessors.OHLCVDFAccessor.plot` for OHLC(V).
Usage:
* Plot the lines of one feature across all symbols:
```pycon
>>> from vectorbtpro import *
>>> start = '2021-01-01 UTC' # crypto is in UTC
>>> end = '2021-06-01 UTC'
>>> data = vbt.YFData.pull(['BTC-USD', 'ETH-USD', 'ADA-USD'], start=start, end=end)
```
[=100% "100%"]{: .candystripe .candystripe-animate }
```pycon
>>> data.plot(feature='Close', base=1).show()
```
* Plot OHLC(V) of one symbol (only if data contains the respective features):
{: .iimg loading=lazy }
{: .iimg loading=lazy }
```pycon
>>> data.plot(symbol='BTC-USD').show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
if column is not None:
if self.feature_oriented:
if symbol is not None:
raise ValueError("Either column or symbol can be provided, not both")
symbol = column
else:
if feature is not None:
raise ValueError("Either column or feature can be provided, not both")
feature = column
if feature is None and self.has_ohlc:
data = self.get(symbols=symbol, squeeze_symbols=True)
if isinstance(data, tuple):
raise ValueError("Cannot plot OHLC of multiple symbols. Select one symbol.")
return data.vbt.ohlcv(feature_map=feature_map).plot(plot_volume=plot_volume, **kwargs)
data = self.get(features=feature, symbols=symbol, squeeze_features=True, squeeze_symbols=True)
if isinstance(data, tuple):
raise ValueError("Cannot plot multiple features and symbols. Select one feature or symbol.")
if base is not None:
data = data.vbt.rebase(base)
return data.vbt.lineplot(**kwargs)
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Data.plots`.
Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
`plots` from `vectorbtpro._settings.data`."""
return merge_dicts(Analyzable.plots_defaults.__get__(self), self.get_base_settings()["plots"])
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot=RepEval(
"""
if symbols is None:
symbols = self.symbols
if not self.has_multiple_keys(symbols):
symbols = [symbols]
[
dict(
check_is_not_grouped=True,
plot_func="plot",
plot_volume=False,
symbol=s,
title=s,
pass_add_trace_kwargs=True,
xaxis_kwargs=dict(rangeslider_visible=False, showgrid=True),
yaxis_kwargs=dict(showgrid=True),
tags="data",
)
for s in symbols
]""",
context=dict(symbols=None),
)
),
)
@property
def subplots(self) -> Config:
return self._subplots
# ############# Docs ############# #
@classmethod
def build_feature_config_doc(cls, source_cls: tp.Optional[type] = None) -> str:
"""Build feature config documentation."""
if source_cls is None:
source_cls = Data
return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "feature_config").__doc__)).substitute(
{"feature_config": cls.feature_config.prettify(), "cls_name": cls.__name__},
)
@classmethod
def override_feature_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None:
"""Call this method on each subclass that overrides `Data.feature_config`."""
__pdoc__[cls.__name__ + ".feature_config"] = cls.build_feature_config_doc(source_cls=source_cls)
Data.override_feature_config_doc(__pdoc__)
Data.override_metrics_doc(__pdoc__)
Data.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class decorators for data."""
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import copy_dict
__all__ = []
def attach_symbol_dict_methods(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
"""Class decorator to attach methods for updating symbol dictionaries."""
checks.assert_subclass_of(cls, "Data")
DataT = tp.TypeVar("DataT", bound="Data")
for target_name in cls._key_dict_attrs:
def select_method(self, key: tp.Key, _target_name=target_name, **kwargs) -> tp.Any:
if _target_name.endswith("_kwargs"):
return self.select_key_kwargs(
key,
getattr(self, _target_name),
kwargs_name=_target_name,
**kwargs,
)
return self.select_key_from_dict(
key,
getattr(self, _target_name),
dct_name=_target_name,
**kwargs,
)
select_method.__name__ = "select_" + target_name
select_method.__module__ = cls.__module__
select_method.__qualname__ = f"{cls.__name__}.{select_method.__name__}"
select_method.__doc__ = f"""Select a feature or symbol from `Data.{target_name}`."""
setattr(cls, select_method.__name__, select_method)
for target_name in cls._updatable_attrs:
def update_method(self: DataT, _target_name=target_name, check_dict_type: bool = True, **kwargs) -> DataT:
from vectorbtpro.data.base import key_dict
new_kwargs = copy_dict(getattr(self, _target_name))
for s in self.get_keys(type(new_kwargs)):
if s not in new_kwargs:
new_kwargs[s] = dict()
for k, v in kwargs.items():
if check_dict_type:
self.check_dict_type(v, k, dict_type=type(new_kwargs))
if type(v) is key_dict or isinstance(v, type(new_kwargs)):
for s, _v in v.items():
new_kwargs[s][k] = _v
else:
for s in new_kwargs:
new_kwargs[s][k] = v
return self.replace(**{_target_name: new_kwargs})
update_method.__name__ = "update_" + target_name
update_method.__module__ = cls.__module__
update_method.__qualname__ = f"{cls.__name__}.{update_method.__name__}"
update_method.__doc__ = f"""Update `Data.{target_name}`. Returns a new instance."""
setattr(cls, update_method.__name__, update_method)
return cls
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for generating data.
Provides an arsenal of Numba-compiled functions that are used to generate data.
These only accept NumPy arrays and other Numba-compatible types."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb
from vectorbtpro.base.reshaping import to_1d_array_nb
from vectorbtpro.registries.jit_registry import register_jitted
__all__ = []
@register_jitted(cache=True)
def generate_random_data_1d_nb(
n_rows: int,
start_value: float = 100.0,
mean: float = 0.0,
std: float = 0.01,
symmetric: bool = False,
) -> tp.Array1d:
"""Generate data using cumulative product of returns drawn from normal (Gaussian) distribution.
Turn on `symmetric` to diminish negative returns and make them symmetric to positive ones.
Otherwise, the majority of generated paths will go downward."""
out = np.empty(n_rows, dtype=float_)
for i in range(n_rows):
if i == 0:
prev_value = start_value
else:
prev_value = out[i - 1]
return_ = np.random.normal(mean, std)
if symmetric and return_ < 0:
return_ = -abs(return_) / (1 + abs(return_))
out[i] = prev_value * (1 + return_)
return out
@register_jitted(cache=True, tags={"can_parallel"})
def generate_random_data_nb(
shape: tp.Shape,
start_value: tp.FlexArray1dLike = 100.0,
mean: tp.FlexArray1dLike = 0.0,
std: tp.FlexArray1dLike = 0.01,
symmetric: tp.FlexArray1dLike = False,
) -> tp.Array2d:
"""2-dim version of `generate_random_data_1d_nb`.
Each argument can be provided per column thanks to flexible indexing."""
start_value_ = to_1d_array_nb(np.asarray(start_value))
mean_ = to_1d_array_nb(np.asarray(mean))
std_ = to_1d_array_nb(np.asarray(std))
symmetric_ = to_1d_array_nb(np.asarray(symmetric))
out = np.empty(shape, dtype=float_)
for col in prange(shape[1]):
out[:, col] = generate_random_data_1d_nb(
shape[0],
start_value=flex_select_1d_pc_nb(start_value_, col),
mean=flex_select_1d_pc_nb(mean_, col),
std=flex_select_1d_pc_nb(std_, col),
symmetric=flex_select_1d_pc_nb(symmetric_, col),
)
return out
@register_jitted(cache=True)
def generate_gbm_data_1d_nb(
n_rows: int,
start_value: float = 100.0,
mean: float = 0.0,
std: float = 0.01,
dt: float = 1.0,
) -> tp.Array2d:
"""Generate data using Geometric Brownian Motion (GBM)."""
out = np.empty(n_rows, dtype=float_)
for i in range(n_rows):
if i == 0:
prev_value = start_value
else:
prev_value = out[i - 1]
rand = np.random.standard_normal()
out[i] = prev_value * np.exp((mean - 0.5 * std**2) * dt + std * np.sqrt(dt) * rand)
return out
@register_jitted(cache=True, tags={"can_parallel"})
def generate_gbm_data_nb(
shape: tp.Shape,
start_value: tp.FlexArray1dLike = 100.0,
mean: tp.FlexArray1dLike = 0.0,
std: tp.FlexArray1dLike = 0.01,
dt: tp.FlexArray1dLike = 1.0,
) -> tp.Array2d:
"""2-dim version of `generate_gbm_data_1d_nb`.
Each argument can be provided per column thanks to flexible indexing."""
start_value_ = to_1d_array_nb(np.asarray(start_value))
mean_ = to_1d_array_nb(np.asarray(mean))
std_ = to_1d_array_nb(np.asarray(std))
dt_ = to_1d_array_nb(np.asarray(dt))
out = np.empty(shape, dtype=float_)
for col in prange(shape[1]):
out[:, col] = generate_gbm_data_1d_nb(
shape[0],
start_value=flex_select_1d_pc_nb(start_value_, col),
mean=flex_select_1d_pc_nb(mean_, col),
std=flex_select_1d_pc_nb(std_, col),
dt=flex_select_1d_pc_nb(dt_, col),
)
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for scheduling data saves."""
import logging
from vectorbtpro import _typing as tp
from vectorbtpro.data.base import Data
from vectorbtpro.data.updater import DataUpdater
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"DataSaver",
"CSVDataSaver",
"HDFDataSaver",
"SQLDataSaver",
"DuckDBDataSaver",
]
logger = logging.getLogger(__name__)
class DataSaver(DataUpdater):
"""Base class for scheduling data saves.
Subclasses `vectorbtpro.data.updater.DataUpdater`.
Args:
data (Data): Data instance.
save_kwargs (dict): Default keyword arguments for `DataSaver.init_save_data` and `DataSaver.save_data`.
init_save_kwargs (dict): Default keyword arguments overriding `save_kwargs` for `DataSaver.init_save_data`.
**kwargs: Keyword arguments passed to the constructor of `DataUpdater`.
"""
def __init__(
self,
data: Data,
save_kwargs: tp.KwargsLike = None,
init_save_kwargs: tp.KwargsLike = None,
**kwargs,
) -> None:
DataUpdater.__init__(
self,
data=data,
save_kwargs=save_kwargs,
init_save_kwargs=init_save_kwargs,
**kwargs,
)
self._save_kwargs = save_kwargs
self._init_save_kwargs = init_save_kwargs
@property
def save_kwargs(self) -> tp.KwargsLike:
"""Keyword arguments passed to `DataSaver.save_data`."""
return self._save_kwargs
@property
def init_save_kwargs(self) -> tp.KwargsLike:
"""Keyword arguments passed to `DataSaver.init_save_data`."""
return self._init_save_kwargs
def init_save_data(self, **kwargs) -> None:
"""Save initial data.
This is an abstract method - override it to define custom logic."""
raise NotImplementedError
def save_data(self, **kwargs) -> None:
"""Save data.
This is an abstract method - override it to define custom logic."""
raise NotImplementedError
def update(self, save_kwargs: tp.KwargsLike = None, **kwargs) -> None:
"""Update and save data using `DataSaver.save_data`.
Override to do pre- and postprocessing.
To stop this method from running again, raise `vectorbtpro.utils.schedule_.CancelledError`."""
# In case the method was called by the user
kwargs = merge_dicts(
dict(save_kwargs=self.save_kwargs),
self.update_kwargs,
{"save_kwargs": save_kwargs, **kwargs},
)
save_kwargs = kwargs.pop("save_kwargs")
self._data = self.data.update(concat=False, **kwargs)
self.update_config(data=self.data)
if save_kwargs is None:
save_kwargs = {}
self.save_data(**save_kwargs)
def update_every(
self,
*args,
save_kwargs: tp.KwargsLike = None,
init_save: bool = False,
init_save_kwargs: tp.KwargsLike = None,
**kwargs,
) -> None:
"""Overrides `vectorbtpro.data.updater.DataUpdater` to save initial data prior to updating."""
if init_save:
init_save_kwargs = merge_dicts(
self.save_kwargs,
save_kwargs,
self.init_save_kwargs,
init_save_kwargs,
)
self.init_save_data(**init_save_kwargs)
DataUpdater.update_every(self, *args, save_kwargs=save_kwargs, **kwargs)
class CSVDataSaver(DataSaver):
"""Subclass of `DataSaver` for saving data with `vectorbtpro.data.base.Data.to_csv`."""
def init_save_data(self, **to_csv_kwargs) -> None:
"""Save initial data."""
# In case the method was called by the user
to_csv_kwargs = merge_dicts(
self.save_kwargs,
self.init_save_kwargs,
to_csv_kwargs,
)
self.data.to_csv(**to_csv_kwargs)
new_index = self.data.wrapper.index
logger.info(f"Saved initial {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
def save_data(self, **to_csv_kwargs) -> None:
"""Save data.
By default, appends new data without header."""
# In case the method was called by the user
to_csv_kwargs = merge_dicts(
dict(mode="a", header=False),
self.save_kwargs,
to_csv_kwargs,
)
self.data.to_csv(**to_csv_kwargs)
new_index = self.data.wrapper.index
logger.info(f"Saved {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
class HDFDataSaver(DataSaver):
"""Subclass of `DataSaver` for saving data with `vectorbtpro.data.base.Data.to_hdf`."""
def init_save_data(self, **to_hdf_kwargs) -> None:
"""Save initial data."""
# In case the method was called by the user
to_hdf_kwargs = merge_dicts(
self.save_kwargs,
self.init_save_kwargs,
to_hdf_kwargs,
)
self.data.to_hdf(**to_hdf_kwargs)
new_index = self.data.wrapper.index
logger.info(f"Saved initial {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
def save_data(self, **to_hdf_kwargs) -> None:
"""Save data.
By default, appends new data in a table format."""
# In case the method was called by the user
to_hdf_kwargs = merge_dicts(
dict(mode="a", append=True),
self.save_kwargs,
to_hdf_kwargs,
)
self.data.to_hdf(**to_hdf_kwargs)
new_index = self.data.wrapper.index
logger.info(f"Saved {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
class SQLDataSaver(DataSaver):
"""Subclass of `DataSaver` for saving data with `vectorbtpro.data.base.Data.to_sql`."""
def init_save_data(self, **to_sql_kwargs) -> None:
"""Save initial data."""
# In case the method was called by the user
to_sql_kwargs = merge_dicts(
self.save_kwargs,
self.init_save_kwargs,
to_sql_kwargs,
)
self.data.to_sql(**to_sql_kwargs)
new_index = self.data.wrapper.index
logger.info(f"Saved initial {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
def save_data(self, **to_sql_kwargs) -> None:
"""Save data.
By default, appends new data without header."""
# In case the method was called by the user
to_sql_kwargs = merge_dicts(
dict(if_exists="append"),
self.save_kwargs,
to_sql_kwargs,
)
self.data.to_sql(**to_sql_kwargs)
new_index = self.data.wrapper.index
logger.info(f"Saved {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
class DuckDBDataSaver(DataSaver):
"""Subclass of `DataSaver` for saving data with `vectorbtpro.data.base.Data.to_duckdb`."""
def init_save_data(self, **to_duckdb_kwargs) -> None:
"""Save initial data."""
# In case the method was called by the user
to_duckdb_kwargs = merge_dicts(
self.save_kwargs,
self.init_save_kwargs,
to_duckdb_kwargs,
)
self.data.to_duckdb(**to_duckdb_kwargs)
new_index = self.data.wrapper.index
logger.info(f"Saved initial {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
def save_data(self, **to_duckdb_kwargs) -> None:
"""Save data.
By default, appends new data without header."""
# In case the method was called by the user
to_duckdb_kwargs = merge_dicts(
dict(if_exists="append"),
self.save_kwargs,
to_duckdb_kwargs,
)
self.data.to_duckdb(**to_duckdb_kwargs)
new_index = self.data.wrapper.index
logger.info(f"Saved {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for scheduling data updates."""
import logging
from vectorbtpro import _typing as tp
from vectorbtpro.data.base import Data
from vectorbtpro.utils.config import Configured, merge_dicts
from vectorbtpro.utils.schedule_ import ScheduleManager
__all__ = [
"DataUpdater",
]
logger = logging.getLogger(__name__)
class DataUpdater(Configured):
"""Base class for scheduling data updates.
Args:
data (Data): Data instance.
update_kwargs (dict): Default keyword arguments for `DataSaver.update`.
**kwargs: Keyword arguments passed to the constructor of `Configured`.
"""
def __init__(
self,
data: Data,
schedule_manager: tp.Optional[ScheduleManager] = None,
update_kwargs: tp.KwargsLike = None,
**kwargs,
) -> None:
if schedule_manager is None:
schedule_manager = ScheduleManager()
Configured.__init__(
self,
data=data,
schedule_manager=schedule_manager,
update_kwargs=update_kwargs,
**kwargs,
)
self._data = data
self._schedule_manager = schedule_manager
self._update_kwargs = update_kwargs
@property
def data(self) -> Data:
"""Data instance.
See `vectorbtpro.data.base.Data`."""
return self._data
@property
def schedule_manager(self) -> ScheduleManager:
"""Schedule manager instance.
See `vectorbtpro.utils.schedule_.ScheduleManager`."""
return self._schedule_manager
@property
def update_kwargs(self) -> tp.KwargsLike:
"""Keyword arguments passed to `DataSaver.update`."""
return self._update_kwargs
def update(self, **kwargs) -> None:
"""Method that updates data.
Override to do pre- and postprocessing.
To stop this method from running again, raise `vectorbtpro.utils.schedule_.CancelledError`."""
# In case the method was called by the user
kwargs = merge_dicts(self.update_kwargs, kwargs)
self._data = self.data.update(**kwargs)
self.update_config(data=self.data)
new_index = self.data.wrapper.index
logger.info(f"New data has {len(new_index)} rows from {new_index[0]} to {new_index[-1]}")
def update_every(
self,
*args,
to: int = None,
tags: tp.Optional[tp.Iterable[tp.Hashable]] = None,
in_background: bool = False,
replace: bool = True,
start: bool = True,
start_kwargs: tp.KwargsLike = None,
**update_kwargs,
) -> None:
"""Schedule `DataUpdater.update` as a job.
For `*args`, `to` and `tags`, see `vectorbtpro.utils.schedule_.ScheduleManager.every`.
If `in_background` is set to True, starts in the background as an `asyncio` task.
The task can be stopped with `vectorbtpro.utils.schedule_.ScheduleManager.stop`.
If `replace` is True, will delete scheduled jobs with the same tags, or all jobs if tags are omitted.
If `start` is False, will add the job to the scheduler without starting.
`**update_kwargs` are merged over `DataUpdater.update_kwargs` and passed to `DataUpdater.update`."""
if replace:
self.schedule_manager.clear_jobs(tags)
update_kwargs = merge_dicts(self.update_kwargs, update_kwargs)
self.schedule_manager.every(*args, to=to, tags=tags).do(self.update, **update_kwargs)
if start:
if start_kwargs is None:
start_kwargs = {}
if in_background:
self.schedule_manager.start_in_background(**start_kwargs)
else:
self.schedule_manager.start(**start_kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for generic data.
Provides an arsenal of Numba-compiled functions that are used by accessors
and in many other parts of a backtesting pipeline, such as technical indicators.
These only accept NumPy arrays and other Numba-compatible types.
!!! note
vectorbt treats matrices as first-class citizens and expects input arrays to be
2-dim, unless function has suffix `_1d` or is meant to be input to another function.
Data is processed along index (axis 0).
Rolling functions with `minp=None` have `min_periods` set to the window size.
All functions passed as argument must be Numba-compiled.
Records must retain the order they were created in.
!!! warning
Make sure to use `parallel=True` only if your columns are independent.
"""
from vectorbtpro.generic.nb.apply_reduce import *
from vectorbtpro.generic.nb.base import *
from vectorbtpro.generic.nb.iter_ import *
from vectorbtpro.generic.nb.patterns import *
from vectorbtpro.generic.nb.records import *
from vectorbtpro.generic.nb.rolling import *
from vectorbtpro.generic.nb.sim_range import *
__all__ = []
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Generic Numba-compiled functions for mapping, applying, and reducing."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.generic.nb.base import nancov_1d_nb, nanstd_1d_nb, nancorr_1d_nb
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
# ############# Map, apply, and reduce ############# #
@register_jitted
def map_1d_nb(arr: tp.Array1d, map_func_nb: tp.MapFunc, *args) -> tp.Array1d:
"""Map elements element-wise using `map_func_nb`.
`map_func_nb` must accept the element and `*args`. Must return a single value."""
i_0_out = map_func_nb(arr[0], *args)
out = np.empty_like(arr, dtype=np.asarray(i_0_out).dtype)
out[0] = i_0_out
for i in range(1, arr.shape[0]):
out[i] = map_func_nb(arr[i], *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
map_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def map_nb(arr: tp.Array2d, map_func_nb: tp.MapFunc, *args) -> tp.Array2d:
"""2-dim version of `map_1d_nb`."""
col_0_out = map_1d_nb(arr[:, 0], map_func_nb, *args)
out = np.empty_like(arr, dtype=col_0_out.dtype)
out[:, 0] = col_0_out
for col in prange(1, out.shape[1]):
out[:, col] = map_1d_nb(arr[:, col], map_func_nb, *args)
return out
@register_jitted
def map_1d_meta_nb(n: int, col: int, map_func_nb: tp.MapMetaFunc, *args) -> tp.Array1d:
"""Meta version of `map_1d_nb`.
`map_func_nb` must accept the row index, the column index, and `*args`.
Must return a single value."""
i_0_out = map_func_nb(0, col, *args)
out = np.empty(n, dtype=np.asarray(i_0_out).dtype)
out[0] = i_0_out
for i in range(1, n):
out[i] = map_func_nb(i, col, *args)
return out
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
map_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def map_meta_nb(target_shape: tp.Shape, map_func_nb: tp.MapMetaFunc, *args) -> tp.Array2d:
"""2-dim version of `map_1d_meta_nb`."""
col_0_out = map_1d_meta_nb(target_shape[0], 0, map_func_nb, *args)
out = np.empty(target_shape, dtype=col_0_out.dtype)
out[:, 0] = col_0_out
for col in prange(1, out.shape[1]):
out[:, col] = map_1d_meta_nb(target_shape[0], col, map_func_nb, *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
apply_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def apply_nb(arr: tp.Array2d, apply_func_nb: tp.ApplyFunc, *args) -> tp.Array2d:
"""Apply function on each column of an object.
`apply_func_nb` must accept the array and `*args`.
Must return a single value or an array of shape `a.shape[1]`."""
col_0_out = apply_func_nb(arr[:, 0], *args)
out = np.empty_like(arr, dtype=np.asarray(col_0_out).dtype)
out[:, 0] = col_0_out
for col in prange(1, arr.shape[1]):
out[:, col] = apply_func_nb(arr[:, col], *args)
return out
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
apply_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def apply_meta_nb(target_shape: tp.Shape, apply_func_nb: tp.ApplyMetaFunc, *args) -> tp.Array2d:
"""Meta version of `apply_nb` that prepends the column index to the arguments of `apply_func_nb`."""
col_0_out = apply_func_nb(0, *args)
out = np.empty(target_shape, dtype=np.asarray(col_0_out).dtype)
out[:, 0] = col_0_out
for col in prange(1, target_shape[1]):
out[:, col] = apply_func_nb(col, *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=0),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=0),
apply_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="row_stack",
)
@register_jitted(tags={"can_parallel"})
def row_apply_nb(arr: tp.Array2d, apply_func_nb: tp.ApplyFunc, *args) -> tp.Array2d:
"""`apply_nb` but applied on rows rather than columns."""
row_0_out = apply_func_nb(arr[0, :], *args)
out = np.empty_like(arr, dtype=np.asarray(row_0_out).dtype)
out[0, :] = row_0_out
for i in prange(1, arr.shape[0]):
out[i, :] = apply_func_nb(arr[i, :], *args)
return out
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=0),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=0),
apply_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="row_stack",
)
@register_jitted(tags={"can_parallel"})
def row_apply_meta_nb(target_shape: tp.Shape, apply_func_nb: tp.ApplyMetaFunc, *args) -> tp.Array2d:
"""Meta version of `row_apply_nb` that prepends the row index to the arguments of `apply_func_nb`."""
row_0_out = apply_func_nb(0, *args)
out = np.empty(target_shape, dtype=np.asarray(row_0_out).dtype)
out[0, :] = row_0_out
for i in prange(1, target_shape[0]):
out[i, :] = apply_func_nb(i, *args)
return out
@register_jitted
def rolling_reduce_1d_nb(
arr: tp.Array1d,
window: int,
minp: tp.Optional[int],
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array1d:
"""Provide rolling window calculations.
`reduce_func_nb` must accept the array and `*args`. Must return a single value."""
if minp is None:
minp = window
out = np.empty_like(arr, dtype=float_)
nancnt_arr = np.empty(arr.shape[0], dtype=int_)
nancnt = 0
for i in range(arr.shape[0]):
if np.isnan(arr[i]):
nancnt = nancnt + 1
nancnt_arr[i] = nancnt
if i < window:
valid_cnt = i + 1 - nancnt
else:
valid_cnt = window - (nancnt - nancnt_arr[i - window])
if valid_cnt < minp:
out[i] = np.nan
else:
from_i = max(0, i + 1 - window)
to_i = i + 1
arr_window = arr[from_i:to_i]
out[i] = reduce_func_nb(arr_window, *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
window=None,
minp=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_reduce_nb(
arr: tp.Array2d,
window: int,
minp: tp.Optional[int],
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array2d:
"""2-dim version of `rolling_reduce_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_reduce_1d_nb(arr[:, col], window, minp, reduce_func_nb, *args)
return out
@register_jitted
def rolling_reduce_two_1d_nb(
arr1: tp.Array1d,
arr2: tp.Array1d,
window: int,
minp: tp.Optional[int],
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array1d:
"""Provide rolling window calculations for two arrays.
`reduce_func_nb` must accept two arrays and `*args`. Must return a single value."""
if minp is None:
minp = window
out = np.empty_like(arr1, dtype=float_)
nancnt_arr = np.empty(arr1.shape[0], dtype=int_)
nancnt = 0
for i in range(arr1.shape[0]):
if np.isnan(arr1[i]) or np.isnan(arr2[i]):
nancnt = nancnt + 1
nancnt_arr[i] = nancnt
if i < window:
valid_cnt = i + 1 - nancnt
else:
valid_cnt = window - (nancnt - nancnt_arr[i - window])
if valid_cnt < minp:
out[i] = np.nan
else:
from_i = max(0, i + 1 - window)
to_i = i + 1
arr1_window = arr1[from_i:to_i]
arr2_window = arr2[from_i:to_i]
out[i] = reduce_func_nb(arr1_window, arr2_window, *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
window=None,
minp=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_reduce_two_nb(
arr1: tp.Array2d,
arr2: tp.Array2d,
window: int,
minp: tp.Optional[int],
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array2d:
"""2-dim version of `rolling_reduce_two_1d_nb`."""
out = np.empty_like(arr1, dtype=float_)
for col in prange(arr1.shape[1]):
out[:, col] = rolling_reduce_two_1d_nb(arr1[:, col], arr2[:, col], window, minp, reduce_func_nb, *args)
return out
@register_jitted
def rolling_reduce_1d_meta_nb(
n: int,
col: int,
window: int,
minp: tp.Optional[int],
reduce_func_nb: tp.RangeReduceMetaFunc,
*args,
) -> tp.Array1d:
"""Meta version of `rolling_reduce_1d_nb`.
`reduce_func_nb` must accept the start row index, the end row index, the column, and `*args`.
Must return a single value."""
if minp is None:
minp = window
out = np.empty(n, dtype=float_)
for i in range(n):
valid_cnt = min(i + 1, window)
if valid_cnt < minp:
out[i] = np.nan
else:
from_i = max(0, i + 1 - window)
to_i = i + 1
out[i] = reduce_func_nb(from_i, to_i, col, *args)
return out
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
window=None,
minp=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_reduce_meta_nb(
target_shape: tp.Shape,
window: int,
minp: tp.Optional[int],
reduce_func_nb: tp.RangeReduceMetaFunc,
*args,
) -> tp.Array2d:
"""2-dim version of `rolling_reduce_1d_meta_nb`."""
out = np.empty(target_shape, dtype=float_)
for col in prange(target_shape[1]):
out[:, col] = rolling_reduce_1d_meta_nb(target_shape[0], col, window, minp, reduce_func_nb, *args)
return out
@register_jitted
def rolling_freq_reduce_1d_nb(
index: tp.Array1d,
arr: tp.Array1d,
freq: np.timedelta64,
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array1d:
"""Provide rolling, frequency-based window calculations.
`reduce_func_nb` must accept the array and `*args`. Must return a single value."""
out = np.empty_like(arr, dtype=float_)
from_i = 0
for i in range(arr.shape[0]):
if index[from_i] <= index[i] - freq:
for j in range(from_i + 1, index.shape[0]):
if index[j] > index[i] - freq:
from_i = j
break
to_i = i + 1
arr_window = arr[from_i:to_i]
out[i] = reduce_func_nb(arr_window, *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
index=None,
arr=ch.ArraySlicer(axis=1),
freq=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_freq_reduce_nb(
index: tp.Array1d,
arr: tp.Array2d,
freq: np.timedelta64,
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array2d:
"""2-dim version of `rolling_reduce_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_freq_reduce_1d_nb(index, arr[:, col], freq, reduce_func_nb, *args)
return out
@register_jitted
def rolling_freq_reduce_1d_meta_nb(
col: int,
index: tp.Array1d,
freq: np.timedelta64,
reduce_func_nb: tp.RangeReduceMetaFunc,
*args,
) -> tp.Array1d:
"""Meta version of `rolling_freq_reduce_1d_nb`.
`reduce_func_nb` must accept the start row index, the end row index, the column, and `*args`.
Must return a single value."""
out = np.empty(index.shape[0], dtype=float_)
from_i = 0
for i in range(index.shape[0]):
if index[from_i] <= index[i] - freq:
for j in range(from_i + 1, index.shape[0]):
if index[j] > index[i] - freq:
from_i = j
break
to_i = i + 1
out[i] = reduce_func_nb(from_i, to_i, col, *args)
return out
@register_chunkable(
size=ch.ArgSizer(arg_query="n_cols"),
arg_take_spec=dict(
n_cols=ch.CountAdapter(),
index=None,
freq=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_freq_reduce_meta_nb(
n_cols: int,
index: tp.Array1d,
freq: np.timedelta64,
reduce_func_nb: tp.RangeReduceMetaFunc,
*args,
) -> tp.Array2d:
"""2-dim version of `rolling_freq_reduce_1d_meta_nb`."""
out = np.empty((index.shape[0], n_cols), dtype=float_)
for col in prange(n_cols):
out[:, col] = rolling_freq_reduce_1d_meta_nb(col, index, freq, reduce_func_nb, *args)
return out
@register_jitted
def groupby_reduce_1d_nb(arr: tp.Array1d, group_map: tp.GroupMap, reduce_func_nb: tp.ReduceFunc, *args) -> tp.Array1d:
"""Provide group-by reduce calculations.
`reduce_func_nb` must accept the array and `*args`. Must return a single value."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_0_out = reduce_func_nb(arr[group_0_idxs], *args)
out = np.empty(group_lens.shape[0], dtype=np.asarray(group_0_out).dtype)
out[0] = group_0_out
for group in range(1, group_lens.shape[0]):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
idxs = group_idxs[start_idx : start_idx + group_len]
out[group] = reduce_func_nb(arr[idxs], *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
group_map=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def groupby_reduce_nb(arr: tp.Array2d, group_map: tp.GroupMap, reduce_func_nb: tp.ReduceFunc, *args) -> tp.Array2d:
"""2-dim version of `groupby_reduce_1d_nb`."""
col_0_out = groupby_reduce_1d_nb(arr[:, 0], group_map, reduce_func_nb, *args)
out = np.empty((col_0_out.shape[0], arr.shape[1]), dtype=col_0_out.dtype)
out[:, 0] = col_0_out
for col in prange(1, arr.shape[1]):
out[:, col] = groupby_reduce_1d_nb(arr[:, col], group_map, reduce_func_nb, *args)
return out
@register_jitted
def groupby_reduce_1d_meta_nb(
col: int,
group_map: tp.GroupMap,
reduce_func_nb: tp.GroupByReduceMetaFunc,
*args,
) -> tp.Array1d:
"""Meta version of `groupby_reduce_1d_nb`.
`reduce_func_nb` must accept the array of indices in the group, the group index, the column index,
and `*args`. Must return a single value."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_0_out = reduce_func_nb(group_0_idxs, 0, col, *args)
out = np.empty(group_lens.shape[0], dtype=np.asarray(group_0_out).dtype)
out[0] = group_0_out
for group in range(1, group_lens.shape[0]):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
idxs = group_idxs[start_idx : start_idx + group_len]
out[group] = reduce_func_nb(idxs, group, col, *args)
return out
@register_chunkable(
size=ch.ArgSizer(arg_query="n_cols"),
arg_take_spec=dict(
n_cols=ch.CountAdapter(),
group_map=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def groupby_reduce_meta_nb(
n_cols: int,
group_map: tp.GroupMap,
reduce_func_nb: tp.GroupByReduceMetaFunc,
*args,
) -> tp.Array2d:
"""2-dim version of `groupby_reduce_1d_meta_nb`."""
col_0_out = groupby_reduce_1d_meta_nb(0, group_map, reduce_func_nb, *args)
out = np.empty((col_0_out.shape[0], n_cols), dtype=col_0_out.dtype)
out[:, 0] = col_0_out
for col in prange(1, n_cols):
out[:, col] = groupby_reduce_1d_meta_nb(col, group_map, reduce_func_nb, *args)
return out
@register_jitted(tags={"can_parallel"})
def groupby_transform_nb(
arr: tp.Array2d,
group_map: tp.GroupMap,
transform_func_nb: tp.GroupByTransformFunc,
*args,
) -> tp.Array2d:
"""Provide group-by transform calculations.
`transform_func_nb` must accept the 2-dim array of the group and `*args`. Must return a scalar
or an array that broadcasts against the group array's shape."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_0_out = transform_func_nb(arr[group_0_idxs], *args)
out = np.empty(arr.shape, dtype=np.asarray(group_0_out).dtype)
out[group_0_idxs] = group_0_out
for group in prange(1, group_lens.shape[0]):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
idxs = group_idxs[start_idx : start_idx + group_len]
out[idxs] = transform_func_nb(arr[idxs], *args)
return out
@register_jitted(tags={"can_parallel"})
def groupby_transform_meta_nb(
target_shape: tp.Shape,
group_map: tp.GroupMap,
transform_func_nb: tp.GroupByTransformMetaFunc,
*args,
) -> tp.Array2d:
"""Meta version of `groupby_transform_nb`.
`transform_func_nb` must accept the array of indices in the group, the group index, and `*args`.
Must return a scalar or an array that broadcasts against the group's shape."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_0_out = transform_func_nb(group_0_idxs, 0, *args)
out = np.empty(target_shape, dtype=np.asarray(group_0_out).dtype)
out[group_0_idxs] = group_0_out
for group in prange(1, group_lens.shape[0]):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
idxs = group_idxs[start_idx : start_idx + group_len]
out[idxs] = transform_func_nb(idxs, group, *args)
return out
@register_jitted
def reduce_index_ranges_1d_nb(
arr: tp.Array1d,
range_starts: tp.Array1d,
range_ends: tp.Array1d,
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array1d:
"""Reduce each index range.
`reduce_func_nb` must accept the array and `*args`. Must return a single value."""
out = np.empty(range_starts.shape[0], dtype=float_)
for k in range(len(range_starts)):
from_i = range_starts[k]
to_i = range_ends[k]
if from_i == -1 or to_i == -1:
out[k] = np.nan
else:
out[k] = reduce_func_nb(arr[from_i:to_i], *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
range_starts=None,
range_ends=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_index_ranges_nb(
arr: tp.Array2d,
range_starts: tp.Array1d,
range_ends: tp.Array1d,
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array2d:
"""2-dim version of `reduce_index_ranges_1d_nb`."""
out = np.empty((range_starts.shape[0], arr.shape[1]), dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = reduce_index_ranges_1d_nb(arr[:, col], range_starts, range_ends, reduce_func_nb, *args)
return out
@register_jitted
def reduce_index_ranges_1d_meta_nb(
col: int,
range_starts: tp.Array1d,
range_ends: tp.Array1d,
reduce_func_nb: tp.RangeReduceMetaFunc,
*args,
) -> tp.Array1d:
"""Meta version of `reduce_index_ranges_1d_nb`.
`reduce_func_nb` must accept the start row index, the end row index, the column,
and `*args`. Must return a single value."""
out = np.empty(range_starts.shape[0], dtype=float_)
for k in range(len(range_starts)):
from_i = range_starts[k]
to_i = range_ends[k]
if from_i == -1 or to_i == -1:
out[k] = np.nan
else:
out[k] = reduce_func_nb(from_i, to_i, col, *args)
return out
@register_chunkable(
size=ch.ArgSizer(arg_query="n_cols"),
arg_take_spec=dict(
n_cols=ch.CountAdapter(),
range_starts=None,
range_ends=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_index_ranges_meta_nb(
n_cols: int,
range_starts: tp.Array1d,
range_ends: tp.Array1d,
reduce_func_nb: tp.RangeReduceMetaFunc,
*args,
) -> tp.Array2d:
"""2-dim version of `reduce_index_ranges_1d_meta_nb`."""
out = np.empty((range_starts.shape[0], n_cols), dtype=float_)
for col in prange(n_cols):
out[:, col] = reduce_index_ranges_1d_meta_nb(col, range_starts, range_ends, reduce_func_nb, *args)
return out
@register_jitted
def apply_and_reduce_1d_nb(
arr: tp.Array1d,
apply_func_nb: tp.ApplyFunc,
apply_args: tuple,
reduce_func_nb: tp.ReduceFunc,
reduce_args: tuple,
) -> tp.Scalar:
"""Apply `apply_func_nb` and reduce into a single value using `reduce_func_nb`.
`apply_func_nb` must accept the array and `*apply_args`.
Must return an array.
`reduce_func_nb` must accept the array of results from `apply_func_nb` and `*reduce_args`.
Must return a single value."""
temp = apply_func_nb(arr, *apply_args)
return reduce_func_nb(temp, *reduce_args)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
apply_func_nb=None,
apply_args=ch.ArgsTaker(),
reduce_func_nb=None,
reduce_args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def apply_and_reduce_nb(
arr: tp.Array2d,
apply_func_nb: tp.ApplyFunc,
apply_args: tuple,
reduce_func_nb: tp.ReduceFunc,
reduce_args: tuple,
) -> tp.Array1d:
"""2-dim version of `apply_and_reduce_1d_nb`."""
col_0_out = apply_and_reduce_1d_nb(arr[:, 0], apply_func_nb, apply_args, reduce_func_nb, reduce_args)
out = np.empty(arr.shape[1], dtype=np.asarray(col_0_out).dtype)
out[0] = col_0_out
for col in prange(1, arr.shape[1]):
out[col] = apply_and_reduce_1d_nb(arr[:, col], apply_func_nb, apply_args, reduce_func_nb, reduce_args)
return out
@register_jitted
def apply_and_reduce_1d_meta_nb(
col: int,
apply_func_nb: tp.ApplyMetaFunc,
apply_args: tuple,
reduce_func_nb: tp.ReduceMetaFunc,
reduce_args: tuple,
) -> tp.Scalar:
"""Meta version of `apply_and_reduce_1d_nb`.
`apply_func_nb` must accept the column index, the array, and `*apply_args`.
Must return an array.
`reduce_func_nb` must accept the column index, the array of results from `apply_func_nb`, and `*reduce_args`.
Must return a single value."""
temp = apply_func_nb(col, *apply_args)
return reduce_func_nb(col, temp, *reduce_args)
@register_chunkable(
size=ch.ArgSizer(arg_query="n_cols"),
arg_take_spec=dict(
n_cols=ch.CountAdapter(),
apply_func_nb=None,
apply_args=ch.ArgsTaker(),
reduce_func_nb=None,
reduce_args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def apply_and_reduce_meta_nb(
n_cols: int,
apply_func_nb: tp.ApplyMetaFunc,
apply_args: tuple,
reduce_func_nb: tp.ReduceMetaFunc,
reduce_args: tuple,
) -> tp.Array1d:
"""2-dim version of `apply_and_reduce_1d_meta_nb`."""
col_0_out = apply_and_reduce_1d_meta_nb(0, apply_func_nb, apply_args, reduce_func_nb, reduce_args)
out = np.empty(n_cols, dtype=np.asarray(col_0_out).dtype)
out[0] = col_0_out
for col in prange(1, n_cols):
out[col] = apply_and_reduce_1d_meta_nb(col, apply_func_nb, apply_args, reduce_func_nb, reduce_args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_nb(arr: tp.Array2d, reduce_func_nb: tp.ReduceFunc, *args) -> tp.Array1d:
"""Reduce each column into a single value using `reduce_func_nb`.
`reduce_func_nb` must accept the array and `*args`. Must return a single value."""
col_0_out = reduce_func_nb(arr[:, 0], *args)
out = np.empty(arr.shape[1], dtype=np.asarray(col_0_out).dtype)
out[0] = col_0_out
for col in prange(1, arr.shape[1]):
out[col] = reduce_func_nb(arr[:, col], *args)
return out
@register_chunkable(
size=ch.ArgSizer(arg_query="n_cols"),
arg_take_spec=dict(
n_cols=ch.CountAdapter(),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_meta_nb(n_cols: int, reduce_func_nb: tp.ReduceMetaFunc, *args) -> tp.Array1d:
"""Meta version of `reduce_nb`.
`reduce_func_nb` must accept the column index and `*args`. Must return a single value."""
col_0_out = reduce_func_nb(0, *args)
out = np.empty(n_cols, dtype=np.asarray(col_0_out).dtype)
out[0] = col_0_out
for col in prange(1, n_cols):
out[col] = reduce_func_nb(col, *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_to_array_nb(arr: tp.Array2d, reduce_func_nb: tp.ReduceToArrayFunc, *args) -> tp.Array2d:
"""Same as `reduce_nb` but `reduce_func_nb` must return an array."""
col_0_out = reduce_func_nb(arr[:, 0], *args)
out = np.empty((col_0_out.shape[0], arr.shape[1]), dtype=col_0_out.dtype)
out[:, 0] = col_0_out
for col in prange(1, arr.shape[1]):
out[:, col] = reduce_func_nb(arr[:, col], *args)
return out
@register_chunkable(
size=ch.ArgSizer(arg_query="n_cols"),
arg_take_spec=dict(
n_cols=ch.CountAdapter(),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_to_array_meta_nb(n_cols: int, reduce_func_nb: tp.ReduceToArrayMetaFunc, *args) -> tp.Array2d:
"""Same as `reduce_meta_nb` but `reduce_func_nb` must return an array."""
col_0_out = reduce_func_nb(0, *args)
out = np.empty((col_0_out.shape[0], n_cols), dtype=col_0_out.dtype)
out[:, 0] = col_0_out
for col in prange(1, n_cols):
out[:, col] = reduce_func_nb(col, *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper),
group_map=base_ch.GroupMapSlicer(),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_grouped_nb(
arr: tp.Array2d,
group_map: tp.GroupMap,
reduce_func_nb: tp.ReduceGroupedFunc,
*args,
) -> tp.Array1d:
"""Reduce each group of columns into a single value using `reduce_func_nb`.
`reduce_func_nb` must accept the 2-dim array and `*args`. Must return a single value."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_0_out = reduce_func_nb(arr[:, group_0_idxs], *args)
out = np.empty(len(group_lens), dtype=np.asarray(group_0_out).dtype)
out[0] = group_0_out
for group in prange(1, len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
out[group] = reduce_func_nb(arr[:, col_idxs], *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
group_map=base_ch.GroupMapSlicer(),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_grouped_meta_nb(group_map: tp.GroupMap, reduce_func_nb: tp.ReduceGroupedMetaFunc, *args) -> tp.Array1d:
"""Meta version of `reduce_grouped_nb`.
`reduce_func_nb` must accept the column indices of the group, the group index, and `*args`.
Must return a single value."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_0_out = reduce_func_nb(group_0_idxs, 0, *args)
out = np.empty(len(group_lens), dtype=np.asarray(group_0_out).dtype)
out[0] = group_0_out
for group in prange(1, len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
out[group] = reduce_func_nb(col_idxs, group, *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def flatten_forder_nb(arr: tp.Array2d) -> tp.Array1d:
"""Flatten the array in F order."""
out = np.empty(arr.shape[0] * arr.shape[1], dtype=arr.dtype)
for col in prange(arr.shape[1]):
out[col * arr.shape[0] : (col + 1) * arr.shape[0]] = arr[:, col]
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper),
group_map=base_ch.GroupMapSlicer(),
in_c_order=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_flat_grouped_nb(
arr: tp.Array2d,
group_map: tp.GroupMap,
in_c_order: bool,
reduce_func_nb: tp.ReduceToArrayFunc,
*args,
) -> tp.Array1d:
"""Same as `reduce_grouped_nb` but passes flattened array."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
if in_c_order:
group_0_out = reduce_func_nb(arr[:, group_0_idxs].flatten(), *args)
else:
group_0_out = reduce_func_nb(flatten_forder_nb(arr[:, group_0_idxs]), *args)
out = np.empty(len(group_lens), dtype=np.asarray(group_0_out).dtype)
out[0] = group_0_out
for group in prange(1, len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
if in_c_order:
out[group] = reduce_func_nb(arr[:, col_idxs].flatten(), *args)
else:
out[group] = reduce_func_nb(flatten_forder_nb(arr[:, col_idxs]), *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper),
group_map=base_ch.GroupMapSlicer(),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_grouped_to_array_nb(
arr: tp.Array2d,
group_map: tp.GroupMap,
reduce_func_nb: tp.ReduceGroupedToArrayFunc,
*args,
) -> tp.Array2d:
"""Same as `reduce_grouped_nb` but `reduce_func_nb` must return an array."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_0_out = reduce_func_nb(arr[:, group_0_idxs], *args)
out = np.empty((group_0_out.shape[0], len(group_lens)), dtype=group_0_out.dtype)
out[:, 0] = group_0_out
for group in prange(1, len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
out[:, group] = reduce_func_nb(arr[:, col_idxs], *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
group_map=base_ch.GroupMapSlicer(),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_grouped_to_array_meta_nb(
group_map: tp.GroupMap,
reduce_func_nb: tp.ReduceGroupedToArrayMetaFunc,
*args,
) -> tp.Array2d:
"""Same as `reduce_grouped_meta_nb` but `reduce_func_nb` must return an array."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_0_out = reduce_func_nb(group_0_idxs, 0, *args)
out = np.empty((group_0_out.shape[0], len(group_lens)), dtype=group_0_out.dtype)
out[:, 0] = group_0_out
for group in prange(1, len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
out[:, group] = reduce_func_nb(col_idxs, group, *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper),
group_map=base_ch.GroupMapSlicer(),
in_c_order=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_flat_grouped_to_array_nb(
arr: tp.Array2d,
group_map: tp.GroupMap,
in_c_order: bool,
reduce_func_nb: tp.ReduceToArrayFunc,
*args,
) -> tp.Array2d:
"""Same as `reduce_grouped_to_array_nb` but passes flattened array."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
if in_c_order:
group_0_out = reduce_func_nb(arr[:, group_0_idxs].flatten(), *args)
else:
group_0_out = reduce_func_nb(flatten_forder_nb(arr[:, group_0_idxs]), *args)
out = np.empty((group_0_out.shape[0], len(group_lens)), dtype=group_0_out.dtype)
out[:, 0] = group_0_out
for group in prange(1, len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
if in_c_order:
out[:, group] = reduce_func_nb(arr[:, col_idxs].flatten(), *args)
else:
out[:, group] = reduce_func_nb(flatten_forder_nb(arr[:, col_idxs]), *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper),
group_map=base_ch.GroupMapSlicer(),
squeeze_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def squeeze_grouped_nb(arr: tp.Array2d, group_map: tp.GroupMap, squeeze_func_nb: tp.ReduceFunc, *args) -> tp.Array2d:
"""Squeeze each group of columns into a single column using `squeeze_func_nb`.
`squeeze_func_nb` must accept index the array and `*args`. Must return a single value."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_i_0_out = squeeze_func_nb(arr[0][group_0_idxs], *args)
out = np.empty((arr.shape[0], len(group_lens)), dtype=np.asarray(group_i_0_out).dtype)
out[0, 0] = group_i_0_out
for group in prange(len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
for i in range(arr.shape[0]):
if group == 0 and i == 0:
continue
out[i, group] = squeeze_func_nb(arr[i][col_idxs], *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
n_rows=None,
group_map=base_ch.GroupMapSlicer(),
squeeze_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def squeeze_grouped_meta_nb(
n_rows: int,
group_map: tp.GroupMap,
squeeze_func_nb: tp.GroupSqueezeMetaFunc,
*args,
) -> tp.Array2d:
"""Meta version of `squeeze_grouped_nb`.
`squeeze_func_nb` must accept the row index, the column indices of the group,
the group index, and `*args`. Must return a single value."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
group_0_idxs = group_idxs[group_start_idxs[0] : group_start_idxs[0] + group_lens[0]]
group_i_0_out = squeeze_func_nb(0, group_0_idxs, 0, *args)
out = np.empty((n_rows, len(group_lens)), dtype=np.asarray(group_i_0_out).dtype)
out[0, 0] = group_i_0_out
for group in prange(len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
for i in range(n_rows):
if group == 0 and i == 0:
continue
out[i, group] = squeeze_func_nb(i, col_idxs, group, *args)
return out
# ############# Flattening ############# #
@register_jitted(cache=True)
def flatten_grouped_nb(arr: tp.Array2d, group_map: tp.GroupMap, in_c_order: bool) -> tp.Array2d:
"""Flatten each group of columns."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
out = np.full((arr.shape[0] * np.max(group_lens), len(group_lens)), np.nan, dtype=float_)
max_len = np.max(group_lens)
for group in range(len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
for k in range(group_len):
col = col_idxs[k]
if in_c_order:
out[k::max_len, group] = arr[:, col]
else:
out[k * arr.shape[0] : (k + 1) * arr.shape[0], group] = arr[:, col]
return out
@register_jitted(cache=True)
def flatten_uniform_grouped_nb(arr: tp.Array2d, group_map: tp.GroupMap, in_c_order: bool) -> tp.Array2d:
"""Flatten each group of columns of the same length."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
out = np.empty((arr.shape[0] * np.max(group_lens), len(group_lens)), dtype=arr.dtype)
max_len = np.max(group_lens)
for group in range(len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
for k in range(group_len):
col = col_idxs[k]
if in_c_order:
out[k::max_len, group] = arr[:, col]
else:
out[k * arr.shape[0] : (k + 1) * arr.shape[0], group] = arr[:, col]
return out
# ############# Proximity ############# #
@register_jitted(tags={"can_parallel"})
def proximity_reduce_nb(
arr: tp.Array2d,
window: int,
reduce_func_nb: tp.ProximityReduceMetaFunc,
*args,
) -> tp.Array2d:
"""Flatten `window` surrounding rows and columns and reduce them into a single value using `reduce_func_nb`.
`reduce_func_nb` must accept the array and `*args`. Must return a single value."""
out = np.empty_like(arr, dtype=float_)
for i in prange(arr.shape[0]):
for col in range(arr.shape[1]):
from_i = max(0, i - window)
to_i = min(i + window + 1, arr.shape[0])
from_col = max(0, col - window)
to_col = min(col + window + 1, arr.shape[1])
stride_arr = arr[from_i:to_i, from_col:to_col]
out[i, col] = reduce_func_nb(stride_arr.flatten(), *args)
return out
@register_jitted(tags={"can_parallel"})
def proximity_reduce_meta_nb(
target_shape: tp.Shape,
window: int,
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array2d:
"""Meta version of `proximity_reduce_nb`.
`reduce_func_nb` must accept the start row index, the end row index, the start column index,
the end column index, and `*args`. Must return a single value."""
out = np.empty(target_shape, dtype=float_)
for i in prange(target_shape[0]):
for col in range(target_shape[1]):
from_i = max(0, i - window)
to_i = min(i + window + 1, target_shape[0])
from_col = max(0, col - window)
to_col = min(col + window + 1, target_shape[1])
out[i, col] = reduce_func_nb(from_i, to_i, from_col, to_col, *args)
return out
# ############# Reducers ############# #
@register_jitted(cache=True)
def nth_reduce_nb(arr: tp.Array1d, n: int) -> float:
"""Get n-th element."""
if (n < 0 and abs(n) > arr.shape[0]) or n >= arr.shape[0]:
raise ValueError("index is out of bounds")
return arr[n]
@register_jitted(cache=True)
def first_reduce_nb(arr: tp.Array1d) -> float:
"""Get first non-NA element."""
if arr.shape[0] == 0:
raise ValueError("index is out of bounds")
for i in range(len(arr)):
if not np.isnan(arr[i]):
return arr[i]
return np.nan
@register_jitted(cache=True)
def last_reduce_nb(arr: tp.Array1d) -> float:
"""Get last non-NA element."""
if arr.shape[0] == 0:
raise ValueError("index is out of bounds")
for i in range(len(arr) - 1, -1, -1):
if not np.isnan(arr[i]):
return arr[i]
return np.nan
@register_jitted(cache=True)
def first_index_reduce_nb(arr: tp.Array1d) -> int:
"""Get index of first non-NA element."""
if arr.shape[0] == 0:
raise ValueError("index is out of bounds")
for i in range(len(arr)):
if not np.isnan(arr[i]):
return i
return -1
@register_jitted(cache=True)
def last_index_reduce_nb(arr: tp.Array1d) -> int:
"""Get index of last non-NA element."""
if arr.shape[0] == 0:
raise ValueError("index is out of bounds")
for i in range(len(arr) - 1, -1, -1):
if not np.isnan(arr[i]):
return i
return -1
@register_jitted(cache=True)
def nth_index_reduce_nb(arr: tp.Array1d, n: int) -> int:
"""Get index of n-th element including NA elements."""
if (n < 0 and abs(n) > arr.shape[0]) or n >= arr.shape[0]:
raise ValueError("index is out of bounds")
if n >= 0:
return n
return arr.shape[0] + n
@register_jitted(cache=True)
def any_reduce_nb(arr: tp.Array1d) -> bool:
"""Get whether any of the elements are True."""
return np.any(arr)
@register_jitted(cache=True)
def all_reduce_nb(arr: tp.Array1d) -> bool:
"""Get whether all of the elements are True."""
return np.all(arr)
@register_jitted(cache=True)
def min_reduce_nb(arr: tp.Array1d) -> float:
"""Get min. Ignores NaN."""
return np.nanmin(arr)
@register_jitted(cache=True)
def max_reduce_nb(arr: tp.Array1d) -> float:
"""Get max. Ignores NaN."""
return np.nanmax(arr)
@register_jitted(cache=True)
def mean_reduce_nb(arr: tp.Array1d) -> float:
"""Get mean. Ignores NaN."""
return np.nanmean(arr)
@register_jitted(cache=True)
def median_reduce_nb(arr: tp.Array1d) -> float:
"""Get median. Ignores NaN."""
return np.nanmedian(arr)
@register_jitted(cache=True)
def std_reduce_nb(arr: tp.Array1d, ddof) -> float:
"""Get std. Ignores NaN."""
return nanstd_1d_nb(arr, ddof=ddof)
@register_jitted(cache=True)
def sum_reduce_nb(arr: tp.Array1d) -> float:
"""Get sum. Ignores NaN."""
return np.nansum(arr)
@register_jitted(cache=True)
def prod_reduce_nb(arr: tp.Array1d) -> float:
"""Get product. Ignores NaN."""
return np.nanprod(arr)
@register_jitted(cache=True)
def nonzero_prod_reduce_nb(arr: tp.Array1d) -> float:
"""Get product. Ignores zero and NaN. Default value is zero."""
prod = 0.0
for i in range(len(arr)):
if not np.isnan(arr[i]) and arr[i] != 0:
if prod == 0:
prod = 1.0
prod *= arr[i]
return prod
@register_jitted(cache=True)
def count_reduce_nb(arr: tp.Array1d) -> int:
"""Get count. Ignores NaN."""
return np.sum(~np.isnan(arr))
@register_jitted(cache=True)
def argmin_reduce_nb(arr: tp.Array1d) -> int:
"""Get position of min."""
arr = np.copy(arr)
mask = np.isnan(arr)
if np.all(mask):
raise ValueError("All-NaN slice encountered")
arr[mask] = np.inf
return np.argmin(arr)
@register_jitted(cache=True)
def argmax_reduce_nb(arr: tp.Array1d) -> int:
"""Get position of max."""
arr = np.copy(arr)
mask = np.isnan(arr)
if np.all(mask):
raise ValueError("All-NaN slice encountered")
arr[mask] = -np.inf
return np.argmax(arr)
@register_jitted(cache=True)
def describe_reduce_nb(arr: tp.Array1d, perc: tp.Array1d, ddof: int) -> tp.Array1d:
"""Get descriptive statistics. Ignores NaN.
Numba equivalent to `pd.Series(arr).describe(perc)`."""
arr = arr[~np.isnan(arr)]
out = np.empty(5 + len(perc), dtype=float_)
out[0] = len(arr)
if len(arr) > 0:
out[1] = np.mean(arr)
out[2] = nanstd_1d_nb(arr, ddof=ddof)
out[3] = np.min(arr)
out[4:-1] = np.percentile(arr, perc * 100)
out[4 + len(perc)] = np.max(arr)
else:
out[1:] = np.nan
return out
@register_jitted(cache=True)
def cov_reduce_grouped_meta_nb(
group_idxs: tp.GroupIdxs,
group: int,
arr1: tp.Array2d,
arr2: tp.Array2d,
ddof: int,
) -> float:
"""Get correlation coefficient. Ignores NaN."""
return nancov_1d_nb(arr1[:, group_idxs].flatten(), arr2[:, group_idxs].flatten(), ddof=ddof)
@register_jitted(cache=True)
def corr_reduce_grouped_meta_nb(group_idxs: tp.GroupIdxs, group: int, arr1: tp.Array2d, arr2: tp.Array2d) -> float:
"""Get correlation coefficient. Ignores NaN."""
return nancorr_1d_nb(arr1[:, group_idxs].flatten(), arr2[:, group_idxs].flatten())
@register_jitted(cache=True)
def wmean_range_reduce_meta_nb(from_i: int, to_i: int, col: int, arr1: tp.Array2d, arr2: tp.Array2d) -> float:
"""Get the weighted average."""
nom_cumsum = 0
denum_cumsum = 0
for i in range(from_i, to_i):
nom_cumsum += arr1[i, col] * arr2[i, col]
denum_cumsum += arr2[i, col]
if denum_cumsum == 0:
return np.nan
return nom_cumsum / denum_cumsum
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Generic Numba-compiled functions for base operations."""
import numpy as np
from numba import prange
from numba.core.types import Type, Omitted
from numba.extending import overload
from numba.np.numpy_support import as_dtype
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.flex_indexing import flex_select_1d_nb, flex_select_col_nb
from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
def _select_indices_1d_nb(arr, indices, fill_value):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
value_dtype = as_dtype(fill_value)
else:
a_dtype = arr.dtype
value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(a_dtype, value_dtype)
def impl(arr, indices, fill_value):
out = np.empty(indices.shape, dtype=dtype)
for i in range(indices.shape[0]):
if 0 <= indices[i] <= arr.shape[0] - 1:
out[i] = arr[indices[i]]
else:
out[i] = fill_value
return out
if not nb_enabled:
return impl(arr, indices, fill_value)
return impl
overload(_select_indices_1d_nb)(_select_indices_1d_nb)
@register_jitted(cache=True)
def select_indices_1d_nb(arr: tp.Array1d, indices: tp.Array1d, fill_value: tp.Scalar) -> tp.Array1d:
"""Set each element to a value by boolean mask."""
return _select_indices_1d_nb(arr, indices, fill_value)
def _select_indices_nb(arr, indices, fill_value):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
value_dtype = as_dtype(fill_value)
else:
a_dtype = arr.dtype
value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(a_dtype, value_dtype)
def impl(arr, indices, fill_value):
out = np.empty(indices.shape, dtype=dtype)
for col in range(indices.shape[1]):
for i in range(indices.shape[0]):
if 0 <= indices[i, col] <= arr.shape[0] - 1:
out[i, col] = arr[indices[i, col], col]
else:
out[i, col] = fill_value
return out
if not nb_enabled:
return impl(arr, indices, fill_value)
return impl
overload(_select_indices_nb)(_select_indices_nb)
@register_jitted(cache=True)
def select_indices_nb(arr: tp.Array2d, indices: tp.Array2d, fill_value: tp.Scalar) -> tp.Array2d:
"""2-dim version of `select_indices_1d_nb`."""
return _select_indices_nb(arr, indices, fill_value)
@register_jitted(cache=True)
def shuffle_1d_nb(arr: tp.Array1d, seed: tp.Optional[int] = None) -> tp.Array1d:
"""Shuffle each column in the array.
Specify `seed` to make output deterministic."""
if seed is not None:
np.random.seed(seed)
return np.random.permutation(arr)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), seed=None),
merge_func="column_stack",
)
@register_jitted(cache=True)
def shuffle_nb(arr: tp.Array2d, seed: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `shuffle_1d_nb`."""
if seed is not None:
np.random.seed(seed)
out = np.empty_like(arr, dtype=arr.dtype)
for col in range(arr.shape[1]):
out[:, col] = np.random.permutation(arr[:, col])
return out
def _set_by_mask_1d_nb(arr, mask, value):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
value_dtype = as_dtype(value)
else:
a_dtype = arr.dtype
value_dtype = np.array(value).dtype
dtype = np.promote_types(a_dtype, value_dtype)
def impl(arr, mask, value):
out = arr.astype(dtype)
out[mask] = value
return out
if not nb_enabled:
return impl(arr, mask, value)
return impl
overload(_set_by_mask_1d_nb)(_set_by_mask_1d_nb)
@register_jitted(cache=True)
def set_by_mask_1d_nb(arr: tp.Array1d, mask: tp.Array1d, value: tp.Scalar) -> tp.Array1d:
"""Set each element to a value by boolean mask."""
return _set_by_mask_1d_nb(arr, mask, value)
def _set_by_mask_nb(arr, mask, value):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
value_dtype = as_dtype(value)
else:
a_dtype = arr.dtype
value_dtype = np.array(value).dtype
dtype = np.promote_types(a_dtype, value_dtype)
def impl(arr, mask, value):
out = arr.astype(dtype)
for col in range(arr.shape[1]):
out[mask[:, col], col] = value
return out
if not nb_enabled:
return impl(arr, mask, value)
return impl
overload(_set_by_mask_nb)(_set_by_mask_nb)
@register_jitted(cache=True)
def set_by_mask_nb(arr: tp.Array2d, mask: tp.Array2d, value: tp.Scalar) -> tp.Array2d:
"""2-dim version of `set_by_mask_1d_nb`."""
return _set_by_mask_nb(arr, mask, value)
def _set_by_mask_mult_1d_nb(arr, mask, values):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
value_dtype = as_dtype(values.dtype)
else:
a_dtype = arr.dtype
value_dtype = values.dtype
dtype = np.promote_types(a_dtype, value_dtype)
def impl(arr, mask, values):
out = arr.astype(dtype)
out[mask] = values[mask]
return out
if not nb_enabled:
return impl(arr, mask, values)
return impl
overload(_set_by_mask_mult_1d_nb)(_set_by_mask_mult_1d_nb)
@register_jitted(cache=True)
def set_by_mask_mult_1d_nb(arr: tp.Array1d, mask: tp.Array1d, values: tp.Array1d) -> tp.Array1d:
"""Set each element in one array to the corresponding element in another by boolean mask.
`values` must be of the same shape as in the array."""
return _set_by_mask_mult_1d_nb(arr, mask, values)
def _set_by_mask_mult_nb(arr, mask, values):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
value_dtype = as_dtype(values.dtype)
else:
a_dtype = arr.dtype
value_dtype = values.dtype
dtype = np.promote_types(a_dtype, value_dtype)
def impl(arr, mask, values):
out = arr.astype(dtype)
for col in range(arr.shape[1]):
out[mask[:, col], col] = values[mask[:, col], col]
return out
if not nb_enabled:
return impl(arr, mask, values)
return impl
overload(_set_by_mask_mult_nb)(_set_by_mask_mult_nb)
@register_jitted(cache=True)
def set_by_mask_mult_nb(arr: tp.Array2d, mask: tp.Array2d, values: tp.Array2d) -> tp.Array2d:
"""2-dim version of `set_by_mask_mult_1d_nb`."""
return _set_by_mask_mult_nb(arr, mask, values)
@register_jitted(cache=True)
def first_valid_index_1d_nb(arr: tp.Array1d, check_inf: bool = True) -> int:
"""Get the index of the first valid value."""
for i in range(arr.shape[0]):
if not np.isnan(arr[i]) and (not check_inf or not np.isinf(arr[i])):
return i
return -1
@register_jitted(cache=True)
def first_valid_index_nb(arr, check_inf: bool = True):
"""2-dim version of `first_valid_index_1d_nb`."""
out = np.empty(arr.shape[1], dtype=int_)
for col in range(arr.shape[1]):
out[col] = first_valid_index_1d_nb(arr[:, col], check_inf=check_inf)
return out
@register_jitted(cache=True)
def last_valid_index_1d_nb(arr: tp.Array1d, check_inf: bool = True) -> int:
"""Get the index of the last valid value."""
for i in range(arr.shape[0] - 1, -1, -1):
if not np.isnan(arr[i]) and (not check_inf or not np.isinf(arr[i])):
return i
return -1
@register_jitted(cache=True)
def last_valid_index_nb(arr, check_inf: bool = True):
"""2-dim version of `last_valid_index_1d_nb`."""
out = np.empty(arr.shape[1], dtype=int_)
for col in range(arr.shape[1]):
out[col] = last_valid_index_1d_nb(arr[:, col], check_inf=check_inf)
return out
@register_jitted(cache=True)
def fillna_1d_nb(arr: tp.Array1d, value: tp.Scalar) -> tp.Array1d:
"""Replace NaNs with value.
Numba equivalent to `pd.Series(arr).fillna(value)`."""
return set_by_mask_1d_nb(arr, np.isnan(arr), value)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), value=None),
merge_func="column_stack",
)
@register_jitted(cache=True)
def fillna_nb(arr: tp.Array2d, value: tp.Scalar) -> tp.Array2d:
"""2-dim version of `fillna_1d_nb`."""
return set_by_mask_nb(arr, np.isnan(arr), value)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def fbfill_nb(arr: tp.Array2d) -> tp.Array2d:
"""Forward and backward fill NaN values.
!!! note
If there are no NaN (or any) values, will return `arr`."""
if arr.size == 0:
return arr
need_fbfill = False
for col in range(arr.shape[1]):
for i in range(arr.shape[0]):
if np.isnan(arr[i, col]):
need_fbfill = True
break
if need_fbfill:
break
if not need_fbfill:
return arr
out = np.empty_like(arr)
for col in prange(arr.shape[1]):
last_valid = np.nan
for i in range(arr.shape[0]):
if not np.isnan(arr[i, col]):
last_valid = arr[i, col]
out[i, col] = last_valid
if np.isnan(out[0, col]):
last_valid = np.nan
for i in range(arr.shape[0] - 1, -1, -1):
if not np.isnan(arr[i, col]):
last_valid = arr[i, col]
if np.isnan(out[i, col]):
out[i, col] = last_valid
return out
def _bshift_1d_nb(arr, n, fill_value):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
if isinstance(fill_value, Omitted):
fill_value_dtype = np.asarray(fill_value.value).dtype
else:
fill_value_dtype = as_dtype(fill_value)
else:
a_dtype = arr.dtype
fill_value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(a_dtype, fill_value_dtype)
def impl(arr, n, fill_value):
out = np.empty(arr.shape[0], dtype=dtype)
for i in range(out.shape[0]):
if i + n <= out.shape[0] - 1:
out[i] = arr[i + n]
else:
out[i] = fill_value
return out
if not nb_enabled:
return impl(arr, n, fill_value)
return impl
overload(_bshift_1d_nb)(_bshift_1d_nb)
@register_jitted(cache=True)
def bshift_1d_nb(arr: tp.Array1d, n: int = 1, fill_value: tp.Scalar = np.nan) -> tp.Array1d:
"""Shift backward by `n` positions.
Numba equivalent to `pd.Series(arr).shift(-n)`.
!!! warning
This operation looks ahead."""
return _bshift_1d_nb(arr, n, fill_value)
def _bshift_nb(arr, n, fill_value):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
if isinstance(fill_value, Omitted):
fill_value_dtype = np.asarray(fill_value.value).dtype
else:
fill_value_dtype = as_dtype(fill_value)
else:
a_dtype = arr.dtype
fill_value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(a_dtype, fill_value_dtype)
def impl(arr, n, fill_value):
out = np.empty_like(arr, dtype=dtype)
for col in range(arr.shape[1]):
out[:, col] = bshift_1d_nb(arr[:, col], n=n, fill_value=fill_value)
return out
if not nb_enabled:
return impl(arr, n, fill_value)
return impl
overload(_bshift_nb)(_bshift_nb)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), n=None, fill_value=None),
merge_func="column_stack",
)
@register_jitted(cache=True)
def bshift_nb(arr: tp.Array2d, n: int = 1, fill_value: tp.Scalar = np.nan) -> tp.Array2d:
"""2-dim version of `bshift_1d_nb`."""
return _bshift_nb(arr, n, fill_value)
def _fshift_1d_nb(arr, n, fill_value):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
if isinstance(fill_value, Omitted):
fill_value_dtype = np.asarray(fill_value.value).dtype
else:
fill_value_dtype = as_dtype(fill_value)
else:
a_dtype = arr.dtype
fill_value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(a_dtype, fill_value_dtype)
def impl(arr, n, fill_value):
out = np.empty(arr.shape[0], dtype=dtype)
for i in range(out.shape[0]):
if i - n >= 0:
out[i] = arr[i - n]
else:
out[i] = fill_value
return out
if not nb_enabled:
return impl(arr, n, fill_value)
return impl
overload(_fshift_1d_nb)(_fshift_1d_nb)
@register_jitted(cache=True)
def fshift_1d_nb(arr: tp.Array1d, n: int = 1, fill_value: tp.Scalar = np.nan) -> tp.Array1d:
"""Shift forward by `n` positions.
Numba equivalent to `pd.Series(arr).shift(n)`."""
return _fshift_1d_nb(arr, n, fill_value)
def _fshift_nb(arr, n, fill_value):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
if isinstance(fill_value, Omitted):
fill_value_dtype = np.asarray(fill_value.value).dtype
else:
fill_value_dtype = as_dtype(fill_value)
else:
a_dtype = arr.dtype
fill_value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(a_dtype, fill_value_dtype)
def impl(arr, n, fill_value):
out = np.empty_like(arr, dtype=dtype)
for col in range(arr.shape[1]):
out[:, col] = fshift_1d_nb(arr[:, col], n=n, fill_value=fill_value)
return out
if not nb_enabled:
return impl(arr, n, fill_value)
return impl
overload(_fshift_nb)(_fshift_nb)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), n=None, fill_value=None),
merge_func="column_stack",
)
@register_jitted(cache=True)
def fshift_nb(arr: tp.Array2d, n: int = 1, fill_value: tp.Scalar = np.nan) -> tp.Array2d:
"""2-dim version of `fshift_1d_nb`."""
return _fshift_nb(arr, n, fill_value)
@register_jitted(cache=True)
def diff_1d_nb(arr: tp.Array1d, n: int = 1) -> tp.Array1d:
"""Compute the 1-th discrete difference.
Numba equivalent to `pd.Series(arr).diff()`."""
out = np.empty(arr.shape[0], dtype=float_)
for i in range(out.shape[0]):
if i - n >= 0:
out[i] = arr[i] - arr[i - n]
else:
out[i] = np.nan
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), n=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def diff_nb(arr: tp.Array2d, n: int = 1) -> tp.Array2d:
"""2-dim version of `diff_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = diff_1d_nb(arr[:, col], n=n)
return out
@register_jitted(cache=True)
def pct_change_1d_nb(arr: tp.Array1d, n: int = 1) -> tp.Array1d:
"""Compute the percentage change.
Numba equivalent to `pd.Series(arr).pct_change()`."""
out = np.empty(arr.shape[0], dtype=float_)
for i in range(out.shape[0]):
if i - n >= 0:
out[i] = (arr[i] - arr[i - n]) / arr[i - n]
else:
out[i] = np.nan
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), n=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def pct_change_nb(arr: tp.Array2d, n: int = 1) -> tp.Array2d:
"""2-dim version of `pct_change_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = pct_change_1d_nb(arr[:, col], n=n)
return out
@register_jitted(cache=True)
def bfill_1d_nb(arr: tp.Array1d) -> tp.Array1d:
"""Fill NaNs by propagating first valid observation backward.
Numba equivalent to `pd.Series(arr).fillna(method='bfill')`.
!!! warning
This operation looks ahead."""
out = np.empty_like(arr, dtype=arr.dtype)
lastval = arr[-1]
for i in range(arr.shape[0] - 1, -1, -1):
if np.isnan(arr[i]):
out[i] = lastval
else:
lastval = out[i] = arr[i]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def bfill_nb(arr: tp.Array2d) -> tp.Array2d:
"""2-dim version of `bfill_1d_nb`."""
out = np.empty_like(arr, dtype=arr.dtype)
for col in prange(arr.shape[1]):
out[:, col] = bfill_1d_nb(arr[:, col])
return out
@register_jitted(cache=True)
def ffill_1d_nb(arr: tp.Array1d) -> tp.Array1d:
"""Fill NaNs by propagating last valid observation forward.
Numba equivalent to `pd.Series(arr).fillna(method='ffill')`."""
out = np.empty_like(arr, dtype=arr.dtype)
lastval = arr[0]
for i in range(arr.shape[0]):
if np.isnan(arr[i]):
out[i] = lastval
else:
lastval = out[i] = arr[i]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ffill_nb(arr: tp.Array2d) -> tp.Array2d:
"""2-dim version of `ffill_1d_nb`."""
out = np.empty_like(arr, dtype=arr.dtype)
for col in prange(arr.shape[1]):
out[:, col] = ffill_1d_nb(arr[:, col])
return out
def _nanprod_nb(arr):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
else:
a_dtype = arr.dtype
dtype = np.promote_types(a_dtype, int)
def impl(arr):
out = np.empty(arr.shape[1], dtype=dtype)
for col in prange(arr.shape[1]):
out[col] = np.nanprod(arr[:, col])
return out
if not nb_enabled:
return impl(arr)
return impl
overload(_nanprod_nb)(_nanprod_nb)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True)
def nanprod_nb(arr: tp.Array2d) -> tp.Array1d:
"""Numba equivalent of `np.nanprod` along axis 0."""
return _nanprod_nb(arr)
def _nancumsum_nb(arr):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
else:
a_dtype = arr.dtype
dtype = np.promote_types(a_dtype, int)
def impl(arr):
out = np.empty(arr.shape, dtype=dtype)
for col in prange(arr.shape[1]):
out[:, col] = np.nancumsum(arr[:, col])
return out
if not nb_enabled:
return impl(arr)
return impl
overload(_nancumsum_nb)(_nancumsum_nb)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="column_stack",
)
@register_jitted(cache=True)
def nancumsum_nb(arr: tp.Array2d) -> tp.Array2d:
"""Numba equivalent of `np.nancumsum` along axis 0."""
return _nancumsum_nb(arr)
def _nancumprod_nb(arr):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
else:
a_dtype = arr.dtype
dtype = np.promote_types(a_dtype, int)
def impl(arr):
out = np.empty(arr.shape, dtype=dtype)
for col in prange(arr.shape[1]):
out[:, col] = np.nancumprod(arr[:, col])
return out
if not nb_enabled:
return impl(arr)
return impl
overload(_nancumprod_nb)(_nancumprod_nb)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="column_stack",
)
@register_jitted(cache=True)
def nancumprod_nb(arr: tp.Array2d) -> tp.Array2d:
"""Numba equivalent of `np.nancumprod` along axis 0."""
return _nancumprod_nb(arr)
def _nansum_nb(arr):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
else:
a_dtype = arr.dtype
dtype = np.promote_types(a_dtype, int)
def impl(arr):
out = np.empty(arr.shape[1], dtype=dtype)
for col in prange(arr.shape[1]):
out[col] = np.nansum(arr[:, col])
return out
if not nb_enabled:
return impl(arr)
return impl
overload(_nansum_nb)(_nansum_nb)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True)
def nansum_nb(arr: tp.Array2d) -> tp.Array1d:
"""Numba equivalent of `np.nansum` along axis 0."""
return _nansum_nb(arr)
@register_jitted(cache=True)
def nancnt_1d_nb(arr: tp.Array1d) -> int:
"""Compute count while ignoring NaNs and not allocating any arrays."""
cnt = 0
for i in range(arr.shape[0]):
if not np.isnan(arr[i]):
cnt += 1
return cnt
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nancnt_nb(arr: tp.Array2d) -> tp.Array1d:
"""2-dim version of `nancnt_1d_nb`."""
out = np.empty(arr.shape[1], dtype=int_)
for col in prange(arr.shape[1]):
out[col] = nancnt_1d_nb(arr[:, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nanmin_nb(arr: tp.Array2d) -> tp.Array1d:
"""Numba equivalent of `np.nanmin` along axis 0."""
out = np.empty(arr.shape[1], dtype=arr.dtype)
for col in prange(arr.shape[1]):
out[col] = np.nanmin(arr[:, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nanmax_nb(arr: tp.Array2d) -> tp.Array1d:
"""Numba equivalent of `np.nanmax` along axis 0."""
out = np.empty(arr.shape[1], dtype=arr.dtype)
for col in prange(arr.shape[1]):
out[col] = np.nanmax(arr[:, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nanmean_nb(arr: tp.Array2d) -> tp.Array1d:
"""Numba equivalent of `np.nanmean` along axis 0."""
out = np.empty(arr.shape[1], dtype=float_)
for col in prange(arr.shape[1]):
out[col] = np.nanmean(arr[:, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nanmedian_nb(arr: tp.Array2d) -> tp.Array1d:
"""Numba equivalent of `np.nanmedian` along axis 0."""
out = np.empty(arr.shape[1], dtype=float_)
for col in prange(arr.shape[1]):
out[col] = np.nanmedian(arr[:, col])
return out
@register_jitted(cache=True)
def nanpercentile_noarr_1d_nb(arr: tp.Array1d, q: float) -> float:
"""Numba equivalent of `np.nanpercentile` that does not allocate any arrays.
!!! note
Has worst case time complexity of O(N^2), which makes it much slower than `np.nanpercentile`,
but still faster if used in rolling calculations, especially for `q` near 0 and 100."""
if q < 0:
q = 0
elif q > 100:
q = 100
do_min = q < 50
if not do_min:
q = 100 - q
cnt = arr.shape[0]
for i in range(arr.shape[0]):
if np.isnan(arr[i]):
cnt -= 1
if cnt == 0:
return np.nan
nth_float = q / 100 * (cnt - 1)
if nth_float % 1 == 0:
nth1 = nth2 = int(nth_float)
else:
nth1 = int(nth_float)
nth2 = nth1 + 1
found1 = np.nan
found2 = np.nan
k = 0
if do_min:
prev_val = -np.inf
else:
prev_val = np.inf
while True:
n_same = 0
if do_min:
curr_val = np.inf
for i in range(arr.shape[0]):
if not np.isnan(arr[i]):
if arr[i] > prev_val:
if arr[i] < curr_val:
curr_val = arr[i]
n_same = 0
if arr[i] == curr_val:
n_same += 1
else:
curr_val = -np.inf
for i in range(arr.shape[0]):
if not np.isnan(arr[i]):
if arr[i] < prev_val:
if arr[i] > curr_val:
curr_val = arr[i]
n_same = 0
if arr[i] == curr_val:
n_same += 1
prev_val = curr_val
k += n_same
if np.isnan(found1) and k >= nth1 + 1:
found1 = curr_val
if np.isnan(found2) and k >= nth2 + 1:
found2 = curr_val
break
if found1 == found2:
return found1
factor = (nth_float - nth1) / (nth2 - nth1)
return factor * (found2 - found1) + found1
@register_jitted(cache=True)
def nanpartition_mean_noarr_1d_nb(arr: tp.Array1d, q: float) -> float:
"""Average of `np.partition` that ignores NaN values and does not allocate any arrays.
!!! note
Has worst case time complexity of O(N^2), which makes it much slower than `np.partition`,
but still faster if used in rolling calculations, especially for `q` near 0."""
if q < 0:
q = 0
elif q > 100:
q = 100
cnt = arr.shape[0]
for i in range(arr.shape[0]):
if np.isnan(arr[i]):
cnt -= 1
if cnt == 0:
return np.nan
nth = int(q / 100 * (cnt - 1))
prev_val = -np.inf
partition_sum = 0.0
partition_cnt = 0
k = 0
while True:
n_same = 0
curr_val = np.inf
for i in range(arr.shape[0]):
if not np.isnan(arr[i]):
if arr[i] > prev_val:
if arr[i] < curr_val:
curr_val = arr[i]
n_same = 0
if arr[i] == curr_val:
n_same += 1
if k + n_same >= nth + 1:
partition_sum += (nth + 1 - k) * curr_val
partition_cnt += nth + 1 - k
break
else:
partition_sum += n_same * curr_val
partition_cnt += n_same
prev_val = curr_val
k += n_same
return partition_sum / partition_cnt
@register_jitted(cache=True)
def nanvar_1d_nb(arr: tp.Array1d, ddof: int = 0) -> float:
"""Numba equivalent of `np.nanvar` that does not allocate any arrays."""
cnt = arr.shape[0]
for i in range(arr.shape[0]):
if np.isnan(arr[i]):
cnt -= 1
rcount = max(cnt - ddof, 0)
if rcount == 0:
return np.nan
out = 0.0
a_mean = np.nanmean(arr)
for i in range(len(arr)):
if not np.isnan(arr[i]):
out += abs(arr[i] - a_mean) ** 2
return out / rcount
@register_jitted(cache=True)
def nanstd_1d_nb(arr: tp.Array1d, ddof: int = 0) -> float:
"""Numba equivalent of `np.nanstd`."""
return np.sqrt(nanvar_1d_nb(arr, ddof=ddof))
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), ddof=None),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nanstd_nb(arr: tp.Array2d, ddof: int = 0) -> tp.Array1d:
"""2-dim version of `nanstd_1d_nb`."""
out = np.empty(arr.shape[1], dtype=float_)
for col in prange(arr.shape[1]):
out[col] = nanstd_1d_nb(arr[:, col], ddof=ddof)
return out
@register_jitted(cache=True)
def nancov_1d_nb(arr1: tp.Array1d, arr2: tp.Array1d, ddof: int = 0) -> float:
"""Numba equivalent of `np.cov` that ignores NaN values."""
if len(arr1) != len(arr2):
raise ValueError("Arrays must have the same length")
arr1_sum = 0.0
arr2_sum = 0.0
arr12_sumprod = 0.0
k = 0
for i in range(arr1.shape[0]):
if not np.isnan(arr1[i]) and not np.isnan(arr2[i]):
arr1_sum += arr1[i]
arr2_sum += arr2[i]
arr12_sumprod += arr1[i] * arr2[i]
k += 1
if k - ddof <= 0:
return np.nan
arr1_mean = arr1_sum / k
arr2_mean = arr2_sum / k
return (arr12_sumprod - k * arr1_mean * arr2_mean) / (k - ddof)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr1", axis=1),
arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1), ddof=None),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nancov_nb(arr1: tp.Array2d, arr2: tp.Array2d, ddof: int = 0) -> tp.Array1d:
"""2-dim version of `nancov_1d_nb`."""
out = np.empty(arr1.shape[1], dtype=float_)
for col in prange(arr1.shape[1]):
out[col] = nancov_1d_nb(arr1[:, col], arr2[:, col], ddof=ddof)
return out
@register_jitted(cache=True)
def nancorr_1d_nb(arr1: tp.Array1d, arr2: tp.Array1d) -> float:
"""Numba equivalent of `np.corrcoef` that ignores NaN values.
Numerically stable."""
if len(arr1) != len(arr2):
raise ValueError("Arrays must have the same length")
arr1_sum = 0.0
arr2_sum = 0.0
arr1_sumsq = 0.0
arr2_sumsq = 0.0
arr12_sumprod = 0.0
k = 0
for i in range(arr1.shape[0]):
if not np.isnan(arr1[i]) and not np.isnan(arr2[i]):
arr1_sum += arr1[i]
arr2_sum += arr2[i]
arr1_sumsq += float(arr1[i]) ** 2
arr2_sumsq += float(arr2[i]) ** 2
arr12_sumprod += arr1[i] * arr2[i]
k += 1
if k == 0:
return np.nan
arr1_mean = arr1_sum / k
arr2_mean = arr2_sum / k
arr1_meansq = arr1_sumsq / k
arr2_meansq = arr2_sumsq / k
arr12_meanprod = arr12_sumprod / k
num = arr12_meanprod - arr1_mean * arr2_mean
denom = np.sqrt((arr1_meansq - arr1_mean**2) * (arr2_meansq - arr2_mean**2))
if denom == 0:
return np.nan
return num / denom
@register_chunkable(
size=ch.ArraySizer(arg_query="arr1", axis=1),
arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nancorr_nb(arr1: tp.Array2d, arr2: tp.Array2d) -> tp.Array1d:
"""2-dim version of `nancorr_1d_nb`."""
out = np.empty(arr1.shape[1], dtype=float_)
for col in prange(arr1.shape[1]):
out[col] = nancorr_1d_nb(arr1[:, col], arr2[:, col])
return out
@register_jitted(cache=True)
def rank_1d_nb(arr: tp.Array1d, argsorted: tp.Optional[tp.Array1d] = None, pct: bool = False) -> tp.Array1d:
"""Compute numerical data ranks.
Numba equivalent to `pd.Series(arr).rank(pct=pct)`."""
if argsorted is None:
argsorted = np.argsort(arr)
out = np.empty_like(arr, dtype=float_)
rank_sum = 0
rank_cnt = 0
nan_cnt = 0
for i in range(arr.shape[0]):
if np.isnan(arr[i]):
nan_cnt += 1
if nan_cnt == arr.shape[0]:
out[:] = np.nan
return out
valid_cnt = out.shape[0] - nan_cnt
for i in range(argsorted.shape[0]):
rank = i + 1
if np.isnan(arr[argsorted[i]]):
out[argsorted[i]] = np.nan
elif i < out.shape[0] - 1 and arr[argsorted[i]] == arr[argsorted[i + 1]]:
rank_sum += rank
rank_cnt += 1
if pct:
v = rank / valid_cnt
else:
v = rank
out[argsorted[i]] = v
elif rank_sum > 0:
rank_sum += rank
rank_cnt += 1
if pct:
v = rank_sum / rank_cnt / valid_cnt
else:
v = rank_sum / rank_cnt
out[argsorted[i - rank_cnt + 1 : i + 1]] = v
rank_sum = 0
rank_cnt = 0
else:
if pct:
v = rank / valid_cnt
else:
v = rank
out[argsorted[i]] = v
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), argsorted=ch.ArraySlicer(axis=1), pct=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rank_nb(arr: tp.Array2d, argsorted: tp.Optional[tp.Array2d] = None, pct: bool = False) -> tp.Array2d:
"""2-dim version of `rank_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
if argsorted is None:
out[:, col] = rank_1d_nb(arr[:, col], argsorted=None, pct=pct)
else:
out[:, col] = rank_1d_nb(arr[:, col], argsorted=argsorted[:, col], pct=pct)
return out
@register_jitted(cache=True)
def polyfit_1d_nb(x: tp.Array1d, y: tp.Array1d, deg: int, stabilize: bool = False) -> tp.Array1d:
"""Compute the least squares polynomial fit."""
if stabilize:
mat_ = np.ones(shape=(x.shape[0], deg + 1))
mat_[:, 1] = x
if deg > 1:
for n in range(2, deg + 1):
mat_[:, n] = mat_[:, n - 1] * x
scale_vect = np.empty((deg + 1,), dtype=float_)
for n in range(0, deg + 1):
col_norm = np.linalg.norm(mat_[:, n])
scale_vect[n] = col_norm
mat_[:, n] /= col_norm
det_ = np.linalg.lstsq(mat_, y, rcond=-1)[0] / scale_vect
else:
mat_ = np.zeros(shape=(x.shape[0], deg + 1))
const = np.ones_like(x)
mat_[:, 0] = const
mat_[:, 1] = x
if deg > 1:
for n in range(2, deg + 1):
mat_[:, n] = x**n
det_ = np.linalg.lstsq(mat_, y, rcond=-1)[0]
return det_[::-1]
@register_jitted(cache=True)
def fir_filter_1d_nb(b: tp.Array1d, x: tp.Array1d) -> tp.Array1d:
"""Filter data along one-dimension with an FIR filter."""
n = len(x)
m = len(b)
y = np.zeros(n)
for i in range(n):
for j in range(m):
if i - j >= 0:
y[i] += b[j] * x[i - j]
return y
# ############# Value counts ############# #
@register_chunkable(
size=ch.ArraySizer(arg_query="codes", axis=1),
arg_take_spec=dict(
codes=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper),
n_uniques=None,
group_map=base_ch.GroupMapSlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def value_counts_nb(codes: tp.Array2d, n_uniques: int, group_map: tp.GroupMap) -> tp.Array2d:
"""Compute value counts per column/group."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
out = np.full((n_uniques, group_lens.shape[0]), 0, dtype=int_)
for group in prange(len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
for k in range(group_len):
col = col_idxs[k]
for i in range(codes.shape[0]):
out[codes[i, col], group] += 1
return out
@register_jitted(cache=True)
def value_counts_1d_nb(codes: tp.Array1d, n_uniques: int) -> tp.Array1d:
"""Compute value counts."""
out = np.full(n_uniques, 0, dtype=int_)
for i in range(codes.shape[0]):
out[codes[i]] += 1
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="codes", axis=0),
arg_take_spec=dict(codes=ch.ArraySlicer(axis=0), n_uniques=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def value_counts_per_row_nb(codes: tp.Array2d, n_uniques: int) -> tp.Array2d:
"""Compute value counts per row."""
out = np.empty((n_uniques, codes.shape[0]), dtype=int_)
for i in prange(codes.shape[0]):
out[:, i] = value_counts_1d_nb(codes[i, :], n_uniques)
return out
# ############# Repartitioning ############# #
@register_jitted(cache=True)
def repartition_nb(arr: tp.Array2d, counts: tp.Array1d) -> tp.Array1d:
"""Repartition a 2-dimensional array into a 1-dimensional by removing empty elements."""
if arr.shape[0] == 0:
return arr.flatten()
out = np.empty(np.sum(counts), dtype=arr.dtype)
j = 0
for col in range(counts.shape[0]):
out[j : j + counts[col]] = arr[: counts[col], col]
j += counts[col]
return out
# ############# Crossover ############# #
@register_jitted(cache=True)
def crossed_above_1d_nb(arr1: tp.Array1d, arr2: tp.FlexArray1dLike, wait: int = 0, dropna: bool = False) -> tp.Array1d:
"""Get the crossover of the first array going above the second array.
If `dropna` is True, produces the same results as if all rows with at least one NaN were dropped."""
arr2_ = to_1d_array_nb(np.asarray(arr2))
out = np.empty(arr1.shape, dtype=np.bool_)
was_below = False
confirmed = 0
for i in range(arr1.shape[0]):
_arr1 = arr1[i]
_arr2 = flex_select_1d_nb(arr2_, i)
if np.isnan(_arr1) or np.isnan(_arr2):
if not dropna:
was_below = False
confirmed = 0
out[i] = False
elif _arr1 > _arr2:
if was_below:
confirmed += 1
out[i] = confirmed == wait + 1
else:
out[i] = False
elif _arr1 == _arr2:
if confirmed > 0:
was_below = False
confirmed = 0
out[i] = False
elif _arr1 < _arr2:
confirmed = 0
was_below = True
out[i] = False
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr1", axis=1),
arg_take_spec=dict(
arr1=ch.ArraySlicer(axis=1),
arr2=base_ch.FlexArraySlicer(axis=1),
wait=None,
dropna=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def crossed_above_nb(arr1: tp.Array2d, arr2: tp.FlexArray2dLike, wait: int = 0, dropna: bool = False) -> tp.Array2d:
"""2-dim version of `crossed_above_1d_nb`."""
arr2_ = to_2d_array_nb(np.asarray(arr2))
out = np.empty(arr1.shape, dtype=np.bool_)
for col in prange(arr1.shape[1]):
_arr2 = flex_select_col_nb(arr2_, col)
out[:, col] = crossed_above_1d_nb(arr1[:, col], _arr2, wait=wait, dropna=dropna)
return out
@register_jitted(cache=True)
def crossed_below_1d_nb(arr1: tp.Array1d, arr2: tp.FlexArray1dLike, wait: int = 0, dropna: bool = False) -> tp.Array1d:
"""Get the crossover of the first array going below the second array.
If `dropna` is True, produces the same results as if all rows with at least one NaN were dropped."""
arr2_ = to_1d_array_nb(np.asarray(arr2))
out = np.empty(arr1.shape, dtype=np.bool_)
was_above = False
confirmed = 0
for i in range(arr1.shape[0]):
_arr1 = arr1[i]
_arr2 = flex_select_1d_nb(arr2_, i)
if np.isnan(_arr1) or np.isnan(_arr2):
if not dropna:
was_above = False
confirmed = 0
out[i] = False
elif _arr1 < _arr2:
if was_above:
confirmed += 1
out[i] = confirmed == wait + 1
else:
out[i] = False
elif _arr1 == _arr2:
if confirmed > 0:
was_above = False
confirmed = 0
out[i] = False
elif _arr1 > _arr2:
confirmed = 0
was_above = True
out[i] = False
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr1", axis=1),
arg_take_spec=dict(
arr1=ch.ArraySlicer(axis=1),
arr2=base_ch.FlexArraySlicer(axis=1),
wait=None,
dropna=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def crossed_below_nb(arr1: tp.Array2d, arr2: tp.FlexArray2dLike, wait: int = 0, dropna: bool = False) -> tp.Array2d:
"""2-dim version of `crossed_below_1d_nb`."""
arr2_ = to_2d_array_nb(np.asarray(arr2))
out = np.empty(arr1.shape, dtype=np.bool_)
for col in prange(arr1.shape[1]):
_arr2 = flex_select_col_nb(arr2_, col)
out[:, col] = crossed_below_1d_nb(arr1[:, col], _arr2, wait=wait, dropna=dropna)
return out
# ############# Transforming ############# #
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="group_map"),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1, mapper=base_ch.group_idxs_mapper),
group_map=base_ch.GroupMapSlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def demean_nb(arr: tp.Array2d, group_map: tp.GroupMap) -> tp.Array2d:
"""Demean each value within its group."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
out = np.empty_like(arr, dtype=float_)
for group in prange(len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
for i in range(arr.shape[0]):
group_sum = 0
group_cnt = 0
for k in range(group_len):
col = col_idxs[k]
if not np.isnan(arr[i, col]):
group_sum += arr[i, col]
group_cnt += 1
for k in range(group_len):
col = col_idxs[k]
if np.isnan(arr[i, col]) or group_cnt == 0:
out[i, col] = np.nan
else:
out[i, col] = arr[i, col] - group_sum / group_cnt
return out
@register_jitted(cache=True)
def to_renko_1d_nb(
arr: tp.Array1d,
brick_size: tp.FlexArray1dLike,
relative: tp.FlexArray1dLike = False,
start_value: tp.Optional[float] = None,
max_out_len: tp.Optional[int] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]:
"""Convert to Renko format."""
brick_size_ = to_1d_array_nb(np.asarray(brick_size))
relative_ = to_1d_array_nb(np.asarray(relative))
if max_out_len is None:
out_n = arr.shape[0]
else:
out_n = max_out_len
arr_out = np.empty(out_n, dtype=float_)
idx_out = np.empty(out_n, dtype=int_)
uptrend_out = np.empty(out_n, dtype=np.bool_)
prev_value = np.nan
k = 0
trend = 0
for i in range(arr.shape[0]):
_brick_size = abs(flex_select_1d_nb(brick_size_, i))
_relative = flex_select_1d_nb(relative_, i)
curr_value = arr[i]
if np.isnan(curr_value):
continue
if np.isnan(prev_value):
if start_value is None:
if not _relative:
prev_value = curr_value - curr_value % _brick_size
else:
prev_value = curr_value
else:
prev_value = start_value
continue
if _relative:
diff = (curr_value - prev_value) / prev_value
else:
diff = curr_value - prev_value
while abs(diff) >= _brick_size:
prev_trend = trend
if diff >= 0:
if _relative:
prev_value *= 1 + _brick_size
else:
prev_value += _brick_size
trend = 1
else:
if _relative:
prev_value *= 1 - _brick_size
else:
prev_value -= _brick_size
trend = -1
if _relative:
diff = (curr_value - prev_value) / prev_value
else:
diff = curr_value - prev_value
if trend == -prev_trend:
continue
if k >= len(arr_out):
raise IndexError("Index out of range. Set a higher max_out_len.")
arr_out[k] = prev_value
idx_out[k] = i
uptrend_out[k] = trend == 1
k += 1
return arr_out[:k], idx_out[:k], uptrend_out[:k]
@register_jitted(cache=True)
def to_renko_ohlc_1d_nb(
arr: tp.Array1d,
brick_size: tp.FlexArray1dLike,
relative: tp.FlexArray1dLike = False,
start_value: tp.Optional[float] = None,
max_out_len: tp.Optional[int] = None,
) -> tp.Tuple[tp.Array2d, tp.Array1d]:
"""Convert to Renko OHLC format."""
brick_size_ = to_1d_array_nb(np.asarray(brick_size))
relative_ = to_1d_array_nb(np.asarray(relative))
if max_out_len is None:
out_n = arr.shape[0]
else:
out_n = max_out_len
arr_out = np.empty((out_n, 4), dtype=float_)
idx_out = np.empty(out_n, dtype=int_)
prev_value = np.nan
k = 0
trend = 0
for i in range(arr.shape[0]):
_brick_size = abs(flex_select_1d_nb(brick_size_, i))
_relative = flex_select_1d_nb(relative_, i)
curr_value = arr[i]
if np.isnan(curr_value):
continue
if np.isnan(prev_value):
if start_value is None:
if not _relative:
prev_value = curr_value - curr_value % _brick_size
else:
prev_value = curr_value
else:
prev_value = start_value
continue
if _relative:
diff = (curr_value - prev_value) / prev_value
else:
diff = curr_value - prev_value
while abs(diff) >= _brick_size:
open_value = prev_value
prev_trend = trend
if diff >= 0:
if _relative:
prev_value *= 1 + _brick_size
else:
prev_value += _brick_size
trend = 1
else:
if _relative:
prev_value *= 1 - _brick_size
else:
prev_value -= _brick_size
trend = -1
if _relative:
diff = (curr_value - prev_value) / prev_value
else:
diff = curr_value - prev_value
if trend == -prev_trend:
continue
if k >= len(arr_out):
raise IndexError("Index out of range. Set a higher max_out_len.")
if trend == 1:
high_value = prev_value
low_value = open_value
else:
high_value = open_value
low_value = prev_value
close_value = prev_value
arr_out[k, 0] = open_value
arr_out[k, 1] = high_value
arr_out[k, 2] = low_value
arr_out[k, 3] = close_value
idx_out[k] = i
k += 1
return arr_out[:k], idx_out[:k]
# ############# Resampling ############# #
def _realign_1d_nb(
arr,
source_index,
target_index,
source_freq,
target_freq,
source_rbound,
target_rbound,
nan_value,
ffill,
):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
value_dtype = as_dtype(nan_value)
else:
a_dtype = arr.dtype
value_dtype = np.array(nan_value).dtype
dtype = np.promote_types(a_dtype, value_dtype)
def impl(
arr,
source_index,
target_index,
source_freq,
target_freq,
source_rbound,
target_rbound,
nan_value,
ffill,
):
out = np.empty(target_index.shape[0], dtype=dtype)
curr_j = -1
last_j = curr_j
last_valid = np.nan
for i in range(len(target_index)):
if i > 0 and target_index[i] < target_index[i - 1]:
raise ValueError("Target index must be increasing")
target_bound_inf = target_rbound and i == len(target_index) - 1 and target_freq is None
last_valid_at_i = np.nan
for j in range(curr_j + 1, source_index.shape[0]):
if j > 0 and source_index[j] < source_index[j - 1]:
raise ValueError("Array index must be increasing")
source_bound_inf = source_rbound and j == len(source_index) - 1 and source_freq is None
if source_bound_inf and target_bound_inf:
curr_j = j
if not np.isnan(arr[curr_j]):
last_valid_at_i = arr[curr_j]
break
if source_bound_inf:
break
if target_bound_inf:
curr_j = j
if not np.isnan(arr[curr_j]):
last_valid_at_i = arr[curr_j]
continue
if source_rbound and target_rbound:
if source_freq is None:
source_val = source_index[j + 1]
else:
source_val = source_index[j] + source_freq
if target_freq is None:
target_val = target_index[i + 1]
else:
target_val = target_index[i] + target_freq
if source_val > target_val:
break
elif source_rbound:
if source_freq is None:
source_val = source_index[j + 1]
else:
source_val = source_index[j] + source_freq
if source_val > target_index[i]:
break
elif target_rbound:
if target_freq is None:
target_val = target_index[i + 1]
else:
target_val = target_index[i] + target_freq
if source_index[j] >= target_val:
break
else:
if source_index[j] > target_index[i]:
break
curr_j = j
if not np.isnan(arr[curr_j]):
last_valid_at_i = arr[curr_j]
if ffill and not np.isnan(last_valid_at_i):
last_valid = last_valid_at_i
if curr_j == -1 or (not ffill and curr_j == last_j):
out[i] = nan_value
else:
if ffill:
if np.isnan(last_valid):
out[i] = nan_value
else:
out[i] = last_valid
else:
if np.isnan(last_valid_at_i):
out[i] = nan_value
else:
out[i] = last_valid_at_i
last_j = curr_j
return out
if not nb_enabled:
return impl(
arr,
source_index,
target_index,
source_freq,
target_freq,
source_rbound,
target_rbound,
nan_value,
ffill,
)
return impl
overload(_realign_1d_nb)(_realign_1d_nb)
@register_jitted(cache=True)
def realign_1d_nb(
arr: tp.Array1d,
source_index: tp.Array1d,
target_index: tp.Array1d,
source_freq: tp.Optional[tp.Scalar] = None,
target_freq: tp.Optional[tp.Scalar] = None,
source_rbound: bool = False,
target_rbound: bool = None,
nan_value: tp.Scalar = np.nan,
ffill: bool = True,
) -> tp.Array1d:
"""Get the latest in `arr` at each index in `target_index` based on `source_index`.
If `source_rbound` is True, then each element in `source_index` is effectively located at
the right bound, which is the frequency or the next element (excluding) if the frequency is None.
The same for `target_rbound` and `target_index`.
!!! note
Both index arrays must be increasing. Repeating values are allowed.
If `arr` contains bar data, both indexes must represent the opening time."""
return _realign_1d_nb(
arr,
source_index,
target_index,
source_freq,
target_freq,
source_rbound,
target_rbound,
nan_value,
ffill,
)
def _realign_nb(
arr,
source_index,
target_index,
source_freq,
target_freq,
source_rbound,
target_rbound,
nan_value,
ffill,
):
nb_enabled = isinstance(arr, Type)
if nb_enabled:
a_dtype = as_dtype(arr.dtype)
value_dtype = as_dtype(nan_value)
else:
a_dtype = arr.dtype
value_dtype = np.array(nan_value).dtype
dtype = np.promote_types(a_dtype, value_dtype)
def impl(
arr,
source_index,
target_index,
source_freq,
target_freq,
source_rbound,
target_rbound,
nan_value,
ffill,
):
out = np.empty((target_index.shape[0], arr.shape[1]), dtype=dtype)
for col in prange(arr.shape[1]):
out[:, col] = realign_1d_nb(
arr[:, col],
source_index,
target_index,
source_freq=source_freq,
target_freq=target_freq,
source_rbound=source_rbound,
target_rbound=target_rbound,
nan_value=nan_value,
ffill=ffill,
)
return out
if not nb_enabled:
return impl(
arr,
source_index,
target_index,
source_freq,
target_freq,
source_rbound,
target_rbound,
nan_value,
ffill,
)
return impl
overload(_realign_nb)(_realign_nb)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
source_index=None,
target_index=None,
source_freq=None,
target_freq=None,
source_rbound=None,
target_rbound=None,
nan_value=None,
ffill=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True)
def realign_nb(
arr: tp.Array2d,
source_index: tp.Array1d,
target_index: tp.Array1d,
source_freq: tp.Optional[tp.Scalar] = None,
target_freq: tp.Optional[tp.Scalar] = None,
source_rbound: bool = False,
target_rbound: bool = False,
nan_value: tp.Scalar = np.nan,
ffill: bool = True,
) -> tp.Array2d:
"""2-dim version of `realign_1d_nb`."""
return _realign_nb(
arr,
source_index,
target_index,
source_freq,
target_freq,
source_rbound,
target_rbound,
nan_value,
ffill,
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Generic Numba-compiled functions for iterative use."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.base.flex_indexing import flex_select_nb
from vectorbtpro.registries.jit_registry import register_jitted
@register_jitted(cache=True)
def iter_above_nb(arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: int, col: int) -> bool:
"""Check whether `arr1` is above `arr2` at specific row and column."""
if i < 0:
return False
arr1_now = flex_select_nb(arr1, i, col)
arr2_now = flex_select_nb(arr2, i, col)
if np.isnan(arr1_now) or np.isnan(arr2_now):
return False
return arr1_now > arr2_now
@register_jitted(cache=True)
def iter_below_nb(arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: int, col: int) -> bool:
"""Check whether `arr1` is below `arr2` at specific row and column."""
if i < 0:
return False
arr1_now = flex_select_nb(arr1, i, col)
arr2_now = flex_select_nb(arr2, i, col)
if np.isnan(arr1_now) or np.isnan(arr2_now):
return False
return arr1_now < arr2_now
@register_jitted(cache=True)
def iter_crossed_above_nb(arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: int, col: int) -> bool:
"""Check whether `arr1` crossed above `arr2` at specific row and column."""
if i < 0 or i - 1 < 0:
return False
arr1_prev = flex_select_nb(arr1, i - 1, col)
arr2_prev = flex_select_nb(arr2, i - 1, col)
arr1_now = flex_select_nb(arr1, i, col)
arr2_now = flex_select_nb(arr2, i, col)
if np.isnan(arr1_prev) or np.isnan(arr2_prev) or np.isnan(arr1_now) or np.isnan(arr2_now):
return False
return arr1_prev < arr2_prev and arr1_now > arr2_now
@register_jitted(cache=True)
def iter_crossed_below_nb(arr1: tp.FlexArray2d, arr2: tp.FlexArray2d, i: int, col: int) -> bool:
"""Check whether `arr1` crossed below `arr2` at specific row and column."""
if i < 0 or i - 1 < 0:
return False
arr1_prev = flex_select_nb(arr1, i - 1, col)
arr2_prev = flex_select_nb(arr2, i - 1, col)
arr1_now = flex_select_nb(arr1, i, col)
arr2_now = flex_select_nb(arr2, i, col)
if np.isnan(arr1_prev) or np.isnan(arr2_prev) or np.isnan(arr1_now) or np.isnan(arr2_now):
return False
return arr1_prev > arr2_prev and arr1_now < arr2_now
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Generic Numba-compiled functions for working with patterns."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.flex_indexing import flex_select_1d_nb
from vectorbtpro.base.reshaping import to_1d_array_nb
from vectorbtpro.generic.enums import RescaleMode, InterpMode, ErrorType, DistanceMeasure
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils.array_ import rescale_nb
@register_jitted(cache=True)
def linear_interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int) -> float:
"""Get the value at a specific position in a target size using linear interpolation."""
if i == 0 or source_size == 1 or target_size == 1:
return float(flex_select_1d_nb(arr, 0))
if source_size == target_size:
return float(flex_select_1d_nb(arr, i))
if i == target_size - 1:
return float(flex_select_1d_nb(arr, source_size - 1))
mapped_i = i / (target_size - 1) * (source_size - 1)
left_i = int(np.floor(mapped_i))
right_i = int(np.ceil(mapped_i))
norm_mapped_i = mapped_i - left_i
left_elem = float(flex_select_1d_nb(arr, left_i))
right_elem = float(flex_select_1d_nb(arr, right_i))
return left_elem + norm_mapped_i * (right_elem - left_elem)
@register_jitted(cache=True)
def nearest_interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int) -> float:
"""Get the value at a specific position in a target size using nearest-neighbor interpolation."""
if i == 0 or source_size == 1 or target_size == 1:
return float(flex_select_1d_nb(arr, 0))
if source_size == target_size:
return float(flex_select_1d_nb(arr, i))
if i == target_size - 1:
return float(flex_select_1d_nb(arr, source_size - 1))
mapped_i = i / (target_size - 1) * (source_size - 1)
return float(flex_select_1d_nb(arr, round(mapped_i)))
@register_jitted(cache=True)
def discrete_interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int) -> float:
"""Get the value at a specific position in a target size using discrete interpolation."""
if source_size >= target_size:
return nearest_interp_nb(arr, i, source_size, target_size)
if i == 0 or source_size == 1 or target_size == 1:
return float(flex_select_1d_nb(arr, 0))
if i == target_size - 1:
return float(flex_select_1d_nb(arr, source_size - 1))
curr_float_mapped_i = i / (target_size - 1) * (source_size - 1)
curr_remainder = curr_float_mapped_i % 1
if curr_remainder == 0:
return float(flex_select_1d_nb(arr, int(curr_float_mapped_i)))
if curr_remainder <= 0.5:
prev_float_mapped_i = (i - 1) / (target_size - 1) * (source_size - 1)
if int(curr_float_mapped_i) != int(prev_float_mapped_i):
prev_remainder = prev_float_mapped_i % 1
if curr_remainder < 1 - prev_remainder:
return float(flex_select_1d_nb(arr, int(np.floor(curr_float_mapped_i))))
return np.nan
next_float_mapped_i = (i + 1) / (target_size - 1) * (source_size - 1)
if int(curr_float_mapped_i) != int(next_float_mapped_i):
next_remainder = next_float_mapped_i % 1
if 1 - curr_remainder <= next_remainder:
return float(flex_select_1d_nb(arr, int(np.ceil(curr_float_mapped_i))))
return np.nan
@register_jitted(cache=True)
def mixed_interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int) -> float:
"""Get the value at a specific position in a target size using mixed interpolation.
Mixed interpolation is based on the discrete interpolation, while filling resulting NaN values
using the linear interpolation. This way, the vertical scale of the pattern array is respected."""
value = discrete_interp_nb(arr, i, source_size, target_size)
if np.isnan(value):
value = linear_interp_nb(arr, i, source_size, target_size)
return value
@register_jitted(cache=True)
def interp_nb(arr: tp.FlexArray1d, i: int, source_size: int, target_size: int, interp_mode: int) -> float:
"""Get the value at a specific position in a target size using an interpolation mode.
See `vectorbtpro.generic.enums.InterpMode`."""
if interp_mode == InterpMode.Linear:
return linear_interp_nb(arr, i, source_size, target_size)
if interp_mode == InterpMode.Nearest:
return nearest_interp_nb(arr, i, source_size, target_size)
if interp_mode == InterpMode.Discrete:
return discrete_interp_nb(arr, i, source_size, target_size)
if interp_mode == InterpMode.Mixed:
return mixed_interp_nb(arr, i, source_size, target_size)
raise ValueError("Invalid interpolation mode")
@register_jitted(cache=True)
def interp_resize_1d_nb(arr: tp.FlexArray1d, target_size: int, interp_mode: int) -> tp.Array1d:
"""Resize an array using `interp_nb`."""
out = np.empty(target_size, dtype=float_)
for i in range(target_size):
out[i] = interp_nb(arr, i, arr.size, target_size, interp_mode)
return out
@register_jitted(cache=True)
def fit_pattern_nb(
arr: tp.Array1d,
pattern: tp.Array1d,
interp_mode: int = InterpMode.Mixed,
rescale_mode: int = RescaleMode.MinMax,
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: int = ErrorType.Absolute,
max_error: tp.FlexArray1dLike = np.nan,
max_error_interp_mode: tp.Optional[int] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Fit pattern.
Returns the resized and rescaled pattern and max error."""
max_error_ = to_1d_array_nb(np.asarray(max_error))
if max_error_interp_mode is None:
max_error_interp_mode = interp_mode
fit_pattern = interp_resize_1d_nb(
pattern,
len(arr),
interp_mode,
)
fit_max_error = interp_resize_1d_nb(
max_error_,
len(arr),
max_error_interp_mode,
)
if np.isnan(vmin):
vmin = np.nanmin(arr)
else:
vmin = vmin
if np.isnan(vmax):
vmax = np.nanmax(arr)
else:
vmax = vmax
if np.isnan(pmin):
pmin = np.nanmin(fit_pattern)
else:
pmin = pmin
if np.isnan(pmax):
pmax = np.nanmax(fit_pattern)
else:
pmax = pmax
if invert:
fit_pattern = pmax + pmin - fit_pattern
if rescale_mode == RescaleMode.Rebase:
if not np.isnan(pmin):
if fit_pattern[0] == 0:
pmin = np.nan
else:
pmin = pmin / fit_pattern[0] * arr[0]
if not np.isnan(pmax):
if fit_pattern[0] == 0:
pmax = np.nan
else:
pmax = pmax / fit_pattern[0] * arr[0]
if rescale_mode == RescaleMode.Rebase:
if fit_pattern[0] == 0:
fit_pattern = np.full(fit_pattern.shape, np.nan)
else:
fit_pattern = fit_pattern / fit_pattern[0] * arr[0]
fit_max_error = fit_max_error * fit_pattern
fit_pattern = np.clip(fit_pattern, pmin, pmax)
if rescale_mode == RescaleMode.MinMax:
fit_pattern = rescale_nb(fit_pattern, (pmin, pmax), (vmin, vmax))
if error_type == ErrorType.Absolute:
if pmax - pmin == 0:
fit_max_error = np.full(fit_max_error.shape, np.nan)
else:
fit_max_error = fit_max_error / (pmax - pmin) * (vmax - vmin)
else:
fit_max_error = fit_max_error * fit_pattern
return fit_pattern, fit_max_error
@register_jitted(cache=True)
def pattern_similarity_nb(
arr: tp.Array1d,
pattern: tp.Array1d,
interp_mode: int = InterpMode.Mixed,
rescale_mode: int = RescaleMode.MinMax,
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: int = ErrorType.Absolute,
distance_measure: int = DistanceMeasure.MAE,
max_error: tp.FlexArray1dLike = np.nan,
max_error_interp_mode: tp.Optional[int] = None,
max_error_as_maxdist: bool = False,
max_error_strict: bool = False,
min_pct_change: float = np.nan,
max_pct_change: float = np.nan,
min_similarity: float = np.nan,
minp: tp.Optional[int] = None,
) -> float:
"""Get the similarity between an array and a pattern array.
At each position in the array, the value in `arr` is first mapped into the range of `pattern`.
Then, the absolute distance between the actual and expected value is calculated (= error).
This error is then divided by the maximum error at this position to get a relative value between 0 and 1.
Finally, all relative errors are added together and subtracted from 1 to get the similarity measure.
* For `interp_mode`, see `vectorbtpro.generic.enums.InterpMode`
* For `rescale_mode`, see `vectorbtpro.generic.enums.RescaleMode`
* For `error_type`, see `vectorbtpro.generic.enums.ErrorType`
* For `distance_measure`, see `vectorbtpro.generic.enums.DistanceMeasure`
"""
max_error_ = to_1d_array_nb(np.asarray(max_error))
if len(arr) == 0:
return np.nan
if len(pattern) == 0:
return np.nan
if rescale_mode == RescaleMode.Rebase:
if np.isnan(pattern[0]):
return np.nan
if np.isnan(arr[0]):
return np.nan
if max_error_interp_mode is None or max_error_interp_mode == -1:
_max_error_interp_mode = interp_mode
else:
_max_error_interp_mode = max_error_interp_mode
max_size = max(arr.shape[0], pattern.shape[0])
if error_type != ErrorType.Absolute and error_type != ErrorType.Relative:
raise ValueError("Invalid error type")
if (
distance_measure != DistanceMeasure.MAE
and distance_measure != DistanceMeasure.MSE
and distance_measure != DistanceMeasure.RMSE
):
raise ValueError("Invalid distance mode")
if minp is None:
minp = arr.shape[0]
min_max_required = False
if rescale_mode == RescaleMode.MinMax:
min_max_required = True
if not np.isnan(min_pct_change):
min_max_required = True
if not np.isnan(max_pct_change):
min_max_required = True
if not max_error_as_maxdist:
min_max_required = True
if invert:
min_max_required = True
if min_max_required:
vmin_set = not np.isnan(vmin)
vmax_set = not np.isnan(vmax)
pmin_set = not np.isnan(pmin)
pmax_set = not np.isnan(pmax)
if not vmin_set or not vmax_set or not pmin_set or not pmax_set:
for i in range(max_size):
if arr.shape[0] >= pattern.shape[0]:
arr_elem = arr[i]
else:
arr_elem = interp_nb(arr, i, arr.shape[0], pattern.shape[0], interp_mode)
if pattern.shape[0] >= arr.shape[0]:
pattern_elem = pattern[i]
else:
pattern_elem = interp_nb(pattern, i, pattern.shape[0], arr.shape[0], interp_mode)
if not np.isnan(arr_elem):
if not vmin_set and (np.isnan(vmin) or arr_elem < vmin):
vmin = arr_elem
if not vmax_set and (np.isnan(vmax) or arr_elem > vmax):
vmax = arr_elem
if not np.isnan(pattern_elem):
if not pmin_set and (np.isnan(pmin) or pattern_elem < pmin):
pmin = pattern_elem
if not pmax_set and (np.isnan(pmax) or pattern_elem > pmax):
pmax = pattern_elem
if np.isnan(vmin) or np.isnan(vmax):
return np.nan
if np.isnan(pmin) or np.isnan(pmax):
return np.nan
if vmin == vmax and rescale_mode == RescaleMode.MinMax:
return np.nan
if pmin == pmax and rescale_mode == RescaleMode.MinMax:
return np.nan
if not np.isnan(min_pct_change) and vmin != 0 and (vmax - vmin) / vmin < min_pct_change:
return np.nan
if not np.isnan(max_pct_change) and vmin != 0 and (vmax - vmin) / vmin > max_pct_change:
return np.nan
first_pattern_elem = pattern[0]
if invert:
first_pattern_elem = pmax + pmin - first_pattern_elem
if rescale_mode == RescaleMode.Rebase:
if not np.isnan(pmin):
if first_pattern_elem == 0:
pmin = np.nan
else:
pmin = pmin / first_pattern_elem * arr[0]
if not np.isnan(pmax):
if first_pattern_elem == 0:
pmax = np.nan
else:
pmax = pmax / first_pattern_elem * arr[0]
if rescale_mode == RescaleMode.Rebase or rescale_mode == RescaleMode.Disable:
if not np.isnan(pmin) and not np.isnan(vmin):
_min = min(pmin, vmin)
else:
_min = vmin
if not np.isnan(pmax) and not np.isnan(vmax):
_max = max(pmax, vmax)
else:
_max = vmax
else:
_min = vmin
_max = vmax
distance_sum = 0.0
maxdistance_sum = 0.0
nan_count = 0
for i in range(max_size):
if i < arr.shape[0]:
if np.isnan(arr[i]):
nan_count += 1
if max_size - nan_count < minp:
return np.nan
if arr.shape[0] == pattern.shape[0]:
arr_elem = arr[i]
pattern_elem = pattern[i]
_max_error = flex_select_1d_nb(max_error_, i)
elif arr.shape[0] > pattern.shape[0]:
arr_elem = arr[i]
pattern_elem = interp_nb(pattern, i, pattern.shape[0], arr.shape[0], interp_mode)
_max_error = interp_nb(max_error_, i, pattern.shape[0], arr.shape[0], _max_error_interp_mode)
else:
arr_elem = interp_nb(arr, i, arr.shape[0], pattern.shape[0], interp_mode)
pattern_elem = pattern[i]
_max_error = flex_select_1d_nb(max_error_, i)
if not np.isnan(arr_elem) and not np.isnan(pattern_elem):
if invert:
pattern_elem = pmax + pmin - pattern_elem
if rescale_mode == RescaleMode.Rebase:
if first_pattern_elem == 0:
pattern_elem = np.nan
else:
pattern_elem = pattern_elem / first_pattern_elem * arr[0]
if error_type == ErrorType.Absolute:
_max_error = _max_error * pattern_elem
if not np.isnan(vmin) and arr_elem < vmin:
arr_elem = vmin
if not np.isnan(vmax) and arr_elem > vmax:
arr_elem = vmax
if not np.isnan(pmin) and pattern_elem < pmin:
pattern_elem = pmin
if not np.isnan(pmax) and pattern_elem > pmax:
pattern_elem = pmax
if rescale_mode == RescaleMode.MinMax:
if pmax - pmin == 0:
pattern_elem = np.nan
else:
pattern_elem = (pattern_elem - pmin) / (pmax - pmin) * (vmax - vmin) + vmin
if error_type == ErrorType.Absolute:
if pmax - pmin == 0:
_max_error = np.nan
else:
_max_error = _max_error / (pmax - pmin) * (vmax - vmin)
if distance_measure == DistanceMeasure.MAE:
if error_type == ErrorType.Absolute:
dist = abs(arr_elem - pattern_elem)
else:
if pattern_elem == 0:
dist = np.nan
else:
dist = abs(arr_elem - pattern_elem) / pattern_elem
else:
if error_type == ErrorType.Absolute:
dist = (arr_elem - pattern_elem) ** 2
else:
if pattern_elem == 0:
dist = np.nan
else:
dist = ((arr_elem - pattern_elem) / pattern_elem) ** 2
if max_error_as_maxdist:
if np.isnan(_max_error):
continue
maxdist = _max_error
else:
if distance_measure == DistanceMeasure.MAE:
if error_type == ErrorType.Absolute:
maxdist = max(pattern_elem - _min, _max - pattern_elem)
else:
if pattern_elem == 0:
maxdist = np.nan
else:
maxdist = max(pattern_elem - _min, _max - pattern_elem) / pattern_elem
else:
if error_type == ErrorType.Absolute:
maxdist = max(pattern_elem - _min, _max - pattern_elem) ** 2
else:
if pattern_elem == 0:
maxdist = np.nan
else:
maxdist = (max(pattern_elem - _min, _max - pattern_elem) / pattern_elem) ** 2
if dist > 0 and maxdist == 0:
return np.nan
if not np.isnan(_max_error) and dist > _max_error:
if max_error_strict:
return np.nan
dist = maxdist
if dist > maxdist:
dist = maxdist
distance_sum = distance_sum + dist
maxdistance_sum = maxdistance_sum + maxdist
if not np.isnan(min_similarity):
if not max_error_as_maxdist or max_error_.size == 1:
if max_error_as_maxdist:
if np.isnan(_max_error):
return np.nan
worst_maxdist = _max_error
else:
if distance_measure == DistanceMeasure.MAE:
if error_type == ErrorType.Absolute:
worst_maxdist = _max - _min
else:
if _min == 0:
worst_maxdist = np.nan
else:
worst_maxdist = (_max - _min) / _min
else:
if error_type == ErrorType.Absolute:
worst_maxdist = (_max - _min) ** 2
else:
if _min == 0:
worst_maxdist = np.nan
else:
worst_maxdist = ((_max - _min) / _min) ** 2
worst_maxdistance_sum = maxdistance_sum + worst_maxdist * (max_size - i - 1)
if worst_maxdistance_sum == 0:
return np.nan
if distance_measure == DistanceMeasure.RMSE:
if worst_maxdistance_sum == 0:
best_similarity = np.nan
else:
best_similarity = 1 - np.sqrt(distance_sum) / np.sqrt(worst_maxdistance_sum)
else:
if worst_maxdistance_sum == 0:
best_similarity = np.nan
else:
best_similarity = 1 - distance_sum / worst_maxdistance_sum
if best_similarity < min_similarity:
return np.nan
if distance_sum == 0:
return 1.0
if maxdistance_sum == 0:
return np.nan
if distance_measure == DistanceMeasure.RMSE:
if maxdistance_sum == 0:
similarity = np.nan
else:
similarity = 1 - np.sqrt(distance_sum) / np.sqrt(maxdistance_sum)
else:
if maxdistance_sum == 0:
similarity = np.nan
else:
similarity = 1 - distance_sum / maxdistance_sum
if not np.isnan(min_similarity):
if similarity < min_similarity:
return np.nan
return similarity
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Generic Numba-compiled functions for records."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb, flex_select_nb
from vectorbtpro.base.reshaping import to_1d_array_nb
from vectorbtpro.generic.enums import *
from vectorbtpro.generic.nb.base import repartition_nb
from vectorbtpro.generic.nb.patterns import pattern_similarity_nb
from vectorbtpro.generic.nb.sim_range import prepare_sim_range_nb
from vectorbtpro.records import chunking as records_ch
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.template import Rep
# ############# Ranges ############# #
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), gap_value=None),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_ranges_nb(arr: tp.Array2d, gap_value: tp.Scalar) -> tp.RecordArray:
"""Fill range records between gaps.
Usage:
* Find ranges in time series:
```pycon
>>> from vectorbtpro import *
>>> a = np.array([
... [np.nan, np.nan, np.nan, np.nan],
... [ 2, np.nan, np.nan, np.nan],
... [ 3, 3, np.nan, np.nan],
... [np.nan, 4, 4, np.nan],
... [ 5, np.nan, 5, 5],
... [ 6, 6, np.nan, 6]
... ])
>>> records = vbt.nb.get_ranges_nb(a, np.nan)
>>> pd.DataFrame.from_records(records)
id col start_idx end_idx status
0 0 0 1 3 1
1 1 0 4 5 0
2 0 1 2 4 1
3 1 1 5 5 0
4 0 2 3 5 1
5 0 3 4 5 0
```
"""
new_records = np.empty(arr.shape, dtype=range_dt)
counts = np.full(arr.shape[1], 0, dtype=int_)
for col in prange(arr.shape[1]):
range_started = False
start_idx = -1
end_idx = -1
store_record = False
status = -1
for i in range(arr.shape[0]):
cur_val = arr[i, col]
if cur_val == gap_value or np.isnan(cur_val) and np.isnan(gap_value):
if range_started:
# If stopped, save the current range
end_idx = i
range_started = False
store_record = True
status = RangeStatus.Closed
else:
if not range_started:
# If started, register a new range
start_idx = i
range_started = True
if i == arr.shape[0] - 1 and range_started:
# If still running, mark for save
end_idx = arr.shape[0] - 1
range_started = False
store_record = True
status = RangeStatus.Open
if store_record:
# Save range to the records
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = start_idx
new_records["end_idx"][r, col] = end_idx
new_records["status"][r, col] = status
counts[col] += 1
# Reset running vars for a new range
store_record = False
return repartition_nb(new_records, counts)
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
n_rows=None,
idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
id_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
index=None,
delta=None,
delta_use_index=None,
),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_ranges_from_delta_nb(
n_rows: int,
idx_arr: tp.Array1d,
id_arr: tp.Array1d,
col_map: tp.GroupMap,
index: tp.Optional[tp.Array1d] = None,
delta: int = 0,
delta_use_index: bool = False,
shift: int = 0,
) -> tp.RecordArray:
"""Build delta ranges."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.empty(idx_arr.shape[0], dtype=range_dt)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
ridxs = col_idxs[col_start_idx : col_start_idx + col_len]
for r in ridxs:
r_idx = idx_arr[r] + shift
if r_idx < 0:
r_idx = 0
if r_idx > n_rows - 1:
r_idx = n_rows - 1
if delta >= 0:
start_idx = r_idx
if delta_use_index:
if index is None:
raise ValueError("Index is required")
end_idx = len(index) - 1
status = RangeStatus.Open
for i in range(start_idx, index.shape[0]):
if index[i] >= index[start_idx] + delta:
end_idx = i
status = RangeStatus.Closed
break
else:
if start_idx + delta < n_rows:
end_idx = start_idx + delta
status = RangeStatus.Closed
else:
end_idx = n_rows - 1
status = RangeStatus.Open
else:
end_idx = r_idx
status = RangeStatus.Closed
if delta_use_index:
if index is None:
raise ValueError("Index is required")
start_idx = 0
for i in range(end_idx, -1, -1):
if index[i] <= index[end_idx] + delta:
start_idx = i
break
else:
if end_idx + delta >= 0:
start_idx = end_idx + delta
else:
start_idx = 0
out["id"][r] = id_arr[r]
out["col"][r] = col
out["start_idx"][r] = start_idx
out["end_idx"][r] = end_idx
out["status"][r] = status
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="start_idx_arr", axis=0),
arg_take_spec=dict(
start_idx_arr=ch.ArraySlicer(axis=0),
end_idx_arr=ch.ArraySlicer(axis=0),
status_arr=ch.ArraySlicer(axis=0),
freq=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def range_duration_nb(
start_idx_arr: tp.Array1d,
end_idx_arr: tp.Array1d,
status_arr: tp.Array2d,
freq: int = 1,
) -> tp.Array1d:
"""Get duration of each range record."""
out = np.empty(start_idx_arr.shape[0], dtype=int_)
for r in prange(start_idx_arr.shape[0]):
if status_arr[r] == RangeStatus.Open:
out[r] = end_idx_arr[r] - start_idx_arr[r] + freq
else:
out[r] = end_idx_arr[r] - start_idx_arr[r]
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
start_idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
end_idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
status_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
index_lens=ch.ArraySlicer(axis=0),
overlapping=None,
normalize=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def range_coverage_nb(
start_idx_arr: tp.Array1d,
end_idx_arr: tp.Array1d,
status_arr: tp.Array2d,
col_map: tp.GroupMap,
index_lens: tp.Array1d,
overlapping: bool = False,
normalize: bool = False,
) -> tp.Array1d:
"""Get coverage of range records.
Set `overlapping` to True to get the number of overlapping steps.
Set `normalize` to True to get the number of steps in relation either to the total number of steps
(when `overlapping=False`) or to the number of covered steps (when `overlapping=True`).
"""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(col_lens.shape[0], np.nan, dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
ridxs = col_idxs[col_start_idx : col_start_idx + col_len]
temp = np.full(index_lens[col], 0, dtype=int_)
for r in ridxs:
if status_arr[r] == RangeStatus.Open:
temp[start_idx_arr[r] : end_idx_arr[r] + 1] += 1
else:
temp[start_idx_arr[r] : end_idx_arr[r]] += 1
if overlapping:
if normalize:
pos_temp_sum = np.sum(temp > 0)
if pos_temp_sum == 0:
out[col] = np.nan
else:
out[col] = np.sum(temp > 1) / pos_temp_sum
else:
out[col] = np.sum(temp > 1)
else:
if normalize:
if index_lens[col] == 0:
out[col] = np.nan
else:
out[col] = np.sum(temp > 0) / index_lens[col]
else:
out[col] = np.sum(temp > 0)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
start_idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
end_idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
status_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
index_len=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ranges_to_mask_nb(
start_idx_arr: tp.Array1d,
end_idx_arr: tp.Array1d,
status_arr: tp.Array2d,
col_map: tp.GroupMap,
index_len: int,
) -> tp.Array2d:
"""Convert ranges to 2-dim mask."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full((index_len, col_lens.shape[0]), False, dtype=np.bool_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
ridxs = col_idxs[col_start_idx : col_start_idx + col_len]
for r in ridxs:
if status_arr[r] == RangeStatus.Open:
out[start_idx_arr[r] : end_idx_arr[r] + 1, col] = True
else:
out[start_idx_arr[r] : end_idx_arr[r], col] = True
return out
@register_jitted(cache=True)
def map_ranges_to_projections_nb(
close: tp.Array2d,
col_arr: tp.Array1d,
start_idx_arr: tp.Array1d,
end_idx_arr: tp.Array1d,
status_arr: tp.Array1d,
index: tp.Optional[tp.Array1d] = None,
proj_start: int = 0,
proj_start_use_index: bool = False,
proj_period: tp.Optional[int] = None,
proj_period_use_index: bool = False,
incl_end_idx: bool = True,
extend: bool = False,
rebase: bool = True,
start_value: tp.FlexArray1dLike = 1.0,
ffill: bool = False,
remove_empty: bool = False,
) -> tp.Tuple[tp.Array1d, tp.Array2d]:
"""Map each range into a projection.
Returns two arrays:
1. One-dimensional array where elements are record indices
2. Two-dimensional array where rows are projections"""
start_value_ = to_1d_array_nb(np.asarray(start_value))
index_ranges_temp = np.empty((start_idx_arr.shape[0], 2), dtype=int_)
max_duration = 0
for r in range(start_idx_arr.shape[0]):
if proj_start_use_index:
if index is None:
raise ValueError("Index is required")
r_proj_start = len(index) - start_idx_arr[r]
for i in range(start_idx_arr[r], index.shape[0]):
if index[i] >= index[start_idx_arr[r]] + proj_start:
r_proj_start = i - start_idx_arr[r]
break
r_start_idx = start_idx_arr[r] + r_proj_start
else:
r_start_idx = start_idx_arr[r] + proj_start
if status_arr[r] == RangeStatus.Open:
if incl_end_idx:
r_duration = end_idx_arr[r] - start_idx_arr[r] + 1
else:
r_duration = end_idx_arr[r] - start_idx_arr[r]
else:
if incl_end_idx:
r_duration = end_idx_arr[r] - start_idx_arr[r]
else:
r_duration = end_idx_arr[r] - start_idx_arr[r] - 1
if proj_period is None:
r_end_idx = start_idx_arr[r] + r_duration
else:
if proj_period_use_index:
if index is None:
raise ValueError("Index is required")
r_proj_period = -1
for i in range(r_start_idx, index.shape[0]):
if index[i] <= index[r_start_idx] + proj_period:
r_proj_period = i - r_start_idx
else:
break
else:
r_proj_period = proj_period
if extend:
r_end_idx = r_start_idx + r_proj_period
else:
r_end_idx = min(start_idx_arr[r] + r_duration, r_start_idx + r_proj_period)
r_end_idx = r_end_idx + 1
if r_end_idx > close.shape[0]:
r_end_idx = close.shape[0]
if r_start_idx > r_end_idx:
r_start_idx = r_end_idx
if r_end_idx - r_start_idx > max_duration:
max_duration = r_end_idx - r_start_idx
index_ranges_temp[r, 0] = r_start_idx
index_ranges_temp[r, 1] = r_end_idx
ridx_out = np.empty((start_idx_arr.shape[0],), dtype=int_)
proj_out = np.empty((start_idx_arr.shape[0], max_duration), dtype=float_)
k = 0
for r in range(start_idx_arr.shape[0]):
if extend:
r_start_idx = index_ranges_temp[r, 0]
r_end_idx = index_ranges_temp[r, 0] + proj_out.shape[1]
else:
r_start_idx = index_ranges_temp[r, 0]
r_end_idx = index_ranges_temp[r, 1]
r_close = close[r_start_idx:r_end_idx, col_arr[r]]
any_set = False
for i in range(proj_out.shape[1]):
if i >= r_close.shape[0]:
proj_out[k, i] = np.nan
else:
if rebase:
if i == 0:
_start_value = flex_select_1d_pc_nb(start_value_, col_arr[r])
if _start_value == -1:
proj_out[k, i] = close[-1, col_arr[r]]
else:
proj_out[k, i] = _start_value
else:
if r_close[i - 1] == 0:
proj_out[k, i] = np.nan
else:
proj_out[k, i] = proj_out[k, i - 1] * r_close[i] / r_close[i - 1]
else:
proj_out[k, i] = r_close[i]
if not np.isnan(proj_out[k, i]) and i > 0:
any_set = True
if ffill and np.isnan(proj_out[k, i]) and i > 0:
proj_out[k, i] = proj_out[k, i - 1]
if any_set or not remove_empty:
ridx_out[k] = r
k += 1
if remove_empty:
return ridx_out[:k], proj_out[:k]
return ridx_out, proj_out
@register_jitted(cache=True)
def find_pattern_1d_nb(
arr: tp.Array1d,
pattern: tp.Array1d,
window: tp.Optional[int] = None,
max_window: tp.Optional[int] = None,
row_select_prob: float = 1.0,
window_select_prob: float = 1.0,
roll_forward: bool = False,
interp_mode: int = InterpMode.Mixed,
rescale_mode: int = RescaleMode.MinMax,
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: int = ErrorType.Absolute,
distance_measure: int = DistanceMeasure.MAE,
max_error: tp.FlexArray1dLike = np.nan,
max_error_interp_mode: tp.Optional[int] = None,
max_error_as_maxdist: bool = False,
max_error_strict: bool = False,
min_pct_change: float = np.nan,
max_pct_change: float = np.nan,
min_similarity: float = 0.85,
minp: tp.Optional[int] = None,
overlap_mode: int = OverlapMode.Disallow,
max_records: tp.Optional[int] = None,
col: int = 0,
) -> tp.RecordArray:
"""Find all occurrences of a pattern in an array.
Uses `vectorbtpro.generic.nb.patterns.pattern_similarity_nb` to fill records of the type
`vectorbtpro.generic.enums.pattern_range_dt`.
Goes through the array, and for each window selected between `window` and `max_window` (including),
checks whether the window of array values is similar enough to the pattern. If so, writes a new
range to the output array. If `window_select_prob` is set, decides whether to test a window based on
the given probability. The same for `row_select_prob` but on rows.
If `roll_forward`, windows are rolled forward (`start_idx` is guaranteed to be sorted), otherwise
backward (`end_idx` is guaranteed to be sorted).
By default, creates an empty record array of the same size as the number of rows in `arr`.
This can be increased or decreased using `max_records`."""
max_error_ = to_1d_array_nb(np.asarray(max_error))
if window is None:
window = pattern.shape[0]
if max_window is None:
max_window = window
if max_records is None:
records_out = np.empty(arr.shape[0], dtype=pattern_range_dt)
else:
records_out = np.empty(max_records, dtype=pattern_range_dt)
r = 0
min_max_required = False
if rescale_mode == RescaleMode.MinMax:
min_max_required = True
if not np.isnan(min_pct_change):
min_max_required = True
if not np.isnan(max_pct_change):
min_max_required = True
if not max_error_as_maxdist:
min_max_required = True
if min_max_required:
if np.isnan(pmin):
pmin = np.nanmin(pattern)
if np.isnan(pmax):
pmax = np.nanmax(pattern)
for i in range(arr.shape[0]):
if roll_forward:
from_i = i
to_i = i + window
if to_i > arr.shape[0]:
break
else:
from_i = i - window + 1
to_i = i + 1
if from_i < 0:
continue
if np.random.uniform(0, 1) < row_select_prob:
_vmin = vmin
_vmax = vmax
if min_max_required:
if np.isnan(_vmin) or np.isnan(_vmax):
for j in range(from_i, to_i):
if np.isnan(_vmin) or arr[j] < _vmin:
_vmin = arr[j]
if np.isnan(_vmax) or arr[j] > _vmax:
_vmax = arr[j]
for w in range(window, max_window + 1):
if roll_forward:
from_i = i
to_i = i + w
if to_i > arr.shape[0]:
break
if min_max_required:
if w > window:
if arr[to_i - 1] < _vmin:
_vmin = arr[to_i - 1]
if arr[to_i - 1] > _vmax:
_vmax = arr[to_i - 1]
else:
from_i = i - w + 1
to_i = i + 1
if from_i < 0:
continue
if min_max_required:
if w > window:
if arr[from_i] < _vmin:
_vmin = arr[from_i]
if arr[from_i] > _vmax:
_vmax = arr[from_i]
if np.random.uniform(0, 1) < window_select_prob:
arr_window = arr[from_i:to_i]
similarity = pattern_similarity_nb(
arr_window,
pattern,
interp_mode=interp_mode,
rescale_mode=rescale_mode,
vmin=_vmin,
vmax=_vmax,
pmin=pmin,
pmax=pmax,
invert=invert,
error_type=error_type,
distance_measure=distance_measure,
max_error=max_error_,
max_error_interp_mode=max_error_interp_mode,
max_error_as_maxdist=max_error_as_maxdist,
max_error_strict=max_error_strict,
min_pct_change=min_pct_change,
max_pct_change=max_pct_change,
min_similarity=min_similarity,
minp=minp,
)
if not np.isnan(similarity):
skip = False
while True:
if r > 0:
if roll_forward:
prev_same_row = records_out["start_idx"][r - 1] == from_i
else:
prev_same_row = records_out["end_idx"][r - 1] == to_i
if overlap_mode != OverlapMode.AllowAll and prev_same_row:
if similarity > records_out["similarity"][r - 1]:
r -= 1
continue
else:
skip = True
break
elif overlap_mode >= 0:
overlap = records_out["end_idx"][r - 1] - from_i
if overlap > overlap_mode:
if similarity > records_out["similarity"][r - 1]:
r -= 1
continue
else:
skip = True
break
break
if skip:
continue
if r >= records_out.shape[0]:
raise IndexError("Records index out of range. Set a higher max_records.")
records_out["id"][r] = r
records_out["col"][r] = col
records_out["start_idx"][r] = from_i
if to_i <= arr.shape[0] - 1:
records_out["end_idx"][r] = to_i
records_out["status"][r] = RangeStatus.Closed
else:
records_out["end_idx"][r] = arr.shape[0] - 1
records_out["status"][r] = RangeStatus.Open
records_out["similarity"][r] = similarity
r += 1
return records_out[:r]
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
pattern=None,
window=None,
max_window=None,
row_select_prob=None,
window_select_prob=None,
roll_forward=None,
interp_mode=None,
rescale_mode=None,
vmin=None,
vmax=None,
pmin=None,
pmax=None,
invert=None,
error_type=None,
distance_measure=None,
max_error=None,
max_error_interp_mode=None,
max_error_as_maxdist=None,
max_error_strict=None,
min_pct_change=None,
max_pct_change=None,
min_similarity=None,
minp=None,
overlap_mode=None,
max_records=None,
),
merge_func=records_ch.merge_records,
)
@register_jitted(cache=True, tags={"can_parallel"})
def find_pattern_nb(
arr: tp.Array2d,
pattern: tp.Array1d,
window: tp.Optional[int] = None,
max_window: tp.Optional[int] = None,
row_select_prob: float = 1.0,
window_select_prob: float = 1.0,
roll_forward: bool = False,
interp_mode: int = InterpMode.Mixed,
rescale_mode: int = RescaleMode.MinMax,
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: int = ErrorType.Absolute,
distance_measure: int = DistanceMeasure.MAE,
max_error: tp.FlexArray1dLike = np.nan,
max_error_interp_mode: tp.Optional[int] = None,
max_error_as_maxdist: bool = False,
max_error_strict: bool = False,
min_pct_change: float = np.nan,
max_pct_change: float = np.nan,
min_similarity: float = 0.85,
minp: tp.Optional[int] = None,
overlap_mode: int = OverlapMode.Disallow,
max_records: tp.Optional[int] = None,
) -> tp.RecordArray:
"""2-dim version of `find_pattern_1d_nb`."""
max_error_ = to_1d_array_nb(np.asarray(max_error))
if window is None:
window = pattern.shape[0]
if max_window is None:
max_window = window
if max_records is None:
records_out = np.empty((arr.shape[0], arr.shape[1]), dtype=pattern_range_dt)
else:
records_out = np.empty((max_records, arr.shape[1]), dtype=pattern_range_dt)
record_counts = np.full(arr.shape[1], 0, dtype=int_)
for col in prange(arr.shape[1]):
records = find_pattern_1d_nb(
arr[:, col],
pattern,
window=window,
max_window=max_window,
row_select_prob=row_select_prob,
window_select_prob=window_select_prob,
roll_forward=roll_forward,
interp_mode=interp_mode,
rescale_mode=rescale_mode,
vmin=vmin,
vmax=vmax,
pmin=pmin,
pmax=pmax,
invert=invert,
error_type=error_type,
distance_measure=distance_measure,
max_error=max_error_,
max_error_interp_mode=max_error_interp_mode,
max_error_as_maxdist=max_error_as_maxdist,
max_error_strict=max_error_strict,
min_pct_change=min_pct_change,
max_pct_change=max_pct_change,
min_similarity=min_similarity,
minp=minp,
overlap_mode=overlap_mode,
max_records=max_records,
col=col,
)
record_counts[col] = records.shape[0]
records_out[: records.shape[0], col] = records
return repartition_nb(records_out, record_counts)
# ############# Drawdowns ############# #
@register_jitted(cache=True)
def drawdown_1d_nb(arr: tp.Array1d) -> tp.Array1d:
"""Compute drawdown."""
out = np.empty_like(arr, dtype=float_)
max_val = np.nan
for i in range(arr.shape[0]):
if np.isnan(max_val) or arr[i] > max_val:
max_val = arr[i]
if np.isnan(max_val) or max_val == 0:
out[i] = np.nan
else:
out[i] = arr[i] / max_val - 1
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1)),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def drawdown_nb(arr: tp.Array2d) -> tp.Array2d:
"""2-dim version of `drawdown_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = drawdown_1d_nb(arr[:, col])
return out
@register_jitted(cache=True)
def fill_drawdown_record_nb(
new_records: tp.RecordArray2d,
counts: tp.Array2d,
i: int,
col: int,
start_idx: int,
valley_idx: int,
start_val: float,
valley_val: float,
end_val: float,
status: int,
):
"""Fill a drawdown record."""
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = start_idx
new_records["valley_idx"][r, col] = valley_idx
new_records["end_idx"][r, col] = i
new_records["start_val"][r, col] = start_val
new_records["valley_val"][r, col] = valley_val
new_records["end_val"][r, col] = end_val
new_records["status"][r, col] = status
counts[col] += 1
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
open=ch.ArraySlicer(axis=1),
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_drawdowns_nb(
open: tp.Optional[tp.Array2d],
high: tp.Optional[tp.Array2d],
low: tp.Optional[tp.Array2d],
close: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.RecordArray:
"""Fill drawdown records by analyzing a time series.
Only `close` must be provided, other time series are optional.
Usage:
```pycon
>>> from vectorbtpro import *
>>> close = np.array([
... [1, 5, 1, 3],
... [2, 4, 2, 2],
... [3, 3, 3, 1],
... [4, 2, 2, 2],
... [5, 1, 1, 3]
... ])
>>> records = vbt.nb.get_drawdowns_nb(None, None, None, close)
>>> pd.DataFrame.from_records(records)
id col start_idx valley_idx end_idx start_val valley_val end_val \\
0 0 1 0 4 4 5.0 1.0 1.0
1 0 2 2 4 4 3.0 1.0 1.0
2 0 3 0 2 4 3.0 1.0 3.0
status
0 0
1 0
2 1
```
"""
new_records = np.empty(close.shape, dtype=drawdown_dt)
counts = np.full(close.shape[1], 0, dtype=int_)
sim_start_, sim_end_ = prepare_sim_range_nb(
sim_shape=close.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(close.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
drawdown_started = False
_close = close[0, col]
if open is None:
_open = np.nan
else:
_open = open[0, col]
start_idx = 0
valley_idx = 0
start_val = _open
valley_val = _open
for i in range(_sim_start, _sim_end):
_close = close[i, col]
if open is None:
_open = np.nan
else:
_open = open[i, col]
if high is None:
_high = np.nan
else:
_high = high[i, col]
if low is None:
_low = np.nan
else:
_low = low[i, col]
if np.isnan(_high):
if np.isnan(_open):
_high = _close
elif np.isnan(_close):
_high = _open
else:
_high = max(_open, _close)
if np.isnan(_low):
if np.isnan(_open):
_low = _close
elif np.isnan(_close):
_low = _open
else:
_low = min(_open, _close)
if drawdown_started:
if _open >= start_val:
drawdown_started = False
fill_drawdown_record_nb(
new_records=new_records,
counts=counts,
i=i,
col=col,
start_idx=start_idx,
valley_idx=valley_idx,
start_val=start_val,
valley_val=valley_val,
end_val=_open,
status=DrawdownStatus.Recovered,
)
start_idx = i
valley_idx = i
start_val = _open
valley_val = _open
if drawdown_started:
if _low < valley_val:
valley_idx = i
valley_val = _low
if _high >= start_val:
drawdown_started = False
fill_drawdown_record_nb(
new_records=new_records,
counts=counts,
i=i,
col=col,
start_idx=start_idx,
valley_idx=valley_idx,
start_val=start_val,
valley_val=valley_val,
end_val=_high,
status=DrawdownStatus.Recovered,
)
start_idx = i
valley_idx = i
start_val = _high
valley_val = _high
else:
if np.isnan(start_val) or _high >= start_val:
start_idx = i
valley_idx = i
start_val = _high
valley_val = _high
elif _low < valley_val:
if not np.isnan(valley_val):
drawdown_started = True
valley_idx = i
valley_val = _low
if drawdown_started:
if i == _sim_end - 1:
drawdown_started = False
fill_drawdown_record_nb(
new_records=new_records,
counts=counts,
i=i,
col=col,
start_idx=start_idx,
valley_idx=valley_idx,
start_val=start_val,
valley_val=valley_val,
end_val=_close,
status=DrawdownStatus.Active,
)
return repartition_nb(new_records, counts)
@register_chunkable(
size=ch.ArraySizer(arg_query="start_val_arr", axis=0),
arg_take_spec=dict(start_val_arr=ch.ArraySlicer(axis=0), valley_val_arr=ch.ArraySlicer(axis=0)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def dd_drawdown_nb(start_val_arr: tp.Array1d, valley_val_arr: tp.Array1d) -> tp.Array1d:
"""Compute the drawdown of each drawdown record."""
out = np.empty(valley_val_arr.shape[0], dtype=float_)
for r in prange(valley_val_arr.shape[0]):
if start_val_arr[r] == 0:
out[r] = np.nan
else:
out[r] = (valley_val_arr[r] - start_val_arr[r]) / start_val_arr[r]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="start_idx_arr", axis=0),
arg_take_spec=dict(start_idx_arr=ch.ArraySlicer(axis=0), valley_idx_arr=ch.ArraySlicer(axis=0)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def dd_decline_duration_nb(start_idx_arr: tp.Array1d, valley_idx_arr: tp.Array1d) -> tp.Array1d:
"""Compute the duration of the peak-to-valley phase of each drawdown record."""
out = np.empty(valley_idx_arr.shape[0], dtype=float_)
for r in prange(valley_idx_arr.shape[0]):
out[r] = valley_idx_arr[r] - start_idx_arr[r]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="valley_idx_arr", axis=0),
arg_take_spec=dict(valley_idx_arr=ch.ArraySlicer(axis=0), end_idx_arr=ch.ArraySlicer(axis=0)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def dd_recovery_duration_nb(valley_idx_arr: tp.Array1d, end_idx_arr: tp.Array1d) -> tp.Array1d:
"""Compute the duration of the valley-to-recovery phase of each drawdown record."""
out = np.empty(end_idx_arr.shape[0], dtype=float_)
for r in prange(end_idx_arr.shape[0]):
out[r] = end_idx_arr[r] - valley_idx_arr[r]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="start_idx_arr", axis=0),
arg_take_spec=dict(
start_idx_arr=ch.ArraySlicer(axis=0),
valley_idx_arr=ch.ArraySlicer(axis=0),
end_idx_arr=ch.ArraySlicer(axis=0),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def dd_recovery_duration_ratio_nb(
start_idx_arr: tp.Array1d,
valley_idx_arr: tp.Array1d,
end_idx_arr: tp.Array1d,
) -> tp.Array1d:
"""Compute the ratio of the recovery duration to the decline duration of each drawdown record."""
out = np.empty(start_idx_arr.shape[0], dtype=float_)
for r in prange(start_idx_arr.shape[0]):
if valley_idx_arr[r] - start_idx_arr[r] == 0:
out[r] = np.nan
else:
out[r] = (end_idx_arr[r] - valley_idx_arr[r]) / (valley_idx_arr[r] - start_idx_arr[r])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="valley_val_arr", axis=0),
arg_take_spec=dict(valley_val_arr=ch.ArraySlicer(axis=0), end_val_arr=ch.ArraySlicer(axis=0)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def dd_recovery_return_nb(valley_val_arr: tp.Array1d, end_val_arr: tp.Array1d) -> tp.Array1d:
"""Compute the recovery return of each drawdown record."""
out = np.empty(end_val_arr.shape[0], dtype=float_)
for r in prange(end_val_arr.shape[0]):
if valley_val_arr[r] == 0:
out[r] = np.nan
else:
out[r] = (end_val_arr[r] - valley_val_arr[r]) / valley_val_arr[r]
return out
@register_jitted(cache=True)
def bar_price_nb(records: tp.RecordArray, price: tp.Optional[tp.FlexArray2d]) -> tp.Array1d:
"""Return the bar price."""
out = np.empty(len(records), dtype=float_)
for i in range(len(records)):
record = records[i]
if price is not None:
out[i] = float(flex_select_nb(price, record["idx"], record["col"]))
else:
out[i] = np.nan
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Generic Numba-compiled functions for rolling and expanding windows."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.reshaping import to_1d_array_nb
from vectorbtpro.generic.enums import *
from vectorbtpro.generic.nb.base import rank_1d_nb
from vectorbtpro.generic.nb.patterns import pattern_similarity_nb
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
# ############# Rolling functions ############# #
@register_jitted(cache=True)
def rolling_sum_acc_nb(in_state: RollSumAIS) -> RollSumAOS:
"""Accumulator of `rolling_sum_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.RollSumAIS` and returns
a state of type `vectorbtpro.generic.enums.RollSumAOS`."""
i = in_state.i
value = in_state.value
pre_window_value = in_state.pre_window_value
cumsum = in_state.cumsum
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
if np.isnan(value):
nancnt = nancnt + 1
else:
cumsum = cumsum + value
if i < window:
window_len = i + 1 - nancnt
window_cumsum = cumsum
else:
if np.isnan(pre_window_value):
nancnt = nancnt - 1
else:
cumsum = cumsum - pre_window_value
window_len = window - nancnt
window_cumsum = cumsum
if window_len < minp:
value = np.nan
else:
value = window_cumsum
return RollSumAOS(cumsum=cumsum, nancnt=nancnt, window_len=window_len, value=value)
@register_jitted(cache=True)
def rolling_sum_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d:
"""Compute rolling sum.
Uses `rolling_sum_acc_nb` at each iteration.
Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).sum()`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
cumsum = 0.0
nancnt = 0
for i in range(arr.shape[0]):
in_state = RollSumAIS(
i=i,
value=arr[i],
pre_window_value=arr[i - window] if i - window >= 0 else np.nan,
cumsum=cumsum,
nancnt=nancnt,
window=window,
minp=minp,
)
out_state = rolling_sum_acc_nb(in_state)
cumsum = out_state.cumsum
nancnt = out_state.nancnt
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_sum_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `rolling_sum_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_sum_1d_nb(arr[:, col], window, minp=minp)
return out
@register_jitted(cache=True)
def rolling_prod_acc_nb(in_state: RollProdAIS) -> RollProdAOS:
"""Accumulator of `rolling_prod_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.RollProdAIS` and returns
a state of type `vectorbtpro.generic.enums.RollProdAOS`."""
i = in_state.i
value = in_state.value
pre_window_value = in_state.pre_window_value
cumprod = in_state.cumprod
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
if np.isnan(value):
nancnt = nancnt + 1
else:
cumprod = cumprod * value
if i < window:
window_len = i + 1 - nancnt
window_cumprod = cumprod
else:
if np.isnan(pre_window_value):
nancnt = nancnt - 1
else:
cumprod = cumprod / pre_window_value
window_len = window - nancnt
window_cumprod = cumprod
if window_len < minp:
value = np.nan
else:
value = window_cumprod
return RollProdAOS(cumprod=cumprod, nancnt=nancnt, window_len=window_len, value=value)
@register_jitted(cache=True)
def rolling_prod_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d:
"""Compute rolling product.
Uses `rolling_prod_acc_nb` at each iteration.
Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).apply(np.prod)`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
cumprod = 1.0
nancnt = 0
for i in range(arr.shape[0]):
in_state = RollProdAIS(
i=i,
value=arr[i],
pre_window_value=arr[i - window] if i - window >= 0 else np.nan,
cumprod=cumprod,
nancnt=nancnt,
window=window,
minp=minp,
)
out_state = rolling_prod_acc_nb(in_state)
cumprod = out_state.cumprod
nancnt = out_state.nancnt
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_prod_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `rolling_prod_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_prod_1d_nb(arr[:, col], window, minp=minp)
return out
@register_jitted(cache=True)
def rolling_mean_acc_nb(in_state: RollMeanAIS) -> RollMeanAOS:
"""Accumulator of `rolling_mean_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.RollMeanAIS` and returns
a state of type `vectorbtpro.generic.enums.RollMeanAOS`."""
i = in_state.i
value = in_state.value
pre_window_value = in_state.pre_window_value
cumsum = in_state.cumsum
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
if np.isnan(value):
nancnt = nancnt + 1
else:
cumsum = cumsum + value
if i < window:
window_len = i + 1 - nancnt
window_cumsum = cumsum
else:
if np.isnan(pre_window_value):
nancnt = nancnt - 1
else:
cumsum = cumsum - pre_window_value
window_len = window - nancnt
window_cumsum = cumsum
if window_len < minp:
value = np.nan
else:
value = window_cumsum / window_len
return RollMeanAOS(cumsum=cumsum, nancnt=nancnt, window_len=window_len, value=value)
@register_jitted(cache=True)
def rolling_mean_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d:
"""Compute rolling mean.
Uses `rolling_mean_acc_nb` at each iteration.
Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).mean()`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
cumsum = 0.0
nancnt = 0
for i in range(arr.shape[0]):
in_state = RollMeanAIS(
i=i,
value=arr[i],
pre_window_value=arr[i - window] if i - window >= 0 else np.nan,
cumsum=cumsum,
nancnt=nancnt,
window=window,
minp=minp,
)
out_state = rolling_mean_acc_nb(in_state)
cumsum = out_state.cumsum
nancnt = out_state.nancnt
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_mean_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `rolling_mean_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_mean_1d_nb(arr[:, col], window, minp=minp)
return out
@register_jitted(cache=True)
def rolling_std_acc_nb(in_state: RollStdAIS) -> RollStdAOS:
"""Accumulator of `rolling_std_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.RollStdAIS` and returns
a state of type `vectorbtpro.generic.enums.RollStdAOS`."""
i = in_state.i
value = in_state.value
pre_window_value = in_state.pre_window_value
cumsum = in_state.cumsum
cumsum_sq = in_state.cumsum_sq
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
ddof = in_state.ddof
if np.isnan(value):
nancnt = nancnt + 1
else:
cumsum = cumsum + value
cumsum_sq = cumsum_sq + value**2
if i < window:
window_len = i + 1 - nancnt
else:
if np.isnan(pre_window_value):
nancnt = nancnt - 1
else:
cumsum = cumsum - pre_window_value
cumsum_sq = cumsum_sq - pre_window_value**2
window_len = window - nancnt
if window_len < minp or window_len == ddof:
value = np.nan
else:
mean = cumsum / window_len
value = np.sqrt(np.abs(cumsum_sq - 2 * cumsum * mean + window_len * mean**2) / (window_len - ddof))
return RollStdAOS(cumsum=cumsum, cumsum_sq=cumsum_sq, nancnt=nancnt, window_len=window_len, value=value)
@register_jitted(cache=True)
def rolling_std_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None, ddof: int = 0) -> tp.Array1d:
"""Compute rolling standard deviation.
Uses `rolling_std_acc_nb` at each iteration.
Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).std(ddof=ddof)`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
cumsum = 0.0
cumsum_sq = 0.0
nancnt = 0
for i in range(arr.shape[0]):
in_state = RollStdAIS(
i=i,
value=arr[i],
pre_window_value=arr[i - window] if i - window >= 0 else np.nan,
cumsum=cumsum,
cumsum_sq=cumsum_sq,
nancnt=nancnt,
window=window,
minp=minp,
ddof=ddof,
)
out_state = rolling_std_acc_nb(in_state)
cumsum = out_state.cumsum
cumsum_sq = out_state.cumsum_sq
nancnt = out_state.nancnt
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, ddof=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_std_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, ddof: int = 0) -> tp.Array2d:
"""2-dim version of `rolling_std_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_std_1d_nb(arr[:, col], window, minp=minp, ddof=ddof)
return out
@register_jitted(cache=True)
def rolling_zscore_acc_nb(in_state: RollZScoreAIS) -> RollZScoreAOS:
"""Accumulator of `rolling_zscore_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.RollZScoreAIS` and returns
a state of type `vectorbtpro.generic.enums.RollZScoreAOS`."""
mean_in_state = RollMeanAIS(
i=in_state.i,
value=in_state.value,
pre_window_value=in_state.pre_window_value,
cumsum=in_state.cumsum,
nancnt=in_state.nancnt,
window=in_state.window,
minp=in_state.minp,
)
std_in_state = RollStdAIS(
i=in_state.i,
value=in_state.value,
pre_window_value=in_state.pre_window_value,
cumsum=in_state.cumsum,
cumsum_sq=in_state.cumsum_sq,
nancnt=in_state.nancnt,
window=in_state.window,
minp=in_state.minp,
ddof=in_state.ddof,
)
mean_out_state = rolling_mean_acc_nb(mean_in_state)
std_out_state = rolling_std_acc_nb(std_in_state)
if std_out_state.value == 0:
value = np.nan
else:
value = (in_state.value - mean_out_state.value) / std_out_state.value
return RollZScoreAOS(
cumsum=std_out_state.cumsum,
cumsum_sq=std_out_state.cumsum_sq,
nancnt=std_out_state.nancnt,
window_len=std_out_state.window_len,
value=value,
)
@register_jitted(cache=True)
def rolling_zscore_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None, ddof: int = 0) -> tp.Array1d:
"""Compute rolling z-score.
Uses `rolling_zscore_acc_nb` at each iteration."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
cumsum = 0.0
cumsum_sq = 0.0
nancnt = 0
for i in range(arr.shape[0]):
in_state = RollZScoreAIS(
i=i,
value=arr[i],
pre_window_value=arr[i - window] if i - window >= 0 else np.nan,
cumsum=cumsum,
cumsum_sq=cumsum_sq,
nancnt=nancnt,
window=window,
minp=minp,
ddof=ddof,
)
out_state = rolling_zscore_acc_nb(in_state)
cumsum = out_state.cumsum
cumsum_sq = out_state.cumsum_sq
nancnt = out_state.nancnt
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, ddof=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_zscore_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, ddof: int = 0) -> tp.Array2d:
"""2-dim version of `rolling_zscore_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_zscore_1d_nb(arr[:, col], window, minp=minp, ddof=ddof)
return out
@register_jitted(cache=True)
def wm_mean_acc_nb(in_state: WMMeanAIS) -> WMMeanAOS:
"""Accumulator of `wm_mean_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.WMMeanAIS` and returns
a state of type `vectorbtpro.generic.enums.WMMeanAOS`."""
i = in_state.i
value = in_state.value
pre_window_value = in_state.pre_window_value
cumsum = in_state.cumsum
wcumsum = in_state.wcumsum
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
if i >= window and not np.isnan(pre_window_value):
wcumsum = wcumsum - cumsum
if np.isnan(value):
nancnt = nancnt + 1
else:
cumsum = cumsum + value
if i < window:
window_len = i + 1 - nancnt
else:
if np.isnan(pre_window_value):
nancnt = nancnt - 1
else:
cumsum = cumsum - pre_window_value
window_len = window - nancnt
if not np.isnan(value):
wcumsum = wcumsum + value * window_len
if window_len < minp:
value = np.nan
else:
value = wcumsum * 2 / (window_len + 1) / window_len
return WMMeanAOS(cumsum=cumsum, wcumsum=wcumsum, nancnt=nancnt, window_len=window_len, value=value)
@register_jitted(cache=True)
def wm_mean_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d:
"""Compute weighted moving average.
Uses `wm_mean_acc_nb` at each iteration."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
cumsum = 0.0
wcumsum = 0.0
nancnt = 0
for i in range(arr.shape[0]):
in_state = WMMeanAIS(
i=i,
value=arr[i],
pre_window_value=arr[i - window] if i - window >= 0 else np.nan,
cumsum=cumsum,
wcumsum=wcumsum,
nancnt=nancnt,
window=window,
minp=minp,
)
out_state = wm_mean_acc_nb(in_state)
cumsum = out_state.cumsum
wcumsum = out_state.wcumsum
nancnt = out_state.nancnt
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def wm_mean_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `wm_mean_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = wm_mean_1d_nb(arr[:, col], window, minp=minp)
return out
@register_jitted(cache=True)
def alpha_from_com_nb(com: float) -> float:
"""Get the smoothing factor `alpha` from a center of mass."""
return 1.0 / (1.0 + com)
@register_jitted(cache=True)
def alpha_from_span_nb(span: float) -> float:
"""Get the smoothing factor `alpha` from a span."""
com = (span - 1) / 2.0
return alpha_from_com_nb(com)
@register_jitted(cache=True)
def alpha_from_halflife_nb(halflife: float) -> float:
"""Get the smoothing factor `alpha` from a half-life."""
return 1 - np.exp(-np.log(2) / halflife)
@register_jitted(cache=True)
def alpha_from_wilder_nb(period: int) -> float:
"""Get the smoothing factor `alpha` from a Wilder's period."""
return 1 / period
@register_jitted(cache=True)
def ewm_mean_acc_nb(in_state: EWMMeanAIS) -> EWMMeanAOS:
"""Accumulator of `ewm_mean_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.EWMMeanAIS` and returns
a state of type `vectorbtpro.generic.enums.EWMMeanAOS`."""
i = in_state.i
value = in_state.value
old_wt = in_state.old_wt
weighted_avg = in_state.weighted_avg
nobs = in_state.nobs
alpha = in_state.alpha
minp = in_state.minp
adjust = in_state.adjust
old_wt_factor = 1.0 - alpha
new_wt = 1.0 if adjust else alpha
if i > 0:
is_observation = not np.isnan(value)
nobs += is_observation
if not np.isnan(weighted_avg):
old_wt *= old_wt_factor
if is_observation:
# avoid numerical errors on constant series
if weighted_avg != value:
weighted_avg = ((old_wt * weighted_avg) + (new_wt * value)) / (old_wt + new_wt)
if adjust:
old_wt += new_wt
else:
old_wt = 1.0
elif is_observation:
weighted_avg = value
else:
is_observation = not np.isnan(weighted_avg)
nobs += int(is_observation)
value = weighted_avg if (nobs >= minp) else np.nan
return EWMMeanAOS(old_wt=old_wt, weighted_avg=weighted_avg, nobs=nobs, value=value)
@register_jitted(cache=True)
def ewm_mean_1d_nb(arr: tp.Array1d, span: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array1d:
"""Compute exponential weighted moving average.
Uses `ewm_mean_acc_nb` at each iteration.
Numba equivalent to `pd.Series(arr).ewm(span=span, min_periods=minp, adjust=adjust).mean()`.
Adaptation of `pd._libs.window.aggregations.window_aggregations.ewma` with default arguments.
!!! note
In contrast to the Pandas implementation, `minp` is effective within `span`."""
if minp is None:
minp = span
if minp > span:
raise ValueError("minp must be <= span")
out = np.empty(len(arr), dtype=float_)
if len(arr) == 0:
return out
com = (span - 1) / 2.0
alpha = 1.0 / (1.0 + com)
weighted_avg = float(arr[0]) + 0.0 # cast to float_
nobs = 0
n_obs_lagged = 0
old_wt = 1.0
for i in range(len(arr)):
if i >= span:
if not np.isnan(arr[i - span]):
n_obs_lagged += 1
in_state = EWMMeanAIS(
i=i,
value=arr[i],
old_wt=old_wt,
weighted_avg=weighted_avg,
nobs=nobs - n_obs_lagged,
alpha=alpha,
minp=minp,
adjust=adjust,
)
out_state = ewm_mean_acc_nb(in_state)
old_wt = out_state.old_wt
weighted_avg = out_state.weighted_avg
nobs = out_state.nobs + n_obs_lagged
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), span=None, minp=None, adjust=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ewm_mean_nb(arr: tp.Array2d, span: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array2d:
"""2-dim version of `ewm_mean_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = ewm_mean_1d_nb(arr[:, col], span, minp=minp, adjust=adjust)
return out
@register_jitted(cache=True)
def ewm_std_acc_nb(in_state: EWMStdAIS) -> EWMStdAOS:
"""Accumulator of `ewm_std_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.EWMStdAIS` and returns
a state of type `vectorbtpro.generic.enums.EWMStdAOS`."""
i = in_state.i
value = in_state.value
mean_x = in_state.mean_x
mean_y = in_state.mean_y
cov = in_state.cov
sum_wt = in_state.sum_wt
sum_wt2 = in_state.sum_wt2
old_wt = in_state.old_wt
nobs = in_state.nobs
alpha = in_state.alpha
minp = in_state.minp
adjust = in_state.adjust
old_wt_factor = 1.0 - alpha
new_wt = 1.0 if adjust else alpha
cur_x = value
cur_y = value
is_observation = not np.isnan(cur_x) and not np.isnan(cur_y)
nobs += is_observation
if i > 0:
if not np.isnan(mean_x):
sum_wt *= old_wt_factor
sum_wt2 *= old_wt_factor * old_wt_factor
old_wt *= old_wt_factor
if is_observation:
old_mean_x = mean_x
old_mean_y = mean_y
# avoid numerical errors on constant series
if mean_x != cur_x:
mean_x = ((old_wt * old_mean_x) + (new_wt * cur_x)) / (old_wt + new_wt)
# avoid numerical errors on constant series
if mean_y != cur_y:
mean_y = ((old_wt * old_mean_y) + (new_wt * cur_y)) / (old_wt + new_wt)
cov = (
(old_wt * (cov + ((old_mean_x - mean_x) * (old_mean_y - mean_y))))
+ (new_wt * ((cur_x - mean_x) * (cur_y - mean_y)))
) / (old_wt + new_wt)
sum_wt += new_wt
sum_wt2 += new_wt * new_wt
old_wt += new_wt
if not adjust:
sum_wt /= old_wt
sum_wt2 /= old_wt * old_wt
old_wt = 1.0
elif is_observation:
mean_x = cur_x
mean_y = cur_y
else:
if not is_observation:
mean_x = np.nan
mean_y = np.nan
if nobs >= minp:
numerator = sum_wt * sum_wt
denominator = numerator - sum_wt2
if denominator > 0.0:
value = (numerator / denominator) * cov
else:
value = np.nan
else:
value = np.nan
return EWMStdAOS(
mean_x=mean_x,
mean_y=mean_y,
cov=cov,
sum_wt=sum_wt,
sum_wt2=sum_wt2,
old_wt=old_wt,
nobs=nobs,
value=value,
)
@register_jitted(cache=True)
def ewm_std_1d_nb(arr: tp.Array1d, span: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array1d:
"""Compute exponential weighted moving standard deviation.
Uses `ewm_std_acc_nb` at each iteration.
Numba equivalent to `pd.Series(arr).ewm(span=span, min_periods=minp).std()`.
Adaptation of `pd._libs.window.aggregations.window_aggregations.ewmcov` with default arguments.
!!! note
In contrast to the Pandas implementation, `minp` is effective within `span`."""
if minp is None:
minp = span
if minp > span:
raise ValueError("minp must be <= span")
out = np.empty(len(arr), dtype=float_)
if len(arr) == 0:
return out
com = (span - 1) / 2.0
alpha = 1.0 / (1.0 + com)
mean_x = float(arr[0]) + 0.0 # cast to float_
mean_y = float(arr[0]) + 0.0 # cast to float_
nobs = 0
n_obs_lagged = 0
cov = 0.0
sum_wt = 1.0
sum_wt2 = 1.0
old_wt = 1.0
for i in range(len(arr)):
if i >= span:
if not np.isnan(arr[i - span]):
n_obs_lagged += 1
in_state = EWMStdAIS(
i=i,
value=arr[i],
mean_x=mean_x,
mean_y=mean_y,
cov=cov,
sum_wt=sum_wt,
sum_wt2=sum_wt2,
old_wt=old_wt,
nobs=nobs - n_obs_lagged,
alpha=alpha,
minp=minp,
adjust=adjust,
)
out_state = ewm_std_acc_nb(in_state)
mean_x = out_state.mean_x
mean_y = out_state.mean_y
cov = out_state.cov
sum_wt = out_state.sum_wt
sum_wt2 = out_state.sum_wt2
old_wt = out_state.old_wt
nobs = out_state.nobs + n_obs_lagged
out[i] = out_state.value
return np.sqrt(out)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), span=None, minp=None, adjust=None, ddof=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ewm_std_nb(arr: tp.Array2d, span: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array2d:
"""2-dim version of `ewm_std_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = ewm_std_1d_nb(arr[:, col], span, minp=minp, adjust=adjust)
return out
@register_jitted(cache=True)
def wwm_mean_1d_nb(arr: tp.Array1d, period: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array1d:
"""Compute Wilder's exponential weighted moving average."""
if minp is None:
minp = period
return ewm_mean_1d_nb(arr, 2 * period - 1, minp=minp, adjust=adjust)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), period=None, minp=None, adjust=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def wwm_mean_nb(arr: tp.Array2d, period: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array2d:
"""2-dim version of `wwm_mean_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = wwm_mean_1d_nb(arr[:, col], period, minp=minp, adjust=adjust)
return out
@register_jitted(cache=True)
def wwm_std_1d_nb(arr: tp.Array1d, period: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array1d:
"""Compute Wilder's exponential weighted moving standard deviation."""
if minp is None:
minp = period
return ewm_std_1d_nb(arr, 2 * period - 1, minp=minp, adjust=adjust)
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), period=None, minp=None, adjust=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def wwm_std_nb(arr: tp.Array2d, period: int, minp: tp.Optional[int] = None, adjust: bool = False) -> tp.Array2d:
"""2-dim version of `wwm_std_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = wwm_std_1d_nb(arr[:, col], period, minp=minp, adjust=adjust)
return out
@register_jitted(cache=True)
def vidya_acc_nb(in_state: VidyaAIS) -> VidyaAOS:
"""Accumulator of `vidya_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.VidyaAIS` and returns
a state of type `vectorbtpro.generic.enums.VidyaAOS`."""
i = in_state.i
prev_value = in_state.prev_value
value = in_state.value
pre_window_prev_value = in_state.pre_window_prev_value
pre_window_value = in_state.pre_window_value
pos_cumsum = in_state.pos_cumsum
neg_cumsum = in_state.neg_cumsum
prev_vidya = in_state.prev_vidya
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
alpha = 2 / (window + 1)
diff = value - prev_value
if np.isnan(diff):
nancnt = nancnt + 1
else:
if diff > 0:
pos_value = diff
neg_value = 0.0
else:
pos_value = 0.0
neg_value = abs(diff)
pos_cumsum = pos_cumsum + pos_value
neg_cumsum = neg_cumsum + neg_value
if i < window:
window_len = i + 1 - nancnt
else:
pre_window_diff = pre_window_value - pre_window_prev_value
if np.isnan(pre_window_diff):
nancnt = nancnt - 1
else:
if pre_window_diff > 0:
pre_window_pos_value = pre_window_diff
pre_window_neg_value = 0.0
else:
pre_window_pos_value = 0.0
pre_window_neg_value = abs(pre_window_diff)
pos_cumsum = pos_cumsum - pre_window_pos_value
neg_cumsum = neg_cumsum - pre_window_neg_value
window_len = window - nancnt
window_pos_cumsum = pos_cumsum
window_neg_cumsum = neg_cumsum
if window_len < minp:
cmo = np.nan
vidya = np.nan
else:
sh = window_pos_cumsum
sl = window_neg_cumsum
if sh + sl == 0:
cmo = 0.0
else:
cmo = np.abs((sh - sl) / (sh + sl))
if np.isnan(prev_vidya):
prev_vidya = 0.0
vidya = alpha * cmo * value + prev_vidya * (1 - alpha * cmo)
return VidyaAOS(
pos_cumsum=pos_cumsum,
neg_cumsum=neg_cumsum,
nancnt=nancnt,
window_len=window_len,
cmo=cmo,
vidya=vidya,
)
@register_jitted(cache=True)
def vidya_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d:
"""Compute VIDYA.
Uses `vidya_acc_nb` at each iteration."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
pos_cumsum = 0.0
neg_cumsum = 0.0
nancnt = 0
for i in range(arr.shape[0]):
in_state = VidyaAIS(
i=i,
prev_value=arr[i - 1] if i - 1 >= 0 else np.nan,
value=arr[i],
pre_window_prev_value=arr[i - window - 1] if i - window - 1 >= 0 else np.nan,
pre_window_value=arr[i - window] if i - window >= 0 else np.nan,
pos_cumsum=pos_cumsum,
neg_cumsum=neg_cumsum,
prev_vidya=out[i - 1] if i - 1 >= 0 else np.nan,
nancnt=nancnt,
window=window,
minp=minp,
)
out_state = vidya_acc_nb(in_state)
pos_cumsum = out_state.pos_cumsum
neg_cumsum = out_state.neg_cumsum
nancnt = out_state.nancnt
out[i] = out_state.vidya
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def vidya_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `vidya_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = vidya_1d_nb(arr[:, col], window, minp=minp)
return out
@register_jitted(cache=True)
def ma_1d_nb(
arr: tp.Array1d,
window: int,
wtype: int = WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array1d:
"""Compute a moving average based on the mode of the type `vectorbtpro.generic.enums.WType`."""
if wtype == WType.Simple:
return rolling_mean_1d_nb(arr, window, minp=minp)
if wtype == WType.Weighted:
return wm_mean_1d_nb(arr, window, minp=minp)
if wtype == WType.Exp:
return ewm_mean_1d_nb(arr, window, minp=minp, adjust=adjust)
if wtype == WType.Wilder:
return wwm_mean_1d_nb(arr, window, minp=minp, adjust=adjust)
if wtype == WType.Vidya:
return vidya_1d_nb(arr, window, minp=minp)
raise ValueError("Invalid rolling mode")
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, wtype=None, minp=None, adjust=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ma_nb(
arr: tp.Array2d,
window: int,
wtype: int = WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array2d:
"""2-dim version of `ma_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = ma_1d_nb(arr[:, col], window, wtype=wtype, minp=minp, adjust=adjust)
return out
@register_jitted(cache=True)
def msd_1d_nb(
arr: tp.Array1d,
window: int,
wtype: int = WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
ddof: int = 0,
) -> tp.Array1d:
"""Compute a moving standard deviation based on the mode of the type `vectorbtpro.generic.enums.WType`."""
if wtype == WType.Simple:
return rolling_std_1d_nb(arr, window, minp=minp, ddof=ddof)
if wtype == WType.Weighted:
raise ValueError("Weighted mode is not supported for standard deviations")
if wtype == WType.Exp:
return ewm_std_1d_nb(arr, window, minp=minp, adjust=adjust)
if wtype == WType.Wilder:
return wwm_std_1d_nb(arr, window, minp=minp, adjust=adjust)
raise ValueError("Invalid rolling mode")
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, wtype=None, minp=None, adjust=None, ddof=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def msd_nb(
arr: tp.Array2d,
window: int,
wtype: int = WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
ddof: int = 0,
) -> tp.Array2d:
"""2-dim version of `msd_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = msd_1d_nb(arr[:, col], window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof)
return out
@register_jitted(cache=True)
def rolling_cov_acc_nb(in_state: RollCovAIS) -> RollCovAOS:
"""Accumulator of `rolling_cov_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.RollCovAIS` and returns
a state of type `vectorbtpro.generic.enums.RollCovAOS`."""
i = in_state.i
value1 = in_state.value1
value2 = in_state.value2
pre_window_value1 = in_state.pre_window_value1
pre_window_value2 = in_state.pre_window_value2
cumsum1 = in_state.cumsum1
cumsum2 = in_state.cumsum2
cumsum_prod = in_state.cumsum_prod
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
ddof = in_state.ddof
if np.isnan(value1) or np.isnan(value2):
nancnt = nancnt + 1
else:
cumsum1 = cumsum1 + value1
cumsum2 = cumsum2 + value2
cumsum_prod = cumsum_prod + value1 * value2
if i < window:
window_len = i + 1 - nancnt
else:
if np.isnan(pre_window_value1) or np.isnan(pre_window_value2):
nancnt = nancnt - 1
else:
cumsum1 = cumsum1 - pre_window_value1
cumsum2 = cumsum2 - pre_window_value2
cumsum_prod = cumsum_prod - pre_window_value1 * pre_window_value2
window_len = window - nancnt
if window_len < minp or window_len == ddof:
value = np.nan
else:
window_prod_mean = cumsum_prod / (window_len - ddof)
window_mean1 = cumsum1 / window_len
window_mean2 = cumsum2 / window_len
window_mean_prod = window_mean1 * window_mean2 * window_len / (window_len - ddof)
value = window_prod_mean - window_mean_prod
return RollCovAOS(
cumsum1=cumsum1,
cumsum2=cumsum2,
cumsum_prod=cumsum_prod,
nancnt=nancnt,
window_len=window_len,
value=value,
)
@register_jitted(cache=True)
def rolling_cov_1d_nb(
arr1: tp.Array1d,
arr2: tp.Array1d,
window: int,
minp: tp.Optional[int] = None,
ddof: int = 0,
) -> tp.Array1d:
"""Compute rolling covariance.
Numba equivalent to `pd.Series(arr1).rolling(window, min_periods=minp).cov(arr2)`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr1, dtype=float_)
cumsum1 = 0.0
cumsum2 = 0.0
cumsum_prod = 0.0
nancnt = 0
for i in range(arr1.shape[0]):
in_state = RollCovAIS(
i=i,
value1=arr1[i],
value2=arr2[i],
pre_window_value1=arr1[i - window] if i - window >= 0 else np.nan,
pre_window_value2=arr2[i - window] if i - window >= 0 else np.nan,
cumsum1=cumsum1,
cumsum2=cumsum2,
cumsum_prod=cumsum_prod,
nancnt=nancnt,
window=window,
minp=minp,
ddof=ddof,
)
out_state = rolling_cov_acc_nb(in_state)
cumsum1 = out_state.cumsum1
cumsum2 = out_state.cumsum2
cumsum_prod = out_state.cumsum_prod
nancnt = out_state.nancnt
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr1", axis=1),
arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1), window=None, minp=None, ddof=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_cov_nb(
arr1: tp.Array2d,
arr2: tp.Array2d,
window: int,
minp: tp.Optional[int] = None,
ddof: int = 0,
) -> tp.Array2d:
"""2-dim version of `rolling_cov_1d_nb`."""
out = np.empty_like(arr1, dtype=float_)
for col in prange(arr1.shape[1]):
out[:, col] = rolling_cov_1d_nb(arr1[:, col], arr2[:, col], window, minp=minp, ddof=ddof)
return out
@register_jitted(cache=True)
def rolling_corr_acc_nb(in_state: RollCorrAIS) -> RollCorrAOS:
"""Accumulator of `rolling_corr_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.RollCorrAIS` and returns
a state of type `vectorbtpro.generic.enums.RollCorrAOS`."""
i = in_state.i
value1 = in_state.value1
value2 = in_state.value2
pre_window_value1 = in_state.pre_window_value1
pre_window_value2 = in_state.pre_window_value2
cumsum1 = in_state.cumsum1
cumsum2 = in_state.cumsum2
cumsum_sq1 = in_state.cumsum_sq1
cumsum_sq2 = in_state.cumsum_sq2
cumsum_prod = in_state.cumsum_prod
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
if np.isnan(value1) or np.isnan(value2):
nancnt = nancnt + 1
else:
cumsum1 = cumsum1 + value1
cumsum2 = cumsum2 + value2
cumsum_sq1 = cumsum_sq1 + value1**2
cumsum_sq2 = cumsum_sq2 + value2**2
cumsum_prod = cumsum_prod + value1 * value2
if i < window:
window_len = i + 1 - nancnt
else:
if np.isnan(pre_window_value1) or np.isnan(pre_window_value2):
nancnt = nancnt - 1
else:
cumsum1 = cumsum1 - pre_window_value1
cumsum2 = cumsum2 - pre_window_value2
cumsum_sq1 = cumsum_sq1 - pre_window_value1**2
cumsum_sq2 = cumsum_sq2 - pre_window_value2**2
cumsum_prod = cumsum_prod - pre_window_value1 * pre_window_value2
window_len = window - nancnt
if window_len < minp:
value = np.nan
else:
nom = window_len * cumsum_prod - cumsum1 * cumsum2
denom1 = np.sqrt(window_len * cumsum_sq1 - cumsum1**2)
denom2 = np.sqrt(window_len * cumsum_sq2 - cumsum2**2)
denom = denom1 * denom2
if denom == 0:
value = np.nan
else:
value = nom / denom
return RollCorrAOS(
cumsum1=cumsum1,
cumsum2=cumsum2,
cumsum_sq1=cumsum_sq1,
cumsum_sq2=cumsum_sq2,
cumsum_prod=cumsum_prod,
nancnt=nancnt,
window_len=window_len,
value=value,
)
@register_jitted(cache=True)
def rolling_corr_1d_nb(arr1: tp.Array1d, arr2: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d:
"""Compute rolling correlation coefficient.
Numba equivalent to `pd.Series(arr1).rolling(window, min_periods=minp).corr(arr2)`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr1, dtype=float_)
cumsum1 = 0.0
cumsum2 = 0.0
cumsum_sq1 = 0.0
cumsum_sq2 = 0.0
cumsum_prod = 0.0
nancnt = 0
for i in range(arr1.shape[0]):
in_state = RollCorrAIS(
i=i,
value1=arr1[i],
value2=arr2[i],
pre_window_value1=arr1[i - window] if i - window >= 0 else np.nan,
pre_window_value2=arr2[i - window] if i - window >= 0 else np.nan,
cumsum1=cumsum1,
cumsum2=cumsum2,
cumsum_sq1=cumsum_sq1,
cumsum_sq2=cumsum_sq2,
cumsum_prod=cumsum_prod,
nancnt=nancnt,
window=window,
minp=minp,
)
out_state = rolling_corr_acc_nb(in_state)
cumsum1 = out_state.cumsum1
cumsum2 = out_state.cumsum2
cumsum_sq1 = out_state.cumsum_sq1
cumsum_sq2 = out_state.cumsum_sq2
cumsum_prod = out_state.cumsum_prod
nancnt = out_state.nancnt
out[i] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr1", axis=1),
arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_corr_nb(arr1: tp.Array2d, arr2: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `rolling_corr_1d_nb`."""
out = np.empty_like(arr1, dtype=float_)
for col in prange(arr1.shape[1]):
out[:, col] = rolling_corr_1d_nb(arr1[:, col], arr2[:, col], window, minp=minp)
return out
@register_jitted(cache=True)
def rolling_ols_acc_nb(in_state: RollOLSAIS) -> RollOLSAOS:
"""Accumulator of `rolling_ols_1d_nb`.
Takes a state of type `vectorbtpro.generic.enums.RollOLSAIS` and returns
a state of type `vectorbtpro.generic.enums.RollOLSAOS`."""
i = in_state.i
value1 = in_state.value1
value2 = in_state.value2
pre_window_value1 = in_state.pre_window_value1
pre_window_value2 = in_state.pre_window_value2
validcnt = in_state.validcnt
cumsum1 = in_state.cumsum1
cumsum2 = in_state.cumsum2
cumsum_sq1 = in_state.cumsum_sq1
cumsum_prod = in_state.cumsum_prod
nancnt = in_state.nancnt
window = in_state.window
minp = in_state.minp
if np.isnan(value1) or np.isnan(value2):
nancnt = nancnt + 1
else:
validcnt = validcnt + 1
cumsum1 = cumsum1 + value1
cumsum2 = cumsum2 + value2
cumsum_sq1 = cumsum_sq1 + value1**2
cumsum_prod = cumsum_prod + value1 * value2
if i < window:
window_len = i + 1 - nancnt
else:
if np.isnan(pre_window_value1) or np.isnan(pre_window_value2):
nancnt = nancnt - 1
else:
validcnt = validcnt - 1
cumsum1 = cumsum1 - pre_window_value1
cumsum2 = cumsum2 - pre_window_value2
cumsum_sq1 = cumsum_sq1 - pre_window_value1**2
cumsum_prod = cumsum_prod - pre_window_value1 * pre_window_value2
window_len = window - nancnt
if window_len < minp:
slope_value = np.nan
intercept_value = np.nan
else:
slope_num = validcnt * cumsum_prod - cumsum1 * cumsum2
slope_denom = validcnt * cumsum_sq1 - cumsum1**2
if slope_denom != 0:
slope_value = slope_num / slope_denom
else:
slope_value = np.nan
intercept_num = cumsum2 - slope_value * cumsum1
intercept_denom = validcnt
if intercept_denom != 0:
intercept_value = intercept_num / intercept_denom
else:
intercept_value = np.nan
return RollOLSAOS(
validcnt=validcnt,
cumsum1=cumsum1,
cumsum2=cumsum2,
cumsum_sq1=cumsum_sq1,
cumsum_prod=cumsum_prod,
nancnt=nancnt,
window_len=window_len,
slope_value=slope_value,
intercept_value=intercept_value,
)
@register_jitted(cache=True)
def rolling_ols_1d_nb(
arr1: tp.Array1d,
arr2: tp.Array1d,
window: int,
minp: tp.Optional[int] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Compute rolling linear regression."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
slope_out = np.empty_like(arr1, dtype=float_)
intercept_out = np.empty_like(arr1, dtype=float_)
validcnt = 0
cumsum1 = 0.0
cumsum2 = 0.0
cumsum_sq1 = 0.0
cumsum_prod = 0.0
nancnt = 0
for i in range(arr1.shape[0]):
in_state = RollOLSAIS(
i=i,
value1=arr1[i],
value2=arr2[i],
pre_window_value1=arr1[i - window] if i - window >= 0 else np.nan,
pre_window_value2=arr2[i - window] if i - window >= 0 else np.nan,
validcnt=validcnt,
cumsum1=cumsum1,
cumsum2=cumsum2,
cumsum_sq1=cumsum_sq1,
cumsum_prod=cumsum_prod,
nancnt=nancnt,
window=window,
minp=minp,
)
out_state = rolling_ols_acc_nb(in_state)
validcnt = out_state.validcnt
cumsum1 = out_state.cumsum1
cumsum2 = out_state.cumsum2
cumsum_sq1 = out_state.cumsum_sq1
cumsum_prod = out_state.cumsum_prod
nancnt = out_state.nancnt
slope_out[i] = out_state.slope_value
intercept_out[i] = out_state.intercept_value
return slope_out, intercept_out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr1", axis=1),
arg_take_spec=dict(arr1=ch.ArraySlicer(axis=1), arr2=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_ols_nb(
arr1: tp.Array2d,
arr2: tp.Array2d,
window: int,
minp: tp.Optional[int] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""2-dim version of `rolling_ols_1d_nb`."""
slope_out = np.empty_like(arr1, dtype=float_)
intercept_out = np.empty_like(arr1, dtype=float_)
for col in prange(arr1.shape[1]):
slope_out[:, col], intercept_out[:, col] = rolling_ols_1d_nb(arr1[:, col], arr2[:, col], window, minp=minp)
return slope_out, intercept_out
@register_jitted(cache=True)
def rolling_rank_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None, pct: bool = False) -> tp.Array1d:
"""Rolling version of `rank_1d_nb`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
nancnt = 0
for i in range(arr.shape[0]):
if np.isnan(arr[i]):
nancnt = nancnt + 1
if i < window:
valid_cnt = i + 1 - nancnt
else:
if np.isnan(arr[i - window]):
nancnt = nancnt - 1
valid_cnt = window - nancnt
if valid_cnt < minp:
out[i] = np.nan
else:
from_i = max(0, i + 1 - window)
to_i = i + 1
arr_window = arr[from_i:to_i]
out[i] = rank_1d_nb(arr_window, pct=pct)[-1]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, pct=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_rank_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, pct: bool = False) -> tp.Array2d:
"""2-dim version of `rolling_rank_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_rank_1d_nb(arr[:, col], window, minp=minp, pct=pct)
return out
@register_jitted(cache=True)
def rolling_min_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d:
"""Compute rolling min.
Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).min()`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
for i in range(arr.shape[0]):
from_i = max(i - window + 1, 0)
to_i = i + 1
minv = arr[from_i]
cnt = 0
for j in range(from_i, to_i):
if np.isnan(arr[j]):
continue
if np.isnan(minv) or arr[j] < minv:
minv = arr[j]
cnt += 1
if cnt < minp:
out[i] = np.nan
else:
out[i] = minv
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_min_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `rolling_min_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_min_1d_nb(arr[:, col], window, minp=minp)
return out
@register_jitted(cache=True)
def rolling_max_1d_nb(arr: tp.Array1d, window: int, minp: tp.Optional[int] = None) -> tp.Array1d:
"""Compute rolling max.
Numba equivalent to `pd.Series(arr).rolling(window, min_periods=minp).max()`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=float_)
for i in range(arr.shape[0]):
from_i = max(i - window + 1, 0)
to_i = i + 1
maxv = arr[from_i]
cnt = 0
for j in range(from_i, to_i):
if np.isnan(arr[j]):
continue
if np.isnan(maxv) or arr[j] > maxv:
maxv = arr[j]
cnt += 1
if cnt < minp:
out[i] = np.nan
else:
out[i] = maxv
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_max_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None) -> tp.Array2d:
"""2-dim version of `rolling_max_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_max_1d_nb(arr[:, col], window, minp=minp)
return out
@register_jitted(cache=True)
def rolling_argmin_1d_nb(
arr: tp.Array1d,
window: int,
minp: tp.Optional[int] = None,
local: bool = False,
) -> tp.Array1d:
"""Compute rolling min index."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=int_)
for i in range(arr.shape[0]):
from_i = max(i - window + 1, 0)
to_i = i + 1
minv = arr[from_i]
if local:
mini = 0
else:
mini = from_i
cnt = 0
for k, j in enumerate(range(from_i, to_i)):
if np.isnan(arr[j]):
continue
if np.isnan(minv) or arr[j] < minv:
minv = arr[j]
if local:
mini = k
else:
mini = j
cnt += 1
if cnt < minp:
out[i] = -1
else:
out[i] = mini
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, local=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_argmin_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, local: bool = False) -> tp.Array2d:
"""2-dim version of `rolling_argmin_1d_nb`."""
out = np.empty_like(arr, dtype=int_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_argmin_1d_nb(arr[:, col], window, minp=minp, local=local)
return out
@register_jitted(cache=True)
def rolling_argmax_1d_nb(
arr: tp.Array1d,
window: int,
minp: tp.Optional[int] = None,
local: bool = False,
) -> tp.Array1d:
"""Compute rolling max index."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(arr, dtype=int_)
for i in range(arr.shape[0]):
from_i = max(i - window + 1, 0)
to_i = i + 1
maxv = arr[from_i]
if local:
maxi = 0
else:
maxi = from_i
cnt = 0
for k, j in enumerate(range(from_i, to_i)):
if np.isnan(arr[j]):
continue
if np.isnan(maxv) or arr[j] > maxv:
maxv = arr[j]
if local:
maxi = k
else:
maxi = j
cnt += 1
if cnt < minp:
out[i] = -1
else:
out[i] = maxi
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None, minp=None, local=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_argmax_nb(arr: tp.Array2d, window: int, minp: tp.Optional[int] = None, local: bool = False) -> tp.Array2d:
"""2-dim version of `rolling_argmax_1d_nb`."""
out = np.empty_like(arr, dtype=int_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_argmax_1d_nb(arr[:, col], window, minp=minp, local=local)
return out
@register_jitted(cache=True)
def rolling_any_1d_nb(arr: tp.Array1d, window: int) -> tp.Array1d:
"""Compute rolling any."""
out = np.empty_like(arr, dtype=np.bool_)
last_true_i = -1
for i in range(arr.shape[0]):
if not np.isnan(arr[i]) and arr[i]:
last_true_i = i
from_i = max(0, i + 1 - window)
if last_true_i >= from_i:
out[i] = True
else:
out[i] = False
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_any_nb(arr: tp.Array2d, window: int) -> tp.Array2d:
"""2-dim version of `rolling_any_1d_nb`."""
out = np.empty_like(arr, dtype=np.bool_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_any_1d_nb(arr[:, col], window)
return out
@register_jitted(cache=True)
def rolling_all_1d_nb(arr: tp.Array1d, window: int) -> tp.Array1d:
"""Compute rolling all."""
out = np.empty_like(arr, dtype=np.bool_)
last_false_i = -1
for i in range(arr.shape[0]):
if np.isnan(arr[i]) or not arr[i]:
last_false_i = i
from_i = max(0, i + 1 - window)
if last_false_i >= from_i:
out[i] = False
else:
out[i] = True
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), window=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_all_nb(arr: tp.Array2d, window: int) -> tp.Array2d:
"""2-dim version of `rolling_all_1d_nb`."""
out = np.empty_like(arr, dtype=np.bool_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_all_1d_nb(arr[:, col], window)
return out
@register_jitted(cache=True)
def rolling_pattern_similarity_1d_nb(
arr: tp.Array1d,
pattern: tp.Array1d,
window: tp.Optional[int] = None,
max_window: tp.Optional[int] = None,
row_select_prob: float = 1.0,
window_select_prob: float = 1.0,
interp_mode: int = InterpMode.Mixed,
rescale_mode: int = RescaleMode.MinMax,
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: int = ErrorType.Absolute,
distance_measure: int = DistanceMeasure.MAE,
max_error: tp.FlexArray1dLike = np.nan,
max_error_interp_mode: tp.Optional[int] = None,
max_error_as_maxdist: bool = False,
max_error_strict: bool = False,
min_pct_change: float = np.nan,
max_pct_change: float = np.nan,
min_similarity: float = 0.85,
minp: tp.Optional[int] = None,
) -> tp.Array1d:
"""Compute rolling pattern similarity.
Uses `vectorbtpro.generic.nb.patterns.pattern_similarity_nb`."""
max_error_ = to_1d_array_nb(np.asarray(max_error))
if window is None:
window = pattern.shape[0]
if max_window is None:
max_window = window
out = np.full(arr.shape, np.nan, dtype=float_)
min_max_required = False
if rescale_mode == RescaleMode.MinMax:
min_max_required = True
if not np.isnan(min_pct_change):
min_max_required = True
if not np.isnan(max_pct_change):
min_max_required = True
if not max_error_as_maxdist:
min_max_required = True
if min_max_required:
if np.isnan(pmin):
pmin = np.nanmin(pattern)
if np.isnan(pmax):
pmax = np.nanmax(pattern)
for i in range(arr.shape[0]):
from_i = i - window + 1
to_i = i + 1
if from_i < 0:
continue
if np.random.uniform(0, 1) < row_select_prob:
_vmin = vmin
_vmax = vmax
if min_max_required:
if np.isnan(_vmin) or np.isnan(_vmax):
for j in range(from_i, to_i):
if np.isnan(_vmin) or arr[j] < _vmin:
_vmin = arr[j]
if np.isnan(_vmax) or arr[j] > _vmax:
_vmax = arr[j]
for w in range(window, max_window + 1):
from_i = i - w + 1
to_i = i + 1
if from_i < 0:
continue
if min_max_required:
if w > window:
if arr[from_i] < _vmin:
_vmin = arr[from_i]
if arr[from_i] > _vmax:
_vmax = arr[from_i]
if np.random.uniform(0, 1) < window_select_prob:
arr_window = arr[from_i:to_i]
similarity = pattern_similarity_nb(
arr_window,
pattern,
interp_mode=interp_mode,
rescale_mode=rescale_mode,
vmin=_vmin,
vmax=_vmax,
pmin=pmin,
pmax=pmax,
invert=invert,
error_type=error_type,
distance_measure=distance_measure,
max_error=max_error_,
max_error_interp_mode=max_error_interp_mode,
max_error_as_maxdist=max_error_as_maxdist,
max_error_strict=max_error_strict,
min_pct_change=min_pct_change,
max_pct_change=max_pct_change,
min_similarity=min_similarity,
minp=minp,
)
if not np.isnan(similarity):
if not np.isnan(out[i]):
if similarity > out[i]:
out[i] = similarity
else:
out[i] = similarity
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
pattern=None,
window=None,
max_window=None,
row_select_prob=None,
window_select_prob=None,
interp_mode=None,
rescale_mode=None,
vmin=None,
vmax=None,
pmin=None,
pmax=None,
invert=None,
error_type=None,
distance_measure=None,
max_error=None,
max_error_interp_mode=None,
max_error_as_maxdist=None,
max_error_strict=None,
min_pct_change=None,
max_pct_change=None,
min_similarity=None,
minp=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_pattern_similarity_nb(
arr: tp.Array2d,
pattern: tp.Array1d,
window: tp.Optional[int] = None,
max_window: tp.Optional[int] = None,
row_select_prob: float = 1.0,
window_select_prob: float = 1.0,
interp_mode: int = InterpMode.Mixed,
rescale_mode: int = RescaleMode.MinMax,
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: int = ErrorType.Absolute,
distance_measure: int = DistanceMeasure.MAE,
max_error: tp.FlexArray1dLike = np.nan,
max_error_interp_mode: tp.Optional[int] = None,
max_error_as_maxdist: bool = False,
max_error_strict: bool = False,
min_pct_change: float = np.nan,
max_pct_change: float = np.nan,
min_similarity: float = 0.85,
minp: tp.Optional[int] = None,
) -> tp.Array2d:
"""2-dim version of `rolling_pattern_similarity_1d_nb`."""
max_error_ = to_1d_array_nb(np.asarray(max_error))
if window is None:
window = pattern.shape[0]
if max_window is None:
max_window = window
out = np.full(arr.shape, np.nan, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = rolling_pattern_similarity_1d_nb(
arr[:, col],
pattern,
window=window,
max_window=max_window,
row_select_prob=row_select_prob,
window_select_prob=window_select_prob,
interp_mode=interp_mode,
rescale_mode=rescale_mode,
vmin=vmin,
vmax=vmax,
pmin=pmin,
pmax=pmax,
invert=invert,
error_type=error_type,
distance_measure=distance_measure,
max_error=max_error_,
max_error_interp_mode=max_error_interp_mode,
max_error_as_maxdist=max_error_as_maxdist,
max_error_strict=max_error_strict,
min_pct_change=min_pct_change,
max_pct_change=max_pct_change,
min_similarity=min_similarity,
minp=minp,
)
return out
# ############# Expanding functions ############# #
@register_jitted(cache=True)
def expanding_min_1d_nb(arr: tp.Array1d, minp: int = 1) -> tp.Array1d:
"""Compute expanding min.
Numba equivalent to `pd.Series(arr).expanding(min_periods=minp).min()`."""
out = np.empty_like(arr, dtype=float_)
minv = arr[0]
cnt = 0
for i in range(arr.shape[0]):
if np.isnan(minv) or arr[i] < minv:
minv = arr[i]
if not np.isnan(arr[i]):
cnt += 1
if cnt < minp:
out[i] = np.nan
else:
out[i] = minv
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def expanding_min_nb(arr: tp.Array2d, minp: int = 1) -> tp.Array2d:
"""2-dim version of `expanding_min_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = expanding_min_1d_nb(arr[:, col], minp=minp)
return out
@register_jitted(cache=True)
def expanding_max_1d_nb(arr: tp.Array1d, minp: int = 1) -> tp.Array1d:
"""Compute expanding max.
Numba equivalent to `pd.Series(arr).expanding(min_periods=minp).max()`."""
out = np.empty_like(arr, dtype=float_)
maxv = arr[0]
cnt = 0
for i in range(arr.shape[0]):
if np.isnan(maxv) or arr[i] > maxv:
maxv = arr[i]
if not np.isnan(arr[i]):
cnt += 1
if cnt < minp:
out[i] = np.nan
else:
out[i] = maxv
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(arr=ch.ArraySlicer(axis=1), minp=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def expanding_max_nb(arr: tp.Array2d, minp: int = 1) -> tp.Array2d:
"""2-dim version of `expanding_max_1d_nb`."""
out = np.empty_like(arr, dtype=float_)
for col in prange(arr.shape[1]):
out[:, col] = expanding_max_1d_nb(arr[:, col], minp=minp)
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Generic Numba-compiled functions for simulation ranges.
!!! warning
Resolution is more flexible and may return None while preparation always returns NumPy arrays.
Thus, use preparation, not resolution, in Numba-parallel workflows."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb
from vectorbtpro.base.reshaping import to_1d_array_nb
from vectorbtpro.registries.jit_registry import register_jitted
@register_jitted(cache=True)
def resolve_sim_start_nb(
sim_shape: tp.Shape,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Optional[tp.Array1d]:
"""Resolve simulation start."""
if sim_start is None:
if allow_none:
return None
return np.full(sim_shape[1], 0, dtype=int_)
sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_))
if not check_bounds and len(sim_start_) == sim_shape[1]:
return sim_start_
sim_start_out = np.empty(sim_shape[1], dtype=int_)
can_be_none = True
for i in range(sim_shape[1]):
_sim_start = flex_select_1d_pc_nb(sim_start_, i)
if _sim_start < 0:
_sim_start = sim_shape[0] + _sim_start
elif _sim_start > sim_shape[0]:
_sim_start = sim_shape[0]
sim_start_out[i] = _sim_start
if _sim_start != 0:
can_be_none = False
if allow_none and can_be_none:
return None
return sim_start_out
@register_jitted(cache=True)
def resolve_sim_end_nb(
sim_shape: tp.Shape,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Optional[tp.Array1d]:
"""Resolve simulation end."""
if sim_end is None:
if allow_none:
return None
return np.full(sim_shape[1], sim_shape[0], dtype=int_)
sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_))
if not check_bounds and len(sim_end_) == sim_shape[1]:
return sim_end_
new_sim_end = np.empty(sim_shape[1], dtype=int_)
can_be_none = True
for i in range(sim_shape[1]):
_sim_end = flex_select_1d_pc_nb(sim_end_, i)
if _sim_end < 0:
_sim_end = sim_shape[0] + _sim_end
elif _sim_end > sim_shape[0]:
_sim_end = sim_shape[0]
new_sim_end[i] = _sim_end
if _sim_end != sim_shape[0]:
can_be_none = False
if allow_none and can_be_none:
return None
return new_sim_end
@register_jitted(cache=True)
def resolve_sim_range_nb(
sim_shape: tp.Shape,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Tuple[tp.Optional[tp.Array1d], tp.Optional[tp.Array1d]]:
"""Resolve simulation start and end."""
new_sim_start = resolve_sim_start_nb(
sim_shape=sim_shape,
sim_start=sim_start,
allow_none=allow_none,
check_bounds=check_bounds,
)
new_sim_end = resolve_sim_end_nb(
sim_shape=sim_shape,
sim_end=sim_end,
allow_none=allow_none,
check_bounds=check_bounds,
)
return new_sim_start, new_sim_end
@register_jitted(cache=True)
def resolve_grouped_sim_start_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Optional[tp.Array1d]:
"""Resolve grouped simulation start."""
if sim_start is None:
if allow_none:
return None
return np.full(len(group_lens), 0, dtype=int_)
sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_))
if len(sim_start_) == len(group_lens):
if not check_bounds:
return sim_start_
return resolve_sim_start_nb(
(target_shape[0], len(group_lens)),
sim_start=sim_start_,
allow_none=allow_none,
check_bounds=check_bounds,
)
new_sim_start = np.empty(len(group_lens), dtype=int_)
can_be_none = True
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
min_sim_start = target_shape[0]
for col in range(from_col, to_col):
_sim_start = flex_select_1d_pc_nb(sim_start_, col)
if _sim_start < 0:
_sim_start = target_shape[0] + _sim_start
elif _sim_start > target_shape[0]:
_sim_start = target_shape[0]
if _sim_start < min_sim_start:
min_sim_start = _sim_start
new_sim_start[group] = min_sim_start
if min_sim_start != 0:
can_be_none = False
if allow_none and can_be_none:
return None
return new_sim_start
@register_jitted(cache=True)
def resolve_grouped_sim_end_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Optional[tp.Array1d]:
"""Resolve grouped simulation end."""
if sim_end is None:
if allow_none:
return None
return np.full(len(group_lens), target_shape[0], dtype=int_)
sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_))
if len(sim_end_) == len(group_lens):
if not check_bounds:
return sim_end_
return resolve_sim_end_nb(
(target_shape[0], len(group_lens)),
sim_end=sim_end_,
allow_none=allow_none,
check_bounds=check_bounds,
)
new_sim_end = np.empty(len(group_lens), dtype=int_)
can_be_none = True
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
max_sim_end = 0
for col in range(from_col, to_col):
_sim_end = flex_select_1d_pc_nb(sim_end_, col)
if _sim_end < 0:
_sim_end = target_shape[0] + _sim_end
elif _sim_end > target_shape[0]:
_sim_end = target_shape[0]
if _sim_end > max_sim_end:
max_sim_end = _sim_end
new_sim_end[group] = max_sim_end
if max_sim_end != target_shape[0]:
can_be_none = False
if allow_none and can_be_none:
return None
return new_sim_end
@register_jitted(cache=True)
def resolve_grouped_sim_range_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Tuple[tp.Optional[tp.Array1d], tp.Optional[tp.Array1d]]:
"""Resolve grouped simulation start and end."""
new_sim_start = resolve_grouped_sim_start_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start,
allow_none=allow_none,
check_bounds=check_bounds,
)
new_sim_end = resolve_grouped_sim_end_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_end=sim_end,
allow_none=allow_none,
check_bounds=check_bounds,
)
return new_sim_start, new_sim_end
@register_jitted(cache=True)
def resolve_ungrouped_sim_start_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Optional[tp.Array1d]:
"""Resolve ungrouped simulation start."""
if sim_start is None:
if allow_none:
return None
return np.full(target_shape[1], 0, dtype=int_)
sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_))
if len(sim_start_) == target_shape[1]:
if not check_bounds:
return sim_start_
return resolve_sim_start_nb(
target_shape,
sim_start=sim_start_,
allow_none=allow_none,
check_bounds=check_bounds,
)
new_sim_start = np.empty(target_shape[1], dtype=int_)
can_be_none = True
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
_sim_start = flex_select_1d_pc_nb(sim_start_, group)
if _sim_start < 0:
_sim_start = target_shape[0] + _sim_start
elif _sim_start > target_shape[0]:
_sim_start = target_shape[0]
for col in range(from_col, to_col):
new_sim_start[col] = _sim_start
if _sim_start != 0:
can_be_none = False
if allow_none and can_be_none:
return None
return new_sim_start
@register_jitted(cache=True)
def resolve_ungrouped_sim_end_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Optional[tp.Array1d]:
"""Resolve ungrouped simulation end."""
if sim_end is None:
if allow_none:
return None
return np.full(target_shape[1], target_shape[0], dtype=int_)
sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_))
if len(sim_end_) == target_shape[1]:
if not check_bounds:
return sim_end_
return resolve_sim_end_nb(
target_shape,
sim_end=sim_end_,
allow_none=allow_none,
check_bounds=check_bounds,
)
new_sim_end = np.empty(target_shape[1], dtype=int_)
can_be_none = True
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
_sim_end = flex_select_1d_pc_nb(sim_end_, group)
if _sim_end < 0:
_sim_end = target_shape[0] + _sim_end
elif _sim_end > target_shape[0]:
_sim_end = target_shape[0]
for col in range(from_col, to_col):
new_sim_end[col] = _sim_end
if _sim_end != target_shape[0]:
can_be_none = False
if allow_none and can_be_none:
return None
return new_sim_end
@register_jitted(cache=True)
def resolve_ungrouped_sim_range_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
allow_none: bool = False,
check_bounds: bool = True,
) -> tp.Tuple[tp.Optional[tp.Array1d], tp.Optional[tp.Array1d]]:
"""Resolve ungrouped simulation start and end."""
new_sim_start = resolve_ungrouped_sim_start_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start,
allow_none=allow_none,
check_bounds=check_bounds,
)
new_sim_end = resolve_ungrouped_sim_end_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_end=sim_end,
allow_none=allow_none,
check_bounds=check_bounds,
)
return new_sim_start, new_sim_end
@register_jitted(cache=True)
def prepare_sim_start_nb(
sim_shape: tp.Shape,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Array1d:
"""Prepare simulation start."""
if sim_start is None:
return np.full(sim_shape[1], 0, dtype=int_)
sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_))
if not check_bounds and len(sim_start_) == sim_shape[1]:
return sim_start_
sim_start_out = np.empty(sim_shape[1], dtype=int_)
for i in range(sim_shape[1]):
_sim_start = flex_select_1d_pc_nb(sim_start_, i)
if _sim_start < 0:
_sim_start = sim_shape[0] + _sim_start
elif _sim_start > sim_shape[0]:
_sim_start = sim_shape[0]
sim_start_out[i] = _sim_start
return sim_start_out
@register_jitted(cache=True)
def prepare_sim_end_nb(
sim_shape: tp.Shape,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Array1d:
"""Prepare simulation end."""
if sim_end is None:
return np.full(sim_shape[1], sim_shape[0], dtype=int_)
sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_))
if not check_bounds and len(sim_end_) == sim_shape[1]:
return sim_end_
new_sim_end = np.empty(sim_shape[1], dtype=int_)
for i in range(sim_shape[1]):
_sim_end = flex_select_1d_pc_nb(sim_end_, i)
if _sim_end < 0:
_sim_end = sim_shape[0] + _sim_end
elif _sim_end > sim_shape[0]:
_sim_end = sim_shape[0]
new_sim_end[i] = _sim_end
return new_sim_end
@register_jitted(cache=True)
def prepare_sim_range_nb(
sim_shape: tp.Shape,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Prepare simulation start and end."""
new_sim_start = prepare_sim_start_nb(
sim_shape=sim_shape,
sim_start=sim_start,
check_bounds=check_bounds,
)
new_sim_end = prepare_sim_end_nb(
sim_shape=sim_shape,
sim_end=sim_end,
check_bounds=check_bounds,
)
return new_sim_start, new_sim_end
@register_jitted(cache=True)
def prepare_grouped_sim_start_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Array1d:
"""Prepare grouped simulation start."""
if sim_start is None:
return np.full(len(group_lens), 0, dtype=int_)
sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_))
if len(sim_start_) == len(group_lens):
if not check_bounds:
return sim_start_
return prepare_sim_start_nb(
(target_shape[0], len(group_lens)),
sim_start=sim_start_,
check_bounds=check_bounds,
)
new_sim_start = np.empty(len(group_lens), dtype=int_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
min_sim_start = target_shape[0]
for col in range(from_col, to_col):
_sim_start = flex_select_1d_pc_nb(sim_start_, col)
if _sim_start < 0:
_sim_start = target_shape[0] + _sim_start
elif _sim_start > target_shape[0]:
_sim_start = target_shape[0]
if _sim_start < min_sim_start:
min_sim_start = _sim_start
new_sim_start[group] = min_sim_start
return new_sim_start
@register_jitted(cache=True)
def prepare_grouped_sim_end_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Array1d:
"""Prepare grouped simulation end."""
if sim_end is None:
return np.full(len(group_lens), target_shape[0], dtype=int_)
sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_))
if len(sim_end_) == len(group_lens):
if not check_bounds:
return sim_end_
return prepare_sim_end_nb(
(target_shape[0], len(group_lens)),
sim_end=sim_end_,
check_bounds=check_bounds,
)
new_sim_end = np.empty(len(group_lens), dtype=int_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
max_sim_end = 0
for col in range(from_col, to_col):
_sim_end = flex_select_1d_pc_nb(sim_end_, col)
if _sim_end < 0:
_sim_end = target_shape[0] + _sim_end
elif _sim_end > target_shape[0]:
_sim_end = target_shape[0]
if _sim_end > max_sim_end:
max_sim_end = _sim_end
new_sim_end[group] = max_sim_end
return new_sim_end
@register_jitted(cache=True)
def prepare_grouped_sim_range_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Prepare grouped simulation start and end."""
new_sim_start = prepare_grouped_sim_start_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start,
check_bounds=check_bounds,
)
new_sim_end = prepare_grouped_sim_end_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_end=sim_end,
check_bounds=check_bounds,
)
return new_sim_start, new_sim_end
@register_jitted(cache=True)
def prepare_ungrouped_sim_start_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Array1d:
"""Prepare ungrouped simulation start."""
if sim_start is None:
return np.full(target_shape[1], 0, dtype=int_)
sim_start_ = to_1d_array_nb(np.asarray(sim_start).astype(int_))
if len(sim_start_) == target_shape[1]:
if not check_bounds:
return sim_start_
return prepare_sim_start_nb(
target_shape,
sim_start=sim_start_,
check_bounds=check_bounds,
)
new_sim_start = np.empty(target_shape[1], dtype=int_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
_sim_start = flex_select_1d_pc_nb(sim_start_, group)
if _sim_start < 0:
_sim_start = target_shape[0] + _sim_start
elif _sim_start > target_shape[0]:
_sim_start = target_shape[0]
for col in range(from_col, to_col):
new_sim_start[col] = _sim_start
return new_sim_start
@register_jitted(cache=True)
def prepare_ungrouped_sim_end_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Array1d:
"""Prepare ungrouped simulation end."""
if sim_end is None:
return np.full(target_shape[1], target_shape[0], dtype=int_)
sim_end_ = to_1d_array_nb(np.asarray(sim_end).astype(int_))
if len(sim_end_) == target_shape[1]:
if not check_bounds:
return sim_end_
return prepare_sim_end_nb(
target_shape,
sim_end=sim_end_,
check_bounds=check_bounds,
)
new_sim_end = np.empty(target_shape[1], dtype=int_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
_sim_end = flex_select_1d_pc_nb(sim_end_, group)
if _sim_end < 0:
_sim_end = target_shape[0] + _sim_end
elif _sim_end > target_shape[0]:
_sim_end = target_shape[0]
for col in range(from_col, to_col):
new_sim_end[col] = _sim_end
return new_sim_end
@register_jitted(cache=True)
def prepare_ungrouped_sim_range_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
check_bounds: bool = True,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Prepare ungrouped simulation start and end."""
new_sim_start = prepare_ungrouped_sim_start_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start,
check_bounds=check_bounds,
)
new_sim_end = prepare_ungrouped_sim_end_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_end=sim_end,
check_bounds=check_bounds,
)
return new_sim_start, new_sim_end
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for splitting."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.generic.splitting.base import *
from vectorbtpro.generic.splitting.decorators import *
from vectorbtpro.generic.splitting.nb import *
from vectorbtpro.generic.splitting.purged import *
from vectorbtpro.generic.splitting.sklearn_ import *
__import_if_installed__ = dict()
__import_if_installed__["sklearn_"] = "sklearn"
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for splitting."""
import inspect
import math
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.accessors import BaseIDXAccessor
from vectorbtpro.base.grouping.base import Grouper
from vectorbtpro.base.indexes import combine_indexes, stack_indexes
from vectorbtpro.base.indexing import hslice, PandasIndexer, get_index_ranges
from vectorbtpro.base.merging import row_stack_merge, column_stack_merge, is_merge_func_from_config
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.reshaping import to_dict
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.generic.splitting import nb
from vectorbtpro.generic.splitting.purged import BasePurgedCV, PurgedWalkForwardCV, PurgedKFoldCV
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.annotations import Annotatable, has_annotatables
from vectorbtpro.utils.array_ import is_range
from vectorbtpro.utils.attr_ import DefineMixin, define, MISSING
from vectorbtpro.utils.colors import adjust_opacity
from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig
from vectorbtpro.utils.decorators import hybrid_method
from vectorbtpro.utils.eval_ import Evaluable
from vectorbtpro.utils.execution import Task, NoResult, NoResultsException, filter_out_no_results, execute
from vectorbtpro.utils.merging import parse_merge_func, MergeFunc
from vectorbtpro.utils.parsing import (
get_func_arg_names,
annotate_args,
flatten_ann_args,
unflatten_ann_args,
ann_args_to_args,
)
from vectorbtpro.utils.selection import PosSel, LabelSel
from vectorbtpro.utils.template import CustomTemplate, Rep, RepFunc, substitute_templates
from vectorbtpro.utils.warnings_ import warn
if tp.TYPE_CHECKING:
from sklearn.model_selection import BaseCrossValidator as BaseCrossValidatorT
else:
BaseCrossValidatorT = "BaseCrossValidator"
__all__ = [
"FixRange",
"RelRange",
"Takeable",
"Splitter",
]
__pdoc__ = {}
SplitterT = tp.TypeVar("SplitterT", bound="Splitter")
@define
class FixRange(DefineMixin):
"""Class that represents a fixed range."""
range_: tp.FixRangeLike = define.field()
"""Range."""
@define
class RelRange(DefineMixin):
"""Class that represents a relative range."""
offset: tp.Union[int, float, tp.TimedeltaLike] = define.field(default=0)
"""Offset.
Floating values between 0 and 1 are considered relative.
Can be negative."""
offset_anchor: str = define.field(default="prev_end")
"""Offset anchor.
Supported are
* 'start': Start of the range
* 'end': End of the range
* 'prev_start': Start of the previous range
* 'prev_end': End of the previous range
"""
offset_space: str = define.field(default="free")
"""Offset space.
Supported are
* 'all': All space
* 'free': Remaining space after the offset anchor
* 'prev': Length of the previous range
Applied only when `RelRange.offset` is a relative number."""
length: tp.Union[int, float, tp.TimedeltaLike] = define.field(default=1.0)
"""Length.
Floating values between 0 and 1 are considered relative.
Can be negative."""
length_space: str = define.field(default="free")
"""Length space.
Supported are
* 'all': All space
* 'free': Remaining space after the offset
* 'free_or_prev': Remaining space after the offset or the start/end of the previous range,
depending what comes first in the direction of `RelRange.length`
Applied only when `RelRange.length` is a relative number."""
out_of_bounds: str = define.field(default="warn")
"""Check if start and stop are within bounds.
Supported are
* 'keep': Keep out-of-bounds values
* 'ignore': Ignore if out-of-bounds
* 'warn': Emit a warning if out-of-bounds
* 'raise": Raise an error if out-of-bounds
"""
is_gap: bool = define.field(default=False)
"""Whether the range acts as a gap."""
def __attrs_post_init__(self):
object.__setattr__(self, "offset_anchor", self.offset_anchor.lower())
if self.offset_anchor not in ("start", "end", "prev_start", "prev_end", "next_start", "next_end"):
raise ValueError(f"Invalid offset_anchor: '{self.offset_anchor}'")
object.__setattr__(self, "offset_space", self.offset_space.lower())
if self.offset_space not in ("all", "free", "prev"):
raise ValueError(f"Invalid offset_space: '{self.offset_space}'")
object.__setattr__(self, "length_space", self.length_space.lower())
if self.length_space not in ("all", "free", "free_or_prev"):
raise ValueError(f"Invalid length_space: '{self.length_space}'")
object.__setattr__(self, "out_of_bounds", self.out_of_bounds.lower())
if self.out_of_bounds not in ("keep", "ignore", "warn", "raise"):
raise ValueError(f"Invalid out_of_bounds: '{self.out_of_bounds}'")
def to_slice(
self,
total_len: int,
prev_start: int = 0,
prev_end: int = 0,
index: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> slice:
"""Convert the relative range into a slice."""
if index is not None:
index = dt.prepare_dt_index(index)
freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False)
offset_anchor = self.offset_anchor
offset = self.offset
length = self.length
if not checks.is_number(offset) or not checks.is_number(length):
if not isinstance(index, pd.DatetimeIndex):
raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}")
if offset_anchor == "start":
if checks.is_number(offset):
offset_anchor = 0
else:
offset_anchor = index[0]
elif offset_anchor == "end":
if checks.is_number(offset):
offset_anchor = total_len
else:
if freq is None:
raise ValueError("Must provide frequency")
offset_anchor = index[-1] + freq
elif offset_anchor == "prev_start":
if checks.is_number(offset):
offset_anchor = prev_start
else:
offset_anchor = index[prev_start]
else:
if checks.is_number(offset):
offset_anchor = prev_end
else:
if prev_end < total_len:
offset_anchor = index[prev_end]
else:
if freq is None:
raise ValueError("Must provide frequency")
offset_anchor = index[-1] + freq
if checks.is_float(offset) and 0 <= abs(offset) <= 1:
if self.offset_space == "all":
offset = offset_anchor + int(offset * total_len)
elif self.offset_space == "free":
if offset < 0:
offset = int((1 + offset) * offset_anchor)
else:
offset = offset_anchor + int(offset * (total_len - offset_anchor))
else:
offset = offset_anchor + int(offset * (prev_end - prev_start))
else:
if checks.is_float(offset):
if not offset.is_integer():
raise TypeError(f"Floating number for offset ({offset}) must be between 0 and 1")
offset = offset_anchor + int(offset)
elif not checks.is_int(offset):
offset = offset_anchor + dt.to_freq(offset)
if index[0] <= offset <= index[-1]:
offset = index.get_indexer([offset], method="ffill")[0]
elif offset < index[0]:
if freq is None:
raise ValueError("Must provide frequency")
offset = -int((index[0] - offset) / freq)
else:
if freq is None:
raise ValueError("Must provide frequency")
offset = total_len - 1 + int((offset - index[-1]) / freq)
else:
offset = offset_anchor + offset
if checks.is_float(length) and 0 <= abs(length) <= 1:
if self.length_space == "all":
length = int(length * total_len)
elif self.length_space == "free":
if length < 0:
length = int(length * offset)
else:
length = int(length * (total_len - offset))
else:
if length < 0:
if offset > prev_end:
length = int(length * (offset - prev_end))
else:
length = int(length * offset)
else:
if offset < prev_start:
length = int(length * (prev_start - offset))
else:
length = int(length * (total_len - offset))
else:
if checks.is_float(length):
if not length.is_integer():
raise TypeError(f"Floating number for length ({length}) must be between 0 and 1")
length = int(length)
elif not checks.is_int(length):
length = dt.to_freq(length)
start = offset
if checks.is_int(length):
stop = start + length
else:
if 0 <= start < total_len:
stop = index[start] + length
elif start < 0:
if freq is None:
raise ValueError("Must provide frequency")
stop = index[0] + start * freq + length
else:
if freq is None:
raise ValueError("Must provide frequency")
stop = index[-1] + (start - total_len + 1) * freq + length
if stop <= index[-1]:
stop = index.get_indexer([stop], method="bfill")[0]
else:
if freq is None:
raise ValueError("Must provide frequency")
stop = total_len - 1 + int((stop - index[-1]) / freq)
if checks.is_int(length):
if length < 0:
start, stop = stop, start
else:
if length < pd.Timedelta(0):
start, stop = stop, start
if start < 0:
if self.out_of_bounds == "ignore":
start = 0
elif self.out_of_bounds == "warn":
warn(f"Range start ({start}) is out of bounds")
start = 0
elif self.out_of_bounds == "raise":
raise ValueError(f"Range start ({start}) is out of bounds")
if stop > total_len:
if self.out_of_bounds == "ignore":
stop = total_len
elif self.out_of_bounds == "warn":
warn(f"Range stop ({stop}) is out of bounds")
stop = total_len
elif self.out_of_bounds == "raise":
raise ValueError(f"Range stop ({stop}) is out of bounds")
if stop - start <= 0:
raise ValueError("Range length is negative or zero")
return slice(start, stop)
@define
class Takeable(Evaluable, Annotatable, DefineMixin):
"""Class that represents an object from which a range can be taken."""
obj: tp.Any = define.required_field()
"""Takeable object."""
remap_to_obj: bool = define.optional_field()
"""Whether to remap `Splitter.index` to the index of `Takeable.obj`.
Otherwise, will assume that the object has the same index."""
index: tp.Optional[tp.IndexLike] = define.optional_field()
"""Index of the object.
If not present, will be accessed using `Splitter.get_obj_index`."""
freq: tp.Optional[tp.FrequencyLike] = define.optional_field()
"""Frequency of `Takeable.index`."""
point_wise: bool = define.optional_field()
"""Whether to select one range point at a time and return a tuple."""
eval_id: tp.Optional[tp.MaybeSequence[tp.Hashable]] = define.field(default=None)
"""One or more identifiers at which to evaluate this instance."""
class ZeroLengthError(ValueError):
"""Thrown whenever a range has a length of zero."""
pass
class Splitter(Analyzable):
"""Base class for splitting."""
@classmethod
def from_splits(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
splits: tp.Splits,
squeeze: bool = False,
fix_ranges: bool = True,
wrap_with_fixrange: bool = False,
split_range_kwargs: tp.KwargsLike = None,
split_check_template: tp.Optional[tp.CustomTemplate] = None,
template_context: tp.KwargsLike = None,
split_labels: tp.Optional[tp.IndexLike] = None,
set_labels: tp.Optional[tp.IndexLike] = None,
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from an iterable of splits.
Argument `splits` supports both absolute and relative ranges.
To transform relative ranges into the absolute format, enable `fix_ranges`.
Arguments `split_range_kwargs` are then passed to `Splitter.split_range`.
Enable `wrap_with_fixrange` to wrap any fixed range with `FixRange`. If the range
is an array, it will be wrapped regardless of this argument to avoid building a 3d array.
Pass a template via `split_check_template` to discard splits that do not fulfill certain criteria.
The current split will be available as `split`. Should return a boolean (`False` to discard).
Labels for splits and sets can be provided via `split_labels` and `set_labels` respectively.
Both arguments can be provided as templates. The split array will be available as `splits`."""
index = dt.prepare_dt_index(index)
if split_range_kwargs is None:
split_range_kwargs = {}
new_splits = []
removed_indices = []
for i, split in enumerate(splits):
already_fixed = False
if checks.is_number(split) or checks.is_td_like(split):
split = cls.split_range(
slice(None),
split,
template_context=template_context,
index=index,
wrap_with_fixrange=False,
**split_range_kwargs,
)
already_fixed = True
new_split = split
ndim = 2
elif cls.is_range_relative(split):
new_split = [split]
ndim = 1
elif not checks.is_sequence(split):
new_split = [split]
ndim = 1
elif isinstance(split, np.ndarray):
new_split = [split]
ndim = 1
else:
new_split = split
ndim = 2
if fix_ranges and not already_fixed:
new_split = cls.split_range(
slice(None),
new_split,
template_context=template_context,
index=index,
wrap_with_fixrange=False,
**split_range_kwargs,
)
_new_split = []
for range_ in new_split:
if checks.is_number(range_) or checks.is_td_like(range_):
range_ = RelRange(length=range_)
if not isinstance(range_, (FixRange, RelRange)):
if wrap_with_fixrange or checks.is_sequence(range_):
_new_split.append(FixRange(range_))
else:
_new_split.append(range_)
else:
_new_split.append(range_)
if split_check_template is not None:
_template_context = merge_dicts(dict(index=index, i=i, split=_new_split), template_context)
split_ok = substitute_templates(split_check_template, _template_context, eval_id="split_check_template")
if not split_ok:
removed_indices.append(i)
continue
new_splits.append(_new_split)
if len(new_splits) == 0:
raise ValueError("Must provide at least one range")
new_splits_arr = np.asarray(new_splits, dtype=object)
if squeeze and new_splits_arr.shape[1] == 1:
ndim = 1
if split_labels is None:
split_labels = pd.RangeIndex(stop=new_splits_arr.shape[0], name="split")
else:
if isinstance(split_labels, CustomTemplate):
_template_context = merge_dicts(dict(index=index, splits_arr=new_splits_arr), template_context)
split_labels = substitute_templates(split_labels, _template_context, eval_id=split_labels)
if not isinstance(split_labels, pd.Index):
split_labels = pd.Index(split_labels, name="split")
else:
if not isinstance(split_labels, pd.Index):
split_labels = pd.Index(split_labels, name="split")
if len(removed_indices) > 0:
split_labels = split_labels.delete(removed_indices)
if set_labels is None:
set_labels = pd.Index(["set_%d" % i for i in range(new_splits_arr.shape[1])], name="set")
else:
if isinstance(split_labels, CustomTemplate):
_template_context = merge_dicts(dict(index=index, splits_arr=new_splits_arr), template_context)
set_labels = substitute_templates(set_labels, _template_context, eval_id=set_labels)
if not isinstance(set_labels, pd.Index):
set_labels = pd.Index(set_labels, name="set")
if wrapper_kwargs is None:
wrapper_kwargs = {}
wrapper = ArrayWrapper(index=split_labels, columns=set_labels, ndim=ndim, **wrapper_kwargs)
return cls(wrapper, index, new_splits_arr, **kwargs)
@classmethod
def from_single(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
split: tp.Optional[tp.SplitLike],
split_range_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a single split."""
if split_range_kwargs is None:
split_range_kwargs = {}
new_split = cls.split_range(
slice(None),
split,
template_context=template_context,
index=index,
**split_range_kwargs,
)
splits = [new_split]
return cls.from_splits(
index,
splits,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
@classmethod
def from_rolling(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
length: tp.Union[int, float, tp.TimedeltaLike],
offset: tp.Union[int, float, tp.TimedeltaLike] = 0,
offset_anchor: str = "prev_end",
offset_anchor_set: tp.Optional[int] = 0,
offset_space: str = "prev",
backwards: tp.Union[bool, str] = False,
split: tp.Optional[tp.SplitLike] = None,
split_range_kwargs: tp.KwargsLike = None,
range_bounds_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a rolling range of a fixed length.
Uses `Splitter.from_splits` to prepare the splits array and labels, and to build the instance.
Args:
index (index_like): Index.
length (int, float, or timedelta_like): See `RelRange.length`.
offset (int, float, or timedelta_like): See `RelRange.offset`.
offset_anchor (str): See `RelRange.offset_anchor`.
offset_anchor_set (int): Offset anchor set.
Selects the set from the previous range to be used as an offset anchor.
If None, the whole previous split is considered as a single range.
By default, it's the first set.
offset_space (str): See `RelRange.offset_space`.
backwards (bool or str): Whether to roll backwards.
If 'sorted', will roll backwards and sort the resulting splits by the start index.
split (any): Ranges to split the range into.
If None, will produce the entire range as a single range.
Otherwise, will use `Splitter.split_range` to split the range into multiple ranges.
split_range_kwargs (dict): Keyword arguments passed to `Splitter.split_range`.
range_bounds_kwargs (dict): Keyword arguments passed to `Splitter.get_range_bounds`.
template_context (dict): Context used to substitute templates in ranges.
freq (any): Index frequency in case it cannot be parsed from `index`.
If None, will be parsed using `vectorbtpro.base.accessors.BaseIDXAccessor.get_freq`.
**kwargs: Keyword arguments passed to the constructor of `Splitter`.
Usage:
* Divide a range into a set of non-overlapping ranges:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", "2021", freq="D")
>>> splitter = vbt.Splitter.from_rolling(index, 30)
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
* Divide a range into ranges, each split into 1/2:
```pycon
>>> splitter = vbt.Splitter.from_rolling(
... index,
... 60,
... split=1/2,
... set_labels=["train", "test"]
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
* Make the ranges above non-overlapping by using the right bound of the last
set as an offset anchor:
```pycon
>>> splitter = vbt.Splitter.from_rolling(
... index,
... 60,
... offset_anchor_set=-1,
... split=1/2,
... set_labels=["train", "test"]
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
index = dt.prepare_dt_index(index)
freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False)
if isinstance(backwards, str):
if backwards.lower() == "sorted":
sort_backwards = True
else:
raise ValueError(f"Invalid backwards: '{backwards}'")
backwards = True
else:
sort_backwards = False
if split_range_kwargs is None:
split_range_kwargs = {}
if "freq" not in split_range_kwargs:
split_range_kwargs = dict(split_range_kwargs)
split_range_kwargs["freq"] = freq
if range_bounds_kwargs is None:
range_bounds_kwargs = {}
splits = []
bounds = []
while True:
if len(splits) == 0:
new_split = RelRange(
length=-length if backwards else length,
offset_anchor="end" if backwards else "start",
out_of_bounds="keep",
).to_slice(total_len=len(index), index=index, freq=freq)
else:
if offset_anchor_set is None:
prev_start, prev_end = bounds[-1][0][0], bounds[-1][-1][1]
else:
prev_start, prev_end = bounds[-1][offset_anchor_set]
new_split = RelRange(
offset=offset,
offset_anchor=offset_anchor,
offset_space=offset_space,
length=-length if backwards else length,
length_space="all",
out_of_bounds="keep",
).to_slice(total_len=len(index), prev_start=prev_start, prev_end=prev_end, index=index, freq=freq)
if backwards:
if new_split.stop >= bounds[-1][-1][1]:
raise ValueError("Infinite loop detected. Provide a positive offset.")
else:
if new_split.start <= bounds[-1][0][0]:
raise ValueError("Infinite loop detected. Provide a positive offset.")
if backwards:
if new_split.start < 0:
break
if new_split.stop > len(index):
raise ValueError("Range stop cannot exceed index length")
else:
if new_split.start < 0:
raise ValueError("Range start cannot be negative")
if new_split.stop > len(index):
break
if split is not None:
new_split = cls.split_range(
new_split,
split,
template_context=template_context,
index=index,
**split_range_kwargs,
)
bounds.append(
tuple(
map(
lambda x: cls.get_range_bounds(
x,
template_context=template_context,
index=index,
**range_bounds_kwargs,
),
new_split,
)
)
)
else:
bounds.append(((new_split.start, new_split.stop),))
splits.append(new_split)
return cls.from_splits(
index,
splits[::-1] if sort_backwards else splits,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
@classmethod
def from_n_rolling(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
n: int,
length: tp.Union[None, str, int, float, tp.TimedeltaLike] = None,
optimize_anchor_set: int = 1,
split: tp.Optional[tp.SplitLike] = None,
split_range_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a number of rolling ranges of the same length.
If `length` is None, splits the index evenly into `n` non-overlapping ranges
using `Splitter.from_rolling`. Otherwise, picks `n` evenly-spaced, potentially overlapping
ranges of a fixed length. For other arguments, see `Splitter.from_rolling`.
If `length` is "optimize", searches for a length to cover the most of the index.
Use `optimize_anchor_set` to provide the index of a set that should become non-overlapping.
Usage:
* Roll 10 ranges with 100 elements, and split it into 3/4:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", "2021", freq="D")
>>> splitter = vbt.Splitter.from_n_rolling(
... index,
... 10,
... length=100,
... split=3/4,
... set_labels=["train", "test"]
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
index = dt.prepare_dt_index(index)
freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False)
if split_range_kwargs is None:
split_range_kwargs = {}
if "freq" not in split_range_kwargs:
split_range_kwargs = dict(split_range_kwargs)
split_range_kwargs["freq"] = freq
if length is None:
return cls.from_rolling(
index,
length=len(index) // n,
offset=0,
offset_anchor="prev_end",
offset_anchor_set=None,
split=split,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
if isinstance(length, str) and length.lower() == "optimize":
from scipy.optimize import minimize_scalar
if split is not None and not checks.is_float(split):
raise TypeError("Split must be a float when length='optimize'")
checks.assert_in(optimize_anchor_set, (0, 1))
if split is None:
ratio = 1.0
else:
ratio = split
def empty_len_objective(length):
length = math.ceil(length)
first_len = int(ratio * length)
second_len = length - first_len
if split is None or optimize_anchor_set == 0:
empty_len = len(index) - (n * first_len + second_len)
else:
empty_len = len(index) - (n * second_len + first_len)
if empty_len >= 0:
return empty_len
return len(index)
length = math.ceil(minimize_scalar(empty_len_objective).x)
if split is None or optimize_anchor_set == 0:
offset = int(ratio * length)
else:
offset = length - int(ratio * length)
return cls.from_rolling(
index,
length=length,
offset=offset,
offset_anchor="prev_start",
offset_anchor_set=None,
split=split,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
if checks.is_float(length):
if 0 <= abs(length) <= 1:
length = len(index) * length
elif not length.is_integer():
raise TypeError("Floating number for length must be between 0 and 1")
length = int(length)
if checks.is_int(length):
if length < 1 or length > len(index):
raise TypeError(f"Length must be within [{1}, {len(index)}]")
offsets = np.arange(len(index))
offsets = offsets[offsets + length <= len(index)]
else:
length = dt.to_freq(length)
if freq is None:
raise ValueError("Must provide freq")
if length < freq or length > index[-1] + freq - index[0]:
raise TypeError(f"Length must be within [{freq}, {index[-1] + freq - index[0]}]")
offsets = index[index + length <= index[-1] + freq] - index[0]
if n > len(offsets):
n = len(offsets)
rows = np.round(np.linspace(0, len(offsets) - 1, n)).astype(int)
offsets = offsets[rows]
splits = []
for offset in offsets:
new_split = RelRange(
offset=offset,
length=length,
).to_slice(len(index), index=index, freq=freq)
if split is not None:
new_split = cls.split_range(
new_split,
split,
template_context=template_context,
index=index,
**split_range_kwargs,
)
splits.append(new_split)
return cls.from_splits(
index,
splits,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
@classmethod
def from_expanding(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
min_length: tp.Union[int, float, tp.TimedeltaLike],
offset: tp.Union[int, float, tp.TimedeltaLike],
split: tp.Optional[tp.SplitLike] = None,
split_range_kwargs: tp.KwargsLike = None,
range_bounds_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from an expanding range.
Argument `min_length` is the minimum length of the expanding range. Provide it as
a float between 0 and 1 to make it relative to the length of the index. Argument `offset` is
an offset after the right bound of the previous range from which the next range should start.
It can also be a float relative to the index length. For other arguments, see `Splitter.from_rolling`.
Usage:
* Roll an expanding range with a length of 10 and an offset of 10, and split it into 3/4:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", "2021", freq="D")
>>> splitter = vbt.Splitter.from_expanding(
... index,
... 10,
... 10,
... split=3/4,
... set_labels=["train", "test"]
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
index = dt.prepare_dt_index(index)
freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False)
if split_range_kwargs is None:
split_range_kwargs = {}
if range_bounds_kwargs is None:
range_bounds_kwargs = {}
if "freq" not in split_range_kwargs:
split_range_kwargs = dict(split_range_kwargs)
split_range_kwargs["freq"] = freq
splits = []
bounds = []
while True:
if len(splits) == 0:
new_split = RelRange(
length=min_length,
out_of_bounds="keep",
).to_slice(total_len=len(index), index=index, freq=freq)
else:
prev_end = bounds[-1][-1][-1]
new_split = RelRange(
offset=offset,
offset_anchor="prev_end",
offset_space="all",
length=-1.0,
out_of_bounds="keep",
).to_slice(total_len=len(index), prev_end=prev_end, index=index, freq=freq)
if new_split.stop <= prev_end:
raise ValueError("Infinite loop detected. Provide a positive offset.")
if new_split.start < 0:
raise ValueError("Range start cannot be negative")
if new_split.stop > len(index):
break
if split is not None:
new_split = cls.split_range(
new_split,
split,
template_context=template_context,
index=index,
**split_range_kwargs,
)
bounds.append(
tuple(
map(
lambda x: cls.get_range_bounds(
x,
template_context=template_context,
index=index,
**range_bounds_kwargs,
),
new_split,
)
)
)
else:
bounds.append(((new_split.start, new_split.stop),))
splits.append(new_split)
return cls.from_splits(
index,
splits,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
@classmethod
def from_n_expanding(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
n: int,
min_length: tp.Union[None, int, float, tp.TimedeltaLike] = None,
split: tp.Optional[tp.SplitLike] = None,
split_range_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a number of expanding ranges.
Picks `n` evenly-spaced, expanding ranges. Argument `min_length` defines the minimum
length for each range. For other arguments, see `Splitter.from_rolling`.
Usage:
* Roll 10 expanding ranges with a minimum length of 100, while reserving 50 elements for test:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", "2021", freq="D")
>>> splitter = vbt.Splitter.from_n_expanding(
... index,
... 10,
... min_length=100,
... split=-50,
... set_labels=["train", "test"]
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
index = dt.prepare_dt_index(index)
freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False)
if split_range_kwargs is None:
split_range_kwargs = {}
if "freq" not in split_range_kwargs:
split_range_kwargs = dict(split_range_kwargs)
split_range_kwargs["freq"] = freq
if min_length is None:
min_length = len(index) // n
if checks.is_float(min_length):
if 0 <= abs(min_length) <= 1:
min_length = len(index) * min_length
elif not min_length.is_integer():
raise TypeError("Floating number for minimum length must be between 0 and 1")
if checks.is_int(min_length):
min_length = int(min_length)
if min_length < 1 or min_length > len(index):
raise TypeError(f"Minimum length must be within [{1}, {len(index)}]")
lengths = np.arange(1, len(index) + 1)
lengths = lengths[lengths >= min_length]
else:
min_length = dt.to_freq(min_length)
if freq is None:
raise ValueError("Must provide freq")
if min_length < freq or min_length > index[-1] + freq - index[0]:
raise TypeError(f"Minimum length must be within [{freq}, {index[-1] + freq - index[0]}]")
lengths = index[1:].append(index[[-1]] + freq) - index[0]
lengths = lengths[lengths >= min_length]
if n > len(lengths):
n = len(lengths)
rows = np.round(np.linspace(0, len(lengths) - 1, n)).astype(int)
lengths = lengths[rows]
splits = []
for length in lengths:
new_split = RelRange(length=length).to_slice(len(index), index=index, freq=freq)
if split is not None:
new_split = cls.split_range(
new_split,
split,
template_context=template_context,
index=index,
**split_range_kwargs,
)
splits.append(new_split)
return cls.from_splits(
index,
splits,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
@classmethod
def from_ranges(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
split: tp.Optional[tp.SplitLike] = None,
split_range_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from ranges.
Uses `vectorbtpro.base.indexing.get_index_ranges` to generate start and end indices.
Passes only related keyword arguments found in `kwargs`.
Other keyword arguments will be passed to `Splitter.from_splits`. For details on
`split` and `split_range_kwargs`, see `Splitter.from_rolling`.
Usage:
* Translate each quarter into a range:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", "2021", freq="D")
>>> splitter = vbt.Splitter.from_ranges(index, every="QS")
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
* In addition to the above, reserve the last month for testing purposes:
```pycon
>>> splitter = vbt.Splitter.from_ranges(
... index,
... every="QS",
... split=(1.0, lambda index: index.month == index.month[-1]),
... split_range_kwargs=dict(backwards=True)
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
index = dt.prepare_dt_index(index)
if split_range_kwargs is None:
split_range_kwargs = {}
func_arg_names = get_func_arg_names(get_index_ranges)
ranges_kwargs = dict()
for k in list(kwargs.keys()):
if k in func_arg_names:
ranges_kwargs[k] = kwargs.pop(k)
start_idxs, stop_idxs = get_index_ranges(index, skip_not_found=True, **ranges_kwargs)
splits = []
for start, stop in zip(start_idxs, stop_idxs):
new_split = slice(start, stop)
if split is not None:
try:
new_split = cls.split_range(
new_split,
split,
template_context=template_context,
index=index,
**split_range_kwargs,
)
except ZeroLengthError:
continue
splits.append(new_split)
return cls.from_splits(
index,
splits,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
@classmethod
def from_grouper(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
by: tp.AnyGroupByLike,
groupby_kwargs: tp.KwargsLike = None,
grouper_kwargs: tp.KwargsLike = None,
split: tp.Optional[tp.SplitLike] = None,
split_range_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
split_labels: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a grouper.
See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.
Uses `Splitter.from_splits` to prepare the splits array and labels, and to build the instance.
Usage:
* Map each month into a range:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", "2021", freq="D")
>>> def is_month_end(index, split):
... last_range = split[-1]
... return index[last_range][-1].is_month_end
>>> splitter = vbt.Splitter.from_grouper(
... index,
... "M",
... split_check_template=vbt.RepFunc(is_month_end)
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
index = dt.prepare_dt_index(index)
freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False)
if split_range_kwargs is None:
split_range_kwargs = {}
if "freq" not in split_range_kwargs:
split_range_kwargs = dict(split_range_kwargs)
split_range_kwargs["freq"] = freq
if grouper_kwargs is None:
grouper_kwargs = {}
if isinstance(by, CustomTemplate):
_template_context = merge_dicts(dict(index=index), template_context)
by = substitute_templates(by, _template_context, eval_id="by")
grouper = BaseIDXAccessor(index).get_grouper(by, groupby_kwargs=groupby_kwargs, **grouper_kwargs)
splits = []
indices = []
for i, new_split in enumerate(grouper.iter_group_idxs()):
if split is not None:
try:
new_split = cls.split_range(
new_split,
split,
template_context=template_context,
index=index,
**split_range_kwargs,
)
except ZeroLengthError:
continue
else:
new_split = [new_split]
splits.append(new_split)
indices.append(i)
if split_labels is None:
split_labels = grouper.get_index()[indices]
return cls.from_splits(
index,
splits,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
split_labels=split_labels,
**kwargs,
)
@classmethod
def from_n_random(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
n: int,
min_length: tp.Union[int, float, tp.TimedeltaLike],
max_length: tp.Union[None, int, float, tp.TimedeltaLike] = None,
min_start: tp.Union[None, int, float, tp.DatetimeLike] = None,
max_end: tp.Union[None, int, float, tp.DatetimeLike] = None,
length_choice_func: tp.Optional[tp.Callable] = None,
start_choice_func: tp.Optional[tp.Callable] = None,
length_p_func: tp.Optional[tp.Callable] = None,
start_p_func: tp.Optional[tp.Callable] = None,
seed: tp.Optional[int] = None,
split: tp.Optional[tp.SplitLike] = None,
split_range_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a number of random ranges.
Randomly picks the length of a range between `min_length` and `max_length` (including) using
`length_choice_func`, which receives an array of possible values and selects one. It defaults to
`numpy.random.Generator.choice`. Optional function `length_p_func` takes the same as
`length_choice_func` and must return either None or probabilities.
Randomly picks the start position of a range starting at `min_start` and ending at `max_end`
(excluding) minus the chosen length using `start_choice_func`, which receives an array of possible
values and selects one. It defaults to `numpy.random.Generator.choice`. Optional function
`start_p_func` takes the same as `start_choice_func` and must return either None or probabilities.
!!! note
Each function must take two arguments: the iteration index and the array with possible values.
For other arguments, see `Splitter.from_rolling`.
Usage:
* Generate 20 random ranges with a length from [40, 100], and split each into 3/4:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", "2021", freq="D")
>>> splitter = vbt.Splitter.from_n_random(
... index,
... 20,
... min_length=40,
... max_length=100,
... split=3/4,
... set_labels=["train", "test"]
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
index = dt.prepare_dt_index(index)
freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False)
if split_range_kwargs is None:
split_range_kwargs = {}
if "freq" not in split_range_kwargs:
split_range_kwargs = dict(split_range_kwargs)
split_range_kwargs["freq"] = freq
if min_start is None:
min_start = 0
if min_start is not None:
if checks.is_float(min_start):
if 0 <= abs(min_start) <= 1:
min_start = len(index) * min_start
elif not min_start.is_integer():
raise TypeError("Floating number for minimum start must be between 0 and 1")
if checks.is_float(min_start):
min_start = int(min_start)
if checks.is_int(min_start):
if min_start < 0 or min_start > len(index) - 1:
raise TypeError(f"Minimum start must be within [{0}, {len(index) - 1}]")
else:
if not isinstance(index, pd.DatetimeIndex):
raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}")
min_start = dt.try_align_dt_to_index(min_start, index)
if not isinstance(min_start, pd.Timestamp):
raise ValueError(f"Minimum start ({min_start}) could not be parsed")
if min_start < index[0] or min_start > index[-1]:
raise TypeError(f"Minimum start must be within [{index[0]}, {index[-1]}]")
min_start = index.get_indexer([min_start], method="bfill")[0]
if max_end is None:
max_end = len(index)
if checks.is_float(max_end):
if 0 <= abs(max_end) <= 1:
max_end = len(index) * max_end
elif not max_end.is_integer():
raise TypeError("Floating number for maximum end must be between 0 and 1")
if checks.is_float(max_end):
max_end = int(max_end)
if checks.is_int(max_end):
if max_end < 1 or max_end > len(index):
raise TypeError(f"Maximum end must be within [{1}, {len(index)}]")
else:
if not isinstance(index, pd.DatetimeIndex):
raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}")
max_end = dt.try_align_dt_to_index(max_end, index)
if not isinstance(max_end, pd.Timestamp):
raise ValueError(f"Maximum end ({max_end}) could not be parsed")
if freq is None:
raise ValueError("Must provide freq")
if max_end < index[0] + freq or max_end > index[-1] + freq:
raise TypeError(f"Maximum end must be within [{index[0] + freq}, {index[-1] + freq}]")
if max_end > index[-1]:
max_end = len(index)
else:
max_end = index.get_indexer([max_end], method="bfill")[0]
space_len = max_end - min_start
if not checks.is_number(min_length):
index_min_start = index[min_start]
if max_end < len(index):
index_max_end = index[max_end]
else:
if freq is None:
raise ValueError("Must provide freq")
index_max_end = index[-1] + freq
index_space_len = index_max_end - index_min_start
else:
index_min_start = None
index_max_end = None
index_space_len = None
if checks.is_float(min_length):
if 0 <= abs(min_length) <= 1:
min_length = space_len * min_length
elif not min_length.is_integer():
raise TypeError("Floating number for minimum length must be between 0 and 1")
min_length = int(min_length)
if checks.is_int(min_length):
if min_length < 1 or min_length > space_len:
raise TypeError(f"Minimum length must be within [{1}, {space_len}]")
else:
min_length = dt.to_freq(min_length)
if freq is None:
raise ValueError("Must provide freq")
if min_length < freq or min_length > index_space_len:
raise TypeError(f"Minimum length must be within [{freq}, {index_space_len}]")
if max_length is not None:
if checks.is_float(max_length):
if 0 <= abs(max_length) <= 1:
max_length = space_len * max_length
elif not max_length.is_integer():
raise TypeError("Floating number for maximum length must be between 0 and 1")
max_length = int(max_length)
if checks.is_int(max_length):
if max_length < min_length or max_length > space_len:
raise TypeError(f"Maximum length must be within [{min_length}, {space_len}]")
else:
max_length = dt.to_freq(max_length)
if freq is None:
raise ValueError("Must provide freq")
if max_length < min_length or max_length > index_space_len:
raise TypeError(f"Maximum length must be within [{min_length}, {index_space_len}]")
else:
max_length = min_length
rng = np.random.default_rng(seed=seed)
if length_p_func is None:
length_p_func = lambda i, x: None
if start_p_func is None:
start_p_func = lambda i, x: None
if length_choice_func is None:
length_choice_func = lambda i, x: rng.choice(x, p=length_p_func(i, x))
else:
if seed is not None:
np.random.seed(seed)
if start_choice_func is None:
start_choice_func = lambda i, x: rng.choice(x, p=start_p_func(i, x))
else:
if seed is not None:
np.random.seed(seed)
if checks.is_int(min_length):
length_space = np.arange(min_length, max_length + 1)
else:
if freq is None:
raise ValueError("Must provide freq")
length_space = np.arange(min_length // freq, max_length // freq + 1) * freq
index_space = np.arange(len(index))
splits = []
for i in range(n):
length = length_choice_func(i, length_space)
if checks.is_int(length):
start = start_choice_func(i, index_space[min_start : max_end - length + 1])
else:
from_dt = index_min_start.to_datetime64()
to_dt = index_max_end.to_datetime64() - length
start = start_choice_func(i, index_space[(index.values >= from_dt) & (index.values <= to_dt)])
new_split = RelRange(offset=start, length=length).to_slice(len(index), index=index, freq=freq)
if split is not None:
new_split = cls.split_range(
new_split,
split,
template_context=template_context,
index=index,
**split_range_kwargs,
)
splits.append(new_split)
return cls.from_splits(
index,
splits,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
@classmethod
def from_sklearn(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
skl_splitter: BaseCrossValidatorT,
groups: tp.Optional[tp.ArrayLike] = None,
split_labels: tp.Optional[tp.IndexLike] = None,
set_labels: tp.Optional[tp.IndexLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a scikit-learn's splitter.
The splitter must be an instance of `sklearn.model_selection.BaseCrossValidator`.
Uses `Splitter.from_splits` to prepare the splits array and labels, and to build the instance."""
from sklearn.model_selection import BaseCrossValidator
index = dt.prepare_dt_index(index)
checks.assert_instance_of(skl_splitter, BaseCrossValidator)
if set_labels is None:
set_labels = ["train", "test"]
indices_generator = skl_splitter.split(np.arange(len(index))[:, None], groups=groups)
return cls.from_splits(
index,
list(indices_generator),
split_labels=split_labels,
set_labels=set_labels,
**kwargs,
)
@classmethod
def from_purged(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
purged_splitter: BasePurgedCV,
pred_times: tp.Union[None, tp.Index, tp.Series] = None,
eval_times: tp.Union[None, tp.Index, tp.Series] = None,
split_labels: tp.Optional[tp.IndexLike] = None,
set_labels: tp.Optional[tp.IndexLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a purged splitter.
The splitter must be an instance of `vectorbtpro.generic.splitting.purged.BasePurgedCV`.
Uses `Splitter.from_splits` to prepare the splits array and labels, and to build the instance."""
index = dt.prepare_dt_index(index)
checks.assert_instance_of(purged_splitter, BasePurgedCV)
if set_labels is None:
set_labels = ["train", "test"]
indices_generator = purged_splitter.split(
pd.Series(np.arange(len(index)), index=index),
pred_times=pred_times,
eval_times=eval_times,
)
return cls.from_splits(
index,
list(indices_generator),
split_labels=split_labels,
set_labels=set_labels,
**kwargs,
)
@classmethod
def from_purged_walkforward(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
n_folds: int = 10,
n_test_folds: int = 1,
min_train_folds: int = 2,
max_train_folds: tp.Optional[int] = None,
split_by_time: bool = False,
purge_td: tp.TimedeltaLike = 0,
pred_times: tp.Union[None, tp.Index, tp.Series] = None,
eval_times: tp.Union[None, tp.Index, tp.Series] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from `vectorbtpro.generic.splitting.purged.PurgedWalkForwardCV`.
Keyword arguments are passed to `Splitter.from_purged`."""
index = dt.prepare_dt_index(index)
purged_splitter = PurgedWalkForwardCV(
n_folds=n_folds,
n_test_folds=n_test_folds,
min_train_folds=min_train_folds,
max_train_folds=max_train_folds,
split_by_time=split_by_time,
purge_td=purge_td,
)
return cls.from_purged(
index,
purged_splitter,
pred_times=pred_times,
eval_times=eval_times,
**kwargs,
)
@classmethod
def from_purged_kfold(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
n_folds: int = 10,
n_test_folds: int = 2,
purge_td: tp.TimedeltaLike = 0,
embargo_td: tp.TimedeltaLike = 0,
pred_times: tp.Union[None, tp.Index, tp.Series] = None,
eval_times: tp.Union[None, tp.Index, tp.Series] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from `vectorbtpro.generic.splitting.purged.PurgedKFoldCV`.
Keyword arguments are passed to `Splitter.from_purged`."""
index = dt.prepare_dt_index(index)
purged_splitter = PurgedKFoldCV(
n_folds=n_folds,
n_test_folds=n_test_folds,
purge_td=purge_td,
embargo_td=embargo_td,
)
return cls.from_purged(
index,
purged_splitter,
pred_times=pred_times,
eval_times=eval_times,
**kwargs,
)
@classmethod
def from_split_func(
cls: tp.Type[SplitterT],
index: tp.IndexLike,
split_func: tp.Callable,
split_args: tp.ArgsLike = None,
split_kwargs: tp.KwargsLike = None,
fix_ranges: bool = True,
split: tp.Optional[tp.SplitLike] = None,
split_range_kwargs: tp.KwargsLike = None,
range_bounds_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
**kwargs,
) -> SplitterT:
"""Create a `Splitter` instance from a custom split function.
In a while-loop, substitutes templates in `split_args` and `split_kwargs` and passes
them to `split_func`, which should return either a split (see `new_split` in `Splitter.split_range`,
also supports a single range if it's not an iterable) or None to abrupt the while-loop.
If `fix_ranges` is True, the returned split is then converted into a fixed split using
`Splitter.split_range` and the bounds of its sets are measured using `Splitter.get_range_bounds`.
Each template substitution has the following information:
* `split_idx`: Current split index, starting at 0
* `splits`: Nested list of splits appended up to this point
* `bounds`: Nested list of bounds appended up to this point
* `prev_start`: Left bound of the previous split
* `prev_end`: Right bound of the previous split
* Arguments and keyword arguments passed to `Splitter.from_split_func`
Usage:
* Rolling window of 30 elements, 20 for train and 10 for test:
```pycon
>>> from vectorbtpro import *
>>> index = pd.date_range("2020", "2021", freq="D")
>>> def split_func(splits, bounds, index):
... if len(splits) == 0:
... new_split = (slice(0, 20), slice(20, 30))
... else:
... # Previous split, first set, right bound
... prev_end = bounds[-1][0][1]
... new_split = (
... slice(prev_end, prev_end + 20),
... slice(prev_end + 20, prev_end + 30)
... )
... if new_split[-1].stop > len(index):
... return None
... return new_split
>>> splitter = vbt.Splitter.from_split_func(
... index,
... split_func,
... split_args=(
... vbt.Rep("splits"),
... vbt.Rep("bounds"),
... vbt.Rep("index"),
... ),
... set_labels=["train", "test"]
... )
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
index = dt.prepare_dt_index(index)
freq = BaseIDXAccessor(index, freq=freq).get_freq(allow_numeric=False)
if split_range_kwargs is None:
split_range_kwargs = {}
if "freq" not in split_range_kwargs:
split_range_kwargs = dict(split_range_kwargs)
split_range_kwargs["freq"] = freq
if range_bounds_kwargs is None:
range_bounds_kwargs = {}
if split_args is None:
split_args = ()
if split_kwargs is None:
split_kwargs = {}
splits = []
bounds = []
split_idx = 0
n_sets = None
while True:
_template_context = merge_dicts(
dict(
split_idx=split_idx,
splits=splits,
bounds=bounds,
prev_start=bounds[-1][0][0] if len(bounds) > 0 else None,
prev_end=bounds[-1][-1][1] if len(bounds) > 0 else None,
index=index,
freq=freq,
fix_ranges=fix_ranges,
split_args=split_args,
split_kwargs=split_kwargs,
split_range_kwargs=split_range_kwargs,
range_bounds_kwargs=range_bounds_kwargs,
**kwargs,
),
template_context,
)
_split_func = substitute_templates(split_func, _template_context, eval_id="split_func")
_split_args = substitute_templates(split_args, _template_context, eval_id="split_args")
_split_kwargs = substitute_templates(split_kwargs, _template_context, eval_id="split_kwargs")
new_split = _split_func(*_split_args, **_split_kwargs)
if new_split is None:
break
if not checks.is_iterable(new_split):
new_split = (new_split,)
if fix_ranges or split is not None:
new_split = cls.split_range(
slice(None),
new_split,
template_context=_template_context,
index=index,
**split_range_kwargs,
)
if split is not None:
if len(new_split) > 1:
raise ValueError("Split function must return only one range if split is already provided")
new_split = cls.split_range(
new_split[0],
split,
template_context=_template_context,
index=index,
**split_range_kwargs,
)
if n_sets is None:
n_sets = len(new_split)
elif n_sets != len(new_split):
raise ValueError("All splits must have the same number of sets")
splits.append(new_split)
if fix_ranges:
split_bounds = tuple(
map(
lambda x: cls.get_range_bounds(
x,
template_context=_template_context,
index=index,
**range_bounds_kwargs,
),
new_split,
)
)
bounds.append(split_bounds)
split_idx += 1
return cls.from_splits(
index,
splits,
fix_ranges=fix_ranges,
split_range_kwargs=split_range_kwargs,
template_context=template_context,
**kwargs,
)
@classmethod
def guess_method(cls, **kwargs) -> tp.Optional[str]:
"""Guess the factory method based on keyword arguments.
Returns None if cannot guess."""
if len(kwargs) == 0:
return None
keys = {"index"} | set(kwargs.keys())
from_splits_arg_names = set(get_func_arg_names(cls.from_splits))
from_splits_arg_names.remove("splits")
matched_methods = []
n_args = []
for k in cls.__dict__:
if k.startswith("from_") and inspect.ismethod(getattr(cls, k)):
req_func_arg_names = set(get_func_arg_names(getattr(cls, k), req_only=True))
if len(req_func_arg_names) > 0:
if not (req_func_arg_names <= keys):
continue
opt_func_arg_names = set(get_func_arg_names(getattr(cls, k), opt_only=True))
func_arg_names = from_splits_arg_names | req_func_arg_names | opt_func_arg_names
if k == "from_ranges":
func_arg_names |= set(get_func_arg_names(get_index_ranges))
if len(func_arg_names) > 0:
if not (keys <= func_arg_names):
continue
matched_methods.append(k)
n_args.append(len(req_func_arg_names) + len(opt_func_arg_names))
if len(matched_methods) > 1:
if "from_n_rolling" in matched_methods:
return "from_n_rolling"
return sorted(zip(matched_methods, n_args), key=lambda x: x[1])[0][0]
if len(matched_methods) == 1:
return matched_methods[0]
return None
@classmethod
def split_and_take(
cls,
index: tp.IndexLike,
obj: tp.Any,
splitter: tp.Union[None, str, SplitterT, tp.Callable] = None,
splitter_kwargs: tp.KwargsLike = None,
take_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
_splitter_kwargs: tp.KwargsLike = None,
_take_kwargs: tp.KwargsLike = None,
**var_kwargs,
) -> tp.Any:
"""Split an index and take from an object.
Argument `splitter` can be an actual `Splitter` instance, the name of a factory method
(such as "from_n_rolling"), or the factory method itself. If `splitter` is None,
the right method will be guessed based on the supplied arguments using `Splitter.guess_method`.
Keyword arguments `splitter_kwargs` are passed to the factory method. Keyword arguments
`take_kwargs` are passed to `Splitter.take`. If variable keyword arguments are provided, they
will be used as `take_kwargs` if a splitter instance has been built, otherwise, arguments will
be distributed based on the signatures of the factory method and `Splitter.take`."""
if splitter_kwargs is None:
splitter_kwargs = {}
else:
splitter_kwargs = dict(splitter_kwargs)
if take_kwargs is None:
take_kwargs = {}
else:
take_kwargs = dict(take_kwargs)
if _splitter_kwargs is None:
_splitter_kwargs = {}
if _take_kwargs is None:
_take_kwargs = {}
if len(var_kwargs) > 0:
var_splitter_kwargs = {}
var_take_kwargs = {}
if splitter is None or not isinstance(splitter, cls):
take_arg_names = get_func_arg_names(cls.take)
if splitter is not None:
if isinstance(splitter, str):
splitter_arg_names = get_func_arg_names(getattr(cls, splitter))
else:
splitter_arg_names = get_func_arg_names(splitter)
for k, v in var_kwargs.items():
assigned = False
if k in splitter_arg_names:
var_splitter_kwargs[k] = v
assigned = True
if k != "split" and k in take_arg_names:
var_take_kwargs[k] = v
assigned = True
if not assigned:
raise ValueError(f"Argument '{k}' couldn't be assigned")
else:
for k, v in var_kwargs.items():
if k == "freq":
var_splitter_kwargs[k] = v
var_take_kwargs[k] = v
elif k == "split" or k not in take_arg_names:
var_splitter_kwargs[k] = v
else:
var_take_kwargs[k] = v
else:
var_take_kwargs = var_kwargs
splitter_kwargs = merge_dicts(var_splitter_kwargs, splitter_kwargs)
take_kwargs = merge_dicts(var_take_kwargs, take_kwargs)
if len(splitter_kwargs) > 0:
if splitter is None:
splitter = cls.guess_method(**splitter_kwargs)
if splitter is None:
raise ValueError("Splitter method couldn't be guessed")
else:
if splitter is None:
raise ValueError("Must provide splitter or splitter method")
if not isinstance(splitter, cls):
if isinstance(splitter, str):
splitter = getattr(cls, splitter)
for k, v in _splitter_kwargs.items():
if k not in splitter_kwargs:
splitter_kwargs[k] = v
splitter = splitter(index, template_context=template_context, **splitter_kwargs)
for k, v in _take_kwargs.items():
if k not in take_kwargs:
take_kwargs[k] = v
return splitter.take(obj, template_context=template_context, **take_kwargs)
@classmethod
def split_and_apply(
cls,
index: tp.IndexLike,
apply_func: tp.Callable,
*apply_args,
splitter: tp.Union[None, str, SplitterT, tp.Callable] = None,
splitter_kwargs: tp.KwargsLike = None,
apply_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
_splitter_kwargs: tp.KwargsLike = None,
_apply_kwargs: tp.KwargsLike = None,
**var_kwargs,
) -> tp.Any:
"""Split an index and apply a function.
Argument `splitter` can be an actual `Splitter` instance, the name of a factory method
(such as "from_n_rolling"), or the factory method itself. If `splitter` is None,
the right method will be guessed based on the supplied arguments using `Splitter.guess_method`.
Keyword arguments `splitter_kwargs` are passed to the factory method. Keyword arguments
`apply_kwargs` are passed to `Splitter.apply`. If variable keyword arguments are provided, they
will be used as `apply_kwargs` if a splitter instance has been built, otherwise, arguments will
be distributed based on the signatures of the factory method and `Splitter.apply`."""
if splitter_kwargs is None:
splitter_kwargs = {}
else:
splitter_kwargs = dict(splitter_kwargs)
if apply_kwargs is None:
apply_kwargs = {}
else:
apply_kwargs = dict(apply_kwargs)
if _splitter_kwargs is None:
_splitter_kwargs = {}
if _apply_kwargs is None:
_apply_kwargs = {}
if len(var_kwargs) > 0:
var_splitter_kwargs = {}
var_apply_kwargs = {}
if splitter is None or not isinstance(splitter, cls):
apply_arg_names = get_func_arg_names(cls.apply)
if splitter is not None:
if isinstance(splitter, str):
splitter_arg_names = get_func_arg_names(getattr(cls, splitter))
else:
splitter_arg_names = get_func_arg_names(splitter)
for k, v in var_kwargs.items():
assigned = False
if k in splitter_arg_names:
var_splitter_kwargs[k] = v
assigned = True
if k != "split" and k in apply_arg_names:
var_apply_kwargs[k] = v
assigned = True
if not assigned:
raise ValueError(f"Argument '{k}' couldn't be assigned")
else:
for k, v in var_kwargs.items():
if k == "freq":
var_splitter_kwargs[k] = v
var_apply_kwargs[k] = v
elif k == "split" or k not in apply_arg_names:
var_splitter_kwargs[k] = v
else:
var_apply_kwargs[k] = v
else:
var_apply_kwargs = var_kwargs
splitter_kwargs = merge_dicts(var_splitter_kwargs, splitter_kwargs)
apply_kwargs = merge_dicts(var_apply_kwargs, apply_kwargs)
if len(splitter_kwargs) > 0:
if splitter is None:
splitter = cls.guess_method(**splitter_kwargs)
if splitter is None:
raise ValueError("Splitter method couldn't be guessed")
else:
if splitter is None:
raise ValueError("Must provide splitter or splitter method")
if not isinstance(splitter, cls):
if isinstance(splitter, str):
splitter = getattr(cls, splitter)
for k, v in _splitter_kwargs.items():
if k not in splitter_kwargs:
splitter_kwargs[k] = v
splitter = splitter(index, template_context=template_context, **splitter_kwargs)
for k, v in _apply_kwargs.items():
if k not in apply_kwargs:
apply_kwargs[k] = v
return splitter.apply(apply_func, *apply_args, template_context=template_context, **apply_kwargs)
@classmethod
def resolve_row_stack_kwargs(
cls: tp.Type[SplitterT],
*objs: tp.MaybeTuple[SplitterT],
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `Splitter` after stacking along rows."""
if "splits_arr" not in kwargs:
kwargs["splits_arr"] = kwargs["wrapper"].row_stack_arrs(
*[obj.splits for obj in objs],
group_by=False,
wrap=False,
)
return kwargs
@classmethod
def resolve_column_stack_kwargs(
cls: tp.Type[SplitterT],
*objs: tp.MaybeTuple[SplitterT],
reindex_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `Splitter` after stacking along columns."""
if "splits_arr" not in kwargs:
kwargs["splits_arr"] = kwargs["wrapper"].column_stack_arrs(
*[obj.splits for obj in objs],
reindex_kwargs=reindex_kwargs,
group_by=False,
wrap=False,
)
return kwargs
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[SplitterT],
*objs: tp.MaybeTuple[SplitterT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> SplitterT:
"""Stack multiple `Splitter` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Splitter):
raise TypeError("Each object to be merged must be an instance of Splitter")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.row_stack(
*[obj.wrapper for obj in objs],
stack_columns=False,
**wrapper_kwargs,
)
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[SplitterT],
*objs: tp.MaybeTuple[SplitterT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> SplitterT:
"""Stack multiple `Splitter` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Splitter):
raise TypeError("Each object to be merged must be an instance of Splitter")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in objs],
union_index=False,
**wrapper_kwargs,
)
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
def __init__(
self,
wrapper: ArrayWrapper,
index: tp.Index,
splits_arr: tp.SplitsArray,
**kwargs,
) -> None:
if wrapper.grouper.is_grouped():
raise ValueError("Splitter cannot be grouped")
index = dt.prepare_dt_index(index)
if splits_arr.shape[0] != wrapper.shape_2d[0]:
raise ValueError("Number of splits must match wrapper index")
if splits_arr.shape[1] != wrapper.shape_2d[1]:
raise ValueError("Number of sets must match wrapper columns")
Analyzable.__init__(
self,
wrapper,
index=index,
splits_arr=splits_arr,
**kwargs,
)
self._index = index
self._splits_arr = splits_arr
def indexing_func_meta(self, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict:
"""Perform indexing on `Splitter` and return metadata."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs)
if wrapper_meta["rows_changed"] or wrapper_meta["columns_changed"]:
new_splits_arr = ArrayWrapper.select_from_flex_array(
self.splits_arr,
row_idxs=wrapper_meta["row_idxs"],
col_idxs=wrapper_meta["col_idxs"],
rows_changed=wrapper_meta["rows_changed"],
columns_changed=wrapper_meta["columns_changed"],
)
else:
new_splits_arr = self.splits_arr
return dict(
wrapper_meta=wrapper_meta,
new_splits_arr=new_splits_arr,
)
def indexing_func(self: SplitterT, *args, splitter_meta: tp.DictLike = None, **kwargs) -> SplitterT:
"""Perform indexing on `Splitter`."""
if splitter_meta is None:
splitter_meta = self.indexing_func_meta(*args, **kwargs)
return self.replace(
wrapper=splitter_meta["wrapper_meta"]["new_wrapper"],
splits_arr=splitter_meta["new_splits_arr"],
)
@property
def index(self) -> tp.Index:
"""Index."""
return self._index
@property
def splits_arr(self) -> tp.SplitsArray:
"""Two-dimensional, object-dtype DataFrame with splits.
First axis represents splits. Second axis represents sets. Elements represent ranges.
Range must be either a slice, a sequence of indices, a mask, or a callable that returns such."""
return self._splits_arr
@property
def splits(self) -> tp.Frame:
"""`Splitter.splits_arr` as a DataFrame."""
return self.wrapper.wrap(self._splits_arr, group_by=False)
@property
def split_labels(self) -> tp.Index:
"""Split labels."""
return self.wrapper.index
@property
def set_labels(self) -> tp.Index:
"""Set labels."""
return self.wrapper.columns
@property
def n_splits(self) -> int:
"""Number of splits."""
return self.splits_arr.shape[0]
@property
def n_sets(self) -> int:
"""Number of sets."""
return self.splits_arr.shape[1]
def get_split_grouper(self, split_group_by: tp.AnyGroupByLike = None) -> tp.Optional[Grouper]:
"""Get split grouper."""
if split_group_by is None:
return None
if isinstance(split_group_by, Grouper):
return split_group_by
return BaseIDXAccessor(self.split_labels).get_grouper(split_group_by, def_lvl_name="split_group")
def get_set_grouper(self, set_group_by: tp.AnyGroupByLike = None) -> tp.Optional[Grouper]:
"""Get set grouper."""
if set_group_by is None:
return None
if isinstance(set_group_by, Grouper):
return set_group_by
return BaseIDXAccessor(self.set_labels).get_grouper(set_group_by, def_lvl_name="set_group")
def get_n_splits(self, split_group_by: tp.AnyGroupByLike = None) -> int:
"""Get number of splits while considering the grouper."""
if split_group_by is not None:
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
return split_group_by.get_group_count()
return self.n_splits
def get_n_sets(self, set_group_by: tp.AnyGroupByLike = None) -> int:
"""Get number of sets while considering the grouper."""
if set_group_by is not None:
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
return set_group_by.get_group_count()
return self.n_sets
def get_split_labels(self, split_group_by: tp.AnyGroupByLike = None) -> tp.Index:
"""Get split labels while considering the grouper."""
if split_group_by is not None:
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
return split_group_by.get_index()
return self.split_labels
def get_set_labels(self, set_group_by: tp.AnyGroupByLike = None) -> tp.Index:
"""Get set labels while considering the grouper."""
if set_group_by is not None:
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
return set_group_by.get_index()
return self.set_labels
# ############# Conversion ############# #
def to_fixed(self: SplitterT, split_range_kwargs: tp.KwargsLike = None, **kwargs) -> SplitterT:
"""Convert relative ranges into fixed ones and return a new `Splitter` instance.
Keyword arguments `split_range_kwargs` are passed to `Splitter.split_range`."""
if split_range_kwargs is None:
split_range_kwargs = {}
split_range_kwargs = dict(split_range_kwargs)
wrap_with_fixrange = split_range_kwargs.pop("wrap_with_fixrange", None)
if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange:
raise ValueError("Argument wrap_with_fixrange must be True or None")
split_range_kwargs["wrap_with_fixrange"] = wrap_with_fixrange
new_splits_arr = []
for split in self.splits_arr:
new_split = self.split_range(slice(None), split, **split_range_kwargs)
new_splits_arr.append(new_split)
new_splits_arr = np.asarray(new_splits_arr, dtype=object)
return self.replace(splits_arr=new_splits_arr, **kwargs)
def to_grouped(
self: SplitterT,
split: tp.Optional[tp.Selection] = None,
set_: tp.Optional[tp.Selection] = None,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
merge_split_kwargs: tp.KwargsLike = None,
**kwargs,
) -> SplitterT:
"""Merge all ranges within the same group and return a new `Splitter` instance."""
if merge_split_kwargs is None:
merge_split_kwargs = {}
merge_split_kwargs = dict(merge_split_kwargs)
wrap_with_fixrange = merge_split_kwargs.pop("wrap_with_fixrange", None)
if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange:
raise ValueError("Argument wrap_with_fixrange must be True or None")
merge_split_kwargs["wrap_with_fixrange"] = wrap_with_fixrange
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
split_group_indices, set_group_indices, split_indices, set_indices = self.select_indices(
split=split,
set_=set_,
split_group_by=split_group_by,
set_group_by=set_group_by,
)
if split is not None:
split_labels = split_labels[split_group_indices]
if set_ is not None:
set_labels = set_labels[set_group_indices]
new_splits_arr = []
for i in split_group_indices:
new_splits_arr.append([])
for j in set_group_indices:
new_range = self.select_range(
split=PosSel(i),
set_=PosSel(j),
split_group_by=split_group_by,
set_group_by=set_group_by,
merge_split_kwargs=merge_split_kwargs,
)
new_splits_arr[-1].append(new_range)
new_splits_arr = np.asarray(new_splits_arr, dtype=object)
if set_group_by is None or not set_group_by.is_grouped():
ndim = self.wrapper.ndim
else:
ndim = 1 if new_splits_arr.shape[1] == 1 else 2
wrapper = self.wrapper.replace(index=split_labels, columns=set_labels, ndim=ndim)
return self.replace(wrapper=wrapper, splits_arr=new_splits_arr, **kwargs)
# ############# Ranges ############# #
@classmethod
def is_range_relative(cls, range_: tp.RangeLike) -> bool:
"""Return whether a range is relative."""
return checks.is_number(range_) or checks.is_td_like(range_) or isinstance(range_, RelRange)
@hybrid_method
def get_ready_range(
cls_or_self,
range_: tp.FixRangeLike,
allow_relative: bool = False,
allow_zero_len: bool = False,
range_format: str = "slice_or_any",
template_context: tp.KwargsLike = None,
index: tp.Optional[tp.IndexLike] = None,
return_meta: bool = False,
) -> tp.Union[tp.RelRangeLike, tp.ReadyRangeLike, dict]:
"""Get a range that can be directly used in array indexing.
Such a range is either an integer or datetime-like slice (right bound is always exclusive!),
a one-dimensional NumPy array with integer indices or datetime-like objects,
or a one-dimensional NumPy mask of the same length as the index.
Argument `range_format` accepts the following options:
* 'any': Return any format
* 'indices': Return indices
* 'mask': Return mask of the same length as index
* 'slice': Return slice
* 'slice_or_indices': If slice fails, return indices
* 'slice_or_mask': If slice fails, return mask
* 'slice_or_any': If slice fails, return any format
"""
if index is None:
if isinstance(cls_or_self, type):
raise ValueError("Must provide index")
index = cls_or_self.index
else:
index = dt.prepare_dt_index(index)
if range_format.lower() not in (
"any",
"indices",
"mask",
"slice",
"slice_or_indices",
"slice_or_mask",
"slice_or_any",
):
raise ValueError(f"Invalid range_format: '{range_format}'")
meta = dict()
meta["was_fixed"] = False
meta["was_template"] = False
meta["was_callable"] = False
meta["was_relative"] = False
meta["was_hslice"] = False
meta["was_slice"] = False
meta["was_neg_slice"] = False
meta["was_datetime"] = False
meta["was_mask"] = False
meta["was_indices"] = False
meta["is_constant"] = False
meta["start"] = None
meta["stop"] = None
meta["length"] = None
if isinstance(range_, FixRange):
meta["was_fixed"] = True
range_ = range_.range_
if isinstance(range_, CustomTemplate):
meta["was_template"] = True
if template_context is None:
template_context = {}
if "index" not in template_context:
template_context["index"] = index
range_ = range_.substitute(context=template_context, eval_id="range")
if callable(range_):
meta["was_callable"] = True
range_ = range_(index)
if cls_or_self.is_range_relative(range_):
meta["was_relative"] = True
if allow_relative:
if return_meta:
meta["range_"] = range_
return meta
return range_
raise TypeError("Relative ranges must be converted to fixed")
if isinstance(range_, hslice):
meta["was_hslice"] = True
range_ = range_.to_slice()
if isinstance(range_, slice):
meta["was_slice"] = True
meta["is_constant"] = True
start = range_.start
stop = range_.stop
if range_.step is not None and range_.step > 1:
raise ValueError("Step must be either None or 1")
if start is not None and checks.is_int(start) and start < 0:
if stop is not None and checks.is_int(stop) and stop > 0:
raise ValueError("Slices must be either strictly negative or positive")
meta["was_neg_slice"] = True
start = len(index) + start
if stop is not None and checks.is_int(stop):
stop = len(index) + stop
if start is None:
start = 0
if stop is None:
stop = len(index)
if not checks.is_int(start):
if not isinstance(index, pd.DatetimeIndex):
raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}")
start = dt.try_align_dt_to_index(start, index)
if not isinstance(start, pd.Timestamp):
raise ValueError(f"Range start ({start}) could not be parsed")
meta["was_datetime"] = True
if not checks.is_int(stop):
if not isinstance(index, pd.DatetimeIndex):
raise TypeError(f"Index must be of type pandas.DatetimeIndex, not {index.dtype}")
stop = dt.try_align_dt_to_index(stop, index)
if not isinstance(stop, pd.Timestamp):
raise ValueError(f"Range start ({stop}) could not be parsed")
meta["was_datetime"] = True
if checks.is_int(start):
if start < 0:
start = 0
else:
if start < index[0]:
start = 0
else:
start = index.get_indexer([start], method="bfill")[0]
if start == -1:
raise ValueError(f"Range start ({start}) is out of bounds")
if checks.is_int(stop):
if stop > len(index):
stop = len(index)
else:
if stop > index[-1]:
stop = len(index)
else:
stop = index.get_indexer([stop], method="bfill")[0]
if stop == -1:
raise ValueError(f"Range stop ({stop}) is out of bounds")
range_ = slice(start, stop)
meta["start"] = start
meta["stop"] = stop
meta["length"] = stop - start
if not allow_zero_len and meta["length"] == 0:
raise ZeroLengthError("Range has zero length")
if range_format.lower() == "indices":
range_ = np.arange(*range_.indices(len(index)))
elif range_format.lower() == "mask":
mask = np.full(len(index), False)
mask[range_] = True
range_ = mask
else:
range_ = np.asarray(range_)
if np.issubdtype(range_.dtype, np.bool_):
if len(range_) != len(index):
raise ValueError("Mask must have the same length as index")
meta["was_mask"] = True
indices = np.flatnonzero(range_)
if len(indices) == 0:
if not allow_zero_len:
raise ZeroLengthError("Range has zero length")
meta["is_constant"] = True
meta["start"] = 0
meta["stop"] = 0
meta["length"] = 0
else:
meta["is_constant"] = is_range(indices)
meta["start"] = indices[0]
meta["stop"] = indices[-1] + 1
meta["length"] = len(indices)
if range_format.lower() == "indices":
range_ = indices
elif range_format.lower().startswith("slice"):
if not meta["is_constant"]:
if range_format.lower() == "slice":
raise ValueError("Cannot convert to slice: range is not constant")
if range_format.lower() == "slice_or_indices":
range_ = indices
else:
range_ = slice(meta["start"], meta["stop"])
else:
if not np.issubdtype(range_.dtype, np.integer):
range_ = dt.try_align_to_dt_index(range_, index)
if not isinstance(range_, pd.DatetimeIndex):
raise ValueError("Range array could not be parsed")
range_ = index.get_indexer(range_, method=None)
if -1 in range_:
raise ValueError(f"Range array has values that cannot be found in index")
if np.issubdtype(range_.dtype, np.integer):
meta["was_indices"] = True
if len(range_) == 0:
if not allow_zero_len:
raise ZeroLengthError("Range has zero length")
meta["is_constant"] = True
meta["start"] = 0
meta["stop"] = 0
meta["length"] = 0
else:
meta["is_constant"] = is_range(range_)
if meta["is_constant"]:
meta["start"] = range_[0]
meta["stop"] = range_[-1] + 1
else:
meta["start"] = np.min(range_)
meta["stop"] = np.max(range_) + 1
meta["length"] = len(range_)
if range_format.lower() == "mask":
mask = np.full(len(index), False)
mask[range_] = True
range_ = mask
elif range_format.lower().startswith("slice"):
if not meta["is_constant"]:
if range_format.lower() == "slice":
raise ValueError("Cannot convert to slice: range is not constant")
if range_format.lower() == "slice_or_mask":
mask = np.full(len(index), False)
mask[range_] = True
range_ = mask
else:
range_ = slice(meta["start"], meta["stop"])
else:
raise TypeError(f"Range array has invalid data type ({range_.dtype})")
if meta["start"] != meta["stop"]:
if meta["start"] > meta["stop"]:
raise ValueError(f"Range start ({meta['start']}) is higher than range stop ({meta['stop']})")
if meta["start"] < 0 or meta["start"] >= len(index):
raise ValueError(f"Range start ({meta['start']}) is out of bounds")
if meta["stop"] < 0 or meta["stop"] > len(index):
raise ValueError(f"Range stop ({meta['stop']}) is out of bounds")
if return_meta:
meta["range_"] = range_
return meta
return range_
@hybrid_method
def split_range(
cls_or_self,
range_: tp.FixRangeLike,
new_split: tp.SplitLike,
backwards: bool = False,
allow_zero_len: bool = False,
range_format: tp.Optional[str] = None,
wrap_with_template: bool = False,
wrap_with_fixrange: tp.Optional[bool] = False,
template_context: tp.KwargsLike = None,
index: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.FixSplit:
"""Split a fixed range into a split of multiple fixed ranges.
Range must be either a template, a callable, a tuple (start and stop), a slice, a sequence
of indices, or a mask. This range will then be re-mapped into the index.
Each sub-range in `new_split` can be either a fixed or relative range, that is, an instance
of `RelRange` or a number that will be used as a length to create an `RelRange`.
Each sub-range will then be re-mapped into the main range. Argument `new_split` can also
be provided as an integer or a float indicating the length; in such a case the second part
(or the first one depending on `backwards`) will stretch. If `new_split` is a string,
the following options are supported:
* 'by_gap': Split `range_` by gap using `vectorbtpro.generic.splitting.nb.split_range_by_gap_nb`.
New ranges are returned relative to the index and in the same order as passed.
For `range_format`, see `Splitter.get_ready_range`. Enable `wrap_with_template` to wrap the
resulting ranges with a template of the type `vectorbtpro.utils.template.Rep`."""
if index is None:
if isinstance(cls_or_self, type):
raise ValueError("Must provide index")
index = cls_or_self.index
else:
index = dt.prepare_dt_index(index)
# Prepare source range
range_meta = cls_or_self.get_ready_range(
range_,
allow_zero_len=allow_zero_len,
range_format="slice_or_indices",
template_context=template_context,
index=index,
return_meta=True,
)
range_ = range_meta["range_"]
range_was_hslice = range_meta["was_hslice"]
range_was_indices = range_meta["was_indices"]
range_was_mask = range_meta["was_mask"]
range_length = range_meta["length"]
if range_format is None:
if range_was_indices:
range_format = "slice_or_indices"
elif range_was_mask:
range_format = "slice_or_mask"
else:
range_format = "slice_or_any"
# Substitute template
if isinstance(new_split, CustomTemplate):
_template_context = merge_dicts(dict(index=index[range_]), template_context)
new_split = substitute_templates(new_split, _template_context, eval_id="new_split")
# Split by gap
if isinstance(new_split, str) and new_split.lower() == "by_gap":
if isinstance(range_, np.ndarray) and np.issubdtype(range_.dtype, np.integer):
range_arr = range_
else:
range_arr = np.arange(len(index))[range_]
start_idxs, stop_idxs = nb.split_range_by_gap_nb(range_arr)
new_split = list(map(lambda x: slice(x[0], x[1]), zip(start_idxs, stop_idxs)))
# Prepare target ranges
if checks.is_number(new_split):
if new_split < 0:
backwards = not backwards
new_split = abs(new_split)
if not backwards:
new_split = (new_split, 1.0)
else:
new_split = (1.0, new_split)
elif checks.is_td_like(new_split):
new_split = dt.to_freq(new_split)
if new_split < pd.Timedelta(0):
backwards = not backwards
new_split = abs(new_split)
if not backwards:
new_split = (new_split, 1.0)
else:
new_split = (1.0, new_split)
elif not checks.is_iterable(new_split):
new_split = (new_split,)
# Perform split
new_ranges = []
if backwards:
new_split = new_split[::-1]
prev_start = range_length
prev_end = range_length
else:
prev_start = 0
prev_end = 0
for new_range in new_split:
# Resolve new range
new_range_meta = cls_or_self.get_ready_range(
new_range,
allow_relative=True,
allow_zero_len=allow_zero_len,
range_format="slice_or_any",
template_context=template_context,
index=index[range_],
return_meta=True,
)
new_range = new_range_meta["range_"]
if checks.is_number(new_range) or checks.is_td_like(new_range):
new_range = RelRange(length=new_range)
if isinstance(new_range, RelRange):
new_range_is_gap = new_range.is_gap
new_range = new_range.to_slice(
range_length,
prev_start=range_length - prev_end if backwards else prev_start,
prev_end=range_length - prev_start if backwards else prev_end,
index=index,
freq=freq,
)
if backwards:
new_range = slice(range_length - new_range.stop, range_length - new_range.start)
else:
new_range_is_gap = False
# Update previous bounds
if isinstance(new_range, slice):
prev_start = new_range.start
prev_end = new_range.stop
else:
prev_start = new_range_meta["start"]
prev_end = new_range_meta["stop"]
# Remap new range to index
if new_range_is_gap:
continue
if isinstance(range_, slice) and isinstance(new_range, slice):
new_range = slice(
range_.start + new_range.start,
range_.start + new_range.stop,
)
else:
if isinstance(range_, slice):
new_range = np.arange(range_.start, range_.stop)[new_range]
else:
new_range = range_[new_range]
new_range = cls_or_self.get_ready_range(
new_range,
allow_zero_len=allow_zero_len,
range_format=range_format,
index=index,
)
if isinstance(new_range, slice) and range_was_hslice:
new_range = hslice.from_slice(new_range)
if wrap_with_template:
new_range = Rep("range_", context=dict(range_=new_range))
if wrap_with_fixrange is None:
_wrap_with_fixrange = checks.is_sequence(new_range)
else:
_wrap_with_fixrange = False
if _wrap_with_fixrange:
new_range = FixRange(new_range)
new_ranges.append(new_range)
if backwards:
return tuple(new_ranges)[::-1]
return tuple(new_ranges)
@hybrid_method
def merge_split(
cls_or_self,
split: tp.FixSplit,
range_format: tp.Optional[str] = None,
wrap_with_template: bool = False,
wrap_with_fixrange: tp.Optional[bool] = False,
wrap_with_hslice: tp.Optional[bool] = False,
template_context: tp.KwargsLike = None,
index: tp.Optional[tp.IndexLike] = None,
) -> tp.FixRangeLike:
"""Merge a split of multiple fixed ranges into a fixed range.
Creates one mask and sets True for each range. If all input ranges are masks,
returns that mask. If all input ranges are slices, returns a slice if possible.
Otherwise, returns integer indices.
For `range_format`, see `Splitter.get_ready_range`. Enable `wrap_with_template` to wrap the
resulting range with a template of the type `vectorbtpro.utils.template.Rep`."""
if index is None:
if isinstance(cls_or_self, type):
raise ValueError("Must provide index")
index = cls_or_self.index
else:
index = dt.prepare_dt_index(index)
all_hslices = True
all_masks = True
new_ranges = []
if len(split) == 1:
raise ValueError("Two or more ranges are required to be merged")
for range_ in split:
range_meta = cls_or_self.get_ready_range(
range_,
allow_zero_len=True,
range_format="any",
template_context=template_context,
index=index,
return_meta=True,
)
if not range_meta["was_hslice"]:
all_hslices = False
if not range_meta["was_mask"]:
all_masks = False
new_ranges.append(range_meta["range_"])
ranges = new_ranges
if range_format is None:
if all_masks:
range_format = "slice_or_mask"
else:
range_format = "slice_or_indices"
new_range = np.full(len(index), False)
for range_ in ranges:
new_range[range_] = True
new_range = cls_or_self.get_ready_range(
new_range,
range_format=range_format,
index=index,
)
if isinstance(new_range, slice) and all_hslices:
if wrap_with_hslice is None:
wrap_with_hslice = True
if wrap_with_hslice:
new_range = hslice.from_slice(new_range)
if wrap_with_template:
new_range = Rep("range_", context=dict(range_=new_range))
if wrap_with_fixrange is None:
_wrap_with_fixrange = checks.is_sequence(new_range)
else:
_wrap_with_fixrange = False
if _wrap_with_fixrange:
new_range = FixRange(new_range)
return new_range
# ############# Taking ############# #
def select_indices(
self,
split: tp.Optional[tp.Selection] = None,
set_: tp.Optional[tp.Selection] = None,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]:
"""Get indices corresponding to selected splits and sets.
Arguments `split` and `set_` can be either integers and labels. Also, multiple
values are accepted; in such a case, the corresponding ranges are merged.
If split/set labels are of an integer data type, treats the provided values as labels
rather than indices, unless the split/set index is not of an integer data type or the values
are wrapped with `vectorbtpro.utils.selection.PosSel`.
If `split_group_by` and/or `set_group_by` are provided, their groupers get
created using `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper` and
arguments `split` and `set_` become relative to the groups.
If `split`/`set_` is not provided, selects all indices.
Returns four arrays: split group indices, set group indices, split indices, and set indices."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
if split is None:
split_group_indices = np.arange(self.get_n_splits(split_group_by=split_group_by))
split_indices = np.arange(self.n_splits)
else:
kind = None
if isinstance(split, PosSel):
split = split.value
kind = "positions"
elif isinstance(split, LabelSel):
split = split.value
kind = "labels"
if checks.is_hashable(split):
split = [split]
if split_group_by is not None:
split_group_indices = []
groups, group_index = split_group_by.get_groups_and_index()
mask = None
for g in split:
if isinstance(g, PosSel):
g = g.value
kind = "positions"
elif isinstance(g, LabelSel):
g = g.value
kind = "labels"
if kind == "positions" or (
kind is None and checks.is_int(g) and not pd.api.types.is_integer_dtype(group_index)
):
i = g
else:
i = group_index.get_indexer([g])[0]
if i == -1:
raise ValueError(f"Split group '{g}' not found")
if mask is None:
mask = groups == i
else:
mask |= groups == i
split_group_indices.append(i)
split_group_indices = np.asarray(split_group_indices)
split_indices = np.arange(self.n_splits)[mask]
else:
split_indices = []
for s in split:
if isinstance(s, PosSel):
s = s.value
kind = "positions"
elif isinstance(s, LabelSel):
s = s.value
kind = "labels"
if kind == "positions" or (
kind is None and checks.is_int(s) and not pd.api.types.is_integer_dtype(self.split_labels)
):
i = s
else:
i = self.split_labels.get_indexer([s])[0]
if i == -1:
raise ValueError(f"Split '{s}' not found")
split_indices.append(i)
split_group_indices = split_indices = np.asarray(split_indices)
if set_ is None:
set_group_indices = np.arange(self.get_n_sets(set_group_by=set_group_by))
set_indices = np.arange(self.n_sets)
else:
kind = None
if isinstance(set_, PosSel):
set_ = set_.value
kind = "positions"
elif isinstance(set_, LabelSel):
set_ = set_.value
kind = "labels"
if checks.is_hashable(set_):
set_ = [set_]
if set_group_by is not None:
set_group_indices = []
groups, group_index = set_group_by.get_groups_and_index()
mask = None
for g in set_:
if isinstance(g, PosSel):
g = g.value
kind = "positions"
elif isinstance(g, LabelSel):
g = g.value
kind = "labels"
if kind == "positions" or (
kind is None and checks.is_int(g) and not pd.api.types.is_integer_dtype(group_index)
):
i = g
else:
i = group_index.get_indexer([g])[0]
if i == -1:
raise ValueError(f"Set group '{g}' not found")
if mask is None:
mask = groups == i
else:
mask |= groups == i
set_group_indices.append(i)
set_group_indices = np.asarray(set_group_indices)
set_indices = np.arange(self.n_sets)[mask]
else:
set_indices = []
for s in set_:
if isinstance(s, PosSel):
s = s.value
kind = "positions"
elif isinstance(s, LabelSel):
s = s.value
kind = "labels"
if kind == "positions" or (
kind is None and checks.is_int(s) and not pd.api.types.is_integer_dtype(self.set_labels)
):
i = s
else:
i = self.set_labels.get_indexer([s])[0]
if i == -1:
raise ValueError(f"Set '{s}' not found")
set_indices.append(i)
set_group_indices = set_indices = np.asarray(set_indices)
return split_group_indices, set_group_indices, split_indices, set_indices
def select_range(self, merge_split_kwargs: tp.KwargsLike = None, **select_indices_kwargs) -> tp.RangeLike:
"""Select a range.
Passes `**select_indices_kwargs` to `Splitter.select_indices` to get the indices for selected
splits and sets. If multiple ranges correspond to those indices, merges them using
`Splitter.merge_split`."""
_, _, split_indices, set_indices = self.select_indices(**select_indices_kwargs)
ranges = []
for i in split_indices:
for j in set_indices:
ranges.append(self.splits_arr[i, j])
if len(ranges) == 1:
return ranges[0]
if merge_split_kwargs is None:
merge_split_kwargs = {}
return self.merge_split(ranges, **merge_split_kwargs)
@hybrid_method
def remap_range(
cls_or_self,
range_: tp.FixRangeLike,
target_index: tp.IndexLike,
target_freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
silence_warnings: bool = False,
index: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.FixRangeLike:
"""Remap a range to a target index.
If `index` and `target_index` are the same, returns the range. Otherwise,
uses `vectorbtpro.base.resampling.base.Resampler.resample_source_mask` to resample
the range into the target index. In such a case, `freq` and `target_freq` must be provided."""
if index is None:
if isinstance(cls_or_self, type):
raise ValueError("Must provide index")
index = cls_or_self.index
else:
index = dt.prepare_dt_index(index)
if target_index is None:
raise ValueError("Must provide target index")
target_index = dt.prepare_dt_index(target_index)
if index.equals(target_index):
return range_
mask = cls_or_self.get_range_mask(range_, template_context=template_context, index=index)
resampler = Resampler(
source_index=index,
target_index=target_index,
source_freq=freq,
target_freq=target_freq,
)
target_mask = resampler.resample_source_mask(mask, jitted=jitted, silence_warnings=silence_warnings)
return target_mask
@classmethod
def get_obj_index(cls, obj: tp.Any) -> tp.Index:
"""Get index from an object."""
if isinstance(obj, pd.Index):
return obj
if hasattr(obj, "index"):
return obj.index
if hasattr(obj, "wrapper"):
return obj.wrapper.index
raise ValueError("Must provide object index")
@hybrid_method
def get_ready_obj_range(
cls_or_self,
obj: tp.Any,
range_: tp.FixRangeLike,
remap_to_obj: bool = True,
obj_index: tp.Optional[tp.IndexLike] = None,
obj_freq: tp.Optional[tp.FrequencyLike] = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
silence_warnings: bool = False,
index: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
return_obj_meta: bool = False,
**ready_range_kwargs,
) -> tp.Any:
"""Get a range that is ready to be mapped into an array-like object.
If the object is Pandas-like and `obj_index` is not None, searches for an index in the object
using `Splitter.get_obj_index`. Once found, uses `Splitter.remap_range` to get the range
that maps to the object index. Finally, uses `Splitter.get_ready_range` to convert the range
into the one that can be used directly in indexing."""
if index is None:
if isinstance(cls_or_self, type):
raise ValueError("Must provide index")
index = cls_or_self.index
else:
index = dt.prepare_dt_index(index)
if remap_to_obj and (
isinstance(obj, (pd.Index, pd.Series, pd.DataFrame, PandasIndexer)) or obj_index is not None
):
if obj_index is None:
obj_index = cls_or_self.get_obj_index(obj)
target_range = cls_or_self.remap_range(
range_,
target_index=obj_index,
target_freq=obj_freq,
template_context=template_context,
jitted=jitted,
silence_warnings=silence_warnings,
index=index,
freq=freq,
)
else:
obj_index = index
obj_freq = freq
target_range = range_
ready_range_or_meta = cls_or_self.get_ready_range(
target_range,
template_context=template_context,
index=obj_index,
**ready_range_kwargs,
)
if return_obj_meta:
obj_meta = dict(index=obj_index, freq=obj_freq)
return obj_meta, ready_range_or_meta
return ready_range_or_meta
@classmethod
def take_range(cls, obj: tp.Any, ready_range: tp.ReadyRangeLike, point_wise: bool = False) -> tp.Any:
"""Take a ready range from an array-like object.
Set `point_wise` to True to select one range point at a time and return a tuple."""
if isinstance(obj, (pd.Series, pd.DataFrame, PandasIndexer)):
if point_wise:
return tuple(obj.iloc[i] for i in np.arange(len(obj))[ready_range])
return obj.iloc[ready_range]
if point_wise:
return tuple(obj[i] for i in np.arange(len(obj))[ready_range])
return obj[ready_range]
@hybrid_method
def take_range_from_takeable(
cls_or_self,
takeable: Takeable,
range_: tp.FixRangeLike,
remap_to_obj: bool = True,
obj_index: tp.Optional[tp.IndexLike] = None,
obj_freq: tp.Optional[tp.FrequencyLike] = None,
point_wise: bool = False,
template_context: tp.KwargsLike = None,
return_obj_meta: bool = False,
return_meta: bool = False,
**ready_obj_range_kwargs,
) -> tp.Any:
"""Take a range from a takeable object."""
takeable.assert_field_not_missing("obj")
obj_meta, obj_range_meta = cls_or_self.get_ready_obj_range(
takeable.obj,
range_,
remap_to_obj=takeable.remap_to_obj if takeable.remap_to_obj is not MISSING else remap_to_obj,
obj_index=takeable.index if takeable.index is not MISSING else obj_index,
obj_freq=takeable.freq if takeable.freq is not MISSING else obj_freq,
template_context=template_context,
return_obj_meta=True,
return_meta=True,
**ready_obj_range_kwargs,
)
if isinstance(takeable.obj, CustomTemplate):
template_context = merge_dicts(
dict(
range_=obj_range_meta["range_"],
range_meta=obj_range_meta,
point_wise=takeable.point_wise if takeable.point_wise is not MISSING else point_wise,
),
template_context,
)
obj_slice = substitute_templates(takeable.obj, template_context, eval_id="take_range")
else:
obj_slice = cls_or_self.take_range(
takeable.obj,
obj_range_meta["range_"],
point_wise=takeable.point_wise if takeable.point_wise is not MISSING else point_wise,
)
if return_obj_meta and return_meta:
return obj_meta, obj_range_meta, obj_slice
if return_obj_meta:
return obj_meta, obj_slice
if return_meta:
return obj_range_meta, obj_slice
return obj_slice
def take(
self,
obj: tp.Any,
split: tp.Optional[tp.Selection] = None,
set_: tp.Optional[tp.Selection] = None,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
squeeze_one_split: bool = True,
squeeze_one_set: bool = True,
into: tp.Optional[str] = None,
remap_to_obj: bool = True,
obj_index: tp.Optional[tp.IndexLike] = None,
obj_freq: tp.Optional[tp.FrequencyLike] = None,
range_format: str = "slice_or_any",
point_wise: bool = False,
attach_bounds: tp.Union[bool, str] = False,
right_inclusive: bool = False,
template_context: tp.KwargsLike = None,
silence_warnings: bool = False,
index_combine_kwargs: tp.KwargsLike = None,
stack_axis: int = 1,
stack_kwargs: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.Any:
"""Take all ranges from an array-like object and optionally column-stack them.
Uses `Splitter.select_indices` to get the indices for selected splits and sets.
Arguments `split_group_by` and `set_group_by` can be used to group splits and sets respectively.
Ranges belonging to the same split and set group will be merged.
For each index pair, resolves the source range using `Splitter.select_range` and
`Splitter.get_ready_range`. Then, remaps this range into the object index using
`Splitter.get_ready_obj_range` and takes the slice from the object using `Splitter.take_range`.
If the object is a custom template, substitutes its instead of calling `Splitter.take_range`.
Finally, uses `vectorbtpro.base.merging.column_stack_merge` (`stack_axis=1`) or
`vectorbtpro.base.merging.row_stack_merge` (`stack_axis=0`) with `stack_kwargs` to merge the taken slices.
If `attach_bounds` is enabled, measures the bounds of each range and makes it an additional
level in the final index hierarchy. The argument supports the following options:
* True, 'index', 'source', or 'source_index': Attach source (index) bounds
* 'target' or 'target_index': Attach target (index) bounds
* False: Do not attach
Argument `into` supports the following options:
* None: Series of range slices
* 'stacked': Stack all slices into a single object
* 'stacked_by_split': Stack set slices in each split and return a Series of objects
* 'stacked_by_set': Stack split slices in each set and return a Series of objects
* 'split_major_meta': Generator with ranges processed lazily in split-major order.
Returns meta with indices and labels, and the generator.
* 'set_major_meta': Generator with ranges processed lazily in set-major order.
Returns meta with indices and labels, and the generator.
Prepend any stacked option with "from_start_" (also "reset_") or "from_end_" to reset the index
from start and from end respectively.
Usage:
* Roll a window and stack it along columns by keeping the index:
```pycon
>>> from vectorbtpro import *
>>> data = vbt.YFData.pull(
... "BTC-USD",
... start="2020-01-01 UTC",
... end="2021-01-01 UTC"
... )
>>> splitter = vbt.Splitter.from_n_rolling(
... data.wrapper.index,
... 3,
... length=5
... )
>>> splitter.take(data.close, into="stacked")
split 0 1 2
Date
2020-01-01 00:00:00+00:00 7200.174316 NaN NaN
2020-01-02 00:00:00+00:00 6985.470215 NaN NaN
2020-01-03 00:00:00+00:00 7344.884277 NaN NaN
2020-01-04 00:00:00+00:00 7410.656738 NaN NaN
2020-01-05 00:00:00+00:00 7411.317383 NaN NaN
2020-06-29 00:00:00+00:00 NaN 9190.854492 NaN
2020-06-30 00:00:00+00:00 NaN 9137.993164 NaN
2020-07-01 00:00:00+00:00 NaN 9228.325195 NaN
2020-07-02 00:00:00+00:00 NaN 9123.410156 NaN
2020-07-03 00:00:00+00:00 NaN 9087.303711 NaN
2020-12-27 00:00:00+00:00 NaN NaN 26272.294922
2020-12-28 00:00:00+00:00 NaN NaN 27084.808594
2020-12-29 00:00:00+00:00 NaN NaN 27362.437500
2020-12-30 00:00:00+00:00 NaN NaN 28840.953125
2020-12-31 00:00:00+00:00 NaN NaN 29001.720703
```
* Disgard the index and attach index bounds to the column hierarchy:
```pycon
>>> splitter.take(
... data.close,
... into="reset_stacked",
... attach_bounds="index"
... )
split 0 1 \\
start 2020-01-01 00:00:00+00:00 2020-06-29 00:00:00+00:00
end 2020-01-06 00:00:00+00:00 2020-07-04 00:00:00+00:00
0 7200.174316 9190.854492
1 6985.470215 9137.993164
2 7344.884277 9228.325195
3 7410.656738 9123.410156
4 7411.317383 9087.303711
split 2
start 2020-12-27 00:00:00+00:00
end 2021-01-01 00:00:00+00:00
0 26272.294922
1 27084.808594
2 27362.437500
3 28840.953125
4 29001.720703
```
"""
if isinstance(attach_bounds, bool):
if attach_bounds:
attach_bounds = "source"
else:
attach_bounds = None
index_bounds = False
if attach_bounds is not None:
if attach_bounds.lower() == "index":
attach_bounds = "source"
index_bounds = True
if attach_bounds.lower() in ("source_index", "target_index"):
attach_bounds = attach_bounds.split("_")[0]
index_bounds = True
if attach_bounds.lower() not in ("source", "target"):
raise ValueError(f"Invalid attach_bounds: '{attach_bounds}'")
if index_combine_kwargs is None:
index_combine_kwargs = {}
if stack_axis not in (0, 1):
raise ValueError("Axis for stacking must be either 0 or 1")
if stack_kwargs is None:
stack_kwargs = {}
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
split_group_indices, set_group_indices, split_indices, set_indices = self.select_indices(
split=split,
set_=set_,
split_group_by=split_group_by,
set_group_by=set_group_by,
)
if split is not None:
split_labels = split_labels[split_group_indices]
if set_ is not None:
set_labels = set_labels[set_group_indices]
n_splits = len(split_group_indices)
n_sets = len(set_group_indices)
one_split = n_splits == 1 and squeeze_one_split
one_set = n_sets == 1 and squeeze_one_set
one_range = one_split and one_set
def _get_bounds(range_meta, obj_meta, obj_range_meta):
if attach_bounds is not None:
if attach_bounds.lower() == "source":
if index_bounds:
bounds = self.map_bounds_to_index(
range_meta["start"],
range_meta["stop"],
right_inclusive=right_inclusive,
freq=freq,
)
else:
if right_inclusive:
bounds = (range_meta["start"], range_meta["stop"] - 1)
else:
bounds = (range_meta["start"], range_meta["stop"])
else:
if index_bounds:
bounds = self.map_bounds_to_index(
obj_range_meta["start"],
obj_range_meta["stop"],
right_inclusive=right_inclusive,
index=obj_meta["index"],
freq=obj_meta["freq"],
)
else:
if right_inclusive:
bounds = (obj_range_meta["start"], obj_range_meta["stop"] - 1)
else:
bounds = (obj_range_meta["start"], obj_range_meta["stop"])
else:
bounds = (None, None)
return bounds
def _get_range_meta(i, j):
split_idx = split_group_indices[i]
set_idx = set_group_indices[j]
range_ = self.select_range(
split=PosSel(split_idx),
set_=PosSel(set_idx),
split_group_by=split_group_by,
set_group_by=set_group_by,
merge_split_kwargs=dict(template_context=template_context),
)
range_meta = self.get_ready_range(
range_,
range_format=range_format,
template_context=template_context,
return_meta=True,
)
obj_meta, obj_range_meta = self.get_ready_obj_range(
obj,
range_meta["range_"],
remap_to_obj=remap_to_obj,
obj_index=obj_index,
obj_freq=obj_freq,
range_format=range_format,
template_context=template_context,
silence_warnings=silence_warnings,
freq=freq,
return_obj_meta=True,
return_meta=True,
)
if isinstance(obj, CustomTemplate):
_template_context = merge_dicts(
dict(
split_idx=split_idx,
set_idx=set_idx,
range_=obj_range_meta["range_"],
range_meta=obj_range_meta,
point_wise=point_wise,
),
template_context,
)
obj_slice = substitute_templates(obj, _template_context, eval_id="take_range")
else:
obj_slice = self.take_range(obj, obj_range_meta["range_"], point_wise=point_wise)
bounds = _get_bounds(range_meta, obj_meta, obj_range_meta)
return dict(
split_idx=split_idx,
set_idx=set_idx,
range_meta=range_meta,
obj_range_meta=obj_range_meta,
obj_slice=obj_slice,
bounds=bounds,
)
def _attach_bounds(keys, range_bounds):
range_bounds = pd.MultiIndex.from_tuples(range_bounds, names=["start", "end"])
if keys is None:
return range_bounds
clean_index_kwargs = dict(index_combine_kwargs)
clean_index_kwargs.pop("ignore_ranges", None)
return stack_indexes((keys, range_bounds), **clean_index_kwargs)
if into is None:
range_objs = []
range_bounds = []
for i in range(n_splits):
for j in range(n_sets):
range_meta = _get_range_meta(i, j)
range_objs.append(range_meta["obj_slice"])
range_bounds.append(range_meta["bounds"])
if one_range:
return range_objs[0]
if one_set:
keys = split_labels
elif one_split:
keys = set_labels
else:
keys = combine_indexes((split_labels, set_labels), **index_combine_kwargs)
if attach_bounds is not None:
keys = _attach_bounds(keys, range_bounds)
return pd.Series(range_objs, index=keys, dtype=object)
if isinstance(into, str) and into.lower().startswith("reset_"):
if stack_axis == 0:
raise ValueError("Cannot use reset_index with stack_axis=0")
stack_kwargs["reset_index"] = "from_start"
into = into.lower().replace("reset_", "")
if isinstance(into, str) and into.lower().startswith("from_start_"):
if stack_axis == 0:
raise ValueError("Cannot use reset_index with stack_axis=0")
stack_kwargs["reset_index"] = "from_start"
into = into.lower().replace("from_start_", "")
if isinstance(into, str) and into.lower().startswith("from_end_"):
if stack_axis == 0:
raise ValueError("Cannot use reset_index with stack_axis=0")
stack_kwargs["reset_index"] = "from_end"
into = into.lower().replace("from_end_", "")
if isinstance(into, str) and into.lower() in ("split_major_meta", "set_major_meta"):
meta = {
"split_group_indices": split_group_indices,
"set_group_indices": set_group_indices,
"split_indices": split_indices,
"set_indices": set_indices,
"n_splits": n_splits,
"n_sets": n_sets,
"split_labels": split_labels,
"set_labels": set_labels,
}
if isinstance(into, str) and into.lower() == "split_major_meta":
def _get_generator():
for i in range(n_splits):
for j in range(n_sets):
yield _get_range_meta(i, j)
return meta, _get_generator()
if isinstance(into, str) and into.lower() == "set_major_meta":
def _get_generator():
for j in range(n_sets):
for i in range(n_splits):
yield _get_range_meta(i, j)
return meta, _get_generator()
if isinstance(into, str) and into.lower() == "stacked":
range_objs = []
range_bounds = []
for i in range(n_splits):
for j in range(n_sets):
range_meta = _get_range_meta(i, j)
range_objs.append(range_meta["obj_slice"])
range_bounds.append(range_meta["bounds"])
if one_range:
return range_objs[0]
if one_set:
keys = split_labels
elif one_split:
keys = set_labels
else:
keys = combine_indexes((split_labels, set_labels), **index_combine_kwargs)
if attach_bounds is not None:
keys = _attach_bounds(keys, range_bounds)
_stack_kwargs = merge_dicts(dict(keys=keys), stack_kwargs)
if stack_axis == 0:
return row_stack_merge(range_objs, **_stack_kwargs)
return column_stack_merge(range_objs, **_stack_kwargs)
if isinstance(into, str) and into.lower() == "stacked_by_split":
new_split_objs = []
one_set_bounds = []
for i in range(n_splits):
range_objs = []
range_bounds = []
for j in range(n_sets):
range_meta = _get_range_meta(i, j)
range_objs.append(range_meta["obj_slice"])
range_bounds.append(range_meta["bounds"])
if one_set and squeeze_one_set:
new_split_objs.append(range_objs[0])
one_set_bounds.append(range_bounds[0])
else:
keys = set_labels
if attach_bounds is not None:
keys = _attach_bounds(keys, range_bounds)
_stack_kwargs = merge_dicts(dict(keys=keys), stack_kwargs)
if stack_axis == 0:
new_split_objs.append(row_stack_merge(range_objs, **_stack_kwargs))
else:
new_split_objs.append(column_stack_merge(range_objs, **_stack_kwargs))
if one_split and squeeze_one_split:
return new_split_objs[0]
if one_set and squeeze_one_set:
if attach_bounds is not None:
return pd.Series(new_split_objs, index=_attach_bounds(split_labels, one_set_bounds), dtype=object)
return pd.Series(new_split_objs, index=split_labels, dtype=object)
if isinstance(into, str) and into.lower() == "stacked_by_set":
new_set_objs = []
one_split_bounds = []
for j in range(n_sets):
range_objs = []
range_bounds = []
for i in range(n_splits):
range_meta = _get_range_meta(i, j)
range_objs.append(range_meta["obj_slice"])
range_bounds.append(range_meta["bounds"])
if one_split and squeeze_one_split:
new_set_objs.append(range_objs[0])
one_split_bounds.append(range_bounds[0])
else:
keys = split_labels
if attach_bounds:
keys = _attach_bounds(keys, range_bounds)
_stack_kwargs = merge_dicts(dict(keys=keys), stack_kwargs)
if stack_axis == 0:
new_set_objs.append(row_stack_merge(range_objs, **_stack_kwargs))
else:
new_set_objs.append(column_stack_merge(range_objs, **_stack_kwargs))
if one_set and squeeze_one_set:
return new_set_objs[0]
if one_split and squeeze_one_split:
if attach_bounds is not None:
return pd.Series(new_set_objs, index=_attach_bounds(set_labels, one_split_bounds), dtype=object)
return pd.Series(new_set_objs, index=set_labels, dtype=object)
raise ValueError(f"Invalid into: '{into}'")
# ############# Applying ############# #
@classmethod
def parse_and_inject_takeables(
cls,
flat_ann_args: tp.FlatAnnArgs,
eval_id: tp.Optional[tp.Hashable] = None,
) -> tp.FlatAnnArgs:
"""Parse `Takeable` instances from function annotations and inject them into flattened annotated arguments."""
new_flat_ann_args = dict()
for k, v in flat_ann_args.items():
new_flat_ann_args[k] = v = dict(v)
if "annotation" in v:
if isinstance(v["annotation"], type) and issubclass(v["annotation"], Takeable):
v["annotation"] = v["annotation"]()
if isinstance(v["annotation"], Takeable) and v["annotation"].meets_eval_id(eval_id):
if "value" in v:
if not isinstance(v["value"], Takeable):
v["value"] = v["annotation"].replace(obj=v["value"])
else:
v["value"] = v["value"].merge_over(v["annotation"])
return new_flat_ann_args
def apply(
self,
apply_func: tp.Callable,
*apply_args,
split: tp.Optional[tp.Selection] = None,
set_: tp.Optional[tp.Selection] = None,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
squeeze_one_split: bool = True,
squeeze_one_set: bool = True,
remap_to_obj: bool = True,
obj_index: tp.Optional[tp.IndexLike] = None,
obj_freq: tp.Optional[tp.FrequencyLike] = None,
range_format: str = "slice_or_any",
point_wise: bool = False,
attach_bounds: tp.Union[bool, str] = False,
right_inclusive: bool = False,
template_context: tp.KwargsLike = None,
silence_warnings: bool = False,
index_combine_kwargs: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
iteration: str = "split_wise",
execute_kwargs: tp.KwargsLike = None,
filter_results: bool = True,
raise_no_results: bool = True,
merge_func: tp.Union[None, str, tuple, tp.Callable] = None,
merge_kwargs: tp.KwargsLike = None,
merge_all: bool = True,
wrap_results: bool = True,
eval_id: tp.Optional[tp.Hashable] = None,
**apply_kwargs,
) -> tp.Any:
"""Apply a function on each range.
Uses `Splitter.select_indices` to get the indices for selected splits and sets.
Arguments `split_group_by` and `set_group_by` can be used to group splits and sets respectively.
Ranges belonging to the same split and set group will be merged.
For each index pair, in a lazily manner, resolves the source range using `Splitter.select_range`
and `Splitter.get_ready_range`. Then, takes each argument from `args` and `kwargs`
wrapped with `Takeable`, remaps the range into each object's index using `Splitter.get_ready_obj_range`,
and takes the slice from that object using `Splitter.take_range`. The original object will
be substituted by this slice. At the end, substitutes any templates in the prepared
`args` and `kwargs` and saves the function and arguments for execution.
For substitution, the following information is available:
* `split/set_group_indices`: Indices corresponding to the selected row/column groups
* `split/set_indices`: Indices corresponding to the selected rows/columns
* `n_splits/sets`: Number of the selected rows/columns
* `split/set_labels`: Labels corresponding to the selected row/column groups
* `split/set_idx`: Index of the selected row/column
* `split/set_label`: Label of the selected row/column
* `range_`: Selected range ready for indexing (see `Splitter.get_ready_range`)
* `range_meta`: Various information on the selected range
* `obj_range_meta`: Various information on the range taken from each takeable argument.
Positional arguments are denoted by position, keyword arguments are denoted by keys.
* `args`: Positional arguments with ranges already selected
* `kwargs`: Keyword arguments with ranges already selected
* `bounds`: A tuple of either integer or index bounds. Can be source or target depending on `attach_bounds`.
* `template_context`: Passed template context
Since each range is processed lazily (that is, upon request), there are multiple iteration
modes controlled by the argument `iteration`:
* 'split_major': Flatten all ranges in split-major order and iterate over them
* 'set_major': Flatten all ranges in set-major order and iterate over them
* 'split_wise': Iterate over splits, while ranges in each split are processed sequentially
* 'set_wise': Iterate over sets, while ranges in each set are processed sequentially
The execution is done using `vectorbtpro.utils.execution.execute` with `execute_kwargs`.
Once all results have been obtained, attempts to merge them using `merge_func` with `merge_kwargs`
(all templates in it will be substituted as well), which can also be a string or a tuple of
strings resolved using `vectorbtpro.base.merging.resolve_merge_func`. If `wrap_results` is enabled,
packs the results into a Pandas object. If `apply_func` returns something complex, the resulting
Pandas object will be of object data type. If `apply_func` returns a tuple (detected by the first
returned result), a Pandas object is built for each element of that tuple.
If `merge_all` is True, will merge all results in a flattened manner irrespective of the
iteration mode. Otherwise, will merge by split/set.
If `vectorbtpro.utils.execution.NoResult` is returned, will skip the current iteration and
remove it from the final index.
Usage:
* Get the return of each data range:
```pycon
>>> from vectorbtpro import *
>>> data = vbt.YFData.pull(
... "BTC-USD",
... start="2020-01-01 UTC",
... end="2021-01-01 UTC"
... )
>>> splitter = vbt.Splitter.from_n_rolling(data.wrapper.index, 5)
>>> def apply_func(data):
... return data.close.iloc[-1] - data.close.iloc[0]
>>> splitter.apply(apply_func, vbt.Takeable(data))
split
0 -1636.467285
1 3706.568359
2 2944.720703
3 -118.113281
4 17098.916016
dtype: float64
```
* The same but by indexing manually:
```pycon
>>> def apply_func(range_, data):
... data = data.iloc[range_]
... return data.close.iloc[-1] - data.close.iloc[0]
>>> splitter.apply(apply_func, vbt.Rep("range_"), data)
split
0 -1636.467285
1 3706.568359
2 2944.720703
3 -118.113281
4 17098.916016
dtype: float64
```
* Divide into two windows, each consisting of 50% train and 50% test, compute SMA for
each range, and row-stack the outputs of each set upon merging:
```pycon
>>> splitter = vbt.Splitter.from_n_rolling(data.wrapper.index, 2, split=0.5)
>>> def apply_func(data):
... return data.run("SMA", 10).real
>>> splitter.apply(
... apply_func,
... vbt.Takeable(data),
... merge_func="row_stack"
... ).unstack("set").vbt.drop_levels("split", axis=0).vbt.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
if isinstance(attach_bounds, bool):
if attach_bounds:
attach_bounds = "source"
else:
attach_bounds = None
index_bounds = False
if attach_bounds is not None:
if attach_bounds.lower() == "index":
attach_bounds = "source"
index_bounds = True
if attach_bounds.lower() in ("source_index", "target_index"):
attach_bounds = attach_bounds.split("_")[0]
index_bounds = True
if attach_bounds.lower() not in ("source", "target"):
raise ValueError(f"Invalid attach_bounds: '{attach_bounds}'")
if index_combine_kwargs is None:
index_combine_kwargs = {}
if execute_kwargs is None:
execute_kwargs = {}
parsed_merge_func = parse_merge_func(apply_func, eval_id=eval_id)
if parsed_merge_func is not None:
if merge_func is not None:
raise ValueError(
f"Two conflicting merge functions: {parsed_merge_func} (annotations) and {merge_func} (merge_func)"
)
merge_func = parsed_merge_func
if merge_kwargs is None:
merge_kwargs = {}
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
split_group_indices, set_group_indices, split_indices, set_indices = self.select_indices(
split=split,
set_=set_,
split_group_by=split_group_by,
set_group_by=set_group_by,
)
if split is not None:
split_labels = split_labels[split_group_indices]
if set_ is not None:
set_labels = set_labels[set_group_indices]
n_splits = len(split_group_indices)
n_sets = len(set_group_indices)
one_split = n_splits == 1 and squeeze_one_split
one_set = n_sets == 1 and squeeze_one_set
one_range = one_split and one_set
template_context = merge_dicts(
{
"splitter": self,
"index": self.index,
"split_group_indices": split_group_indices,
"set_group_indices": set_group_indices,
"split_indices": split_indices,
"set_indices": set_indices,
"n_splits": n_splits,
"n_sets": n_sets,
"split_labels": split_labels,
"set_labels": set_labels,
"one_split": one_split,
"one_set": one_set,
"one_range": one_range,
},
template_context,
)
template_context["eval_id"] = eval_id
if has_annotatables(apply_func):
ann_args = annotate_args(
apply_func,
apply_args,
apply_kwargs,
attach_annotations=True,
)
flat_ann_args = flatten_ann_args(ann_args)
flat_ann_args = self.parse_and_inject_takeables(flat_ann_args, eval_id=eval_id)
ann_args = unflatten_ann_args(flat_ann_args)
apply_args, apply_kwargs = ann_args_to_args(ann_args)
def _get_range_meta(i, j, _template_context):
split_idx = split_group_indices[i]
set_idx = set_group_indices[j]
range_ = self.select_range(
split=PosSel(split_idx),
set_=PosSel(set_idx),
split_group_by=split_group_by,
set_group_by=set_group_by,
merge_split_kwargs=dict(template_context=_template_context),
)
range_meta = self.get_ready_range(
range_,
range_format=range_format,
template_context=_template_context,
return_meta=True,
)
return range_meta
def _take_args(args, range_, _template_context):
obj_meta = {}
obj_range_meta = {}
new_args = ()
if args is not None:
for i, v in enumerate(args):
if isinstance(v, Takeable) and v.meets_eval_id(eval_id):
_obj_meta, _obj_range_meta, obj_slice = self.take_range_from_takeable(
v,
range_,
remap_to_obj=remap_to_obj,
obj_index=obj_index,
obj_freq=obj_freq,
range_format=range_format,
point_wise=point_wise,
template_context=_template_context,
silence_warnings=silence_warnings,
freq=freq,
return_obj_meta=True,
return_meta=True,
)
new_args += (obj_slice,)
obj_meta[i] = _obj_meta
obj_range_meta[i] = _obj_range_meta
else:
new_args += (v,)
return obj_meta, obj_range_meta, new_args
def _take_kwargs(kwargs, range_, _template_context):
obj_meta = {}
obj_range_meta = {}
new_kwargs = {}
if kwargs is not None:
for k, v in kwargs.items():
if isinstance(v, Takeable) and v.meets_eval_id(eval_id):
_obj_meta, _obj_range_meta, obj_slice = self.take_range_from_takeable(
v,
range_,
remap_to_obj=remap_to_obj,
obj_index=obj_index,
obj_freq=obj_freq,
range_format=range_format,
point_wise=point_wise,
template_context=_template_context,
silence_warnings=silence_warnings,
freq=freq,
return_obj_meta=True,
return_meta=True,
)
new_kwargs[k] = obj_slice
obj_meta[k] = _obj_meta
obj_range_meta[k] = _obj_range_meta
else:
new_kwargs[k] = v
return obj_meta, obj_range_meta, new_kwargs
def _get_bounds(range_meta, _template_context):
if attach_bounds is not None:
if isinstance(attach_bounds, str) and attach_bounds.lower() == "source":
if index_bounds:
bounds = self.map_bounds_to_index(
range_meta["start"],
range_meta["stop"],
right_inclusive=right_inclusive,
freq=freq,
)
else:
if right_inclusive:
bounds = (range_meta["start"], range_meta["stop"] - 1)
else:
bounds = (range_meta["start"], range_meta["stop"])
else:
obj_meta, obj_range_meta = self.get_ready_obj_range(
self.index,
range_meta["range_"],
remap_to_obj=remap_to_obj,
obj_index=obj_index,
obj_freq=obj_freq,
range_format=range_format,
template_context=_template_context,
silence_warnings=silence_warnings,
freq=freq,
return_obj_meta=True,
return_meta=True,
)
if index_bounds:
bounds = self.map_bounds_to_index(
obj_range_meta["start"],
obj_range_meta["stop"],
right_inclusive=right_inclusive,
index=obj_meta["index"],
freq=obj_meta["freq"],
)
else:
if right_inclusive:
bounds = (
obj_range_meta["start"],
obj_range_meta["stop"] - 1,
)
else:
bounds = (
obj_range_meta["start"],
obj_range_meta["stop"],
)
else:
bounds = (None, None)
return bounds
bounds = {}
def _get_task(i, j, _bounds=bounds):
split_idx = split_group_indices[i]
set_idx = set_group_indices[j]
_template_context = merge_dicts(
{
"split_idx": split_idx,
"split_label": split_labels[i],
"set_idx": set_idx,
"set_label": set_labels[j],
},
template_context,
)
range_meta = _get_range_meta(i, j, _template_context)
_template_context = merge_dicts(
dict(range_=range_meta["range_"], range_meta=range_meta),
_template_context,
)
obj_meta1, obj_range_meta1, _apply_args = _take_args(apply_args, range_meta["range_"], _template_context)
obj_meta2, obj_range_meta2, _apply_kwargs = _take_kwargs(
apply_kwargs, range_meta["range_"], _template_context
)
obj_meta = {**obj_meta1, **obj_meta2}
obj_range_meta = {**obj_range_meta1, **obj_range_meta2}
_bounds[(i, j)] = _get_bounds(range_meta, _template_context)
_template_context = merge_dicts(
dict(
obj_meta=obj_meta,
obj_range_meta=obj_range_meta,
apply_args=_apply_args,
apply_kwargs=_apply_kwargs,
bounds=_bounds[(i, j)],
),
_template_context,
)
_apply_func = substitute_templates(apply_func, _template_context, eval_id="apply_func")
_apply_args = substitute_templates(_apply_args, _template_context, eval_id="apply_args")
_apply_kwargs = substitute_templates(_apply_kwargs, _template_context, eval_id="apply_kwargs")
return Task(_apply_func, *_apply_args, **_apply_kwargs)
def _attach_bounds(keys, range_bounds):
range_bounds = pd.MultiIndex.from_tuples(range_bounds, names=["start", "end"])
if keys is None:
return range_bounds
clean_index_kwargs = dict(index_combine_kwargs)
clean_index_kwargs.pop("ignore_ranges", None)
return stack_indexes((keys, range_bounds), **clean_index_kwargs)
if iteration.lower() == "split_major":
def _get_task_generator():
for i in range(n_splits):
for j in range(n_sets):
yield _get_task(i, j)
tasks = _get_task_generator()
keys = combine_indexes((split_labels, set_labels), **index_combine_kwargs)
if eval_id is not None:
new_keys = []
for key in keys:
if isinstance(keys, pd.MultiIndex):
new_keys.append((MISSING, *key))
else:
new_keys.append((MISSING, key))
keys = pd.MultiIndex.from_tuples(new_keys, names=(f"eval_id={eval_id}", *keys.names))
execute_kwargs = merge_dicts(dict(show_progress=False if one_split and one_set else None), execute_kwargs)
results = execute(tasks, size=n_splits * n_sets, keys=keys, **execute_kwargs)
elif iteration.lower() == "set_major":
def _get_task_generator():
for j in range(n_sets):
for i in range(n_splits):
yield _get_task(i, j)
tasks = _get_task_generator()
keys = combine_indexes((set_labels, split_labels), **index_combine_kwargs)
if eval_id is not None:
new_keys = []
for key in keys:
if isinstance(keys, pd.MultiIndex):
new_keys.append((MISSING, *key))
else:
new_keys.append((MISSING, key))
keys = pd.MultiIndex.from_tuples(new_keys, names=(f"eval_id={eval_id}", *keys.names))
execute_kwargs = merge_dicts(dict(show_progress=False if one_split and one_set else None), execute_kwargs)
results = execute(tasks, size=n_splits * n_sets, keys=keys, **execute_kwargs)
elif iteration.lower() == "split_wise":
def _process_chunk_tasks(chunk_tasks):
results = []
for func, args, kwargs in chunk_tasks:
results.append(func(*args, **kwargs))
return results
def _get_task_generator():
for i in range(n_splits):
chunk_tasks = []
for j in range(n_sets):
chunk_tasks.append(_get_task(i, j))
yield Task(_process_chunk_tasks, chunk_tasks)
tasks = _get_task_generator()
keys = split_labels
if eval_id is not None:
new_keys = []
for key in keys:
if isinstance(keys, pd.MultiIndex):
new_keys.append((MISSING, *key))
else:
new_keys.append((MISSING, key))
keys = pd.MultiIndex.from_tuples(new_keys, names=(f"eval_id={eval_id}", *keys.names))
execute_kwargs = merge_dicts(dict(show_progress=False if one_split else None), execute_kwargs)
results = execute(tasks, size=n_splits, keys=keys, **execute_kwargs)
elif iteration.lower() == "set_wise":
def _process_chunk_tasks(chunk_tasks):
results = []
for func, args, kwargs in chunk_tasks:
results.append(func(*args, **kwargs))
return results
def _get_task_generator():
for j in range(n_sets):
chunk_tasks = []
for i in range(n_splits):
chunk_tasks.append(_get_task(i, j))
yield Task(_process_chunk_tasks, chunk_tasks)
tasks = _get_task_generator()
keys = set_labels
if eval_id is not None:
new_keys = []
for key in keys:
if isinstance(keys, pd.MultiIndex):
new_keys.append((MISSING, *key))
else:
new_keys.append((MISSING, key))
keys = pd.MultiIndex.from_tuples(new_keys, names=(f"eval_id={eval_id}", *keys.names))
execute_kwargs = merge_dicts(dict(show_progress=False if one_set else None), execute_kwargs)
results = execute(tasks, size=n_sets, keys=keys, **execute_kwargs)
else:
raise ValueError(f"Invalid iteration: '{iteration}'")
if merge_all:
if iteration.lower() in ("split_wise", "set_wise"):
results = [result for _results in results for result in _results]
if one_range:
if results[0] is NoResult:
if raise_no_results:
raise NoResultsException
return NoResult
return results[0]
if iteration.lower() in ("split_major", "split_wise"):
if one_set:
keys = split_labels
elif one_split:
keys = set_labels
else:
keys = combine_indexes((split_labels, set_labels), **index_combine_kwargs)
if attach_bounds is not None:
range_bounds = []
for i in range(n_splits):
for j in range(n_sets):
range_bounds.append(bounds[(i, j)])
keys = _attach_bounds(keys, range_bounds)
else:
if one_set:
keys = split_labels
elif one_split:
keys = set_labels
else:
keys = combine_indexes((set_labels, split_labels), **index_combine_kwargs)
if attach_bounds is not None:
range_bounds = []
for j in range(n_sets):
for i in range(n_splits):
range_bounds.append(bounds[(i, j)])
keys = _attach_bounds(keys, range_bounds)
if filter_results:
try:
results, keys = filter_out_no_results(results, keys=keys)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
no_results_filtered = True
else:
no_results_filtered = False
def _wrap_output(_results):
try:
return pd.Series(_results, index=keys)
except Exception as e:
return pd.Series(_results, index=keys, dtype=object)
if merge_func is not None:
template_context["tasks"] = tasks
template_context["keys"] = keys
if is_merge_func_from_config(merge_func):
merge_kwargs = merge_dicts(
dict(
keys=keys,
filter_results=not no_results_filtered,
raise_no_results=raise_no_results,
),
merge_kwargs,
)
if isinstance(merge_func, MergeFunc):
merge_func = merge_func.replace(
merge_kwargs=merge_kwargs,
context=template_context,
)
else:
merge_func = MergeFunc(
merge_func,
merge_kwargs=merge_kwargs,
context=template_context,
)
return merge_func(results)
if wrap_results:
if isinstance(results[0], tuple):
return tuple(map(_wrap_output, zip(*results)))
return _wrap_output(results)
return results
if iteration.lower() == "split_major":
new_results = []
for i in range(n_splits):
new_results.append(results[i * n_sets : (i + 1) * n_sets])
results = new_results
elif iteration.lower() == "set_major":
new_results = []
for i in range(n_sets):
new_results.append(results[i * n_splits : (i + 1) * n_splits])
results = new_results
if one_range:
if results[0][0] is NoResult:
if raise_no_results:
raise NoResultsException
return NoResult
return results[0][0]
split_bounds = []
if attach_bounds is not None:
for i in range(n_splits):
split_bounds.append([])
for j in range(n_sets):
split_bounds[-1].append(bounds[(i, j)])
set_bounds = []
if attach_bounds is not None:
for j in range(n_sets):
set_bounds.append([])
for i in range(n_splits):
set_bounds[-1].append(bounds[(i, j)])
if iteration.lower() in ("split_major", "split_wise"):
major_keys = split_labels
minor_keys = set_labels
major_bounds = split_bounds
minor_bounds = set_bounds
one_major = one_split
one_minor = one_set
else:
major_keys = set_labels
minor_keys = split_labels
major_bounds = set_bounds
minor_bounds = split_bounds
one_major = one_set
one_minor = one_split
if merge_func is not None:
merged_results = []
keep_major_indices = []
for i, _results in enumerate(results):
if one_minor:
if _results[0] is not NoResult:
merged_results.append(_results[0])
keep_major_indices.append(i)
else:
_template_context = dict(template_context)
_template_context["tasks"] = tasks
if attach_bounds is not None:
minor_keys_wbounds = _attach_bounds(minor_keys, major_bounds[i])
else:
minor_keys_wbounds = minor_keys
if filter_results:
_results, minor_keys_wbounds = filter_out_no_results(
_results,
keys=minor_keys_wbounds,
raise_error=False,
)
no_results_filtered = True
else:
no_results_filtered = False
if len(_results) > 0:
_template_context["keys"] = minor_keys_wbounds
if is_merge_func_from_config(merge_func):
_merge_kwargs = merge_dicts(
dict(
keys=minor_keys_wbounds,
filter_results=not no_results_filtered,
raise_no_results=False,
),
merge_kwargs,
)
else:
_merge_kwargs = merge_kwargs
if isinstance(merge_func, MergeFunc):
_merge_func = merge_func.replace(
merge_kwargs=_merge_kwargs,
context=_template_context,
)
else:
_merge_func = MergeFunc(
merge_func,
merge_kwargs=_merge_kwargs,
context=_template_context,
)
_result = _merge_func(_results)
if _result is not NoResult:
merged_results.append(_result)
keep_major_indices.append(i)
if len(merged_results) == 0:
if raise_no_results:
raise NoResultsException
return NoResult
if len(merged_results) < len(major_keys):
major_keys = major_keys[keep_major_indices]
if one_major:
return merged_results[0]
if wrap_results:
def _wrap_output(_results):
try:
return pd.Series(_results, index=major_keys)
except Exception as e:
return pd.Series(_results, index=major_keys, dtype=object)
if isinstance(merged_results[0], tuple):
return tuple(map(_wrap_output, zip(*merged_results)))
return _wrap_output(merged_results)
return merged_results
if one_major:
results = results[0]
elif one_minor:
results = [_results[0] for _results in results]
if wrap_results:
def _wrap_output(_results):
if one_minor:
if attach_bounds is not None:
major_keys_wbounds = _attach_bounds(major_keys, minor_bounds[0])
else:
major_keys_wbounds = major_keys
if filter_results:
try:
_results, major_keys_wbounds = filter_out_no_results(_results, keys=major_keys_wbounds)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
try:
return pd.Series(_results, index=major_keys_wbounds)
except Exception as e:
return pd.Series(_results, index=major_keys_wbounds, dtype=object)
if one_major:
if attach_bounds is not None:
minor_keys_wbounds = _attach_bounds(minor_keys, major_bounds[0])
else:
minor_keys_wbounds = minor_keys
if filter_results:
try:
_results, major_keys_wbounds = filter_out_no_results(_results, keys=minor_keys_wbounds)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
try:
return pd.Series(_results, index=minor_keys_wbounds)
except Exception as e:
return pd.Series(_results, index=minor_keys_wbounds, dtype=object)
new_results = []
keep_major_indices = []
for i, r in enumerate(_results):
if attach_bounds is not None:
minor_keys_wbounds = _attach_bounds(minor_keys, major_bounds[i])
else:
minor_keys_wbounds = minor_keys
if filter_results:
r, minor_keys_wbounds = filter_out_no_results(
r,
keys=minor_keys_wbounds,
raise_error=False,
)
if len(r) > 0:
try:
new_r = pd.Series(r, index=minor_keys_wbounds)
except Exception as e:
new_r = pd.Series(r, index=minor_keys_wbounds, dtype=object)
new_results.append(new_r)
keep_major_indices.append(i)
if len(new_results) == 0:
if raise_no_results:
raise NoResultsException
return NoResult
if len(new_results) < len(major_keys):
_major_keys = major_keys[keep_major_indices]
else:
_major_keys = major_keys
try:
return pd.Series(new_results, index=_major_keys)
except Exception as e:
return pd.Series(new_results, index=_major_keys, dtype=object)
if one_major or one_minor:
n_results = 1
for r in results:
if isinstance(r, tuple):
n_results = len(r)
break
if n_results > 1:
new_results = []
for k in range(n_results):
new_results.append([])
for i in range(len(results)):
if results[i] is NoResult:
new_results[-1].append(results[i])
else:
new_results[-1].append(results[i][k])
return tuple(map(_wrap_output, new_results))
else:
n_results = 1
for r in results:
for _r in r:
if isinstance(_r, tuple):
n_results = len(_r)
break
if n_results > 1:
break
if n_results > 1:
new_results = []
for k in range(n_results):
new_results.append([])
for i in range(len(results)):
new_results[-1].append([])
for j in range(len(results[0])):
if results[i][j] is NoResult:
new_results[-1][-1].append(results[i][j])
else:
new_results[-1][-1].append(results[i][j][k])
return tuple(map(_wrap_output, new_results))
return _wrap_output(results)
if filter_results:
try:
results = filter_out_no_results(results)
except NoResultsException as e:
if raise_no_results:
raise e
return NoResult
return results
# ############# Splits ############# #
def shuffle_splits(
self: SplitterT,
size: tp.Union[None, str, int] = None,
replace: bool = False,
p: tp.Optional[tp.Array1d] = None,
seed: tp.Optional[int] = None,
wrapper_kwargs: tp.KwargsLike = None,
**init_kwargs,
) -> SplitterT:
"""Shuffle splits."""
if wrapper_kwargs is None:
wrapper_kwargs = {}
rng = np.random.default_rng(seed=seed)
if size is None:
size = self.n_splits
new_split_indices = rng.choice(np.arange(self.n_splits), size=size, replace=replace, p=p)
new_splits_arr = self.splits_arr[new_split_indices]
new_index = self.wrapper.index[new_split_indices]
if "index" not in wrapper_kwargs:
wrapper_kwargs["index"] = new_index
new_wrapper = self.wrapper.replace(**wrapper_kwargs)
return self.replace(wrapper=new_wrapper, splits_arr=new_splits_arr, **init_kwargs)
def break_up_splits(
self: SplitterT,
new_split: tp.SplitLike,
sort: bool = False,
template_context: tp.KwargsLike = None,
wrapper_kwargs: tp.KwargsLike = None,
init_kwargs: tp.KwargsLike = None,
**split_range_kwargs,
) -> SplitterT:
"""Split each split into multiple splits.
If there are multiple sets, make sure to merge them into one beforehand.
Arguments `new_split` and `**split_range_kwargs` are passed to `Splitter.split_range`."""
if self.n_sets > 1:
raise ValueError("Cannot break up splits with more than one set. Merge sets first.")
if wrapper_kwargs is None:
wrapper_kwargs = {}
if init_kwargs is None:
init_kwargs = {}
split_range_kwargs = dict(split_range_kwargs)
wrap_with_fixrange = split_range_kwargs.pop("wrap_with_fixrange", None)
if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange:
raise ValueError("Argument wrap_with_fixrange must be True or None")
split_range_kwargs["wrap_with_fixrange"] = wrap_with_fixrange
new_splits_arr = []
new_index = []
range_starts = []
for i, split in enumerate(self.splits_arr):
new_ranges = self.split_range(split[0], new_split, template_context=template_context, **split_range_kwargs)
for j, range_ in enumerate(new_ranges):
if sort:
range_starts.append(self.get_range_bounds(range_, template_context=template_context)[0])
new_splits_arr.append([range_])
if isinstance(self.split_labels, pd.MultiIndex):
new_index.append((*self.split_labels[i], j))
else:
new_index.append((self.split_labels[i], j))
new_splits_arr = np.asarray(new_splits_arr, dtype=object)
new_index = pd.MultiIndex.from_tuples(new_index, names=[*self.split_labels.names, "split_part"])
if sort:
sorted_indices = np.argsort(range_starts)
new_splits_arr = new_splits_arr[sorted_indices]
new_index = new_index[sorted_indices]
if "index" not in wrapper_kwargs:
wrapper_kwargs["index"] = new_index
new_wrapper = self.wrapper.replace(**wrapper_kwargs)
return self.replace(wrapper=new_wrapper, splits_arr=new_splits_arr, **init_kwargs)
# ############# Sets ############# #
def split_set(
self: SplitterT,
new_split: tp.SplitLike,
column: tp.Optional[tp.Hashable] = None,
new_set_labels: tp.Optional[tp.Sequence[tp.Hashable]] = None,
wrapper_kwargs: tp.KwargsLike = None,
init_kwargs: tp.KwargsLike = None,
**split_range_kwargs,
) -> SplitterT:
"""Split a set (column) into multiple sets (columns).
Arguments `new_split` and `**split_range_kwargs` are passed to `Splitter.split_range`.
Column must be provided if there are two or more sets.
Use `new_set_labels` to specify the labels of the new sets; it must have the same length
as there are new ranges in the new split. To provide final labels, define `columns` in
`wrapper_kwargs`."""
if self.n_sets == 0:
raise ValueError("There are no sets to split")
if self.n_sets > 1:
if column is None:
raise ValueError("Must provide column for multiple sets")
if not isinstance(column, int):
column = self.set_labels.get_indexer([column])[0]
if column == -1:
raise ValueError(f"Column '{column}' not found")
else:
column = 0
if wrapper_kwargs is None:
wrapper_kwargs = {}
if init_kwargs is None:
init_kwargs = {}
split_range_kwargs = dict(split_range_kwargs)
wrap_with_fixrange = split_range_kwargs.pop("wrap_with_fixrange", None)
if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange:
raise ValueError("Argument wrap_with_fixrange must be True or None")
split_range_kwargs["wrap_with_fixrange"] = wrap_with_fixrange
new_splits_arr = []
for split in self.splits_arr:
new_ranges = self.split_range(split[column], new_split, **split_range_kwargs)
new_splits_arr.append([*split[:column], *new_ranges, *split[column + 1 :]])
new_splits_arr = np.asarray(new_splits_arr, dtype=object)
if "columns" not in wrapper_kwargs:
wrapper_kwargs = dict(wrapper_kwargs)
n_new_sets = new_splits_arr.shape[1] - self.n_sets + 1
if new_set_labels is None:
old_set_label = self.set_labels[column]
if isinstance(old_set_label, str) and len(old_set_label.split("+")) == n_new_sets:
new_set_labels = old_set_label.split("+")
else:
new_set_labels = [str(old_set_label) + "/" + str(i) for i in range(n_new_sets)]
if len(new_set_labels) != n_new_sets:
raise ValueError(f"Argument new_set_labels must have length {n_new_sets}, not {len(new_set_labels)}")
new_columns = self.set_labels.copy()
new_columns = new_columns.delete(column)
new_columns = new_columns.insert(column, new_set_labels)
wrapper_kwargs["columns"] = new_columns
new_wrapper = self.wrapper.replace(**wrapper_kwargs)
return self.replace(wrapper=new_wrapper, splits_arr=new_splits_arr, **init_kwargs)
def merge_sets(
self: SplitterT,
columns: tp.Optional[tp.Iterable[tp.Hashable]] = None,
new_set_label: tp.Optional[tp.Hashable] = None,
insert_at_last: bool = False,
wrapper_kwargs: tp.KwargsLike = None,
init_kwargs: tp.KwargsLike = None,
**merge_split_kwargs,
) -> SplitterT:
"""Merge multiple sets (columns) into a set (column).
Arguments `**merge_split_kwargs` are passed to `Splitter.merge_split`.
If columns are not provided, merges all columns. If provided and `insert_at_last` is True,
a new column is inserted at the position of the last column.
Use `new_set_label` to specify the label of the new set. To provide final labels,
define `columns` in `wrapper_kwargs`."""
if self.n_sets < 2:
raise ValueError("There are no sets to merge")
if columns is None:
columns = range(len(self.set_labels))
new_columns = []
for column in columns:
if not isinstance(column, int):
column = self.set_labels.get_indexer([column])[0]
if column == -1:
raise ValueError(f"Column '{column}' not found")
new_columns.append(column)
columns = sorted(new_columns)
if wrapper_kwargs is None:
wrapper_kwargs = {}
if init_kwargs is None:
init_kwargs = {}
merge_split_kwargs = dict(merge_split_kwargs)
wrap_with_fixrange = merge_split_kwargs.pop("wrap_with_fixrange", None)
if isinstance(wrap_with_fixrange, bool) and not wrap_with_fixrange:
raise ValueError("Argument wrap_with_fixrange must be True or None")
merge_split_kwargs["wrap_with_fixrange"] = wrap_with_fixrange
new_splits_arr = []
for split in self.splits_arr:
split_to_merge = []
for j, range_ in enumerate(split):
if j in columns:
split_to_merge.append(range_)
new_range = self.merge_split(split_to_merge, **merge_split_kwargs)
new_split = []
for j in range(self.n_sets):
if j not in columns:
new_split.append(split[j])
else:
if insert_at_last:
if j == columns[-1]:
new_split.append(new_range)
else:
if j == columns[0]:
new_split.append(new_range)
new_splits_arr.append(new_split)
new_splits_arr = np.asarray(new_splits_arr, dtype=object)
if "columns" not in wrapper_kwargs:
wrapper_kwargs = dict(wrapper_kwargs)
if new_set_label is None:
old_set_labels = self.set_labels[columns]
can_aggregate = True
prefix = None
suffix = None
for i, old_set_label in enumerate(old_set_labels):
if not isinstance(old_set_label, str):
can_aggregate = False
break
_prefix = "/".join(old_set_label.split("/")[:-1])
_suffix = old_set_label.split("/")[-1]
if not _suffix.isdigit():
can_aggregate = False
break
_suffix = int(_suffix)
if prefix is None:
prefix = _prefix
suffix = _suffix
continue
if suffix != 0:
can_aggregate = False
break
if not _prefix == prefix or _suffix != i:
can_aggregate = False
break
if can_aggregate and prefix + "/%d" % len(old_set_labels) not in self.set_labels:
new_set_label = prefix
else:
new_set_label = "+".join(map(str, old_set_labels))
new_columns = self.set_labels.copy()
new_columns = new_columns.delete(columns)
if insert_at_last:
new_columns = new_columns.insert(columns[-1] - len(columns) + 1, new_set_label)
else:
new_columns = new_columns.insert(columns[0], new_set_label)
wrapper_kwargs["columns"] = new_columns
if "ndim" not in wrapper_kwargs:
if len(wrapper_kwargs["columns"]) == 1:
wrapper_kwargs["ndim"] = 1
new_wrapper = self.wrapper.replace(**wrapper_kwargs)
return self.replace(wrapper=new_wrapper, splits_arr=new_splits_arr, **init_kwargs)
# ############# Bounds ############# #
@hybrid_method
def map_bounds_to_index(
cls_or_self,
start: int,
stop: int,
right_inclusive: bool = False,
index: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.Tuple[tp.Any, tp.Any]:
"""Map bounds to index."""
if index is None:
if isinstance(cls_or_self, type):
raise ValueError("Must provide index")
index = cls_or_self.index
else:
index = dt.prepare_dt_index(index)
if right_inclusive:
return index[start], index[stop - 1]
if stop == len(index):
freq = BaseIDXAccessor(index, freq=freq).any_freq
if freq is None:
raise ValueError("Must provide freq")
return index[start], index[stop - 1] + freq
return index[start], index[stop]
@hybrid_method
def get_range_bounds(
cls_or_self,
range_: tp.FixRangeLike,
index_bounds: bool = False,
right_inclusive: bool = False,
check_constant: bool = True,
template_context: tp.KwargsLike = None,
index: tp.Optional[tp.IndexLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
) -> tp.Tuple[tp.Any, tp.Any]:
"""Get the left (inclusive) and right (exclusive) bound of a range.
!!! note
Even when mapped to the index, the right bound is always exclusive."""
if index is None:
if isinstance(cls_or_self, type):
raise ValueError("Must provide index")
index = cls_or_self.index
else:
index = dt.prepare_dt_index(index)
range_meta = cls_or_self.get_ready_range(
range_,
template_context=template_context,
index=index,
return_meta=True,
)
if check_constant and not range_meta["is_constant"]:
raise ValueError("Range is not constant")
if index_bounds:
range_meta["start"], range_meta["stop"] = cls_or_self.map_bounds_to_index(
range_meta["start"],
range_meta["stop"],
right_inclusive=right_inclusive,
index=index,
freq=freq,
)
else:
if right_inclusive:
range_meta["stop"] = range_meta["stop"] - 1
return range_meta["start"], range_meta["stop"]
def get_bounds_arr(
self,
index_bounds: bool = False,
right_inclusive: bool = False,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
template_context: tp.KwargsLike = None,
**range_bounds_kwargs,
) -> tp.BoundsArray:
"""Three-dimensional integer array with bounds.
First axis represents splits. Second axis represents sets. Third axis represents bounds.
Each range is getting selected using `Splitter.select_range` and then measured using
`Splitter.get_range_bounds`. Keyword arguments `**kwargs` are passed to
`Splitter.get_range_bounds`."""
if index_bounds:
dtype = self.index.dtype
else:
dtype = int_
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
n_splits = self.get_n_splits(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
n_sets = self.get_n_sets(set_group_by=set_group_by)
try:
bounds = np.empty((n_splits, n_sets, 2), dtype=dtype)
except TypeError as e:
bounds = np.empty((n_splits, n_sets, 2), dtype=object)
for i in range(n_splits):
for j in range(n_sets):
range_ = self.select_range(
split=PosSel(i),
set_=PosSel(j),
split_group_by=split_group_by,
set_group_by=set_group_by,
merge_split_kwargs=dict(template_context=template_context),
)
bounds[i, j, :] = self.get_range_bounds(
range_,
index_bounds=index_bounds,
right_inclusive=right_inclusive,
template_context=template_context,
**range_bounds_kwargs,
)
return bounds
@property
def bounds_arr(self) -> tp.BoundsArray:
"""`Splitter.get_bounds_arr` with default arguments."""
return self.get_bounds_arr()
def get_bounds(
self,
index_bounds: bool = False,
right_inclusive: bool = False,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
squeeze_one_split: bool = True,
squeeze_one_set: bool = True,
index_combine_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Boolean Series/DataFrame where index are bounds and columns are splits stacked together.
Keyword arguments `**kwargs` are passed to `Splitter.get_bounds_arr`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
bounds_arr = self.get_bounds_arr(
index_bounds=index_bounds,
right_inclusive=right_inclusive,
split_group_by=split_group_by,
set_group_by=set_group_by,
**kwargs,
)
out = bounds_arr.reshape((-1, 2))
one_split = len(split_labels) == 1 and squeeze_one_split
one_set = len(set_labels) == 1 and squeeze_one_set
new_columns = pd.Index(["start", "end"], name="bound")
if one_split and one_set:
return pd.Series(out[0], index=new_columns)
if one_split:
return pd.DataFrame(out, index=set_labels, columns=new_columns)
if one_set:
return pd.DataFrame(out, index=split_labels, columns=new_columns)
if index_combine_kwargs is None:
index_combine_kwargs = {}
new_index = combine_indexes((split_labels, set_labels), **index_combine_kwargs)
return pd.DataFrame(out, index=new_index, columns=new_columns)
@property
def bounds(self) -> tp.Frame:
"""`Splitter.get_bounds` with default arguments."""
return self.get_bounds()
@property
def index_bounds(self) -> tp.Frame:
"""`Splitter.get_bounds` with `index_bounds=True`."""
return self.get_bounds(index_bounds=True)
def get_duration(self, **kwargs) -> tp.Series:
"""Get duration."""
bounds = self.get_bounds(right_inclusive=False, **kwargs)
return (bounds["end"] - bounds["start"]).rename("duration")
@property
def duration(self) -> tp.Series:
"""`Splitter.get_duration` with default arguments."""
return self.get_duration()
@property
def index_duration(self) -> tp.Series:
"""`Splitter.get_duration` with `index_bounds=True`."""
return self.get_duration(index_bounds=True)
# ############# Masks ############# #
@hybrid_method
def get_range_mask(
cls_or_self,
range_: tp.FixRangeLike,
template_context: tp.KwargsLike = None,
index: tp.Optional[tp.IndexLike] = None,
) -> tp.Array1d:
"""Get the mask of a range."""
if index is None:
if isinstance(cls_or_self, type):
raise ValueError("Must provide index")
index = cls_or_self.index
else:
index = dt.prepare_dt_index(index)
range_ = cls_or_self.get_ready_range(
range_,
allow_zero_len=True,
template_context=template_context,
index=index,
)
if isinstance(range_, np.ndarray) and range_.dtype == np.bool_:
return range_
mask = np.full(len(index), False)
mask[range_] = True
return mask
def get_iter_split_mask_arrs(
self,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Iterator[tp.Array2d]:
"""Generator of two-dimensional boolean arrays, one per split.
First axis represents sets. Second axis represents index.
Keyword arguments `**kwargs` are passed to `Splitter.get_range_mask`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
n_splits = self.get_n_splits(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
n_sets = self.get_n_sets(set_group_by=set_group_by)
for i in range(n_splits):
out = np.full((n_sets, len(self.index)), False)
for j in range(n_sets):
range_ = self.select_range(
split=PosSel(i),
set_=PosSel(j),
split_group_by=split_group_by,
set_group_by=set_group_by,
merge_split_kwargs=dict(template_context=template_context),
)
out[j, :] = self.get_range_mask(range_, template_context=template_context, **kwargs)
yield out
@property
def iter_split_mask_arrs(self) -> tp.Iterator[tp.Array2d]:
"""`Splitter.get_iter_split_mask_arrs` with default arguments."""
return self.get_iter_split_mask_arrs()
def get_iter_set_mask_arrs(
self,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Iterator[tp.Array2d]:
"""Generator of two-dimensional boolean arrays, one per set.
First axis represents splits. Second axis represents index.
Keyword arguments `**kwargs` are passed to `Splitter.get_range_mask`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
n_splits = self.get_n_splits(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
n_sets = self.get_n_sets(set_group_by=set_group_by)
for j in range(n_sets):
out = np.full((n_splits, len(self.index)), False)
for i in range(n_splits):
range_ = self.select_range(
split=PosSel(i),
set_=PosSel(j),
split_group_by=split_group_by,
set_group_by=set_group_by,
merge_split_kwargs=dict(template_context=template_context),
)
out[i, :] = self.get_range_mask(range_, template_context=template_context, **kwargs)
yield out
@property
def iter_set_mask_arrs(self) -> tp.Iterator[tp.Array2d]:
"""`Splitter.get_iter_set_mask_arrs` with default arguments."""
return self.get_iter_set_mask_arrs()
def get_iter_split_masks(
self,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
**kwargs,
) -> tp.Iterator[tp.Frame]:
"""Generator of boolean DataFrames, one per split.
Keyword arguments `**kwargs` are passed to `Splitter.get_iter_split_mask_arrs`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
for mask in self.get_iter_split_mask_arrs(
split_group_by=split_group_by,
set_group_by=set_group_by,
**kwargs,
):
yield pd.DataFrame(np.moveaxis(mask, -1, 0), index=self.index, columns=set_labels)
@property
def iter_split_masks(self) -> tp.Iterator[tp.Frame]:
"""`Splitter.get_iter_split_masks` with default arguments."""
return self.get_iter_split_masks()
def get_iter_set_masks(
self,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
**kwargs,
) -> tp.Iterator[tp.Frame]:
"""Generator of boolean DataFrames, one per set.
Keyword arguments `**kwargs` are passed to `Splitter.get_iter_set_mask_arrs`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
for mask in self.get_iter_set_mask_arrs(
split_group_by=split_group_by,
set_group_by=set_group_by,
**kwargs,
):
yield pd.DataFrame(np.moveaxis(mask, -1, 0), index=self.index, columns=split_labels)
@property
def iter_set_masks(self) -> tp.Iterator[tp.Frame]:
"""`Splitter.get_iter_set_masks` with default arguments."""
return self.get_iter_set_masks()
def get_mask_arr(
self,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.SplitsMask:
"""Three-dimensional boolean array with splits.
First axis represents splits. Second axis represents sets. Third axis represents index.
Keyword arguments `**kwargs` are passed to `Splitter.get_iter_split_mask_arrs`."""
return np.array(
list(
self.get_iter_split_mask_arrs(
split_group_by=split_group_by,
set_group_by=set_group_by,
template_context=template_context,
**kwargs,
)
)
)
@property
def mask_arr(self) -> tp.SplitsMask:
"""`Splitter.get_mask_arr` with default arguments."""
return self.get_mask_arr()
def get_mask(
self,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
squeeze_one_split: bool = True,
squeeze_one_set: bool = True,
index_combine_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Boolean Series/DataFrame where index is `Splitter.index` and columns are splits stacked together.
Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`.
!!! warning
Boolean arrays for a big number of splits may take a considerable amount of memory."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs)
out = np.moveaxis(mask_arr, -1, 0).reshape((len(self.index), -1))
one_split = len(split_labels) == 1 and squeeze_one_split
one_set = len(set_labels) == 1 and squeeze_one_set
if one_split and one_set:
return pd.Series(out[:, 0], index=self.index)
if one_split:
return pd.DataFrame(out, index=self.index, columns=set_labels)
if one_set:
return pd.DataFrame(out, index=self.index, columns=split_labels)
if index_combine_kwargs is None:
index_combine_kwargs = {}
new_columns = combine_indexes((split_labels, set_labels), **index_combine_kwargs)
return pd.DataFrame(out, index=self.index, columns=new_columns)
@property
def mask(self) -> tp.Frame:
"""`Splitter.get_mask` with default arguments."""
return self.get_mask()
def get_split_coverage(
self,
overlapping: bool = False,
normalize: bool = True,
relative: bool = False,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
squeeze_one_split: bool = True,
**kwargs,
) -> tp.MaybeSeries:
"""Get the coverage of each split mask.
If `overlapping` is True, returns the number of overlapping True values between sets in each split.
If `normalize` is True, returns the number of True values in each split relative to the
length of the index. If `normalize` and `relative` are True, returns the number of True values
in each split relative to the total number of True values across all splits.
Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs)
if overlapping:
if normalize:
coverage = (mask_arr.sum(axis=1) > 1).sum(axis=1) / mask_arr.any(axis=1).sum(axis=1)
else:
coverage = (mask_arr.sum(axis=1) > 1).sum(axis=1)
else:
if normalize:
if relative:
coverage = mask_arr.any(axis=1).sum(axis=1) / mask_arr.any(axis=(0, 1)).sum()
else:
coverage = mask_arr.any(axis=1).mean(axis=1)
else:
coverage = mask_arr.any(axis=1).sum(axis=1)
one_split = len(split_labels) == 1 and squeeze_one_split
if one_split:
return coverage[0]
return pd.Series(coverage, index=split_labels, name="split_coverage")
@property
def split_coverage(self) -> tp.Series:
"""`Splitter.get_split_coverage` with default arguments."""
return self.get_split_coverage()
def get_set_coverage(
self,
overlapping: bool = False,
normalize: bool = True,
relative: bool = False,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
squeeze_one_set: bool = True,
**kwargs,
) -> tp.MaybeSeries:
"""Get the coverage of each set mask.
If `overlapping` is True, returns the number of overlapping True values between splits in each set.
If `normalize` is True, returns the number of True values in each set relative to the
length of the index. If `normalize` and `relative` are True, returns the number of True values
in each set relative to the total number of True values across all sets.
Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs)
if overlapping:
if normalize:
coverage = (mask_arr.sum(axis=0) > 1).sum(axis=1) / mask_arr.any(axis=0).sum(axis=1)
else:
coverage = (mask_arr.sum(axis=0) > 1).sum(axis=1)
else:
if normalize:
if relative:
coverage = mask_arr.any(axis=0).sum(axis=1) / mask_arr.any(axis=(0, 1)).sum()
else:
coverage = mask_arr.any(axis=0).mean(axis=1)
else:
coverage = mask_arr.any(axis=0).sum(axis=1)
one_set = len(set_labels) == 1 and squeeze_one_set
if one_set:
return coverage[0]
return pd.Series(coverage, index=set_labels, name="set_coverage")
@property
def set_coverage(self) -> tp.Series:
"""`Splitter.get_set_coverage` with default arguments."""
return self.get_set_coverage()
def get_range_coverage(
self,
normalize: bool = True,
relative: bool = False,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
squeeze_one_split: bool = True,
squeeze_one_set: bool = True,
index_combine_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get the coverage of each range mask.
If `normalize` is True, returns the number of True values in each range relative to the
length of the index. If `normalize` and `relative` are True, returns the number of True values
in each range relative to the total number of True values in its split.
Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs)
if normalize:
if relative:
coverage = (mask_arr.sum(axis=2) / mask_arr.any(axis=1).sum(axis=1)[:, None]).flatten()
else:
coverage = (mask_arr.sum(axis=2) / mask_arr.shape[2]).flatten()
else:
coverage = mask_arr.sum(axis=2).flatten()
one_split = len(split_labels) == 1 and squeeze_one_split
one_set = len(set_labels) == 1 and squeeze_one_set
if one_split and one_set:
return coverage[0]
if one_split:
return pd.Series(coverage, index=set_labels, name="range_coverage")
if one_set:
return pd.Series(coverage, index=split_labels, name="range_coverage")
if index_combine_kwargs is None:
index_combine_kwargs = {}
index = combine_indexes((split_labels, set_labels), **index_combine_kwargs)
return pd.Series(coverage, index=index, name="range_coverage")
@property
def range_coverage(self) -> tp.Series:
"""`Splitter.get_range_coverage` with default arguments."""
return self.get_range_coverage()
def get_coverage(
self,
overlapping: bool = False,
normalize: bool = True,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
**kwargs,
) -> float:
"""Get the coverage of the entire mask.
If `overlapping` is True, returns the number of overlapping True values.
If `normalize` is True, returns the number of True values relative to the length of the index.
If `overlapping` and `normalize` are True, returns the number of overlapping True values relative
to the total number of True values.
Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`."""
mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs)
if overlapping:
if normalize:
return (mask_arr.sum(axis=(0, 1)) > 1).sum() / mask_arr.any(axis=(0, 1)).sum()
return (mask_arr.sum(axis=(0, 1)) > 1).sum()
if normalize:
return mask_arr.any(axis=(0, 1)).mean()
return mask_arr.any(axis=(0, 1)).sum()
@property
def coverage(self) -> float:
"""`Splitter.get_coverage` with default arguments."""
return self.get_coverage()
def get_overlap_matrix(
self,
by: str = "split",
normalize: bool = True,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
jitted: tp.JittedOption = None,
squeeze_one_split: bool = True,
squeeze_one_set: bool = True,
index_combine_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Frame:
"""Get the overlap between each pair of ranges.
The argument `by` can be one of 'split', 'set', and 'range'.
If `normalize` is True, returns the number of True values in each overlap relative
to the total number of True values in both ranges.
Keyword arguments `**kwargs` are passed to `Splitter.get_mask_arr`."""
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
split_labels = self.get_split_labels(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
mask_arr = self.get_mask_arr(split_group_by=split_group_by, set_group_by=set_group_by, **kwargs)
one_split = len(split_labels) == 1 and squeeze_one_split
one_set = len(set_labels) == 1 and squeeze_one_set
if by.lower() == "split":
if normalize:
func = jit_reg.resolve_option(nb.norm_split_overlap_matrix_nb, jitted)
else:
func = jit_reg.resolve_option(nb.split_overlap_matrix_nb, jitted)
overlap_matrix = func(mask_arr)
if one_split:
return overlap_matrix[0, 0]
index = split_labels
elif by.lower() == "set":
if normalize:
func = jit_reg.resolve_option(nb.norm_set_overlap_matrix_nb, jitted)
else:
func = jit_reg.resolve_option(nb.set_overlap_matrix_nb, jitted)
overlap_matrix = func(mask_arr)
if one_set:
return overlap_matrix[0, 0]
index = set_labels
elif by.lower() == "range":
if normalize:
func = jit_reg.resolve_option(nb.norm_range_overlap_matrix_nb, jitted)
else:
func = jit_reg.resolve_option(nb.range_overlap_matrix_nb, jitted)
overlap_matrix = func(mask_arr)
if one_split and one_set:
return overlap_matrix[0, 0]
if one_split:
index = set_labels
elif one_set:
index = split_labels
else:
if index_combine_kwargs is None:
index_combine_kwargs = {}
index = combine_indexes((split_labels, set_labels), **index_combine_kwargs)
else:
raise ValueError(f"Invalid by: '{by}'")
return pd.DataFrame(overlap_matrix, index=index, columns=index)
@property
def split_overlap_matrix(self) -> tp.Frame:
"""`Splitter.get_overlap_matrix` with `by="split"`."""
return self.get_overlap_matrix(by="split")
@property
def set_overlap_matrix(self) -> tp.Frame:
"""`Splitter.get_overlap_matrix` with `by="set"`."""
return self.get_overlap_matrix(by="set")
@property
def range_overlap_matrix(self) -> tp.Frame:
"""`Splitter.get_overlap_matrix` with `by="range"`."""
return self.get_overlap_matrix(by="range")
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Splitter.stats`.
Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
`stats` from `vectorbtpro._settings.splitter`."""
from vectorbtpro._settings import settings
splitter_stats_cfg = settings["splitter"]["stats"]
return merge_dicts(Analyzable.stats_defaults.__get__(self), splitter_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start=dict(
title="Index Start",
calc_func=lambda self: self.index[0],
agg_func=None,
tags=["splitter", "index"],
),
end=dict(
title="Index End",
calc_func=lambda self: self.index[-1],
agg_func=None,
tags=["splitter", "index"],
),
period=dict(
title="Index Length",
calc_func=lambda self: len(self.index),
agg_func=None,
tags=["splitter", "index"],
),
split_count=dict(
title="Splits",
calc_func="n_splits",
agg_func=None,
tags=["splitter", "splits"],
),
set_count=dict(
title="Sets",
calc_func="n_sets",
agg_func=None,
tags=["splitter", "splits"],
),
coverage=dict(
title=RepFunc(lambda normalize: "Coverage [%]" if normalize else "Coverage"),
calc_func="coverage",
overlapping=False,
post_calc_func=lambda self, out, settings: out * 100 if settings["normalize"] else out,
agg_func=None,
tags=["splitter", "splits", "coverage"],
),
set_coverage=dict(
title=RepFunc(lambda normalize: "Coverage [%]" if normalize else "Coverage"),
check_has_multiple_sets=True,
calc_func="set_coverage",
overlapping=False,
relative=False,
post_calc_func=lambda self, out, settings: to_dict(
out * 100 if settings["normalize"] else out, orient="index_series"
),
agg_func=None,
tags=["splitter", "splits", "coverage"],
),
set_mean_rel_coverage=dict(
title="Mean Rel Coverage [%]",
check_has_multiple_sets=True,
check_normalize=True,
calc_func="range_coverage",
relative=True,
post_calc_func=lambda self, out, settings: to_dict(
out.groupby(self.get_set_labels(set_group_by=settings.get("set_group_by", None)).names).mean()[
self.get_set_labels(set_group_by=settings.get("set_group_by", None))
]
* 100,
orient="index_series",
),
agg_func=None,
tags=["splitter", "splits", "coverage"],
),
overlap_coverage=dict(
title=RepFunc(lambda normalize: "Overlap Coverage [%]" if normalize else "Overlap Coverage"),
calc_func="coverage",
overlapping=True,
post_calc_func=lambda self, out, settings: out * 100 if settings["normalize"] else out,
agg_func=None,
tags=["splitter", "splits", "coverage"],
),
set_overlap_coverage=dict(
title=RepFunc(lambda normalize: "Overlap Coverage [%]" if normalize else "Overlap Coverage"),
check_has_multiple_sets=True,
calc_func="set_coverage",
overlapping=True,
post_calc_func=lambda self, out, settings: to_dict(
out * 100 if settings["normalize"] else out, orient="index_series"
),
agg_func=None,
tags=["splitter", "splits", "coverage"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot(
self,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
mask_kwargs: tp.KwargsLike = None,
trace_kwargs: tp.KwargsLikeSequence = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot splits as rows and sets as colors.
Args:
split_group_by (any): Split groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.
set_group_by (any): Set groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.
mask_kwargs (dict): Keyword arguments passed to `Splitter.get_iter_set_masks`.
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`.
Can be a sequence, one per set.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
* Plot a scikit-learn splitter:
```pycon
>>> from vectorbtpro import *
>>> from sklearn.model_selection import TimeSeriesSplit
>>> index = pd.date_range("2020", "2021", freq="D")
>>> splitter = vbt.Splitter.from_sklearn(index, TimeSeriesSplit())
>>> splitter.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from vectorbtpro.utils.figure import make_figure
import plotly.express as px
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
if fig.layout.colorway is not None:
colorway = fig.layout.colorway
else:
colorway = fig.layout.template.layout.colorway
if len(set_labels) > len(colorway):
colorway = px.colors.qualitative.Alphabet
if self.get_n_splits(split_group_by=split_group_by) > 0:
if self.get_n_sets(set_group_by=set_group_by) > 0:
if mask_kwargs is None:
mask_kwargs = {}
for i, mask in enumerate(
self.get_iter_set_masks(
split_group_by=split_group_by,
set_group_by=set_group_by,
**mask_kwargs,
)
):
df = mask.vbt.wrapper.fill()
df[mask] = i
color = adjust_opacity(colorway[i % len(colorway)], 0.8)
trace_name = str(set_labels[i])
_trace_kwargs = merge_dicts(
dict(
showscale=False,
showlegend=True,
legendgroup=str(set_labels[i]),
name=trace_name,
colorscale=[color, color],
hovertemplate="%{x} Split: %{y} Set: " + trace_name,
),
resolve_dict(trace_kwargs, i=i),
)
fig = df.vbt.ts_heatmap(
trace_kwargs=_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
is_y_category=True,
fig=fig,
)
return fig
def plot_coverage(
self,
stacked: bool = True,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
mask_kwargs: tp.KwargsLike = None,
trace_kwargs: tp.KwargsLikeSequence = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot index as rows and sets as lines.
Args:
stacked (bool): Whether to plot as an area plot.
split_group_by (any): Split groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.
set_group_by (any): Set groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.
mask_kwargs (dict): Keyword arguments passed to `Splitter.get_iter_set_masks`.
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`.
Can be a sequence, one per set.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
* Area plot:
```pycon
>>> from vectorbtpro import *
>>> from sklearn.model_selection import TimeSeriesSplit
>>> index = pd.date_range("2020", "2021", freq="D")
>>> splitter = vbt.Splitter.from_sklearn(index, TimeSeriesSplit())
>>> splitter.plot_coverage().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
* Line plot:
```pycon
>>> splitter.plot_coverage(stacked=False).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from vectorbtpro.utils.figure import make_figure
import plotly.express as px
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
split_group_by = self.get_split_grouper(split_group_by=split_group_by)
set_group_by = self.get_set_grouper(set_group_by=set_group_by)
set_labels = self.get_set_labels(set_group_by=set_group_by)
if fig.layout.colorway is not None:
colorway = fig.layout.colorway
else:
colorway = fig.layout.template.layout.colorway
if len(set_labels) > len(colorway):
colorway = px.colors.qualitative.Alphabet
if self.get_n_splits(split_group_by=split_group_by) > 0:
if self.get_n_sets(set_group_by=set_group_by) > 0:
if mask_kwargs is None:
mask_kwargs = {}
for i, mask in enumerate(
self.get_iter_set_masks(
split_group_by=split_group_by,
set_group_by=set_group_by,
**mask_kwargs,
)
):
_trace_kwargs = merge_dicts(
dict(
stackgroup="coverage" if stacked else None,
legendgroup=str(set_labels[i]),
name=str(set_labels[i]),
line=dict(color=colorway[i % len(colorway)], shape="hv"),
),
resolve_dict(trace_kwargs, i=i),
)
fig = mask.sum(axis=1).vbt.lineplot(
trace_kwargs=_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Splitter.plots`.
Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
`plots` from `vectorbtpro._settings.splitter`."""
from vectorbtpro._settings import settings
splitter_plots_cfg = settings["splitter"]["plots"]
return merge_dicts(Analyzable.plots_defaults.__get__(self), splitter_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot=dict(
title="Splits",
yaxis_kwargs=dict(title="Split"),
plot_func="plot",
tags="splitter",
),
plot_coverage=dict(
title="Coverage",
yaxis_kwargs=dict(title="Count"),
plot_func="plot_coverage",
tags="splitter",
),
)
)
@property
def subplots(self) -> Config:
return self._subplots
Splitter.override_metrics_doc(__pdoc__)
Splitter.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Decorators for splitting."""
import inspect
from functools import wraps
from vectorbtpro import _typing as tp
from vectorbtpro.generic.splitting.base import Splitter, Takeable
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import FrozenConfig, merge_dicts
from vectorbtpro.utils.execution import NoResult, NoResultsException
from vectorbtpro.utils.params import parameterized
from vectorbtpro.utils.parsing import (
annotate_args,
flatten_ann_args,
unflatten_ann_args,
ann_args_to_args,
match_ann_arg,
get_func_arg_names,
)
from vectorbtpro.utils.template import Rep, RepEval, substitute_templates
__all__ = [
"split",
"cv_split",
]
def split(
*args,
splitter: tp.Union[None, str, Splitter, tp.Callable] = None,
splitter_cls: tp.Optional[tp.Type[Splitter]] = None,
splitter_kwargs: tp.KwargsLike = None,
index: tp.Optional[tp.IndexLike] = None,
index_from: tp.Optional[tp.AnnArgQuery] = None,
takeable_args: tp.Optional[tp.MaybeIterable[tp.AnnArgQuery]] = None,
template_context: tp.KwargsLike = None,
forward_kwargs_as: tp.KwargsLike = None,
return_splitter: bool = False,
apply_kwargs: tp.KwargsLike = None,
**var_kwargs,
) -> tp.Callable:
"""Decorator that splits the inputs of a function.
Does the following:
1. Resolves the splitter of the type `vectorbtpro.generic.splitting.base.Splitter` using
the argument `splitter`. It can be either an already provided splitter instance, the
name of a factory method (such as "from_n_rolling"), or the factory method itself.
If `splitter` is None, the right method will be guessed based on the supplied arguments
using `vectorbtpro.generic.splitting.base.Splitter.guess_method`. To construct a splitter,
it will pass `index` and `**splitter_kwargs`. Index is getting resolved either using an already
provided `index`, by parsing the argument under a name/position provided in `index_from`,
or by parsing the first argument from `takeable_args` (in this order).
2. Wraps arguments in `takeable_args` with `vectorbtpro.generic.splitting.base.Takeable`
3. Runs `vectorbtpro.generic.splitting.base.Splitter.apply` with arguments passed
to the function as `args` and `kwargs`, but also `apply_kwargs` (the ones passed to
the decorator)
Keyword arguments `splitter_kwargs` are passed to the factory method. Keyword arguments
`apply_kwargs` are passed to `vectorbtpro.generic.splitting.base.Splitter.apply`. If variable
keyword arguments are provided, they will be used as `splitter_kwargs` if `apply_kwargs` is already set,
and vice versa. If `splitter_kwargs` and `apply_kwargs` aren't set, they will be used as `splitter_kwargs`
if a splitter instance hasn't been built yet, otherwise as `apply_kwargs`. If both arguments are set,
will raise an error.
Usage:
* Split a Series and return its sum:
```pycon
>>> from vectorbtpro import *
>>> @vbt.split(
... splitter="from_n_rolling",
... splitter_kwargs=dict(n=2),
... takeable_args=["sr"]
... )
... def f(sr):
... return sr.sum()
>>> index = pd.date_range("2020-01-01", "2020-01-06")
>>> sr = pd.Series(np.arange(len(index)), index=index)
>>> f(sr)
split
0 3
1 12
dtype: int64
```
* Perform a split manually:
```pycon
>>> @vbt.split(
... splitter="from_n_rolling",
... splitter_kwargs=dict(n=2),
... takeable_args=["index"]
... )
... def f(index, sr):
... return sr[index].sum()
>>> f(index, sr)
split
0 3
1 12
dtype: int64
```
* Construct splitter and mark arguments as "takeable" manually:
```pycon
>>> splitter = vbt.Splitter.from_n_rolling(index, n=2)
>>> @vbt.split(splitter=splitter)
... def f(sr):
... return sr.sum()
>>> f(vbt.Takeable(sr))
split
0 3
1 12
dtype: int64
```
* Split multiple timeframes using a custom index:
```pycon
>>> @vbt.split(
... splitter="from_n_rolling",
... splitter_kwargs=dict(n=2),
... index=index,
... takeable_args=["h12_sr", "d2_sr"]
... )
... def f(h12_sr, d2_sr):
... return h12_sr.sum() + d2_sr.sum()
>>> h12_index = pd.date_range("2020-01-01", "2020-01-06", freq="12H")
>>> d2_index = pd.date_range("2020-01-01", "2020-01-06", freq="2D")
>>> h12_sr = pd.Series(np.arange(len(h12_index)), index=h12_index)
>>> d2_sr = pd.Series(np.arange(len(d2_index)), index=d2_index)
>>> f(h12_sr, d2_sr)
split
0 15
1 42
dtype: int64
```
"""
def decorator(func: tp.Callable) -> tp.Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> tp.Any:
splitter = kwargs.pop("_splitter", wrapper.options["splitter"])
splitter_cls = kwargs.pop("_splitter_cls", wrapper.options["splitter_cls"])
splitter_kwargs = merge_dicts(wrapper.options["splitter_kwargs"], kwargs.pop("_splitter_kwargs", {}))
index = kwargs.pop("_index", wrapper.options["index"])
index_from = kwargs.pop("_index_from", wrapper.options["index_from"])
takeable_args = kwargs.pop("_takeable_args", wrapper.options["takeable_args"])
if takeable_args is None:
takeable_args = set()
elif checks.is_iterable(takeable_args) and not isinstance(takeable_args, str):
takeable_args = set(takeable_args)
else:
takeable_args = {takeable_args}
template_context = merge_dicts(wrapper.options["template_context"], kwargs.pop("_template_context", {}))
apply_kwargs = merge_dicts(wrapper.options["apply_kwargs"], kwargs.pop("_apply_kwargs", {}))
return_splitter = kwargs.pop("_return_splitter", wrapper.options["return_splitter"])
forward_kwargs_as = merge_dicts(wrapper.options["forward_kwargs_as"], kwargs.pop("_forward_kwargs_as", {}))
if len(forward_kwargs_as) > 0:
new_kwargs = dict()
for k, v in kwargs.items():
if k in forward_kwargs_as:
new_kwargs[forward_kwargs_as.pop(k)] = v
else:
new_kwargs[k] = v
kwargs = new_kwargs
if len(forward_kwargs_as) > 0:
for k, v in forward_kwargs_as.items():
kwargs[v] = locals()[k]
if splitter_cls is None:
splitter_cls = Splitter
if len(var_kwargs) > 0:
var_splitter_kwargs = {}
var_apply_kwargs = {}
if splitter is None or not isinstance(splitter, splitter_cls):
apply_arg_names = get_func_arg_names(splitter_cls.apply)
if splitter is not None:
if isinstance(splitter, str):
splitter_arg_names = get_func_arg_names(getattr(splitter_cls, splitter))
else:
splitter_arg_names = get_func_arg_names(splitter)
for k, v in var_kwargs.items():
assigned = False
if k in splitter_arg_names:
var_splitter_kwargs[k] = v
assigned = True
if k != "split" and k in apply_arg_names:
var_apply_kwargs[k] = v
assigned = True
if not assigned:
raise ValueError(f"Argument '{k}' couldn't be assigned")
else:
for k, v in var_kwargs.items():
if k == "freq":
var_splitter_kwargs[k] = v
var_apply_kwargs[k] = v
elif k == "split" or k not in apply_arg_names:
var_splitter_kwargs[k] = v
else:
var_apply_kwargs[k] = v
else:
var_apply_kwargs = var_kwargs
splitter_kwargs = merge_dicts(var_splitter_kwargs, splitter_kwargs)
apply_kwargs = merge_dicts(var_apply_kwargs, apply_kwargs)
if len(splitter_kwargs) > 0:
if splitter is None:
splitter = splitter_cls.guess_method(**splitter_kwargs)
if splitter is None:
raise ValueError("Splitter method couldn't be guessed")
else:
if splitter is None:
raise ValueError("Must provide splitter or splitter method")
if not isinstance(splitter, splitter_cls) and index is not None:
if isinstance(splitter, str):
splitter = getattr(splitter_cls, splitter)
splitter = splitter(index, template_context=template_context, **splitter_kwargs)
if return_splitter:
return splitter
ann_args = annotate_args(func, args, kwargs, attach_annotations=True)
flat_ann_args = flatten_ann_args(ann_args)
if isinstance(splitter, splitter_cls):
flat_ann_args = splitter.parse_and_inject_takeables(flat_ann_args)
else:
flat_ann_args = splitter_cls.parse_and_inject_takeables(flat_ann_args)
for k, v in flat_ann_args.items():
if isinstance(v["value"], Takeable):
takeable_args.add(k)
for takeable_arg in takeable_args:
arg_name = match_ann_arg(ann_args, takeable_arg, return_name=True)
if not isinstance(flat_ann_args[arg_name]["value"], Takeable):
flat_ann_args[arg_name]["value"] = Takeable(flat_ann_args[arg_name]["value"])
new_ann_args = unflatten_ann_args(flat_ann_args)
args, kwargs = ann_args_to_args(new_ann_args)
if not isinstance(splitter, splitter_cls):
if index is None and index_from is not None:
index = splitter_cls.get_obj_index(match_ann_arg(ann_args, index_from))
if index is None and len(takeable_args) > 0:
first_takeable = match_ann_arg(ann_args, list(takeable_args)[0])
if isinstance(first_takeable, Takeable):
first_takeable = first_takeable.obj
index = splitter_cls.get_obj_index(first_takeable)
if index is None:
raise ValueError("Must provide splitter, index, index_from, or takeable_args")
if isinstance(splitter, str):
splitter = getattr(splitter_cls, splitter)
splitter = splitter(index, template_context=template_context, **splitter_kwargs)
if return_splitter:
return splitter
return splitter.apply(
func,
*args,
**kwargs,
**apply_kwargs,
)
wrapper.func = func
wrapper.name = func.__name__
wrapper.is_split = True
wrapper.options = FrozenConfig(
splitter=splitter,
splitter_cls=splitter_cls,
splitter_kwargs=splitter_kwargs,
index=index,
index_from=index_from,
takeable_args=takeable_args,
template_context=template_context,
forward_kwargs_as=forward_kwargs_as,
return_splitter=return_splitter,
apply_kwargs=apply_kwargs,
var_kwargs=var_kwargs,
)
signature = inspect.signature(wrapper)
lists_var_kwargs = False
for k, v in signature.parameters.items():
if v.kind == v.VAR_KEYWORD:
lists_var_kwargs = True
break
if not lists_var_kwargs:
var_kwargs_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD)
new_parameters = tuple(signature.parameters.values()) + (var_kwargs_param,)
wrapper.__signature__ = signature.replace(parameters=new_parameters)
return wrapper
if len(args) == 0:
return decorator
elif len(args) == 1:
return decorator(args[0])
raise ValueError("Either function or keyword arguments must be passed")
def cv_split(
*args,
parameterized_kwargs: tp.KwargsLike = None,
selection: tp.Union[str, tp.Selection] = "max",
return_grid: tp.Union[bool, str] = False,
skip_errored: bool = False,
raise_no_results: bool = True,
template_context: tp.KwargsLike = None,
**split_kwargs,
) -> tp.Callable:
"""Decorator that combines `split` and `vectorbtpro.utils.params.parameterized` for cross-validation.
Creates a new apply function that is going to be decorated with `split` and thus applied
at each single range using `vectorbtpro.generic.splitting.base.Splitter.apply`. Inside
this apply function, there is a test whether the current range belongs to the first (training) set.
If yes, parameterizes the underlying function and runs it on the entire grid of parameters.
The returned results are then stored in a global list. These results are then read by the other
(testing) sets in the same split. If `selection` is a template, it can evaluate the grid results
(available as `grid_results`) and return the best parameter combination. This parameter combination
is then executed by each set (including training).
Argument `selection` also accepts "min" for `np.argmin` and "max" for `np.argmax`.
Keyword arguments `parameterized_kwargs` will be passed to `vectorbtpro.utils.params.parameterized`
and will have their templates substituted with a context that will also include the split-related context
(including `split_idx`, `set_idx`, etc., see `vectorbtpro.generic.splitting.base.Splitter.apply`).
If `return_grid` is True or 'first', returns both the grid and the selection. If `return_grid`
is 'all', executes the grid on each set and returns along with the selection.
Otherwise, returns only the selection.
If `vectorbtpro.utils.execution.NoResultsException` is raised or `skip_errored` is True and
any exception is raised, will skip the current iteration and remove it from the final index.
Usage:
* Permutate a series and pick the first value. Make the seed parameterizable.
Cross-validate based on the highest picked value:
```pycon
>>> from vectorbtpro import *
>>> @vbt.cv_split(
... splitter="from_n_rolling",
... splitter_kwargs=dict(n=3, split=0.5),
... takeable_args=["sr"],
... merge_func="concat",
... )
... def f(sr, seed):
... np.random.seed(seed)
... return np.random.permutation(sr)[0]
>>> index = pd.date_range("2020-01-01", "2020-02-01")
>>> np.random.seed(0)
>>> sr = pd.Series(np.random.permutation(np.arange(len(index))), index=index)
>>> f(sr, vbt.Param([41, 42, 43]))
split set seed
0 set_0 41 22
set_1 41 28
1 set_0 43 8
set_1 43 31
2 set_0 43 19
set_1 43 0
dtype: int64
```
* Extend the example above to also return the grid results of each set:
```pycon
>>> f(sr, vbt.Param([41, 42, 43]), _return_grid="all")
(split set seed
0 set_0 41 22
42 22
43 2
set_1 41 28
42 28
43 20
1 set_0 41 5
42 5
43 8
set_1 41 23
42 23
43 31
2 set_0 41 18
42 18
43 19
set_1 41 27
42 27
43 0
dtype: int64,
split set seed
0 set_0 41 22
set_1 41 28
1 set_0 43 8
set_1 43 31
2 set_0 43 19
set_1 43 0
dtype: int64)
```
"""
def decorator(func: tp.Callable) -> tp.Callable:
if getattr(func, "is_split", False) or getattr(func, "is_parameterized", False):
raise ValueError("Function is already decorated with split or parameterized")
@wraps(func)
def wrapper(*args, **kwargs) -> tp.Any:
parameterized_kwargs = merge_dicts(
wrapper.options["parameterized_kwargs"],
kwargs.pop("_parameterized_kwargs", {}),
)
selection = kwargs.pop("_selection", wrapper.options["selection"])
if isinstance(selection, str) and selection.lower() == "min":
selection = RepEval("[np.nanargmin(grid_results)]")
if isinstance(selection, str) and selection.lower() == "max":
selection = RepEval("[np.nanargmax(grid_results)]")
return_grid = kwargs.pop("_return_grid", wrapper.options["return_grid"])
if isinstance(return_grid, bool):
if return_grid:
return_grid = "first"
else:
return_grid = None
skip_errored = kwargs.pop("_skip_errored", wrapper.options["skip_errored"])
template_context = merge_dicts(
wrapper.options["template_context"],
kwargs.pop("_template_context", {}),
)
split_kwargs = merge_dicts(
wrapper.options["split_kwargs"],
kwargs.pop("_split_kwargs", {}),
)
if "merge_func" in split_kwargs and "merge_func" not in parameterized_kwargs:
parameterized_kwargs["merge_func"] = split_kwargs["merge_func"]
if "show_progress" not in parameterized_kwargs:
parameterized_kwargs["show_progress"] = False
all_grid_results = []
@wraps(func)
def apply_wrapper(*_args, __template_context=None, **_kwargs):
try:
__template_context = dict(__template_context)
__template_context["all_grid_results"] = all_grid_results
_parameterized_kwargs = substitute_templates(
parameterized_kwargs,
__template_context,
eval_id="parameterized_kwargs",
)
parameterized_func = parameterized(
func,
template_context=__template_context,
**_parameterized_kwargs,
)
if __template_context["set_idx"] == 0:
try:
grid_results = parameterized_func(*_args, **_kwargs)
all_grid_results.append(grid_results)
except Exception as e:
if skip_errored or isinstance(e, NoResultsException):
all_grid_results.append(NoResult)
raise e
if all_grid_results[-1] is NoResult:
if raise_no_results:
raise NoResultsException
return NoResult
result = parameterized_func(
*_args,
_selection=selection,
_template_context=dict(grid_results=all_grid_results[-1]),
**_kwargs,
)
if return_grid is not None:
if return_grid.lower() == "first":
return all_grid_results[-1], result
if return_grid.lower() == "all":
grid_results = parameterized_func(
*_args,
_template_context=dict(grid_results=all_grid_results[-1]),
**_kwargs,
)
return grid_results, result
else:
raise ValueError(f"Invalid return_grid: '{return_grid}'")
return result
except Exception as e:
if skip_errored or isinstance(e, NoResultsException):
return NoResult
raise e
signature = inspect.signature(apply_wrapper)
lists_var_kwargs = False
for k, v in signature.parameters.items():
if v.kind == v.VAR_KEYWORD:
lists_var_kwargs = True
break
if not lists_var_kwargs:
var_kwargs_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD)
new_parameters = tuple(signature.parameters.values()) + (var_kwargs_param,)
apply_wrapper.__signature__ = signature.replace(parameters=new_parameters)
split_func = split(apply_wrapper, template_context=template_context, **split_kwargs)
return split_func(*args, __template_context=Rep("context", eval_id="apply_kwargs"), **kwargs)
wrapper.func = func
wrapper.name = func.__name__
wrapper.is_parameterized = True
wrapper.is_split = True
wrapper.options = FrozenConfig(
parameterized_kwargs=parameterized_kwargs,
selection=selection,
return_grid=return_grid,
skip_errored=skip_errored,
template_context=template_context,
split_kwargs=split_kwargs,
)
signature = inspect.signature(wrapper)
lists_var_kwargs = False
for k, v in signature.parameters.items():
if v.kind == v.VAR_KEYWORD:
lists_var_kwargs = True
break
if not lists_var_kwargs:
var_kwargs_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD)
new_parameters = tuple(signature.parameters.values()) + (var_kwargs_param,)
wrapper.__signature__ = signature.replace(parameters=new_parameters)
return wrapper
if len(args) == 0:
return decorator
elif len(args) == 1:
return decorator(args[0])
raise ValueError("Either function or keyword arguments must be passed")
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for splitting."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.registries.jit_registry import register_jitted
__all__ = []
@register_jitted(cache=True, tags={"can_parallel"})
def split_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d:
"""Compute the overlap matrix for splits."""
out = np.empty((mask_arr.shape[0], mask_arr.shape[0]), dtype=int_)
temp_mask = np.empty((mask_arr.shape[0], mask_arr.shape[2]), dtype=np.bool_)
for i in range(mask_arr.shape[0]):
for m in range(mask_arr.shape[2]):
if mask_arr[i, :, m].any():
temp_mask[i, m] = True
else:
temp_mask[i, m] = False
for i1 in prange(mask_arr.shape[0]):
for i2 in range(mask_arr.shape[0]):
intersection = (temp_mask[i1] & temp_mask[i2]).sum()
out[i1, i2] = intersection
return out
@register_jitted(cache=True, tags={"can_parallel"})
def norm_split_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d:
"""Compute the normalized overlap matrix for splits."""
out = np.empty((mask_arr.shape[0], mask_arr.shape[0]), dtype=float_)
temp_mask = np.empty((mask_arr.shape[0], mask_arr.shape[2]), dtype=np.bool_)
for i in range(mask_arr.shape[0]):
for m in range(mask_arr.shape[2]):
if mask_arr[i, :, m].any():
temp_mask[i, m] = True
else:
temp_mask[i, m] = False
for i1 in prange(mask_arr.shape[0]):
for i2 in range(mask_arr.shape[0]):
intersection = (temp_mask[i1] & temp_mask[i2]).sum()
union = (temp_mask[i1] | temp_mask[i2]).sum()
out[i1, i2] = intersection / union
return out
@register_jitted(cache=True, tags={"can_parallel"})
def set_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d:
"""Compute the overlap matrix for sets."""
out = np.empty((mask_arr.shape[1], mask_arr.shape[1]), dtype=int_)
temp_mask = np.empty((mask_arr.shape[1], mask_arr.shape[2]), dtype=np.bool_)
for j in range(mask_arr.shape[1]):
for m in range(mask_arr.shape[2]):
if mask_arr[:, j, m].any():
temp_mask[j, m] = True
else:
temp_mask[j, m] = False
for j1 in prange(mask_arr.shape[1]):
for j2 in range(mask_arr.shape[1]):
intersection = (temp_mask[j1] & temp_mask[j2]).sum()
out[j1, j2] = intersection
return out
@register_jitted(cache=True, tags={"can_parallel"})
def norm_set_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d:
"""Compute the normalized overlap matrix for sets."""
out = np.empty((mask_arr.shape[1], mask_arr.shape[1]), dtype=float_)
temp_mask = np.empty((mask_arr.shape[1], mask_arr.shape[2]), dtype=np.bool_)
for j in range(mask_arr.shape[1]):
for m in range(mask_arr.shape[2]):
if mask_arr[:, j, m].any():
temp_mask[j, m] = True
else:
temp_mask[j, m] = False
for j1 in prange(mask_arr.shape[1]):
for j2 in range(mask_arr.shape[1]):
intersection = (temp_mask[j1] & temp_mask[j2]).sum()
union = (temp_mask[j1] | temp_mask[j2]).sum()
out[j1, j2] = intersection / union
return out
@register_jitted(cache=True, tags={"can_parallel"})
def range_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d:
"""Compute the overlap matrix for ranges."""
out = np.empty((mask_arr.shape[0] * mask_arr.shape[1], mask_arr.shape[0] * mask_arr.shape[1]), dtype=int_)
for k in prange(mask_arr.shape[0] * mask_arr.shape[1]):
i1 = k // mask_arr.shape[1]
j1 = k % mask_arr.shape[1]
for l in range(mask_arr.shape[0] * mask_arr.shape[1]):
i2 = l // mask_arr.shape[1]
j2 = l % mask_arr.shape[1]
intersection = (mask_arr[i1, j1] & mask_arr[i2, j2]).sum()
out[k, l] = intersection
return out
@register_jitted(cache=True, tags={"can_parallel"})
def norm_range_overlap_matrix_nb(mask_arr: tp.Array3d) -> tp.Array2d:
"""Compute the normalized overlap matrix for ranges."""
out = np.empty((mask_arr.shape[0] * mask_arr.shape[1], mask_arr.shape[0] * mask_arr.shape[1]), dtype=float_)
for k in prange(mask_arr.shape[0] * mask_arr.shape[1]):
i1 = k // mask_arr.shape[1]
j1 = k % mask_arr.shape[1]
for l in range(mask_arr.shape[0] * mask_arr.shape[1]):
i2 = l // mask_arr.shape[1]
j2 = l % mask_arr.shape[1]
intersection = (mask_arr[i1, j1] & mask_arr[i2, j2]).sum()
union = (mask_arr[i1, j1] | mask_arr[i2, j2]).sum()
out[k, l] = intersection / union
return out
@register_jitted(cache=True)
def split_range_by_gap_nb(range_: tp.Array1d) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Split a range with gaps into start and end indices."""
if len(range_) == 0:
raise ValueError("Range is empty")
start_idxs_out = np.empty(len(range_), dtype=int_)
stop_idxs_out = np.empty(len(range_), dtype=int_)
start_idxs_out[0] = 0
k = 0
for i in range(1, len(range_)):
if range_[i] - range_[i - 1] != 1:
stop_idxs_out[k] = i
k += 1
start_idxs_out[k] = i
stop_idxs_out[k] = len(range_)
return start_idxs_out[: k + 1], stop_idxs_out[: k + 1]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
#
# MIT License
#
# Copyright (c) 2018 Samuel Monnier
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Classes for purged cross-validation in time series.
As described in Advances in Financial Machine Learning, Marcos Lopez de Prado, 2018."""
from abc import abstractmethod
from itertools import combinations
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.base import Base
__all__ = [
"PurgedWalkForwardCV",
"PurgedKFoldCV",
]
class BasePurgedCV(Base):
"""Abstract class for purged time series cross-validation.
Time series cross-validation requires each sample has a prediction time, at which the features are used to
predict the response, and an evaluation time, at which the response is known and the error can be
computed. Importantly, it means that unlike in standard sklearn cross-validation, the samples X, response y,
`pred_times` and `eval_times` must all be pandas DataFrames/Series having the same index. It is also
assumed that the samples are time-ordered with respect to the prediction time."""
def __init__(self, n_folds: int = 10, purge_td: tp.TimedeltaLike = 0) -> None:
self._n_folds = n_folds
self._pred_times = None
self._eval_times = None
self._indices = None
self._purge_td = dt.to_timedelta(purge_td)
@property
def n_folds(self) -> int:
"""Number of folds."""
return self._n_folds
@property
def purge_td(self) -> pd.Timedelta:
"""Purge period."""
return self._purge_td
@property
def pred_times(self) -> tp.Optional[pd.Series]:
"""Times at which predictions are made."""
return self._pred_times
@property
def eval_times(self) -> tp.Optional[pd.Series]:
"""Times at which the response becomes available and the error can be computed."""
return self._eval_times
@property
def indices(self) -> tp.Optional[tp.Array1d]:
"""Indices."""
return self._indices
def purge(
self,
train_indices: tp.Array1d,
test_fold_start: int,
test_fold_end: int,
) -> tp.Array1d:
"""Purge part of the train set.
Given a left boundary index `test_fold_start` of the test set and a right boundary index
`test_fold_end`, this method removes from the train set all the samples whose evaluation time
is posterior to the prediction time of the first test sample after the boundary."""
time_test_fold_start = self.pred_times.iloc[test_fold_start]
eval_times = self.eval_times + self.purge_td
train_indices_1 = np.intersect1d(train_indices, self.indices[eval_times < time_test_fold_start])
train_indices_2 = np.intersect1d(train_indices, self.indices[test_fold_end:])
return np.concatenate((train_indices_1, train_indices_2))
@abstractmethod
def split(
self,
X: tp.SeriesFrame,
y: tp.Optional[tp.Series] = None,
pred_times: tp.Union[None, tp.Index, tp.Series] = None,
eval_times: tp.Union[None, tp.Index, tp.Series] = None,
):
"""Yield the indices of the train and test sets."""
checks.assert_instance_of(X, (pd.Series, pd.DataFrame), arg_name="X")
if y is not None:
checks.assert_instance_of(y, pd.Series, arg_name="y")
if pred_times is None:
pred_times = X.index
if isinstance(pred_times, pd.Index):
pred_times = pd.Series(pred_times, index=X.index)
else:
checks.assert_instance_of(pred_times, pd.Series, arg_name="pred_times")
checks.assert_index_equal(X.index, pred_times.index, check_names=False)
if eval_times is None:
eval_times = X.index
if isinstance(eval_times, pd.Index):
eval_times = pd.Series(eval_times, index=X.index)
else:
checks.assert_instance_of(eval_times, pd.Series, arg_name="eval_times")
checks.assert_index_equal(X.index, eval_times.index, check_names=False)
self._pred_times = pred_times
self._eval_times = eval_times
self._indices = np.arange(X.shape[0])
class PurgedWalkForwardCV(BasePurgedCV):
"""Purged walk-forward cross-validation.
The samples are decomposed into `n_folds` folds containing equal numbers of samples, without
shuffling. In each cross validation round, `n_test_folds` contiguous folds are used as the
test set, while the train set consists in between `min_train_folds` and `max_train_folds`
immediately preceding folds.
Each sample should be tagged with a prediction time and an evaluation time.
The split is such that the intervals [`pred_times`, `eval_times`] associated to samples in
the train and test set do not overlap. (The overlapping samples are dropped.)
With `split_by_time=True` in the `PurgedWalkForwardCV.split` method, it is also possible to
split the samples in folds spanning equal time intervals (using the prediction time as a time tag),
instead of folds containing equal numbers of samples."""
def __init__(
self,
n_folds: int = 10,
n_test_folds: int = 1,
min_train_folds: int = 2,
max_train_folds: tp.Optional[int] = None,
split_by_time: bool = False,
purge_td: tp.TimedeltaLike = 0,
) -> None:
BasePurgedCV.__init__(self, n_folds=n_folds, purge_td=purge_td)
if n_test_folds >= self.n_folds - 1:
raise ValueError("n_test_folds must be between 1 and n_folds - 1")
self._n_test_folds = n_test_folds
if min_train_folds >= self.n_folds - self.n_test_folds:
raise ValueError("min_train_folds must be between 1 and n_folds - n_test_folds")
self._min_train_folds = min_train_folds
if max_train_folds is None:
max_train_folds = self.n_folds - self.n_test_folds
if max_train_folds > self.n_folds - self.n_test_folds:
raise ValueError("max_train_split must be between 1 and n_folds - n_test_folds")
self._max_train_folds = max_train_folds
self._split_by_time = split_by_time
self._fold_bounds = []
@property
def n_test_folds(self) -> int:
"""Number of folds used in the test set."""
return self._n_test_folds
@property
def min_train_folds(self) -> int:
"""Minimal number of folds to be used in the train set."""
return self._min_train_folds
@property
def max_train_folds(self) -> int:
"""Maximal number of folds to be used in the train set."""
return self._max_train_folds
@property
def split_by_time(self) -> int:
"""Whether the folds span identical time intervals. Otherwise, the folds contain
an (approximately) equal number of samples."""
return self._split_by_time
@property
def fold_bounds(self) -> tp.List[int]:
"""Fold boundaries."""
return self._fold_bounds
def compute_fold_bounds(self) -> tp.List[int]:
"""Compute a list containing the fold (left) boundaries."""
if self.split_by_time:
full_time_span = self.pred_times.max() - self.pred_times.min()
fold_time_span = full_time_span / self.n_folds
fold_bounds_times = [self.pred_times.iloc[0] + fold_time_span * n for n in range(self.n_folds)]
return self.pred_times.searchsorted(fold_bounds_times)
else:
return [fold[0] for fold in np.array_split(self.indices, self.n_folds)]
def compute_train_set(self, fold_bound: int, count_folds: int) -> tp.Array1d:
"""Compute the position indices of the samples in the train set."""
if count_folds > self.max_train_folds:
start_train = self.fold_bounds[count_folds - self.max_train_folds]
else:
start_train = 0
train_indices = np.arange(start_train, fold_bound)
train_indices = self.purge(train_indices, fold_bound, self.indices[-1])
return train_indices
def compute_test_set(self, fold_bound: int, count_folds: int) -> tp.Array1d:
"""Compute the position indices of the samples in the test set."""
if self.n_folds - count_folds > self.n_test_folds:
end_test = self.fold_bounds[count_folds + self.n_test_folds]
else:
end_test = self.indices[-1] + 1
return np.arange(fold_bound, end_test)
def split(
self,
X: tp.SeriesFrame,
y: tp.Optional[tp.Series] = None,
pred_times: tp.Union[None, tp.Index, tp.Series] = None,
eval_times: tp.Union[None, tp.Index, tp.Series] = None,
) -> tp.Iterable[tp.Tuple[tp.Array1d, tp.Array1d]]:
BasePurgedCV.split(self, X, y, pred_times=pred_times, eval_times=eval_times)
self._fold_bounds = self.compute_fold_bounds()
count_folds = 0
for fold_bound in self.fold_bounds:
if count_folds < self.min_train_folds:
count_folds = count_folds + 1
continue
if self.n_folds - count_folds < self.n_test_folds:
break
test_indices = self.compute_test_set(fold_bound, count_folds)
train_indices = self.compute_train_set(fold_bound, count_folds)
count_folds = count_folds + 1
yield train_indices, test_indices
class PurgedKFoldCV(BasePurgedCV):
"""Purged and embargoed combinatorial cross-validation.
The samples are decomposed into `n_folds` folds containing equal numbers of samples, without shuffling.
In each cross validation round, `n_test_folds` folds are used as the test set, while the other folds
are used as the train set. There are as many rounds as `n_test_folds` folds among the `n_folds` folds.
Each sample should be tagged with a prediction time and an evaluation time. The split is such
that the intervals [`pred_times`, `eval_times`] associated to samples in the train and test set
do not overlap. (The overlapping samples are dropped.) In addition, an "embargo" period is defined,
giving the minimal time between an evaluation time in the test set and a prediction time in the
training set. This is to avoid, in the presence of temporal correlation, a contamination of the
test set by the train set."""
def __init__(
self,
n_folds: int = 10,
n_test_folds: int = 2,
purge_td: tp.TimedeltaLike = 0,
embargo_td: tp.TimedeltaLike = 0,
) -> None:
BasePurgedCV.__init__(self, n_folds=n_folds, purge_td=purge_td)
if n_test_folds > self.n_folds - 1:
raise ValueError("n_test_folds must be between 1 and n_folds - 1")
self._n_test_folds = n_test_folds
self._embargo_td = dt.to_timedelta(embargo_td)
@property
def n_test_folds(self) -> int:
"""Number of folds used in the test set."""
return self._n_test_folds
@property
def embargo_td(self) -> pd.Timedelta:
"""Embargo period."""
return self._embargo_td
def embargo(
self,
train_indices: tp.Array1d,
test_indices: tp.Array1d,
test_fold_end: int,
) -> tp.Array1d:
"""Apply the embargo procedure to part of the train set.
This amounts to dropping the train set samples whose prediction time occurs within
`PurgedKFoldCV.embargo_td` of the test set sample evaluation times. This method applies
the embargo only to the part of the training set immediately following the end of the test
set determined by `test_fold_end`."""
last_test_eval_time = self.eval_times.iloc[test_indices[test_indices <= test_fold_end]].max()
min_train_index = len(self.pred_times[self.pred_times <= last_test_eval_time + self.embargo_td])
if min_train_index < self.indices.shape[0]:
allowed_indices = np.concatenate((self.indices[:test_fold_end], self.indices[min_train_index:]))
train_indices = np.intersect1d(train_indices, allowed_indices)
return train_indices
def compute_train_set(
self,
test_fold_bounds: tp.List[tp.Tuple[int, int]],
test_indices: tp.Array1d,
) -> tp.Array1d:
"""Compute the position indices of the samples in the train set."""
train_indices = np.setdiff1d(self.indices, test_indices)
for test_fold_start, test_fold_end in test_fold_bounds:
train_indices = self.purge(train_indices, test_fold_start, test_fold_end)
train_indices = self.embargo(train_indices, test_indices, test_fold_end)
return train_indices
def compute_test_set(
self,
fold_bound_list: tp.List[tp.Tuple[int, int]],
) -> tp.Tuple[tp.List[tp.Tuple[int, int]], tp.Array1d]:
"""Compute the position indices of the samples in the test set."""
test_indices = np.empty(0)
test_fold_bounds = []
for fold_start, fold_end in fold_bound_list:
if not test_fold_bounds or fold_start != test_fold_bounds[-1][-1]:
test_fold_bounds.append((fold_start, fold_end))
elif fold_start == test_fold_bounds[-1][-1]:
test_fold_bounds[-1] = (test_fold_bounds[-1][0], fold_end)
test_indices = np.union1d(test_indices, self.indices[fold_start:fold_end]).astype(int)
return test_fold_bounds, test_indices
def split(
self,
X: tp.SeriesFrame,
y: tp.Optional[tp.Series] = None,
pred_times: tp.Union[None, tp.Index, tp.Series] = None,
eval_times: tp.Union[None, tp.Index, tp.Series] = None,
) -> tp.Iterable[tp.Tuple[tp.Array1d, tp.Array1d]]:
BasePurgedCV.split(self, X, y, pred_times=pred_times, eval_times=eval_times)
fold_bounds = [(fold[0], fold[-1] + 1) for fold in np.array_split(self.indices, self.n_folds)]
selected_fold_bounds = list(combinations(fold_bounds, self.n_test_folds))
selected_fold_bounds.reverse()
for fold_bound_list in selected_fold_bounds:
test_fold_bounds, test_indices = self.compute_test_set(fold_bound_list)
train_indices = self.compute_train_set(test_fold_bounds, test_indices)
yield train_indices, test_indices
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Scikit-learn compatible class for splitting."""
import numpy as np
import pandas as pd
from sklearn.model_selection import BaseCrossValidator
from sklearn.utils.validation import indexable
from vectorbtpro import _typing as tp
from vectorbtpro.generic.splitting.base import Splitter
from vectorbtpro.utils.base import Base
__all__ = [
"SplitterCV",
]
class SplitterCV(BaseCrossValidator, Base):
"""Scikit-learn compatible cross-validator based on `vectorbtpro.generic.splitting.base.Splitter`.
Usage:
* Replicate `TimeSeriesSplit` from scikit-learn:
```pycon
>>> from vectorbtpro import *
>>> X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
>>> y = np.array([1, 2, 3, 4])
>>> cv = vbt.SplitterCV(
... "from_expanding",
... min_length=2,
... offset=1,
... split=-1
... )
>>> for i, (train_indices, test_indices) in enumerate(cv.split(X)):
... print("Split %d:" % i)
... X_train, X_test = X[train_indices], X[test_indices]
... print(" X:", X_train.tolist(), X_test.tolist())
... y_train, y_test = y[train_indices], y[test_indices]
... print(" y:", y_train.tolist(), y_test.tolist())
Split 0:
X: [[1, 2]] [[3, 4]]
y: [1] [2]
Split 1:
X: [[1, 2], [3, 4]] [[5, 6]]
y: [1, 2] [3]
Split 2:
X: [[1, 2], [3, 4], [5, 6]] [[7, 8]]
y: [1, 2, 3] [4]
```
"""
def __init__(
self,
splitter: tp.Union[None, str, Splitter, tp.Callable] = None,
*,
splitter_cls: tp.Optional[tp.Type[Splitter]] = None,
split_group_by: tp.AnyGroupByLike = None,
set_group_by: tp.AnyGroupByLike = None,
template_context: tp.KwargsLike = None,
**splitter_kwargs,
) -> None:
if splitter_cls is None:
splitter_cls = Splitter
if splitter is None:
splitter = splitter_cls.guess_method(**splitter_kwargs)
self._splitter = splitter
self._splitter_kwargs = splitter_kwargs
self._splitter_cls = splitter_cls
self._split_group_by = split_group_by
self._set_group_by = set_group_by
self._template_context = template_context
@property
def splitter(self) -> tp.Union[str, Splitter, tp.Callable]:
"""Splitter.
Either as a `vectorbtpro.generic.splitting.base.Splitter` instance, a factory method name,
or the factory method itself.
If None, will be determined automatically based on `SplitterCV.splitter_kwargs`."""
return self._splitter
@property
def splitter_cls(self) -> tp.Type[Splitter]:
"""Splitter class.
Defaults to `vectorbtpro.generic.splitting.base.Splitter`."""
return self._splitter_cls
@property
def splitter_kwargs(self) -> tp.KwargsLike:
"""Keyword arguments passed to the factory method."""
return self._splitter_kwargs
@property
def split_group_by(self) -> tp.AnyGroupByLike:
"""Split groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.
Not passed to the factory method."""
return self._split_group_by
@property
def set_group_by(self) -> tp.AnyGroupByLike:
"""Set groups. See `vectorbtpro.base.accessors.BaseIDXAccessor.get_grouper`.
Not passed to the factory method."""
return self._set_group_by
@property
def template_context(self) -> tp.KwargsLike:
"""Mapping used to substitute templates in ranges.
Passed to the factory method."""
return self._template_context
def get_splitter(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> Splitter:
"""Get splitter of type `vectorbtpro.generic.splitting.base.Splitter`."""
X, y, groups = indexable(X, y, groups)
try:
index = self.splitter_cls.get_obj_index(X)
except ValueError as e:
index = pd.RangeIndex(stop=len(X))
if isinstance(self.splitter, str):
splitter = getattr(self.splitter_cls, self.splitter)
else:
splitter = self.splitter
splitter = splitter(
index,
template_context=self.template_context,
**self.splitter_kwargs,
)
if splitter.get_n_sets(set_group_by=self.set_group_by) != 2:
raise ValueError("Number of sets in the splitter must be 2: train and test")
return splitter
def _iter_masks(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> tp.Iterator[tp.Tuple[tp.Array1d, tp.Array1d]]:
"""Generates boolean masks corresponding to train and test sets."""
splitter = self.get_splitter(X=X, y=y, groups=groups)
for mask_arr in splitter.get_iter_split_mask_arrs(
split_group_by=self.split_group_by,
set_group_by=self.set_group_by,
template_context=self.template_context,
):
yield mask_arr[0], mask_arr[1]
def _iter_train_masks(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> tp.Iterator[tp.Array1d]:
"""Generates boolean masks corresponding to train sets."""
for train_mask_arr, _ in self._iter_masks(X=X, y=y, groups=groups):
yield train_mask_arr
def _iter_test_masks(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> tp.Iterator[tp.Array1d]:
"""Generates boolean masks corresponding to test sets."""
for _, test_mask_arr in self._iter_masks(X=X, y=y, groups=groups):
yield test_mask_arr
def _iter_indices(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> tp.Iterator[tp.Tuple[tp.Array1d, tp.Array1d]]:
"""Generates integer indices corresponding to train and test sets."""
for train_mask_arr, test_mask_arr in self._iter_masks(X=X, y=y, groups=groups):
yield np.flatnonzero(train_mask_arr), np.flatnonzero(test_mask_arr)
def _iter_train_indices(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> tp.Iterator[tp.Array1d]:
"""Generates integer indices corresponding to train sets."""
for train_indices, _ in self._iter_indices(X=X, y=y, groups=groups):
yield train_indices
def _iter_test_indices(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> tp.Iterator[tp.Array1d]:
"""Generates integer indices corresponding to test sets."""
for _, test_indices in self._iter_indices(X=X, y=y, groups=groups):
yield test_indices
def get_n_splits(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> int:
"""Returns the number of splitting iterations in the cross-validator."""
splitter = self.get_splitter(X=X, y=y, groups=groups)
return splitter.get_n_splits(split_group_by=self.split_group_by)
def split(
self,
X: tp.Any = None,
y: tp.Any = None,
groups: tp.Any = None,
) -> tp.Iterator[tp.Tuple[tp.Array1d, tp.Array1d]]:
"""Generate indices to split data into training and test set."""
return self._iter_indices(X=X, y=y, groups=groups)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for working with generic time series.
In contrast to the `vectorbtpro.base` sub-package, focuses on the data itself."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.generic.nb import *
from vectorbtpro.generic.splitting import *
from vectorbtpro.generic.accessors import *
from vectorbtpro.generic.analyzable import *
from vectorbtpro.generic.decorators import *
from vectorbtpro.generic.drawdowns import *
from vectorbtpro.generic.plots_builder import *
from vectorbtpro.generic.plotting import *
from vectorbtpro.generic.price_records import *
from vectorbtpro.generic.ranges import *
from vectorbtpro.generic.sim_range import *
from vectorbtpro.generic.stats_builder import *
__exclude_from__all__ = [
"enums",
]
__import_if_installed__ = dict()
__import_if_installed__["plotting"] = "plotly"
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom Pandas accessors for generic data.
Methods can be accessed as follows:
* `GenericSRAccessor` -> `pd.Series.vbt.*`
* `GenericDFAccessor` -> `pd.DataFrame.vbt.*`
```pycon
>>> from vectorbtpro import *
>>> # vectorbtpro.generic.accessors.GenericAccessor.rolling_mean
>>> pd.Series([1, 2, 3, 4]).vbt.rolling_mean(2)
0 NaN
1 1.5
2 2.5
3 3.5
dtype: float64
```
The accessors inherit `vectorbtpro.base.accessors` and are inherited by more
specialized accessors, such as `vectorbtpro.signals.accessors` and `vectorbtpro.returns.accessors`.
!!! note
Grouping is only supported by the methods that accept the `group_by` argument.
Accessors do not utilize caching.
Run for the examples below:
```pycon
>>> df = pd.DataFrame({
... 'a': [1, 2, 3, 4, 5],
... 'b': [5, 4, 3, 2, 1],
... 'c': [1, 2, 3, 2, 1]
... }, index=pd.Index(pd.date_range("2020", periods=5)))
>>> df
a b c
2020-01-01 1 5 1
2020-01-02 2 4 2
2020-01-03 3 3 3
2020-01-04 4 2 2
2020-01-05 5 1 1
>>> sr = pd.Series(np.arange(10), index=pd.date_range("2020", periods=10))
>>> sr
2020-01-01 0
2020-01-02 1
2020-01-03 2
2020-01-04 3
2020-01-05 4
2020-01-06 5
2020-01-07 6
2020-01-08 7
2020-01-09 8
2020-01-10 9
dtype: int64
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `GenericAccessor.metrics`.
```pycon
>>> df2 = pd.DataFrame({
... 'a': [np.nan, 2, 3],
... 'b': [4, np.nan, 5],
... 'c': [6, 7, np.nan]
... }, index=['x', 'y', 'z'])
>>> df2.vbt(freq='d').stats(column='a')
Start x
End z
Period 3 days 00:00:00
Count 2
Mean 2.5
Std 0.707107
Min 2.0
Median 2.5
Max 3.0
Min Index y
Max Index z
Name: a, dtype: object
```
### Mapping
Mapping can be set both in `GenericAccessor` (preferred) and `GenericAccessor.stats`:
```pycon
>>> mapping = {x: 'test_' + str(x) for x in pd.unique(df2.values.flatten())}
>>> df2.vbt(freq='d', mapping=mapping).stats(column='a')
Start x
End z
Period 3 days 00:00:00
Count 2
Value Counts: test_2.0 1
Value Counts: test_3.0 1
Value Counts: test_4.0 0
Value Counts: test_5.0 0
Value Counts: test_6.0 0
Value Counts: test_7.0 0
Value Counts: test_nan 1
Name: a, dtype: object
>>> df2.vbt(freq='d').stats(column='a', settings=dict(mapping=mapping))
UserWarning: Changing the mapping will create a copy of this object.
Consider setting it upon object creation to re-use existing cache.
Start x
End z
Period 3 days 00:00:00
Count 2
Value Counts: test_2.0 1
Value Counts: test_3.0 1
Value Counts: test_4.0 0
Value Counts: test_5.0 0
Value Counts: test_6.0 0
Value Counts: test_7.0 0
Value Counts: test_nan 1
Name: a, dtype: object
```
Selecting a column before calling `stats` will consider uniques from this column only:
```pycon
>>> df2['a'].vbt(freq='d', mapping=mapping).stats()
Start x
End z
Period 3 days 00:00:00
Count 2
Value Counts: test_2.0 1
Value Counts: test_3.0 1
Value Counts: test_nan 1
Name: a, dtype: object
```
To include all keys from `mapping`, pass `incl_all_keys=True`:
```pycon
>>> df2['a'].vbt(freq='d', mapping=mapping).stats(settings=dict(incl_all_keys=True))
Start x
End z
Period 3 days 00:00:00
Count 2
Value Counts: test_2.0 1
Value Counts: test_3.0 1
Value Counts: test_4.0 0
Value Counts: test_5.0 0
Value Counts: test_6.0 0
Value Counts: test_7.0 0
Value Counts: test_nan 1
Name: a, dtype: object
```
`GenericAccessor.stats` also supports (re-)grouping:
```pycon
>>> df2.vbt(freq='d').stats(column=0, group_by=[0, 0, 1])
Start x
End z
Period 3 days 00:00:00
Count 4
Mean 3.5
Std 1.290994
Min 2.0
Median 3.5
Max 5.0
Min Index y
Max Index z
Name: 0, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `GenericAccessor.subplots`.
`GenericAccessor` class has a single subplot based on `GenericAccessor.plot`:
```pycon
>>> df2.vbt.plots().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from functools import partial
import numpy as np
import pandas as pd
from pandas.core.resample import Resampler as PandasResampler
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro._settings import settings
from vectorbtpro.base import indexes, reshaping
from vectorbtpro.base.accessors import BaseAccessor, BaseDFAccessor, BaseSRAccessor
from vectorbtpro.base.indexes import repeat_index
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.generic import nb
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.generic.decorators import attach_nb_methods, attach_transform_methods
from vectorbtpro.generic.drawdowns import Drawdowns
from vectorbtpro.generic.enums import WType, InterpMode, RescaleMode, ErrorType, DistanceMeasure
from vectorbtpro.generic.plots_builder import PlotsBuilderMixin
from vectorbtpro.generic.ranges import Ranges, PatternRanges
from vectorbtpro.generic.stats_builder import StatsBuilderMixin
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks, chunking as ch, datetime_ as dt
from vectorbtpro.utils.colors import adjust_opacity, map_value_to_cmap
from vectorbtpro.utils.config import merge_dicts, resolve_dict, Config, ReadonlyConfig, HybridConfig
from vectorbtpro.utils.decorators import hybrid_method, hybrid_property
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.mapping import apply_mapping, to_value_mapping
from vectorbtpro.utils.template import substitute_templates
from vectorbtpro.utils.warnings_ import warn
try:
import bottleneck as bn
nanmean = bn.nanmean
nanstd = bn.nanstd
nansum = bn.nansum
nanmax = bn.nanmax
nanmin = bn.nanmin
nanmedian = bn.nanmedian
nanargmax = bn.nanargmax
nanargmin = bn.nanargmin
except ImportError:
# slower numpy
nanmean = np.nanmean
nanstd = np.nanstd
nansum = np.nansum
nanmax = np.nanmax
nanmin = np.nanmin
nanmedian = np.nanmedian
nanargmax = np.nanargmax
nanargmin = np.nanargmin
__all__ = [
"GenericAccessor",
"GenericSRAccessor",
"GenericDFAccessor",
]
__pdoc__ = {}
GenericAccessorT = tp.TypeVar("GenericAccessorT", bound="GenericAccessor")
SplitOutputT = tp.Union[tp.MaybeTuple[tp.Tuple[tp.Frame, tp.Index]], tp.BaseFigure]
class TransformerT(tp.Protocol):
def __init__(self, **kwargs) -> None: ...
def transform(self, *args, **kwargs) -> tp.Array2d: ...
def fit_transform(self, *args, **kwargs) -> tp.Array2d: ...
__pdoc__["TransformerT"] = False
nb_config = ReadonlyConfig(
{
"shuffle": dict(func=nb.shuffle_nb, disable_chunked=True),
"fillna": dict(func=nb.fillna_nb),
"bshift": dict(func=nb.bshift_nb),
"fshift": dict(func=nb.fshift_nb),
"diff": dict(func=nb.diff_nb),
"pct_change": dict(func=nb.pct_change_nb),
"ffill": dict(func=nb.ffill_nb),
"bfill": dict(func=nb.bfill_nb),
"fbfill": dict(func=nb.fbfill_nb),
"cumsum": dict(func=nb.nancumsum_nb),
"cumprod": dict(func=nb.nancumprod_nb),
"rolling_sum": dict(func=nb.rolling_sum_nb),
"rolling_prod": dict(func=nb.rolling_prod_nb),
"rolling_min": dict(func=nb.rolling_min_nb),
"rolling_max": dict(func=nb.rolling_max_nb),
"expanding_min": dict(func=nb.expanding_min_nb),
"expanding_max": dict(func=nb.expanding_max_nb),
"rolling_any": dict(func=nb.rolling_any_nb),
"rolling_all": dict(func=nb.rolling_all_nb),
"product": dict(func=nb.nanprod_nb, is_reducing=True),
}
)
"""_"""
__pdoc__[
"nb_config"
] = f"""Config of Numba methods to be attached to `GenericAccessor`.
```python
{nb_config.prettify()}
```
"""
@attach_nb_methods(nb_config)
class GenericAccessor(BaseAccessor, Analyzable):
"""Accessor on top of data of any type. For both, Series and DataFrames.
Accessible via `pd.Series.vbt` and `pd.DataFrame.vbt`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
mapping: tp.Optional[tp.MappingLike] = None,
**kwargs,
) -> None:
BaseAccessor.__init__(self, wrapper, obj=obj, mapping=mapping, **kwargs)
StatsBuilderMixin.__init__(self)
PlotsBuilderMixin.__init__(self)
self._mapping = mapping
@hybrid_property
def sr_accessor_cls(cls_or_self) -> tp.Type["GenericSRAccessor"]:
"""Accessor class for `pd.Series`."""
return GenericSRAccessor
@hybrid_property
def df_accessor_cls(cls_or_self) -> tp.Type["GenericDFAccessor"]:
"""Accessor class for `pd.DataFrame`."""
return GenericDFAccessor
# ############# Mapping ############# #
@property
def mapping(self) -> tp.Optional[tp.MappingLike]:
"""Mapping."""
return self._mapping
def resolve_mapping(self, mapping: tp.Union[None, bool, tp.MappingLike] = None) -> tp.Optional[tp.Mapping]:
"""Resolve mapping.
Set `mapping` to False to disable mapping completely."""
if mapping is None:
mapping = self.mapping
if isinstance(mapping, bool):
if not mapping:
return None
raise ValueError("Mapping cannot be True")
if isinstance(mapping, str):
if mapping.lower() == "index":
mapping = self.wrapper.index
elif mapping.lower() == "columns":
mapping = self.wrapper.columns
elif mapping.lower() == "groups":
mapping = self.wrapper.get_columns()
mapping = to_value_mapping(mapping)
return mapping
def apply_mapping(self, mapping: tp.Union[None, bool, tp.MappingLike] = None, **kwargs) -> tp.SeriesFrame:
"""See `vectorbtpro.utils.mapping.apply_mapping`."""
mapping = self.resolve_mapping(mapping)
return apply_mapping(self.obj, mapping, **kwargs)
# ############# Shifting ############# #
def ago(
self,
n: tp.Union[int, tp.FrequencyLike],
fill_value: tp.Scalar = np.nan,
get_indexer_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""For each value, get the value `n` periods ago."""
if checks.is_int(n):
return self.fshift(n, fill_value=fill_value, **kwargs)
if get_indexer_kwargs is None:
get_indexer_kwargs = {}
n = dt.to_timedelta(n)
indices = self.wrapper.index.get_indexer(self.wrapper.index - n, **get_indexer_kwargs)
new_obj = self.wrapper.fill(fill_value=fill_value)
found_mask = indices != -1
new_obj.iloc[np.flatnonzero(found_mask)] = self.obj.iloc[indices[found_mask]]
return new_obj
def any_ago(self, n: tp.Union[int, tp.FrequencyLike], **kwargs) -> tp.SeriesFrame:
"""For each value, check whether any value within a window of `n` last periods is True."""
wrap_kwargs = kwargs.pop("wrap_kwargs", {})
wrap_kwargs = merge_dicts(dict(fillna=False, dtype=bool), wrap_kwargs)
if checks.is_int(n):
return self.rolling_any(n, wrap_kwargs=wrap_kwargs, **kwargs)
return self.rolling_apply(n, "any", wrap_kwargs=wrap_kwargs, **kwargs)
def all_ago(self, n: tp.Union[int, tp.FrequencyLike], **kwargs) -> tp.SeriesFrame:
"""For each value, check whether all values within a window of `n` last periods are True."""
wrap_kwargs = kwargs.pop("wrap_kwargs", {})
wrap_kwargs = merge_dicts(dict(fillna=False, dtype=bool), wrap_kwargs)
if checks.is_int(n):
return self.rolling_all(n, wrap_kwargs=wrap_kwargs, **kwargs)
return self.rolling_apply(n, "all", wrap_kwargs=wrap_kwargs, **kwargs)
# ############# Rolling ############# #
def rolling_idxmin(
self,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
local: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_argmin_nb`."""
if window is None:
window = self.wrapper.shape[0]
func = jit_reg.resolve_option(nb.rolling_argmin_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, minp=minp, local=local)
if not local:
wrap_kwargs = merge_dicts(dict(to_index=True), wrap_kwargs)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def expanding_idxmin(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame:
"""Expanding version of `GenericAccessor.rolling_idxmin`."""
return self.rolling_idxmin(None, minp=minp, **kwargs)
def rolling_idxmax(
self,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
local: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_argmax_nb`."""
if window is None:
window = self.wrapper.shape[0]
func = jit_reg.resolve_option(nb.rolling_argmax_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, minp=minp, local=local)
if not local:
wrap_kwargs = merge_dicts(dict(to_index=True), wrap_kwargs)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def expanding_idxmax(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame:
"""Expanding version of `GenericAccessor.rolling_idxmax`."""
return self.rolling_idxmax(None, minp=minp, **kwargs)
def rolling_mean(
self,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_mean_nb`."""
if window is None:
window = self.wrapper.shape[0]
func = jit_reg.resolve_option(nb.rolling_mean_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, minp=minp)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def expanding_mean(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame:
"""Expanding version of `GenericAccessor.rolling_mean`."""
return self.rolling_mean(None, minp=minp, **kwargs)
def rolling_std(
self,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
ddof: int = 1,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_std_nb`."""
if window is None:
window = self.wrapper.shape[0]
func = jit_reg.resolve_option(nb.rolling_std_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, minp=minp, ddof=ddof)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def expanding_std(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame:
"""Expanding version of `GenericAccessor.rolling_std`."""
return self.rolling_std(None, minp=minp, **kwargs)
def rolling_zscore(
self,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
ddof: int = 1,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_zscore_nb`."""
if window is None:
window = self.wrapper.shape[0]
func = jit_reg.resolve_option(nb.rolling_zscore_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, minp=minp, ddof=ddof)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def expanding_zscore(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame:
"""Expanding version of `GenericAccessor.rolling_zscore`."""
return self.rolling_zscore(None, minp=minp, **kwargs)
def wm_mean(
self,
span: int,
minp: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.wm_mean_nb`."""
func = jit_reg.resolve_option(nb.wm_mean_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), span, minp=minp)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def ewm_mean(
self,
span: int,
minp: tp.Optional[int] = 0,
adjust: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.ewm_mean_nb`."""
func = jit_reg.resolve_option(nb.ewm_mean_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), span, minp=minp, adjust=adjust)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def ewm_std(
self,
span: int,
minp: tp.Optional[int] = 0,
adjust: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.ewm_std_nb`."""
func = jit_reg.resolve_option(nb.ewm_std_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), span, minp=minp, adjust=adjust)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def wwm_mean(
self,
period: int,
minp: tp.Optional[int] = 0,
adjust: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.wwm_mean_nb`."""
func = jit_reg.resolve_option(nb.wwm_mean_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), period, minp=minp, adjust=adjust)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def wwm_std(
self,
period: int,
minp: tp.Optional[int] = 0,
adjust: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.wwm_std_nb`."""
func = jit_reg.resolve_option(nb.wwm_std_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), period, minp=minp, adjust=adjust)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def vidya(
self,
window: int,
minp: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.vidya_nb`."""
func = jit_reg.resolve_option(nb.vidya_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, minp=minp)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def ma(
self,
window: int,
wtype: tp.Union[int, str] = "simple",
minp: tp.Optional[int] = 0,
adjust: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.ma_nb`."""
if isinstance(wtype, str):
wtype = map_enum_fields(wtype, WType)
func = jit_reg.resolve_option(nb.ma_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, wtype=wtype, minp=minp, adjust=adjust)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def msd(
self,
window: int,
wtype: tp.Union[int, str] = "simple",
minp: tp.Optional[int] = 0,
adjust: bool = True,
ddof: int = 1,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.msd_nb`."""
if isinstance(wtype, str):
wtype = map_enum_fields(wtype, WType)
func = jit_reg.resolve_option(nb.msd_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def rolling_cov(
self,
other: tp.SeriesFrame,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
ddof: int = 1,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_cov_nb`."""
self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs))
if window is None:
window = self_obj.shape[0]
func = jit_reg.resolve_option(nb.rolling_cov_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(reshaping.to_2d_array(self_obj), reshaping.to_2d_array(other_obj), window, minp=minp, ddof=ddof)
return ArrayWrapper.from_obj(self_obj).wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def expanding_cov(self, other: tp.SeriesFrame, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame:
"""Expanding version of `GenericAccessor.rolling_cov`."""
return self.rolling_cov(other, None, minp=minp, **kwargs)
def rolling_corr(
self,
other: tp.SeriesFrame,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_corr_nb`."""
self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs))
if window is None:
window = self_obj.shape[0]
func = jit_reg.resolve_option(nb.rolling_corr_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(reshaping.to_2d_array(self_obj), reshaping.to_2d_array(other_obj), window, minp=minp)
return ArrayWrapper.from_obj(self_obj).wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def expanding_corr(self, other: tp.SeriesFrame, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame:
"""Expanding version of `GenericAccessor.rolling_corr`."""
return self.rolling_corr(other, None, minp=minp, **kwargs)
def rolling_ols(
self,
other: tp.SeriesFrame,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame]:
"""See `vectorbtpro.generic.nb.rolling.rolling_ols_nb`.
Returns two arrays: slope and intercept."""
self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs))
if window is None:
window = self_obj.shape[0]
func = jit_reg.resolve_option(nb.rolling_ols_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
slope_out, intercept_out = func(
reshaping.to_2d_array(self_obj),
reshaping.to_2d_array(other_obj),
window,
minp=minp,
)
return (
ArrayWrapper.from_obj(self_obj).wrap(slope_out, group_by=False, **resolve_dict(wrap_kwargs)),
ArrayWrapper.from_obj(self_obj).wrap(intercept_out, group_by=False, **resolve_dict(wrap_kwargs)),
)
def expanding_ols(
self,
other: tp.SeriesFrame,
minp: tp.Optional[int] = 1,
**kwargs,
) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame]:
"""Expanding version of `GenericAccessor.rolling_ols`."""
return self.rolling_ols(other, None, minp=minp, **kwargs)
def rolling_rank(
self,
window: tp.Optional[int],
minp: tp.Optional[int] = None,
pct: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_rank_nb`."""
if window is None:
window = self.wrapper.shape[0]
func = jit_reg.resolve_option(nb.rolling_rank_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array(), window, minp=minp, pct=pct)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def expanding_rank(self, minp: tp.Optional[int] = 1, **kwargs) -> tp.SeriesFrame:
"""Expanding version of `GenericAccessor.rolling_rank`."""
return self.rolling_rank(None, minp=minp, **kwargs)
def rolling_pattern_similarity(
self,
pattern: tp.ArrayLike,
window: tp.Optional[int] = None,
max_window: tp.Optional[int] = None,
row_select_prob: float = 1.0,
window_select_prob: float = 1.0,
interp_mode: tp.Union[int, str] = "mixed",
rescale_mode: tp.Union[int, str] = "minmax",
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: tp.Union[int, str] = "absolute",
distance_measure: tp.Union[int, str] = "mae",
max_error: tp.ArrayLike = np.nan,
max_error_interp_mode: tp.Union[None, int, str] = None,
max_error_as_maxdist: bool = False,
max_error_strict: bool = False,
min_pct_change: float = np.nan,
max_pct_change: float = np.nan,
min_similarity: float = np.nan,
minp: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.rolling.rolling_pattern_similarity_nb`."""
if isinstance(interp_mode, str):
interp_mode = map_enum_fields(interp_mode, InterpMode)
if isinstance(rescale_mode, str):
rescale_mode = map_enum_fields(rescale_mode, RescaleMode)
if isinstance(error_type, str):
error_type = map_enum_fields(error_type, ErrorType)
if isinstance(distance_measure, str):
distance_measure = map_enum_fields(distance_measure, DistanceMeasure)
if max_error_interp_mode is not None and isinstance(max_error_interp_mode, str):
max_error_interp_mode = map_enum_fields(max_error_interp_mode, InterpMode)
if max_error_interp_mode is None:
max_error_interp_mode = interp_mode
func = jit_reg.resolve_option(nb.rolling_pattern_similarity_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
reshaping.to_1d_array(pattern),
window=window,
max_window=max_window,
row_select_prob=row_select_prob,
window_select_prob=window_select_prob,
interp_mode=interp_mode,
rescale_mode=rescale_mode,
vmin=vmin,
vmax=vmax,
pmin=pmin,
pmax=pmax,
invert=invert,
error_type=error_type,
distance_measure=distance_measure,
max_error=reshaping.to_1d_array(max_error),
max_error_interp_mode=max_error_interp_mode,
max_error_as_maxdist=max_error_as_maxdist,
max_error_strict=max_error_strict,
min_pct_change=min_pct_change,
max_pct_change=max_pct_change,
min_similarity=min_similarity,
minp=minp,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
# ############# Mapping ############# #
@hybrid_method
def map(
cls_or_self,
map_func_nb: tp.Union[str, tp.MapFunc, tp.MapMetaFunc],
*args,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.apply_reduce.map_nb`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.map_meta_nb`.
Usage:
* Using regular function:
```pycon
>>> prod_nb = njit(lambda a, x: a * x)
>>> df.vbt.map(prod_nb, 10)
a b c
2020-01-01 10 50 10
2020-01-02 20 40 20
2020-01-03 30 30 30
2020-01-04 40 20 20
2020-01-05 50 10 10
```
* Using meta function:
```pycon
>>> diff_meta_nb = njit(lambda i, col, a, b: a[i, col] / b[i, col])
>>> vbt.pd_acc.map(
... diff_meta_nb,
... df.vbt.to_2d_array() - 1,
... df.vbt.to_2d_array() + 1,
... wrapper=df.vbt.wrapper
... )
a b c
2020-01-01 0.000000 0.666667 0.000000
2020-01-02 0.333333 0.600000 0.333333
2020-01-03 0.500000 0.500000 0.500000
2020-01-04 0.600000 0.333333 0.333333
2020-01-05 0.666667 0.000000 0.000000
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.map(
... diff_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=pd.Series([1, 2, 3, 4, 5], index=df.index),
... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
a b c
2020-01-01 1.0 0.5 0.333333
2020-01-02 2.0 1.0 0.666667
2020-01-03 3.0 1.5 1.000000
2020-01-04 4.0 2.0 1.333333
2020-01-05 5.0 2.5 1.666667
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(map_func_nb, str):
map_func_nb = getattr(nb, map_func_nb + "_map_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context)
args = substitute_templates(args, template_context, eval_id="args")
func = jit_reg.resolve_option(nb.map_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.shape_2d, map_func_nb, *args)
else:
func = jit_reg.resolve_option(nb.map_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), map_func_nb, *args)
if wrapper is None:
wrapper = cls_or_self.wrapper
return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
# ############# Applying ############# #
@hybrid_method
def apply_along_axis(
cls_or_self,
apply_func_nb: tp.Union[str, tp.ApplyFunc, tp.ApplyMetaFunc],
*args,
axis: int = 1,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.apply_reduce.apply_nb` for `axis=1` and
`vectorbtpro.generic.nb.apply_reduce.row_apply_nb` for `axis=0`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.apply_meta_nb`
for `axis=1` and `vectorbtpro.generic.nb.apply_reduce.row_apply_meta_nb` for `axis=0`.
Usage:
* Using regular function:
```pycon
>>> power_nb = njit(lambda a: np.power(a, 2))
>>> df.vbt.apply_along_axis(power_nb)
a b c
2020-01-01 1 25 1
2020-01-02 4 16 4
2020-01-03 9 9 9
2020-01-04 16 4 4
2020-01-05 25 1 1
```
* Using meta function:
```pycon
>>> ratio_meta_nb = njit(lambda col, a, b: a[:, col] / b[:, col])
>>> vbt.pd_acc.apply_along_axis(
... ratio_meta_nb,
... df.vbt.to_2d_array() - 1,
... df.vbt.to_2d_array() + 1,
... wrapper=df.vbt.wrapper
... )
a b c
2020-01-01 0.000000 0.666667 0.000000
2020-01-02 0.333333 0.600000 0.333333
2020-01-03 0.500000 0.500000 0.500000
2020-01-04 0.600000 0.333333 0.333333
2020-01-05 0.666667 0.000000 0.000000
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.apply_along_axis(
... ratio_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=pd.Series([1, 2, 3, 4, 5], index=df.index),
... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
a b c
2020-01-01 1.0 0.5 0.333333
2020-01-02 2.0 1.0 0.666667
2020-01-03 3.0 1.5 1.000000
2020-01-04 4.0 2.0 1.333333
2020-01-05 5.0 2.5 1.666667
```
"""
checks.assert_in(axis, (0, 1))
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(apply_func_nb, str):
apply_func_nb = getattr(nb, apply_func_nb + "_apply_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper, axis=axis), template_context)
args = substitute_templates(args, template_context, eval_id="args")
if axis == 0:
func = jit_reg.resolve_option(nb.row_apply_meta_nb, jitted)
else:
func = jit_reg.resolve_option(nb.apply_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.shape_2d, apply_func_nb, *args)
else:
if axis == 0:
func = jit_reg.resolve_option(nb.row_apply_nb, jitted)
else:
func = jit_reg.resolve_option(nb.apply_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), apply_func_nb, *args)
if wrapper is None:
wrapper = cls_or_self.wrapper
return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def row_apply(self, *args, **kwargs) -> tp.SeriesFrame:
"""`GenericAccessor.apply_along_axis` with `axis=0`."""
return self.apply_along_axis(*args, axis=0, **kwargs)
@hybrid_method
def column_apply(self, *args, **kwargs) -> tp.SeriesFrame:
"""`GenericAccessor.apply_along_axis` with `axis=1`."""
return self.apply_along_axis(*args, axis=1, **kwargs)
# ############# Reducing ############# #
@hybrid_method
def rolling_apply(
cls_or_self,
window: tp.Optional[tp.FrequencyLike],
reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.RangeReduceMetaFunc],
*args,
minp: tp.Optional[int] = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.apply_reduce.rolling_reduce_nb` for integer windows
and `vectorbtpro.generic.nb.apply_reduce.rolling_freq_reduce_nb` for frequency windows.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.rolling_reduce_meta_nb`
for integer windows and `vectorbtpro.generic.nb.apply_reduce.rolling_freq_reduce_meta_nb` for
frequency windows.
If `window` is None, it will become an expanding window.
Usage:
* Using regular function:
```pycon
>>> mean_nb = njit(lambda a: np.nanmean(a))
>>> df.vbt.rolling_apply(3, mean_nb)
a b c
2020-01-01 NaN NaN NaN
2020-01-02 NaN NaN NaN
2020-01-03 2.0 4.0 2.000000
2020-01-04 3.0 3.0 2.333333
2020-01-05 4.0 2.0 2.000000
```
* Using a frequency-based window:
```pycon
>>> df.vbt.rolling_apply("3d", mean_nb)
a b c
2020-01-01 1.0 5.0 1.000000
2020-01-02 1.5 4.5 1.500000
2020-01-03 2.0 4.0 2.000000
2020-01-04 3.0 3.0 2.333333
2020-01-05 4.0 2.0 2.000000
```
* Using meta function:
```pycon
>>> mean_ratio_meta_nb = njit(lambda from_i, to_i, col, a, b: \\
... np.mean(a[from_i:to_i, col]) / np.mean(b[from_i:to_i, col]))
>>> vbt.pd_acc.rolling_apply(
... 3,
... mean_ratio_meta_nb,
... df.vbt.to_2d_array() - 1,
... df.vbt.to_2d_array() + 1,
... wrapper=df.vbt.wrapper,
... )
a b c
2020-01-01 NaN NaN NaN
2020-01-02 NaN NaN NaN
2020-01-03 0.333333 0.600000 0.333333
2020-01-04 0.500000 0.500000 0.400000
2020-01-05 0.600000 0.333333 0.333333
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.rolling_apply(
... 2,
... mean_ratio_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=pd.Series([1, 2, 3, 4, 5], index=df.index),
... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
a b c
2020-01-01 NaN NaN NaN
2020-01-02 1.5 0.75 0.500000
2020-01-03 2.5 1.25 0.833333
2020-01-04 3.5 1.75 1.166667
2020-01-05 4.5 2.25 1.500000
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
if window is not None:
if not isinstance(window, int):
window = dt.to_timedelta64(window)
if minp is None and window is None:
minp = 1
if window is None:
window = wrapper.shape[0]
if minp is None:
minp = window
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb")
if isinstance(cls_or_self, type):
template_context = merge_dicts(
broadcast_named_args,
dict(wrapper=wrapper, window=window, minp=minp),
template_context,
)
args = substitute_templates(args, template_context, eval_id="args")
if isinstance(window, int):
func = jit_reg.resolve_option(nb.rolling_reduce_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.shape_2d, window, minp, reduce_func_nb, *args)
else:
func = jit_reg.resolve_option(nb.rolling_freq_reduce_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.shape_2d[1], wrapper.index.values, window, reduce_func_nb, *args)
else:
if isinstance(window, int):
func = jit_reg.resolve_option(nb.rolling_reduce_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), window, minp, reduce_func_nb, *args)
else:
func = jit_reg.resolve_option(nb.rolling_freq_reduce_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.index.values, cls_or_self.to_2d_array(), window, reduce_func_nb, *args)
return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def expanding_apply(cls_or_self, *args, **kwargs) -> tp.SeriesFrame:
"""`GenericAccessor.rolling_apply` but expanding."""
return cls_or_self.rolling_apply(None, *args, **kwargs)
@hybrid_method
def groupby_apply(
cls_or_self,
by: tp.AnyGroupByLike,
reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.GroupByReduceMetaFunc],
*args,
groupby_kwargs: tp.KwargsLike = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.apply_reduce.groupby_reduce_nb`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.groupby_reduce_meta_nb`.
Argument `by` can be an instance of `vectorbtpro.base.grouping.base.Grouper`,
`pandas.core.groupby.GroupBy`, `pandas.core.resample.Resampler`, or any other groupby-like
object that can be accepted by `vectorbtpro.base.grouping.base.Grouper`, or if it fails,
then by `pd.DataFrame.groupby` with `groupby_kwargs` passed as keyword arguments.
Usage:
* Using regular function:
```pycon
>>> mean_nb = njit(lambda a: np.nanmean(a))
>>> df.vbt.groupby_apply([1, 1, 2, 2, 3], mean_nb)
a b c
1 1.5 4.5 1.5
2 3.5 2.5 2.5
3 5.0 1.0 1.0
```
* Using meta function:
```pycon
>>> mean_ratio_meta_nb = njit(lambda idxs, group, col, a, b: \\
... np.mean(a[idxs, col]) / np.mean(b[idxs, col]))
>>> vbt.pd_acc.groupby_apply(
... [1, 1, 2, 2, 3],
... mean_ratio_meta_nb,
... df.vbt.to_2d_array() - 1,
... df.vbt.to_2d_array() + 1,
... wrapper=df.vbt.wrapper
... )
a b c
1 0.200000 0.636364 0.200000
2 0.555556 0.428571 0.428571
3 0.666667 0.000000 0.000000
```
* Using templates and broadcasting, let's split both input arrays into 2 groups of rows and
run the calculation function on each group:
```pycon
>>> from vectorbtpro.base.grouping.nb import group_by_evenly_nb
>>> vbt.pd_acc.groupby_apply(
... vbt.RepEval('group_by_evenly_nb(wrapper.shape[0], 2)'),
... mean_ratio_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=pd.Series([1, 2, 3, 4, 5], index=df.index),
... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... ),
... template_context=dict(group_by_evenly_nb=group_by_evenly_nb)
... )
a b c
0 2.0 1.00 0.666667
1 4.5 2.25 1.500000
```
The advantage of the approach above is in the flexibility: we can pass two arrays of
any broadcastable shapes and everything else is done for us.
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context)
by = substitute_templates(by, template_context, eval_id="by")
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
grouper = wrapper.get_index_grouper(by, **resolve_dict(groupby_kwargs))
if isinstance(cls_or_self, type):
group_map = grouper.get_group_map()
template_context = merge_dicts(dict(by=by, grouper=grouper), template_context)
args = substitute_templates(args, template_context, eval_id="args")
func = jit_reg.resolve_option(nb.groupby_reduce_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.shape_2d[1], group_map, reduce_func_nb, *args)
else:
group_map = grouper.get_group_map()
func = jit_reg.resolve_option(nb.groupby_reduce_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), group_map, reduce_func_nb, *args)
wrap_kwargs = merge_dicts(dict(name_or_index=grouper.get_index()), wrap_kwargs)
return wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
@hybrid_method
def groupby_transform(
cls_or_self,
by: tp.AnyGroupByLike,
transform_func_nb: tp.Union[str, tp.GroupByTransformFunc, tp.GroupByTransformMetaFunc],
*args,
groupby_kwargs: tp.KwargsLike = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.apply_reduce.groupby_transform_nb`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.groupby_transform_meta_nb`.
For argument `by`, see `GenericAccessor.groupby_apply`.
Usage:
* Using regular function:
```pycon
>>> zscore_nb = njit(lambda a: (a - np.nanmean(a)) / np.nanstd(a))
>>> df.vbt.groupby_transform([1, 1, 2, 2, 3], zscore_nb)
a b c
2020-01-01 -1.000000 1.666667 -1.000000
2020-01-02 -0.333333 1.000000 -0.333333
2020-01-03 0.242536 0.242536 0.242536
2020-01-04 1.697749 -1.212678 -1.212678
2020-01-05 1.414214 -0.707107 -0.707107
```
* Using meta function:
```pycon
>>> zscore_ratio_meta_nb = njit(lambda idxs, group, a, b: \\
... zscore_nb(a[idxs]) / zscore_nb(b[idxs]))
>>> vbt.pd_acc.groupby_transform(
... [1, 1, 2, 2, 3],
... zscore_ratio_meta_nb,
... df.vbt.to_2d_array(),
... df.vbt.to_2d_array()[::-1],
... wrapper=df.vbt.wrapper
... )
a b c
2020-01-01 -0.600000 -1.666667 1.0
2020-01-02 -0.333333 -3.000000 1.0
2020-01-03 1.000000 1.000000 1.0
2020-01-04 -1.400000 -0.714286 1.0
2020-01-05 -2.000000 -0.500000 1.0
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if wrap_kwargs is None:
wrap_kwargs = {}
if isinstance(transform_func_nb, str):
transform_func_nb = getattr(nb, transform_func_nb + "_transform_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context)
by = substitute_templates(by, template_context, eval_id="by")
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
grouper = wrapper.get_index_grouper(by, **resolve_dict(groupby_kwargs))
if isinstance(cls_or_self, type):
group_map = grouper.get_group_map()
template_context = merge_dicts(dict(by=by, grouper=grouper), template_context)
args = substitute_templates(args, template_context, eval_id="args")
func = jit_reg.resolve_option(nb.groupby_transform_meta_nb, jitted)
out = func(wrapper.shape_2d, group_map, transform_func_nb, *args)
else:
group_map = grouper.get_group_map()
func = jit_reg.resolve_option(nb.groupby_transform_nb, jitted)
out = func(cls_or_self.to_2d_array(), group_map, transform_func_nb, *args)
return wrapper.wrap(out, group_by=False, **wrap_kwargs)
@hybrid_method
def resample_apply(
cls_or_self,
rule: tp.AnyRuleLike,
reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.GroupByReduceMetaFunc, tp.RangeReduceMetaFunc],
*args,
use_groupby_apply: bool = False,
freq: tp.Optional[tp.FrequencyLike] = None,
resample_kwargs: tp.KwargsLike = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Resample.
Argument `rule` can be an instance of `vectorbtpro.base.resampling.base.Resampler`,
`pandas.core.resample.Resampler`, or any other frequency-like object that can be accepted by
`pd.DataFrame.resample` with `resample_kwargs` passed as keyword arguments.
If `use_groupby_apply` is True, uses `GenericAccessor.groupby_apply` (with some post-processing).
Otherwise, uses `GenericAccessor.resample_to_index`.
Usage:
* Using regular function:
```pycon
>>> mean_nb = njit(lambda a: np.nanmean(a))
>>> df.vbt.resample_apply('2d', mean_nb)
a b c
2020-01-01 1.5 4.5 1.5
2020-01-03 3.5 2.5 2.5
2020-01-05 5.0 1.0 1.0
```
* Using meta function:
```pycon
>>> mean_ratio_meta_nb = njit(lambda idxs, group, col, a, b: \\
... np.mean(a[idxs, col]) / np.mean(b[idxs, col]))
>>> vbt.pd_acc.resample_apply(
... '2d',
... mean_ratio_meta_nb,
... df.vbt.to_2d_array() - 1,
... df.vbt.to_2d_array() + 1,
... wrapper=df.vbt.wrapper
... )
a b c
2020-01-01 0.200000 0.636364 0.200000
2020-01-03 0.555556 0.428571 0.428571
2020-01-05 0.666667 0.000000 0.000000
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.resample_apply(
... '2d',
... mean_ratio_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=pd.Series([1, 2, 3, 4, 5], index=df.index),
... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
a b c
2020-01-01 1.5 0.75 0.500000
2020-01-03 3.5 1.75 1.166667
2020-01-05 5.0 2.50 1.666667
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context)
rule = substitute_templates(rule, template_context, eval_id="rule")
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
if use_groupby_apply:
if isinstance(rule, Resampler):
raise TypeError("Resampler cannot be used with use_groupby_apply=True")
if not isinstance(rule, PandasResampler):
rule = pd.Series(index=wrapper.index, dtype=object).resample(rule, **resolve_dict(resample_kwargs))
out_obj = cls_or_self.groupby_apply(
rule,
reduce_func_nb,
*args,
template_context=template_context,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
new_index = rule.count().index.rename("group")
if pd.Index.equals(out_obj.index, new_index):
if new_index.freq is not None:
try:
out_obj.index.freq = new_index.freq
except ValueError as e:
pass
return out_obj
resampled_arr = np.full((rule.ngroups, wrapper.shape_2d[1]), np.nan)
resampled_obj = wrapper.wrap(
resampled_arr,
index=new_index,
**resolve_dict(wrap_kwargs),
)
resampled_obj.loc[out_obj.index] = out_obj.values
return resampled_obj
if not isinstance(rule, Resampler):
rule = wrapper.get_resampler(
rule,
freq=freq,
resample_kwargs=resample_kwargs,
return_pd_resampler=False,
)
return cls_or_self.resample_to_index(
rule,
reduce_func_nb,
*args,
template_context=template_context,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@hybrid_method
def apply_and_reduce(
cls_or_self,
apply_func_nb: tp.Union[str, tp.ApplyFunc, tp.ApplyMetaFunc],
reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.ReduceMetaFunc],
apply_args: tp.Optional[tuple] = None,
reduce_args: tp.Optional[tuple] = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""See `vectorbtpro.generic.nb.apply_reduce.apply_and_reduce_nb`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.apply_and_reduce_meta_nb`.
Usage:
* Using regular function:
```pycon
>>> greater_nb = njit(lambda a: a[a > 2])
>>> mean_nb = njit(lambda a: np.nanmean(a))
>>> df.vbt.apply_and_reduce(greater_nb, mean_nb)
a 4.0
b 4.0
c 3.0
Name: apply_and_reduce, dtype: float64
```
* Using meta function:
```pycon
>>> and_meta_nb = njit(lambda col, a, b: a[:, col] & b[:, col])
>>> sum_meta_nb = njit(lambda col, x: np.sum(x))
>>> vbt.pd_acc.apply_and_reduce(
... and_meta_nb,
... sum_meta_nb,
... apply_args=(
... df.vbt.to_2d_array() > 1,
... df.vbt.to_2d_array() < 4
... ),
... wrapper=df.vbt.wrapper
... )
a 2
b 2
c 3
Name: apply_and_reduce, dtype: int64
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.apply_and_reduce(
... and_meta_nb,
... sum_meta_nb,
... apply_args=(
... vbt.Rep('mask_a'),
... vbt.Rep('mask_b')
... ),
... broadcast_named_args=dict(
... mask_a=pd.Series([True, True, True, False, False], index=df.index),
... mask_b=pd.DataFrame([[True, True, False]], columns=['a', 'b', 'c'])
... )
... )
a 3
b 3
c 0
Name: apply_and_reduce, dtype: int64
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(apply_func_nb, str):
apply_func_nb = getattr(nb, apply_func_nb + "_apply_nb")
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb")
if apply_args is None:
apply_args = ()
if reduce_args is None:
reduce_args = ()
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context)
apply_args = substitute_templates(apply_args, template_context, eval_id="apply_args")
reduce_args = substitute_templates(reduce_args, template_context, eval_id="reduce_args")
func = jit_reg.resolve_option(nb.apply_and_reduce_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.shape_2d[1], apply_func_nb, apply_args, reduce_func_nb, reduce_args)
else:
func = jit_reg.resolve_option(nb.apply_and_reduce_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), apply_func_nb, apply_args, reduce_func_nb, reduce_args)
if wrapper is None:
wrapper = cls_or_self.wrapper
wrap_kwargs = merge_dicts(dict(name_or_index="apply_and_reduce"), wrap_kwargs)
return wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
@hybrid_method
def reduce(
cls_or_self,
reduce_func_nb: tp.Union[
str,
tp.ReduceFunc,
tp.ReduceMetaFunc,
tp.ReduceToArrayFunc,
tp.ReduceToArrayMetaFunc,
tp.ReduceGroupedFunc,
tp.ReduceGroupedMetaFunc,
tp.ReduceGroupedToArrayFunc,
tp.ReduceGroupedToArrayMetaFunc,
],
*args,
returns_array: bool = False,
returns_idx: bool = False,
flatten: bool = False,
order: str = "C",
to_index: bool = True,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeriesFrame:
"""Reduce by column/group.
Set `flatten` to True when working with grouped data to pass a flattened array to `reduce_func_nb`.
The order in which to flatten the array can be specified using `order`.
Set `returns_array` to True if `reduce_func_nb` returns an array.
Set `returns_idx` to True if `reduce_func_nb` returns row index/position.
Set `to_index` to True to return labels instead of positions.
For implementation details, see
* `vectorbtpro.generic.nb.apply_reduce.reduce_flat_grouped_to_array_nb` if grouped, `returns_array` is True, and `flatten` is True
* `vectorbtpro.generic.nb.apply_reduce.reduce_flat_grouped_nb` if grouped, `returns_array` is False, and `flatten` is True
* `vectorbtpro.generic.nb.apply_reduce.reduce_grouped_to_array_nb` if grouped, `returns_array` is True, and `flatten` is False
* `vectorbtpro.generic.nb.apply_reduce.reduce_grouped_nb` if grouped, `returns_array` is False, and `flatten` is False
* `vectorbtpro.generic.nb.apply_reduce.reduce_to_array_nb` if not grouped and `returns_array` is True
* `vectorbtpro.generic.nb.apply_reduce.reduce_nb` if not grouped and `returns_array` is False
For implementation details on the meta versions, see
* `vectorbtpro.generic.nb.apply_reduce.reduce_grouped_to_array_meta_nb` if grouped and `returns_array` is True
* `vectorbtpro.generic.nb.apply_reduce.reduce_grouped_meta_nb` if grouped and `returns_array` is False
* `vectorbtpro.generic.nb.apply_reduce.reduce_to_array_meta_nb` if not grouped and `returns_array` is True
* `vectorbtpro.generic.nb.apply_reduce.reduce_meta_nb` if not grouped and `returns_array` is False
`reduce_func_nb` can be a string denoting the suffix of a reducing function
from `vectorbtpro.generic.nb`. For example, "sum" will refer to "sum_reduce_nb".
Usage:
* Using regular function:
```pycon
>>> mean_nb = njit(lambda a: np.nanmean(a))
>>> df.vbt.reduce(mean_nb)
a 3.0
b 3.0
c 1.8
Name: reduce, dtype: float64
>>> argmax_nb = njit(lambda a: np.argmax(a))
>>> df.vbt.reduce(argmax_nb, returns_idx=True)
a 2020-01-05
b 2020-01-01
c 2020-01-03
Name: reduce, dtype: datetime64[ns]
>>> df.vbt.reduce(argmax_nb, returns_idx=True, to_index=False)
a 4
b 0
c 2
Name: reduce, dtype: int64
>>> min_max_nb = njit(lambda a: np.array([np.nanmin(a), np.nanmax(a)]))
>>> df.vbt.reduce(min_max_nb, returns_array=True, wrap_kwargs=dict(name_or_index=['min', 'max']))
a b c
min 1 1 1
max 5 5 3
>>> group_by = pd.Series(['first', 'first', 'second'], name='group')
>>> df.vbt.reduce(mean_nb, group_by=group_by)
group
first 3.0
second 1.8
dtype: float64
```
* Using meta function:
```pycon
>>> mean_meta_nb = njit(lambda col, a: np.nanmean(a[:, col]))
>>> pd.Series.vbt.reduce(
... mean_meta_nb,
... df['a'].vbt.to_2d_array(),
... wrapper=df['a'].vbt.wrapper
... )
3.0
>>> vbt.pd_acc.reduce(
... mean_meta_nb,
... df.vbt.to_2d_array(),
... wrapper=df.vbt.wrapper
... )
a 3.0
b 3.0
c 1.8
Name: reduce, dtype: float64
>>> grouped_mean_meta_nb = njit(lambda group_idxs, group, a: np.nanmean(a[:, group_idxs]))
>>> group_by = pd.Series(['first', 'first', 'second'], name='group')
>>> vbt.pd_acc.reduce(
... grouped_mean_meta_nb,
... df.vbt.to_2d_array(),
... wrapper=df.vbt.wrapper,
... group_by=group_by
... )
group
first 3.0
second 1.8
Name: reduce, dtype: float64
```
* Using templates and broadcasting:
```pycon
>>> mean_a_b_nb = njit(lambda col, a, b: \\
... np.array([np.nanmean(a[:, col]), np.nanmean(b[:, col])]))
>>> vbt.pd_acc.reduce(
... mean_a_b_nb,
... vbt.Rep('arr1'),
... vbt.Rep('arr2'),
... returns_array=True,
... broadcast_named_args=dict(
... arr1=pd.Series([1, 2, 3, 4, 5], index=df.index),
... arr2=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... ),
... wrap_kwargs=dict(name_or_index=['arr1', 'arr2'])
... )
a b c
arr1 3.0 3.0 3.0
arr2 1.0 2.0 3.0
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(
broadcast_named_args,
dict(
wrapper=wrapper,
group_by=group_by,
returns_array=returns_array,
returns_idx=returns_idx,
flatten=flatten,
order=order,
),
template_context,
)
args = substitute_templates(args, template_context, eval_id="args")
if wrapper.grouper.is_grouped(group_by=group_by):
group_map = wrapper.grouper.get_group_map(group_by=group_by)
if returns_array:
func = jit_reg.resolve_option(nb.reduce_grouped_to_array_meta_nb, jitted)
else:
func = jit_reg.resolve_option(nb.reduce_grouped_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(group_map, reduce_func_nb, *args)
else:
if returns_array:
func = jit_reg.resolve_option(nb.reduce_to_array_meta_nb, jitted)
else:
func = jit_reg.resolve_option(nb.reduce_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.shape_2d[1], reduce_func_nb, *args)
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
if wrapper.grouper.is_grouped(group_by=group_by):
group_map = wrapper.grouper.get_group_map(group_by=group_by)
if flatten:
checks.assert_in(order.upper(), ["C", "F"])
in_c_order = order.upper() == "C"
if returns_array:
func = jit_reg.resolve_option(nb.reduce_flat_grouped_to_array_nb, jitted)
else:
func = jit_reg.resolve_option(nb.reduce_flat_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), group_map, in_c_order, reduce_func_nb, *args)
if returns_idx:
if in_c_order:
out //= group_map[1] # flattened in C order
else:
out %= wrapper.shape[0] # flattened in F order
else:
if returns_array:
func = jit_reg.resolve_option(nb.reduce_grouped_to_array_nb, jitted)
else:
func = jit_reg.resolve_option(nb.reduce_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), group_map, reduce_func_nb, *args)
else:
if returns_array:
func = jit_reg.resolve_option(nb.reduce_to_array_nb, jitted)
else:
func = jit_reg.resolve_option(nb.reduce_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), reduce_func_nb, *args)
wrap_kwargs = merge_dicts(
dict(
name_or_index="reduce" if not returns_array else None,
to_index=returns_idx and to_index,
fillna=-1 if returns_idx else None,
dtype=int_ if returns_idx else None,
),
wrap_kwargs,
)
return wrapper.wrap_reduced(out, group_by=group_by, **wrap_kwargs)
@hybrid_method
def proximity_apply(
cls_or_self,
window: int,
reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.ProximityReduceMetaFunc],
*args,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Frame:
"""See `vectorbtpro.generic.nb.apply_reduce.proximity_reduce_nb`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.proximity_reduce_meta_nb`.
Usage:
* Using regular function:
```pycon
>>> mean_nb = njit(lambda a: np.nanmean(a))
>>> df.vbt.proximity_apply(1, mean_nb)
a b c
2020-01-01 3.0 2.500000 3.000000
2020-01-02 3.0 2.666667 3.000000
2020-01-03 3.0 2.777778 2.666667
2020-01-04 3.0 2.666667 2.000000
2020-01-05 3.0 2.500000 1.500000
```
* Using meta function:
```pycon
>>> @njit
... def mean_ratio_meta_nb(from_i, to_i, from_col, to_col, a, b):
... a_mean = np.mean(a[from_i:to_i, from_col:to_col])
... b_mean = np.mean(b[from_i:to_i, from_col:to_col])
... return a_mean / b_mean
>>> vbt.pd_acc.proximity_apply(
... 1,
... mean_ratio_meta_nb,
... df.vbt.to_2d_array() - 1,
... df.vbt.to_2d_array() + 1,
... wrapper=df.vbt.wrapper,
... )
a b c
2020-01-01 0.5 0.428571 0.500000
2020-01-02 0.5 0.454545 0.500000
2020-01-03 0.5 0.470588 0.454545
2020-01-04 0.5 0.454545 0.333333
2020-01-05 0.5 0.428571 0.200000
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.proximity_apply(
... 1,
... mean_ratio_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=pd.Series([1, 2, 3, 4, 5], index=df.index),
... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... )
... )
a b c
2020-01-01 1.000000 0.75 0.6
2020-01-02 1.333333 1.00 0.8
2020-01-03 2.000000 1.50 1.2
2020-01-04 2.666667 2.00 1.6
2020-01-05 3.000000 2.25 1.8
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb")
if isinstance(cls_or_self, type):
template_context = merge_dicts(
broadcast_named_args,
dict(wrapper=wrapper, window=window),
template_context,
)
args = substitute_templates(args, template_context, eval_id="args")
func = jit_reg.resolve_option(nb.proximity_reduce_meta_nb, jitted)
out = func(wrapper.shape_2d, window, reduce_func_nb, *args)
else:
func = jit_reg.resolve_option(nb.proximity_reduce_nb, jitted)
out = func(cls_or_self.to_2d_array(), window, reduce_func_nb, *args)
return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
# ############# Squeezing ############# #
@hybrid_method
def squeeze_grouped(
cls_or_self,
squeeze_func_nb: tp.Union[str, tp.ReduceFunc, tp.GroupSqueezeMetaFunc],
*args,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Squeeze each group of columns into a single column.
See `vectorbtpro.generic.nb.apply_reduce.squeeze_grouped_nb`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.squeeze_grouped_meta_nb`.
Usage:
* Using regular function:
```pycon
>>> mean_nb = njit(lambda a: np.nanmean(a))
>>> group_by = pd.Series(['first', 'first', 'second'], name='group')
>>> df.vbt.squeeze_grouped(mean_nb, group_by=group_by)
group first second
2020-01-01 3.0 1.0
2020-01-02 3.0 2.0
2020-01-03 3.0 3.0
2020-01-04 3.0 2.0
2020-01-05 3.0 1.0
```
* Using meta function:
```pycon
>>> mean_ratio_meta_nb = njit(lambda i, group_idxs, group, a, b: \\
... np.mean(a[i][group_idxs]) / np.mean(b[i][group_idxs]))
>>> vbt.pd_acc.squeeze_grouped(
... mean_ratio_meta_nb,
... df.vbt.to_2d_array() - 1,
... df.vbt.to_2d_array() + 1,
... wrapper=df.vbt.wrapper,
... group_by=group_by
... )
group first second
2020-01-01 0.5 0.000000
2020-01-02 0.5 0.333333
2020-01-03 0.5 0.500000
2020-01-04 0.5 0.333333
2020-01-05 0.5 0.000000
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.squeeze_grouped(
... mean_ratio_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=pd.Series([1, 2, 3, 4, 5], index=df.index),
... b=pd.DataFrame([[1, 2, 3]], columns=['a', 'b', 'c'])
... ),
... group_by=[0, 0, 1]
... )
0 1
2020-01-01 0.666667 0.333333
2020-01-02 1.333333 0.666667
2020-01-03 2.000000 1.000000
2020-01-04 2.666667 1.333333
2020-01-05 3.333333 1.666667
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(squeeze_func_nb, str):
squeeze_func_nb = getattr(nb, squeeze_func_nb + "_reduce_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(
broadcast_named_args,
dict(wrapper=wrapper, group_by=group_by),
template_context,
)
args = substitute_templates(args, template_context, eval_id="args")
if not wrapper.grouper.is_grouped(group_by=group_by):
raise ValueError("Grouping required")
group_map = wrapper.grouper.get_group_map(group_by=group_by)
func = jit_reg.resolve_option(nb.squeeze_grouped_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(wrapper.shape_2d[0], group_map, squeeze_func_nb, *args)
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
if not wrapper.grouper.is_grouped(group_by=group_by):
raise ValueError("Grouping required")
group_map = wrapper.grouper.get_group_map(group_by=group_by)
func = jit_reg.resolve_option(nb.squeeze_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.to_2d_array(), group_map, squeeze_func_nb, *args)
return wrapper.wrap(out, group_by=group_by, **resolve_dict(wrap_kwargs))
# ############# Flattening ############# #
def flatten_grouped(
self,
order: str = "C",
jitted: tp.JittedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Flatten each group of columns.
See `vectorbtpro.generic.nb.apply_reduce.flatten_grouped_nb`.
If all groups have the same length, see `vectorbtpro.generic.nb.apply_reduce.flatten_uniform_grouped_nb`.
!!! warning
Make sure that the distribution of group lengths is close to uniform, otherwise
groups with less columns will be filled with NaN and needlessly occupy memory.
Usage:
```pycon
>>> group_by = pd.Series(['first', 'first', 'second'], name='group')
>>> df.vbt.flatten_grouped(group_by=group_by, order='C')
group first second
2020-01-01 1.0 1.0
2020-01-01 5.0 NaN
2020-01-02 2.0 2.0
2020-01-02 4.0 NaN
2020-01-03 3.0 3.0
2020-01-03 3.0 NaN
2020-01-04 4.0 2.0
2020-01-04 2.0 NaN
2020-01-05 5.0 1.0
2020-01-05 1.0 NaN
>>> df.vbt.flatten_grouped(group_by=group_by, order='F')
group first second
2020-01-01 1.0 1.0
2020-01-02 2.0 2.0
2020-01-03 3.0 3.0
2020-01-04 4.0 2.0
2020-01-05 5.0 1.0
2020-01-01 5.0 NaN
2020-01-02 4.0 NaN
2020-01-03 3.0 NaN
2020-01-04 2.0 NaN
2020-01-05 1.0 NaN
```
"""
if not self.wrapper.grouper.is_grouped(group_by=group_by):
raise ValueError("Grouping required")
checks.assert_in(order.upper(), ["C", "F"])
group_map = self.wrapper.grouper.get_group_map(group_by=group_by)
if np.all(group_map[1] == group_map[1].item(0)):
func = jit_reg.resolve_option(nb.flatten_uniform_grouped_nb, jitted)
else:
func = jit_reg.resolve_option(nb.flatten_grouped_nb, jitted)
if order.upper() == "C":
out = func(self.to_2d_array(), group_map, True)
new_index = indexes.repeat_index(self.wrapper.index, np.max(group_map[1]))
else:
out = func(self.to_2d_array(), group_map, False)
new_index = indexes.tile_index(self.wrapper.index, np.max(group_map[1]))
wrap_kwargs = merge_dicts(dict(index=new_index), wrap_kwargs)
return self.wrapper.wrap(out, group_by=group_by, **wrap_kwargs)
# ############# Resampling ############# #
def realign(
self,
index: tp.AnyRuleLike,
freq: tp.Union[None, bool, tp.FrequencyLike] = None,
nan_value: tp.Optional[tp.Scalar] = None,
ffill: bool = True,
source_rbound: tp.Union[bool, str, tp.IndexLike] = False,
target_rbound: tp.Union[bool, str, tp.IndexLike] = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.MaybeSeriesFrame:
"""See `vectorbtpro.generic.nb.base.realign_nb`.
`index` can be either an instance of `vectorbtpro.base.resampling.base.Resampler`,
or any index-like object.
Gives the same results as `df.resample(closed='right', label='right').last().ffill()`
when applied on the target index of the resampler.
Usage:
* Downsampling:
```pycon
>>> h_index = pd.date_range('2020-01-01', '2020-01-05', freq='1h')
>>> d_index = pd.date_range('2020-01-01', '2020-01-05', freq='1d')
>>> h_sr = pd.Series(range(len(h_index)), index=h_index)
>>> h_sr.vbt.realign(d_index)
2020-01-01 0.0
2020-01-02 24.0
2020-01-03 48.0
2020-01-04 72.0
2020-01-05 96.0
Freq: D, dtype: float64
```
* Upsampling:
```pycon
>>> d_sr = pd.Series(range(len(d_index)), index=d_index)
>>> d_sr.vbt.realign(h_index)
2020-01-01 00:00:00 0.0
2020-01-01 01:00:00 0.0
2020-01-01 02:00:00 0.0
2020-01-01 03:00:00 0.0
2020-01-01 04:00:00 0.0
... ...
2020-01-04 20:00:00 3.0
2020-01-04 21:00:00 3.0
2020-01-04 22:00:00 3.0
2020-01-04 23:00:00 3.0
2020-01-05 00:00:00 4.0
Freq: H, Length: 97, dtype: float64
```
"""
resampler = self.wrapper.get_resampler(
index,
freq=freq,
return_pd_resampler=False,
silence_warnings=silence_warnings,
)
one_index = False
if len(resampler.target_index) == 1 and checks.is_dt_like(index):
if isinstance(index, str):
try:
dt.to_freq(index)
one_index = False
except Exception as e:
one_index = True
else:
one_index = True
if isinstance(source_rbound, bool):
use_source_rbound = source_rbound
else:
use_source_rbound = False
if isinstance(source_rbound, str):
if source_rbound == "pandas":
resampler = resampler.replace(source_index=resampler.source_rbound_index)
else:
raise ValueError(f"Invalid source_rbound: '{source_rbound}'")
else:
resampler = resampler.replace(source_index=source_rbound)
if isinstance(target_rbound, bool):
use_target_rbound = target_rbound
index = resampler.target_index
else:
use_target_rbound = False
index = resampler.target_index
if isinstance(target_rbound, str):
if target_rbound == "pandas":
resampler = resampler.replace(target_index=resampler.target_rbound_index)
else:
raise ValueError(f"Invalid target_rbound: '{target_rbound}'")
else:
resampler = resampler.replace(target_index=target_rbound)
if not use_source_rbound:
source_freq = None
else:
source_freq = resampler.get_np_source_freq()
if not use_target_rbound:
target_freq = None
else:
target_freq = resampler.get_np_target_freq()
if nan_value is None:
if self.mapping is not None:
nan_value = -1
else:
nan_value = np.nan
func = jit_reg.resolve_option(nb.realign_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
resampler.source_index.values,
resampler.target_index.values,
source_freq=source_freq,
target_freq=target_freq,
source_rbound=use_source_rbound,
target_rbound=use_target_rbound,
nan_value=nan_value,
ffill=ffill,
)
wrap_kwargs = merge_dicts(dict(index=index), wrap_kwargs)
out = self.wrapper.wrap(out, group_by=False, **wrap_kwargs)
if one_index:
return out.iloc[0]
return out
def realign_opening(self, *args, **kwargs) -> tp.MaybeSeriesFrame:
"""`GenericAccessor.realign` but creating a resampler and using the left bound
of the source and target index."""
return self.realign(*args, source_rbound=False, target_rbound=False, **kwargs)
def realign_closing(self, *args, **kwargs) -> tp.MaybeSeriesFrame:
"""`GenericAccessor.realign` but creating a resampler and using the right bound
of the source and target index.
!!! note
The timestamps in the source and target index should denote the open time."""
return self.realign(*args, source_rbound=True, target_rbound=True, **kwargs)
@hybrid_method
def resample_to_index(
cls_or_self,
index: tp.AnyRuleLike,
reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.RangeReduceMetaFunc],
*args,
freq: tp.Union[None, bool, tp.FrequencyLike] = None,
before: bool = False,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
) -> tp.SeriesFrame:
"""Resample solely based on target index.
Applies `vectorbtpro.generic.nb.apply_reduce.reduce_index_ranges_nb` on index ranges
from `vectorbtpro.base.resampling.nb.map_index_to_source_ranges_nb`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.reduce_index_ranges_meta_nb`.
Usage:
* Downsampling:
```pycon
>>> h_index = pd.date_range('2020-01-01', '2020-01-05', freq='1h')
>>> d_index = pd.date_range('2020-01-01', '2020-01-05', freq='1d')
>>> h_sr = pd.Series(range(len(h_index)), index=h_index)
>>> h_sr.vbt.resample_to_index(d_index, njit(lambda x: x.mean()))
2020-01-01 11.5
2020-01-02 35.5
2020-01-03 59.5
2020-01-04 83.5
2020-01-05 96.0
Freq: D, dtype: float64
>>> h_sr.vbt.resample_to_index(d_index, njit(lambda x: x.mean()), before=True)
2020-01-01 0.0
2020-01-02 12.5
2020-01-03 36.5
2020-01-04 60.5
2020-01-05 84.5
Freq: D, dtype: float64
```
* Upsampling:
```pycon
>>> d_sr = pd.Series(range(len(d_index)), index=d_index)
>>> d_sr.vbt.resample_to_index(h_index, njit(lambda x: x[-1]))
2020-01-01 00:00:00 0.0
2020-01-01 01:00:00 NaN
2020-01-01 02:00:00 NaN
2020-01-01 03:00:00 NaN
2020-01-01 04:00:00 NaN
... ...
2020-01-04 20:00:00 NaN
2020-01-04 21:00:00 NaN
2020-01-04 22:00:00 NaN
2020-01-04 23:00:00 NaN
2020-01-05 00:00:00 4.0
Freq: H, Length: 97, dtype: float64
```
* Using meta function:
```pycon
>>> mean_ratio_meta_nb = njit(lambda from_i, to_i, col, a, b: \\
... np.mean(a[from_i:to_i][col]) / np.mean(b[from_i:to_i][col]))
>>> vbt.pd_acc.resample_to_index(
... d_index,
... mean_ratio_meta_nb,
... h_sr.vbt.to_2d_array() - 1,
... h_sr.vbt.to_2d_array() + 1,
... wrapper=h_sr.vbt.wrapper
... )
2020-01-01 -1.000000
2020-01-02 0.920000
2020-01-03 0.959184
2020-01-04 0.972603
2020-01-05 0.979381
Freq: D, dtype: float64
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.resample_to_index(
... d_index,
... mean_ratio_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=h_sr - 1,
... b=h_sr + 1
... )
... )
2020-01-01 -1.000000
2020-01-02 0.920000
2020-01-03 0.959184
2020-01-04 0.972603
2020-01-05 0.979381
Freq: D, dtype: float64
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context)
index = substitute_templates(index, template_context, eval_id="index")
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
resampler = wrapper.get_resampler(
index,
freq=freq,
return_pd_resampler=False,
silence_warnings=silence_warnings,
)
index_ranges = resampler.map_index_to_source_ranges(before=before, jitted=jitted)
if isinstance(cls_or_self, type):
template_context = merge_dicts(
dict(
resampler=resampler,
index_ranges=index_ranges,
),
template_context,
)
args = substitute_templates(args, template_context, eval_id="args")
func = jit_reg.resolve_option(nb.reduce_index_ranges_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
wrapper.shape_2d[1],
index_ranges[0],
index_ranges[1],
reduce_func_nb,
*args,
)
else:
func = jit_reg.resolve_option(nb.reduce_index_ranges_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
cls_or_self.to_2d_array(),
index_ranges[0],
index_ranges[1],
reduce_func_nb,
*args,
)
wrap_kwargs = merge_dicts(dict(index=resampler.target_index), wrap_kwargs)
return wrapper.wrap(out, group_by=False, **wrap_kwargs)
@hybrid_method
def resample_between_bounds(
cls_or_self,
target_lbound_index: tp.IndexLike,
target_rbound_index: tp.IndexLike,
reduce_func_nb: tp.Union[str, tp.ReduceFunc, tp.RangeReduceMetaFunc],
*args,
closed_lbound: bool = True,
closed_rbound: bool = False,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_with_lbound: tp.Optional[bool] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Resample between target index bounds.
Applies `vectorbtpro.generic.nb.apply_reduce.reduce_index_ranges_nb` on index ranges
from `vectorbtpro.base.resampling.nb.map_bounds_to_source_ranges_nb`.
For details on the meta version, see `vectorbtpro.generic.nb.apply_reduce.reduce_index_ranges_meta_nb`.
Usage:
* Using regular function:
```pycon
>>> h_index = pd.date_range('2020-01-01', '2020-01-05', freq='1h')
>>> d_index = pd.date_range('2020-01-01', '2020-01-05', freq='1d')
>>> h_sr = pd.Series(range(len(h_index)), index=h_index)
>>> h_sr.vbt.resample_between_bounds(d_index, d_index.shift(), njit(lambda x: x.mean()))
2020-01-01 11.5
2020-01-02 35.5
2020-01-03 59.5
2020-01-04 83.5
2020-01-05 96.0
Freq: D, dtype: float64
```
* Using meta function:
```pycon
>>> mean_ratio_meta_nb = njit(lambda from_i, to_i, col, a, b: \\
... np.mean(a[from_i:to_i][col]) / np.mean(b[from_i:to_i][col]))
>>> vbt.pd_acc.resample_between_bounds(
... d_index,
... d_index.shift(),
... mean_ratio_meta_nb,
... h_sr.vbt.to_2d_array() - 1,
... h_sr.vbt.to_2d_array() + 1,
... wrapper=h_sr.vbt.wrapper
... )
2020-01-01 -1.000000
2020-01-02 0.920000
2020-01-03 0.959184
2020-01-04 0.972603
2020-01-05 0.979381
Freq: D, dtype: float64
```
* Using templates and broadcasting:
```pycon
>>> vbt.pd_acc.resample_between_bounds(
... d_index,
... d_index.shift(),
... mean_ratio_meta_nb,
... vbt.Rep('a'),
... vbt.Rep('b'),
... broadcast_named_args=dict(
... a=h_sr - 1,
... b=h_sr + 1
... )
... )
2020-01-01 -1.000000
2020-01-02 0.920000
2020-01-03 0.959184
2020-01-04 0.972603
2020-01-05 0.979381
Freq: D, dtype: float64
```
"""
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(nb, reduce_func_nb + "_reduce_nb")
if isinstance(cls_or_self, type):
if len(broadcast_named_args) > 0:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
if wrapper is not None:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=wrapper.shape_2d,
**broadcast_kwargs,
)
else:
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
template_context = merge_dicts(broadcast_named_args, dict(wrapper=wrapper), template_context)
target_lbound_index = substitute_templates(
target_lbound_index, template_context, eval_id="target_lbound_index"
)
target_rbound_index = substitute_templates(
target_rbound_index, template_context, eval_id="target_rbound_index"
)
else:
if wrapper is None:
wrapper = cls_or_self.wrapper
target_lbound_index = dt.prepare_dt_index(target_lbound_index)
target_rbound_index = dt.prepare_dt_index(target_rbound_index)
if len(target_lbound_index) == 1 and len(target_rbound_index) > 1:
target_lbound_index = repeat_index(target_lbound_index, len(target_rbound_index))
if wrap_with_lbound is None:
wrap_with_lbound = False
elif len(target_lbound_index) > 1 and len(target_rbound_index) == 1:
target_rbound_index = repeat_index(target_rbound_index, len(target_lbound_index))
if wrap_with_lbound is None:
wrap_with_lbound = True
index_ranges = Resampler.map_bounds_to_source_ranges(
source_index=wrapper.index.values,
target_lbound_index=target_lbound_index.values,
target_rbound_index=target_rbound_index.values,
closed_lbound=closed_lbound,
closed_rbound=closed_rbound,
skip_not_found=False,
jitted=jitted,
)
if isinstance(cls_or_self, type):
template_context = merge_dicts(
dict(
target_lbound_index=target_lbound_index,
target_rbound_index=target_rbound_index,
index_ranges=index_ranges,
),
template_context,
)
args = substitute_templates(args, template_context, eval_id="args")
func = jit_reg.resolve_option(nb.reduce_index_ranges_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
wrapper.shape_2d[1],
index_ranges[0],
index_ranges[1],
reduce_func_nb,
*args,
)
else:
func = jit_reg.resolve_option(nb.reduce_index_ranges_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
cls_or_self.to_2d_array(),
index_ranges[0],
index_ranges[1],
reduce_func_nb,
*args,
)
if wrap_with_lbound is None:
if closed_lbound:
wrap_with_lbound = True
elif closed_rbound:
wrap_with_lbound = False
else:
wrap_with_lbound = True
if wrap_with_lbound:
wrap_kwargs = merge_dicts(dict(index=target_lbound_index), wrap_kwargs)
else:
wrap_kwargs = merge_dicts(dict(index=target_rbound_index), wrap_kwargs)
return wrapper.wrap(out, group_by=False, **wrap_kwargs)
# ############# Describing ############# #
def min(
self,
use_jitted: tp.Optional[bool] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return min of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="min"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.min_reduce_nb, jitted),
flatten=True,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
from vectorbtpro._settings import settings
generic_cfg = settings["generic"]
arr = self.to_2d_array()
if use_jitted is None:
use_jitted = generic_cfg["use_jitted"]
if use_jitted:
func = jit_reg.resolve_option(nb.nanmin_nb, jitted)
elif arr.dtype != int and arr.dtype != float:
# bottleneck can't consume other than that
func = partial(np.nanmin, axis=0)
else:
func = partial(nanmin, axis=0)
func = ch_reg.resolve_option(nb.nanmin_nb, chunked, target_func=func)
return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs)
def max(
self,
use_jitted: tp.Optional[bool] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return max of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="max"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.max_reduce_nb, jitted),
flatten=True,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
from vectorbtpro._settings import settings
generic_cfg = settings["generic"]
arr = self.to_2d_array()
if use_jitted is None:
use_jitted = generic_cfg["use_jitted"]
if use_jitted:
func = jit_reg.resolve_option(nb.nanmax_nb, jitted)
elif arr.dtype != int and arr.dtype != float:
# bottleneck can't consume other than that
func = partial(np.nanmax, axis=0)
else:
func = partial(nanmax, axis=0)
func = ch_reg.resolve_option(nb.nanmax_nb, chunked, target_func=func)
return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs)
def mean(
self,
use_jitted: tp.Optional[bool] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return mean of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="mean"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.mean_reduce_nb, jitted),
flatten=True,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
from vectorbtpro._settings import settings
generic_cfg = settings["generic"]
arr = self.to_2d_array()
if use_jitted is None:
use_jitted = generic_cfg["use_jitted"]
if use_jitted:
func = jit_reg.resolve_option(nb.nanmean_nb, jitted)
elif arr.dtype != int and arr.dtype != float:
# bottleneck can't consume other than that
func = partial(np.nanmean, axis=0)
else:
func = partial(nanmean, axis=0)
func = ch_reg.resolve_option(nb.nanmean_nb, chunked, target_func=func)
return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs)
def median(
self,
use_jitted: tp.Optional[bool] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return median of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="median"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.median_reduce_nb, jitted),
flatten=True,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
from vectorbtpro._settings import settings
generic_cfg = settings["generic"]
arr = self.to_2d_array()
if use_jitted is None:
use_jitted = generic_cfg["use_jitted"]
if use_jitted:
func = jit_reg.resolve_option(nb.nanmedian_nb, jitted)
elif arr.dtype != int and arr.dtype != float:
# bottleneck can't consume other than that
func = partial(np.nanmedian, axis=0)
else:
func = partial(nanmedian, axis=0)
func = ch_reg.resolve_option(nb.nanmedian_nb, chunked, target_func=func)
return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs)
def std(
self,
ddof: int = 1,
use_jitted: tp.Optional[bool] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return standard deviation of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="std"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.std_reduce_nb, jitted),
ddof,
flatten=True,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
from vectorbtpro._settings import settings
generic_cfg = settings["generic"]
arr = self.to_2d_array()
if use_jitted is None:
use_jitted = generic_cfg["use_jitted"]
if use_jitted:
func = jit_reg.resolve_option(nb.nanstd_nb, jitted)
elif arr.dtype != int and arr.dtype != float:
# bottleneck can't consume other than that
func = partial(np.nanstd, axis=0)
else:
func = partial(nanstd, axis=0)
func = ch_reg.resolve_option(nb.nanstd_nb, chunked, target_func=func)
return self.wrapper.wrap_reduced(func(arr, ddof=ddof), group_by=False, **wrap_kwargs)
def sum(
self,
use_jitted: tp.Optional[bool] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return sum of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="sum"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.sum_reduce_nb, jitted),
flatten=True,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
from vectorbtpro._settings import settings
generic_cfg = settings["generic"]
arr = self.to_2d_array()
if use_jitted is None:
use_jitted = generic_cfg["use_jitted"]
if use_jitted:
func = jit_reg.resolve_option(nb.nansum_nb, jitted)
elif arr.dtype != int and arr.dtype != float:
# bottleneck can't consume other than that
func = partial(np.nansum, axis=0)
else:
func = partial(nansum, axis=0)
func = ch_reg.resolve_option(nb.nansum_nb, chunked, target_func=func)
return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs)
def count(
self,
use_jitted: tp.Optional[bool] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return count of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="count", dtype=int_), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.count_reduce_nb, jitted),
flatten=True,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
from vectorbtpro._settings import settings
generic_cfg = settings["generic"]
arr = self.to_2d_array()
if use_jitted is None:
use_jitted = generic_cfg["use_jitted"]
if use_jitted:
func = jit_reg.resolve_option(nb.nancnt_nb, jitted)
else:
func = lambda a: np.sum(~np.isnan(a), axis=0)
func = ch_reg.resolve_option(nb.nancnt_nb, chunked, target_func=func)
return self.wrapper.wrap_reduced(func(arr), group_by=False, **wrap_kwargs)
def cov(
self,
other: tp.SeriesFrame,
ddof: int = 1,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return covariance of non-NaN elements."""
self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs))
self_arr = reshaping.to_2d_array(self_obj)
other_arr = reshaping.to_2d_array(other_obj)
wrap_kwargs = merge_dicts(dict(name_or_index="cov"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return type(self).reduce(
jit_reg.resolve_option(nb.cov_reduce_grouped_meta_nb, jitted),
self_arr,
other_arr,
ddof,
flatten=True,
jitted=jitted,
chunked=chunked,
wrapper=ArrayWrapper.from_obj(self_obj),
group_by=self.wrapper.grouper.resolve_group_by(group_by=group_by),
wrap_kwargs=wrap_kwargs,
)
func = jit_reg.resolve_option(nb.nancov_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
return self.wrapper.wrap_reduced(func(self_arr, other_arr, ddof=ddof), group_by=False, **wrap_kwargs)
def corr(
self,
other: tp.SeriesFrame,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return correlation coefficient of non-NaN elements."""
self_obj, other_obj = reshaping.broadcast(self.obj, other, **resolve_dict(broadcast_kwargs))
self_arr = reshaping.to_2d_array(self_obj)
other_arr = reshaping.to_2d_array(other_obj)
wrap_kwargs = merge_dicts(dict(name_or_index="corr"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return type(self).reduce(
jit_reg.resolve_option(nb.corr_reduce_grouped_meta_nb, jitted),
self_arr,
other_arr,
flatten=True,
jitted=jitted,
chunked=chunked,
wrapper=ArrayWrapper.from_obj(self_obj),
group_by=self.wrapper.grouper.resolve_group_by(group_by=group_by),
wrap_kwargs=wrap_kwargs,
)
func = jit_reg.resolve_option(nb.nancorr_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
return self.wrapper.wrap_reduced(func(self_arr, other_arr), group_by=False, **wrap_kwargs)
def rank(
self,
pct: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Compute numerical data rank.
By default, equal values are assigned a rank that is the average of the ranks of those values."""
func = jit_reg.resolve_option(nb.rank_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
arr = self.to_2d_array()
argsorted = np.argsort(arr, axis=0)
rank = func(arr, argsorted=argsorted, pct=pct)
return self.wrapper.wrap(rank, group_by=False, **resolve_dict(wrap_kwargs))
def idxmin(
self,
order: str = "C",
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return labeled index of min of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="idxmin"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.argmin_reduce_nb, jitted),
returns_idx=True,
flatten=True,
order=order,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
def func(arr, index):
out = np.full(arr.shape[1], np.nan, dtype=object)
nan_mask = np.all(np.isnan(arr), axis=0)
out[~nan_mask] = index[nanargmin(arr[:, ~nan_mask], axis=0)]
return out
chunked = ch.specialize_chunked_option(chunked, arg_take_spec=dict(index=None))
func = ch_reg.resolve_option(nb.nanmin_nb, chunked, target_func=func)
out = func(self.to_2d_array(), self.wrapper.index)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def idxmax(
self,
order: str = "C",
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Return labeled index of max of non-NaN elements."""
wrap_kwargs = merge_dicts(dict(name_or_index="idxmax"), wrap_kwargs)
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.argmax_reduce_nb, jitted),
returns_idx=True,
flatten=True,
order=order,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
def func(arr, index):
out = np.full(arr.shape[1], np.nan, dtype=object)
nan_mask = np.all(np.isnan(arr), axis=0)
out[~nan_mask] = index[nanargmax(arr[:, ~nan_mask], axis=0)]
return out
chunked = ch.specialize_chunked_option(chunked, arg_take_spec=dict(index=None))
func = ch_reg.resolve_option(nb.nanmax_nb, chunked, target_func=func)
out = func(self.to_2d_array(), self.wrapper.index)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def describe(
self,
percentiles: tp.Optional[tp.ArrayLike] = None,
ddof: int = 1,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.apply_reduce.describe_reduce_nb`.
For `percentiles`, see `pd.DataFrame.describe`.
Usage:
```pycon
>>> df.vbt.describe()
a b c
count 5.000000 5.000000 5.00000
mean 3.000000 3.000000 1.80000
std 1.581139 1.581139 0.83666
min 1.000000 1.000000 1.00000
25% 2.000000 2.000000 1.00000
50% 3.000000 3.000000 2.00000
75% 4.000000 4.000000 2.00000
max 5.000000 5.000000 3.00000
```
"""
if percentiles is not None:
percentiles = reshaping.to_1d_array(percentiles)
else:
percentiles = np.array([0.25, 0.5, 0.75])
percentiles = percentiles.tolist()
if 0.5 not in percentiles:
percentiles.append(0.5)
percentiles = np.unique(percentiles)
perc_formatted = pd.io.formats.format.format_percentiles(percentiles)
index = pd.Index(["count", "mean", "std", "min", *perc_formatted, "max"])
wrap_kwargs = merge_dicts(dict(name_or_index=index), wrap_kwargs)
chunked = ch.specialize_chunked_option(chunked, arg_take_spec=dict(args=ch.ArgsTaker(None, None)))
if self.wrapper.grouper.is_grouped(group_by=group_by):
return self.reduce(
jit_reg.resolve_option(nb.describe_reduce_nb, jitted),
percentiles,
ddof,
returns_array=True,
flatten=True,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
else:
return self.reduce(
jit_reg.resolve_option(nb.describe_reduce_nb, jitted),
percentiles,
ddof,
returns_array=True,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
)
def digitize(
self,
bins: tp.ArrayLike = "auto",
right: bool = False,
return_mapping: bool = False,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Union[tp.SeriesFrame, tp.Tuple[tp.SeriesFrame, dict]]:
"""Apply `np.digitize`.
Usage:
```pycon
>>> df.vbt.digitize(3)
a b c
2020-01-01 1 3 1
2020-01-02 1 3 1
2020-01-03 2 2 2
2020-01-04 3 1 1
2020-01-05 3 1 1
```
"""
if wrap_kwargs is None:
wrap_kwargs = {}
arr = self.to_2d_array()
if not np.iterable(bins):
if np.isscalar(bins) and bins < 1:
raise ValueError("Bins must be a positive integer")
rng = (np.nanmin(self.obj.values), np.nanmax(self.obj.values))
mn, mx = (mi + 0.0 for mi in rng)
if np.isinf(mn) or np.isinf(mx):
raise ValueError("Cannot specify integer bins when input data contains infinity")
elif mn == mx: # adjust end points before binning
mn -= 0.001 * abs(mn) if mn != 0 else 0.001
mx += 0.001 * abs(mx) if mx != 0 else 0.001
bins = np.linspace(mn, mx, bins + 1, endpoint=True)
else: # adjust end points after binning
bins = np.linspace(mn, mx, bins + 1, endpoint=True)
adj = (mx - mn) * 0.001 # 0.1% of the range
if right:
bins[0] -= adj
else:
bins[-1] += adj
bin_edges = reshaping.to_1d_array(bins)
mapping = dict()
if right:
out = np.digitize(arr, bin_edges[1:], right=right)
if return_mapping:
for i in range(len(bin_edges) - 1):
mapping[i] = (bin_edges[i], bin_edges[i + 1])
else:
out = np.digitize(arr, bin_edges[:-1], right=right)
if return_mapping:
for i in range(1, len(bin_edges)):
mapping[i] = (bin_edges[i - 1], bin_edges[i])
if return_mapping:
return self.wrapper.wrap(out, **wrap_kwargs), mapping
return self.wrapper.wrap(out, **wrap_kwargs)
def value_counts(
self,
axis: int = 1,
normalize: bool = False,
sort_uniques: bool = True,
sort: bool = False,
ascending: bool = False,
dropna: bool = False,
group_by: tp.GroupByLike = None,
mapping: tp.Union[None, bool, tp.MappingLike] = None,
incl_all_keys: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Return a Series/DataFrame containing counts of unique values.
Args:
axis (int): 0 - counts per row, 1 - counts per column, and -1 - counts across the whole object.
normalize (bool): Whether to return the relative frequencies of the unique values.
sort_uniques (bool): Whether to sort uniques.
sort (bool): Whether to sort by frequency.
ascending (bool): Whether to sort in ascending order.
dropna (bool): Whether to exclude counts of NaN.
group_by (any): Group or ungroup columns.
See `vectorbtpro.base.grouping.base.Grouper`.
mapping (mapping_like): Mapping of values to labels.
incl_all_keys (bool): Whether to include all mapping keys, no only those that are present in the array.
jitted (any): Whether to JIT-compile `vectorbtpro.generic.nb.base.value_counts_nb` or options.
chunked (any): Whether to chunk `vectorbtpro.generic.nb.base.value_counts_nb` or options.
See `vectorbtpro.utils.chunking.resolve_chunked`.
wrap_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper.wrap`.
**kwargs: Keyword arguments passed to `vectorbtpro.utils.mapping.apply_mapping`.
Usage:
```pycon
>>> df.vbt.value_counts()
a b c
1 1 1 2
2 1 1 2
3 1 1 1
4 1 1 0
5 1 1 0
>>> df.vbt.value_counts(axis=-1)
1 4
2 4
3 3
4 2
5 2
Name: value_counts, dtype: int64
>>> mapping = {x: 'test_' + str(x) for x in pd.unique(df.values.flatten())}
>>> df.vbt.value_counts(mapping=mapping)
a b c
test_1 1 1 2
test_2 1 1 2
test_3 1 1 1
test_4 1 1 0
test_5 1 1 0
>>> sr = pd.Series([1, 2, 2, 3, 3, 3, np.nan])
>>> sr.vbt.value_counts(mapping=mapping)
test_1 1
test_2 2
test_3 3
NaN 1
dtype: int64
>>> sr.vbt.value_counts(mapping=mapping, dropna=True)
test_1 1
test_2 2
test_3 3
dtype: int64
>>> sr.vbt.value_counts(mapping=mapping, sort=True)
test_3 3
test_2 2
test_1 1
NaN 1
dtype: int64
>>> sr.vbt.value_counts(mapping=mapping, sort=True, ascending=True)
test_1 1
NaN 1
test_2 2
test_3 3
dtype: int64
>>> sr.vbt.value_counts(mapping=mapping, incl_all_keys=True)
test_1 1
test_2 2
test_3 3
test_4 0
test_5 0
NaN 1
dtype: int64
```
"""
checks.assert_in(axis, (-1, 0, 1))
mapping = self.resolve_mapping(mapping=mapping)
codes, uniques = pd.factorize(self.obj.values.flatten(), sort=False, use_na_sentinel=False)
if axis == 0:
func = jit_reg.resolve_option(nb.value_counts_per_row_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
value_counts = func(codes.reshape(self.wrapper.shape_2d), len(uniques))
elif axis == 1:
group_map = self.wrapper.grouper.get_group_map(group_by=group_by)
func = jit_reg.resolve_option(nb.value_counts_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
value_counts = func(codes.reshape(self.wrapper.shape_2d), len(uniques), group_map)
else:
func = jit_reg.resolve_option(nb.value_counts_1d_nb, jitted)
value_counts = func(codes, len(uniques))
if incl_all_keys and mapping is not None:
missing_keys = []
for x in mapping:
if pd.isnull(x) and pd.isnull(uniques).any():
continue
if x not in uniques:
missing_keys.append(x)
if axis == 0 or axis == 1:
value_counts = np.vstack((value_counts, np.full((len(missing_keys), value_counts.shape[1]), 0)))
else:
value_counts = np.concatenate((value_counts, np.full(len(missing_keys), 0)))
uniques = np.concatenate((uniques, np.array(missing_keys)))
nan_mask = np.isnan(uniques)
if dropna:
value_counts = value_counts[~nan_mask]
uniques = uniques[~nan_mask]
if sort_uniques:
new_indices = uniques.argsort()
value_counts = value_counts[new_indices]
uniques = uniques[new_indices]
if axis == 0 or axis == 1:
value_counts_sum = value_counts.sum(axis=1)
else:
value_counts_sum = value_counts
if normalize:
value_counts = value_counts / value_counts_sum.sum()
if sort:
if ascending:
new_indices = value_counts_sum.argsort()
else:
new_indices = (-value_counts_sum).argsort()
value_counts = value_counts[new_indices]
uniques = uniques[new_indices]
if axis == 0:
wrapper = ArrayWrapper.from_obj(value_counts)
value_counts_pd = wrapper.wrap(
value_counts,
index=uniques,
columns=self.wrapper.index,
**resolve_dict(wrap_kwargs),
)
elif axis == 1:
value_counts_pd = self.wrapper.wrap(
value_counts,
index=uniques,
group_by=group_by,
**resolve_dict(wrap_kwargs),
)
else:
wrapper = ArrayWrapper.from_obj(value_counts)
value_counts_pd = wrapper.wrap(
value_counts,
index=uniques,
**merge_dicts(dict(columns=["value_counts"]), wrap_kwargs),
)
if mapping is not None:
value_counts_pd.index = apply_mapping(value_counts_pd.index, mapping, **kwargs)
return value_counts_pd
# ############# Transforming ############# #
def demean(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.base.demean_nb`."""
func = jit_reg.resolve_option(nb.demean_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
group_map = self.wrapper.grouper.get_group_map(group_by=group_by)
out = func(self.to_2d_array(), group_map)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def transform(self, transformer: TransformerT, wrap_kwargs: tp.KwargsLike = None, **kwargs) -> tp.SeriesFrame:
"""Transform using a transformer.
A transformer can be any class instance that has `transform` and `fit_transform` methods,
ideally subclassing `sklearn.base.TransformerMixin` and `sklearn.base.BaseEstimator`.
Will fit `transformer` if not fitted.
`**kwargs` are passed to the `transform` or `fit_transform` method.
Usage:
```pycon
>>> from sklearn.preprocessing import MinMaxScaler
>>> df.vbt.transform(MinMaxScaler((-1, 1)))
a b c
2020-01-01 -1.0 1.0 -1.0
2020-01-02 -0.5 0.5 0.0
2020-01-03 0.0 0.0 1.0
2020-01-04 0.5 -0.5 0.0
2020-01-05 1.0 -1.0 -1.0
>>> fitted_scaler = MinMaxScaler((-1, 1)).fit(np.array([[2], [4]]))
>>> df.vbt.transform(fitted_scaler)
a b c
2020-01-01 -2.0 2.0 -2.0
2020-01-02 -1.0 1.0 -1.0
2020-01-03 0.0 0.0 0.0
2020-01-04 1.0 -1.0 -1.0
2020-01-05 2.0 -2.0 -2.0
```
"""
is_fitted = True
try:
check_is_fitted(transformer)
except NotFittedError:
is_fitted = False
if not is_fitted:
result = transformer.fit_transform(self.to_2d_array(), **kwargs)
else:
result = transformer.transform(self.to_2d_array(), **kwargs)
return self.wrapper.wrap(result, group_by=False, **resolve_dict(wrap_kwargs))
def zscore(self, **kwargs) -> tp.SeriesFrame:
"""Compute z-score using `sklearn.preprocessing.StandardScaler`."""
return self.scale(with_mean=True, with_std=True, **kwargs)
def rebase(
self,
base: float,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Rebase all series to the given base.
This makes comparing/plotting different series together easier.
Will forward and backward fill NaN values."""
func = jit_reg.resolve_option(nb.fbfill_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
result = func(self.to_2d_array())
result = result / result[0] * base
return self.wrapper.wrap(result, group_by=False, **resolve_dict(wrap_kwargs))
# ############# Conversion ############# #
def drawdown(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get drawdown series."""
func = jit_reg.resolve_option(nb.drawdown_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(self.to_2d_array())
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
@property
def ranges(self) -> Ranges:
"""`GenericAccessor.get_ranges` with default arguments."""
return self.get_ranges()
def get_ranges(self, *args, wrapper_kwargs: tp.KwargsLike = None, **kwargs) -> Ranges:
"""Generate range records.
See `vectorbtpro.generic.ranges.Ranges.from_array`."""
wrapper_kwargs = merge_dicts(self.wrapper.config, wrapper_kwargs)
return Ranges.from_array(self.obj, *args, wrapper_kwargs=wrapper_kwargs, **kwargs)
@property
def drawdowns(self) -> Drawdowns:
"""`GenericAccessor.get_drawdowns` with default arguments."""
return self.get_drawdowns()
def get_drawdowns(self, *args, **kwargs) -> Drawdowns:
"""Generate drawdown records.
See `vectorbtpro.generic.drawdowns.Drawdowns.from_price`."""
return Drawdowns.from_price(self.obj, *args, wrapper=self.wrapper, **kwargs)
def to_mapped(
self,
dropna: bool = True,
dtype: tp.Optional[tp.DTypeLike] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> MappedArray:
"""Convert this object into an instance of `vectorbtpro.records.mapped_array.MappedArray`."""
mapped_arr = self.to_2d_array().flatten(order="F")
col_arr = np.repeat(np.arange(self.wrapper.shape_2d[1]), self.wrapper.shape_2d[0])
idx_arr = np.tile(np.arange(self.wrapper.shape_2d[0]), self.wrapper.shape_2d[1])
if dropna and np.isnan(mapped_arr).any():
not_nan_mask = ~np.isnan(mapped_arr)
mapped_arr = mapped_arr[not_nan_mask]
col_arr = col_arr[not_nan_mask]
idx_arr = idx_arr[not_nan_mask]
return MappedArray(
self.wrapper,
np.asarray(mapped_arr, dtype=dtype),
col_arr,
idx_arr=idx_arr,
**kwargs,
).regroup(group_by)
def to_returns(self, **kwargs) -> tp.SeriesFrame:
"""Get returns of this object."""
from vectorbtpro.returns.accessors import ReturnsAccessor
return ReturnsAccessor.from_value(
self._obj,
wrapper=self.wrapper,
return_values=True,
**kwargs,
)
def to_log_returns(self, **kwargs) -> tp.SeriesFrame:
"""Get log returns of this object."""
from vectorbtpro.returns.accessors import ReturnsAccessor
return ReturnsAccessor.from_value(
self._obj,
wrapper=self.wrapper,
return_values=True,
log_returns=True,
**kwargs,
)
def to_daily_returns(self, **kwargs) -> tp.SeriesFrame:
"""Get daily returns of this object."""
from vectorbtpro.returns.accessors import ReturnsAccessor
return ReturnsAccessor.from_value(
self._obj,
wrapper=self.wrapper,
return_values=False,
**kwargs,
).daily()
def to_daily_log_returns(self, **kwargs) -> tp.SeriesFrame:
"""Get daily log returns of this object."""
from vectorbtpro.returns.accessors import ReturnsAccessor
return ReturnsAccessor.from_value(
self._obj,
wrapper=self.wrapper,
return_values=False,
log_returns=True,
**kwargs,
).daily()
# ############# Patterns ############# #
def find_pattern(self, *args, **kwargs) -> PatternRanges:
"""Generate pattern range records.
See `vectorbtpro.generic.ranges.PatternRanges.from_pattern_search`."""
return PatternRanges.from_pattern_search(self.obj, *args, **kwargs)
# ############# Crossover ############# #
def crossed_above(
self,
other: tp.ArrayLike,
wait: int = 0,
dropna: bool = False,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.base.crossed_above_nb`.
Usage:
```pycon
>>> df['b'].vbt.crossed_above(df['c'])
2020-01-01 False
2020-01-02 False
2020-01-03 False
2020-01-04 False
2020-01-05 False
dtype: bool
>>> df['a'].vbt.crossed_above(df['b'])
2020-01-01 False
2020-01-02 False
2020-01-03 False
2020-01-04 True
2020-01-05 False
dtype: bool
>>> df['a'].vbt.crossed_above(df['b'], wait=1)
2020-01-01 False
2020-01-02 False
2020-01-03 False
2020-01-04 False
2020-01-05 True
dtype: bool
```
"""
broadcastable_args = dict(obj=self.obj, other=other)
broadcast_kwargs = merge_dicts(dict(keep_flex=dict(obj=False, other=True)), broadcast_kwargs)
broadcasted_args, wrapper = reshaping.broadcast(
broadcastable_args,
to_pd=False,
min_ndim=2,
return_wrapper=True,
**broadcast_kwargs,
)
func = jit_reg.resolve_option(nb.crossed_above_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
broadcasted_args["obj"],
broadcasted_args["other"],
wait=wait,
dropna=dropna,
)
return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def crossed_below(
self,
other: tp.ArrayLike,
wait: int = 0,
dropna: bool = True,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.nb.base.crossed_below_nb`.
Also, see `GenericAccessor.crossed_above` for similar examples."""
broadcastable_args = dict(obj=self.obj, other=other)
broadcast_kwargs = merge_dicts(dict(keep_flex=dict(obj=False, other=True)), broadcast_kwargs)
broadcasted_args, wrapper = reshaping.broadcast(
broadcastable_args,
to_pd=False,
min_ndim=2,
return_wrapper=True,
**broadcast_kwargs,
)
func = jit_reg.resolve_option(nb.crossed_below_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
broadcasted_args["obj"],
broadcasted_args["other"],
wait=wait,
dropna=dropna,
)
return wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
# ############# Resolution ############# #
def resolve_self(
self: GenericAccessorT,
cond_kwargs: tp.KwargsLike = None,
custom_arg_names: tp.Optional[tp.Set[str]] = None,
impacts_caching: bool = True,
silence_warnings: bool = False,
) -> GenericAccessorT:
"""Resolve self.
See `vectorbtpro.base.wrapping.Wrapping.resolve_self`.
Creates a copy of this instance `mapping` is different in `cond_kwargs`."""
if cond_kwargs is None:
cond_kwargs = {}
if custom_arg_names is None:
custom_arg_names = set()
reself = Wrapping.resolve_self(
self,
cond_kwargs=cond_kwargs,
custom_arg_names=custom_arg_names,
impacts_caching=impacts_caching,
silence_warnings=silence_warnings,
)
if "mapping" in cond_kwargs:
self_copy = reself.replace(mapping=cond_kwargs["mapping"])
if not checks.is_deep_equal(self_copy.mapping, reself.mapping):
if not silence_warnings:
warn(
f"Changing the mapping will create a copy of this object. "
f"Consider setting it upon object creation to re-use existing cache."
)
for alias in reself.self_aliases:
if alias not in custom_arg_names:
cond_kwargs[alias] = self_copy
cond_kwargs["mapping"] = self_copy.mapping
if impacts_caching:
cond_kwargs["use_caching"] = False
return self_copy
return reself
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `GenericAccessor.stats`.
Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
`stats` from `vectorbtpro._settings.generic`."""
from vectorbtpro._settings import settings
generic_stats_cfg = settings["generic"]["stats"]
return merge_dicts(Analyzable.stats_defaults.__get__(self), generic_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
count=dict(title="Count", calc_func="count", inv_check_has_mapping=True, tags=["generic", "describe"]),
mean=dict(title="Mean", calc_func="mean", inv_check_has_mapping=True, tags=["generic", "describe"]),
std=dict(title="Std", calc_func="std", inv_check_has_mapping=True, tags=["generic", "describe"]),
min=dict(title="Min", calc_func="min", inv_check_has_mapping=True, tags=["generic", "describe"]),
median=dict(title="Median", calc_func="median", inv_check_has_mapping=True, tags=["generic", "describe"]),
max=dict(title="Max", calc_func="max", inv_check_has_mapping=True, tags=["generic", "describe"]),
idx_min=dict(
title="Min Index",
calc_func="idxmin",
agg_func=None,
inv_check_has_mapping=True,
tags=["generic", "index"],
),
idx_max=dict(
title="Max Index",
calc_func="idxmax",
agg_func=None,
inv_check_has_mapping=True,
tags=["generic", "index"],
),
value_counts=dict(
title="Value Counts",
calc_func=lambda value_counts: reshaping.to_dict(value_counts, orient="index_series"),
resolve_value_counts=True,
check_has_mapping=True,
tags=["generic", "value_counts"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot(
self,
column: tp.Optional[tp.Label] = None,
trace_names: tp.TraceNames = None,
x_labels: tp.Optional[tp.Labels] = None,
return_fig: bool = True,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Create `vectorbtpro.generic.plotting.Scatter` and return the figure.
Usage:
```pycon
>>> df.vbt.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.generic.plotting import Scatter
if column is not None:
_self = self.select_col(column=column)
else:
_self = self
if x_labels is None:
x_labels = _self.wrapper.index
if trace_names is None:
if _self.is_frame() or (_self.is_series() and _self.wrapper.name is not None):
trace_names = _self.wrapper.columns
scatter = Scatter(data=_self.to_2d_array(), trace_names=trace_names, x_labels=x_labels, **kwargs)
if return_fig:
return scatter.fig
return scatter
def lineplot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""`GenericAccessor.plot` with 'lines' mode.
Usage:
```pycon
>>> df.vbt.lineplot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
return self.plot(column=column, **merge_dicts(dict(trace_kwargs=dict(mode="lines")), kwargs))
def scatterplot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""`GenericAccessor.plot` with 'markers' mode.
Usage:
```pycon
>>> df.vbt.scatterplot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
return self.plot(column=column, **merge_dicts(dict(trace_kwargs=dict(mode="markers")), kwargs))
def barplot(
self,
column: tp.Optional[tp.Label] = None,
trace_names: tp.TraceNames = None,
x_labels: tp.Optional[tp.Labels] = None,
return_fig: bool = True,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Create `vectorbtpro.generic.plotting.Bar` and return the figure.
Usage:
```pycon
>>> df.vbt.barplot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.generic.plotting import Bar
if column is not None:
_self = self.select_col(column=column)
else:
_self = self
if x_labels is None:
x_labels = _self.wrapper.index
if trace_names is None:
if _self.is_frame() or (_self.is_series() and _self.wrapper.name is not None):
trace_names = _self.wrapper.columns
bar = Bar(data=_self.to_2d_array(), trace_names=trace_names, x_labels=x_labels, **kwargs)
if return_fig:
return bar.fig
return bar
def histplot(
self,
column: tp.Optional[tp.Label] = None,
by_level: tp.Optional[tp.Level] = None,
trace_names: tp.TraceNames = None,
group_by: tp.GroupByLike = None,
return_fig: bool = True,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Create `vectorbtpro.generic.plotting.Histogram` and return the figure.
Usage:
```pycon
>>> df.vbt.histplot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.generic.plotting import Histogram
if by_level is not None:
return self.obj.unstack(by_level).vbt.histplot(
column=column,
trace_names=trace_names,
group_by=group_by,
return_fig=return_fig,
**kwargs,
)
if column is not None:
_self = self.select_col(column=column)
else:
_self = self
if _self.wrapper.grouper.is_grouped(group_by=group_by):
return _self.flatten_grouped(group_by=group_by).vbt.histplot(trace_names=trace_names, **kwargs)
if trace_names is None:
if _self.is_frame() or (_self.is_series() and _self.wrapper.name is not None):
trace_names = _self.wrapper.columns
hist = Histogram(data=_self.to_2d_array(), trace_names=trace_names, **kwargs)
if return_fig:
return hist.fig
return hist
def boxplot(
self,
column: tp.Optional[tp.Label] = None,
by_level: tp.Optional[tp.Level] = None,
trace_names: tp.TraceNames = None,
group_by: tp.GroupByLike = None,
return_fig: bool = True,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Create `vectorbtpro.generic.plotting.Box` and return the figure.
Usage:
```pycon
>>> df.vbt.boxplot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.generic.plotting import Box
if by_level is not None:
return self.obj.unstack(by_level).vbt.boxplot(
column=column,
trace_names=trace_names,
group_by=group_by,
return_fig=return_fig,
**kwargs,
)
if column is not None:
_self = self.select_col(column=column)
else:
_self = self
if _self.wrapper.grouper.is_grouped(group_by=group_by):
return _self.flatten_grouped(group_by=group_by).vbt.boxplot(trace_names=trace_names, **kwargs)
if trace_names is None:
if _self.is_frame() or (_self.is_series() and _self.wrapper.name is not None):
trace_names = _self.wrapper.columns
box = Box(data=_self.to_2d_array(), trace_names=trace_names, **kwargs)
if return_fig:
return box.fig
return box
def plot_against(
self,
other: tp.ArrayLike,
column: tp.Optional[tp.Label] = None,
trace_kwargs: tp.KwargsLike = None,
other_trace_kwargs: tp.Union[str, tp.KwargsLike] = None,
pos_trace_kwargs: tp.KwargsLike = None,
neg_trace_kwargs: tp.KwargsLike = None,
hidden_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot Series as a line against another line.
Args:
other (array_like): Second array. Will broadcast.
column (hashable): Column to plot.
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`.
other_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `other`.
Set to 'hidden' to hide.
pos_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for positive line.
neg_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for negative line.
hidden_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for hidden lines.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> df['a'].vbt.plot_against(df['b']).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
if trace_kwargs is None:
trace_kwargs = {}
if other_trace_kwargs is None:
other_trace_kwargs = {}
if pos_trace_kwargs is None:
pos_trace_kwargs = {}
if neg_trace_kwargs is None:
neg_trace_kwargs = {}
if hidden_trace_kwargs is None:
hidden_trace_kwargs = {}
obj = self.obj
if isinstance(obj, pd.DataFrame):
obj = self.select_col_from_obj(obj, column=column)
if other is None:
other = pd.Series.vbt.empty_like(obj, 1)
else:
other = reshaping.to_pd_array(other)
if isinstance(other, pd.DataFrame):
other = self.select_col_from_obj(other, column=column)
obj, other = reshaping.broadcast(obj, other, columns_from="keep")
if other.name is None:
other = other.rename("Other")
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
# TODO: Using masks feels hacky
pos_mask = obj > other
if pos_mask.any():
# Fill positive area
pos_obj = obj.copy()
pos_obj[~pos_mask] = other[~pos_mask]
other.vbt.lineplot(
trace_kwargs=merge_dicts(
dict(
line=dict(color="rgba(0, 0, 0, 0)", width=0),
opacity=0,
hoverinfo="skip",
showlegend=False,
name=None,
),
hidden_trace_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
pos_obj.vbt.lineplot(
trace_kwargs=merge_dicts(
dict(
fillcolor="rgba(0, 128, 0, 0.25)",
line=dict(color="rgba(0, 0, 0, 0)", width=0),
opacity=0,
fill="tonexty",
connectgaps=False,
hoverinfo="skip",
showlegend=False,
name=None,
),
pos_trace_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
neg_mask = obj < other
if neg_mask.any():
# Fill negative area
neg_obj = obj.copy()
neg_obj[~neg_mask] = other[~neg_mask]
other.vbt.lineplot(
trace_kwargs=merge_dicts(
dict(
line=dict(color="rgba(0, 0, 0, 0)", width=0),
opacity=0,
hoverinfo="skip",
showlegend=False,
name=None,
),
hidden_trace_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
neg_obj.vbt.lineplot(
trace_kwargs=merge_dicts(
dict(
line=dict(color="rgba(0, 0, 0, 0)", width=0),
fillcolor="rgba(255, 0, 0, 0.25)",
opacity=0,
fill="tonexty",
connectgaps=False,
hoverinfo="skip",
showlegend=False,
name=None,
),
neg_trace_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
# Plot main traces
obj.vbt.lineplot(trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig)
if other_trace_kwargs == "hidden":
other_trace_kwargs = dict(
line=dict(color="rgba(0, 0, 0, 0)", width=0),
opacity=0.0,
hoverinfo="skip",
showlegend=False,
name=None,
)
other.vbt.lineplot(trace_kwargs=other_trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig)
return fig
def overlay_with_heatmap(
self,
other: tp.ArrayLike,
column: tp.Optional[tp.Label] = None,
trace_kwargs: tp.KwargsLike = None,
heatmap_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot Series as a line and overlays it with a heatmap.
Args:
other (array_like): Second array. Will broadcast.
column (hashable): Column to plot.
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`.
heatmap_kwargs (dict): Keyword arguments passed to `GenericDFAccessor.heatmap`.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> df['a'].vbt.overlay_with_heatmap(df['b']).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_subplots
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if trace_kwargs is None:
trace_kwargs = {}
if heatmap_kwargs is None:
heatmap_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
obj = self.obj
if isinstance(obj, pd.DataFrame):
obj = self.select_col_from_obj(obj, column=column)
if other is None:
other = pd.Series.vbt.empty_like(obj, 1)
else:
other = reshaping.to_pd_array(other)
if isinstance(other, pd.DataFrame):
other = self.select_col_from_obj(other, column=column)
obj, other = reshaping.broadcast(obj, other, columns_from="keep")
if other.name is None:
other = other.rename("Other")
if fig is None:
fig = make_subplots(specs=[[{"secondary_y": True}]])
if "width" in plotting_cfg["layout"]:
fig.update_layout(width=plotting_cfg["layout"]["width"] + 100)
fig.update_layout(**layout_kwargs)
other.vbt.ts_heatmap(**heatmap_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig)
obj.vbt.lineplot(
trace_kwargs=merge_dicts(dict(line=dict(color=plotting_cfg["color_schema"]["blue"])), trace_kwargs),
add_trace_kwargs=merge_dicts(dict(secondary_y=True), add_trace_kwargs),
fig=fig,
)
return fig
def heatmap(
self,
column: tp.Optional[tp.Label] = None,
x_level: tp.Optional[tp.Level] = None,
y_level: tp.Optional[tp.Level] = None,
symmetric: bool = False,
sort: bool = True,
x_labels: tp.Optional[tp.Labels] = None,
y_labels: tp.Optional[tp.Labels] = None,
slider_level: tp.Optional[tp.Level] = None,
active: int = 0,
slider_labels: tp.Optional[tp.Labels] = None,
return_fig: bool = True,
fig: tp.Optional[tp.BaseFigure] = None,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Create a heatmap figure based on object's multi-index and values.
If the object is two-dimensional or the index is not a multi-index, returns a regular heatmap.
If multi-index contains more than two levels or you want them in specific order,
pass `x_level` and `y_level`, each (`int` if index or `str` if name) corresponding
to an axis of the heatmap. Optionally, pass `slider_level` to use a level as a slider.
Creates `vectorbtpro.generic.plotting.Heatmap` and returns the figure.
Usage:
* Plotting a figure based on a regular index:
```pycon
>>> df = pd.DataFrame([
... [0, np.nan, np.nan],
... [np.nan, 1, np.nan],
... [np.nan, np.nan, 2]
... ])
>>> df.vbt.heatmap().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
* Plotting a figure based on a multi-index:
```pycon
>>> multi_index = pd.MultiIndex.from_tuples([
... (1, 1),
... (2, 2),
... (3, 3)
... ])
>>> sr = pd.Series(np.arange(len(multi_index)), index=multi_index)
>>> sr
1 1 0
2 2 1
3 3 2
dtype: int64
>>> sr.vbt.heatmap().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.generic.plotting import Heatmap
if column is not None:
_self = self.select_col(column=column)
else:
_self = self
if _self.ndim == 2 or not isinstance(self.wrapper.index, pd.MultiIndex):
if x_labels is None:
x_labels = _self.wrapper.columns
if y_labels is None:
y_labels = _self.wrapper.index
heatmap = Heatmap(
data=_self.to_2d_array(),
x_labels=x_labels,
y_labels=y_labels,
fig=fig,
**kwargs,
)
if return_fig:
return heatmap.fig
return heatmap
(x_level, y_level), (slider_level,) = indexes.pick_levels(
_self.wrapper.index,
required_levels=(x_level, y_level),
optional_levels=(slider_level,),
)
x_level_vals = _self.wrapper.index.get_level_values(x_level)
y_level_vals = _self.wrapper.index.get_level_values(y_level)
x_name = x_level_vals.name if x_level_vals.name is not None else "x"
y_name = y_level_vals.name if y_level_vals.name is not None else "y"
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
hovertemplate=f"{x_name}: %{{x}} " + f"{y_name}: %{{y}} " + "value: %{z}"
),
xaxis_title=x_level_vals.name,
yaxis_title=y_level_vals.name,
),
kwargs,
)
if slider_level is None:
# No grouping
df = _self.unstack_to_df(
index_levels=y_level,
column_levels=x_level,
symmetric=symmetric,
sort=sort,
)
return df.vbt.heatmap(
x_labels=x_labels,
y_labels=y_labels,
fig=fig,
return_fig=return_fig,
**kwargs,
)
# Requires grouping
# See https://plotly.com/python/sliders/
if not return_fig:
raise ValueError("Cannot use return_fig=False and slider_level simultaneously")
_slider_labels = []
for i, (name, group) in enumerate(_self.obj.groupby(level=slider_level)):
if slider_labels is not None:
name = slider_labels[i]
_slider_labels.append(name)
df = group.vbt.unstack_to_df(
index_levels=y_level,
column_levels=x_level,
symmetric=symmetric,
sort=sort,
)
if x_labels is None:
x_labels = df.columns
if y_labels is None:
y_labels = df.index
_kwargs = merge_dicts(
dict(
trace_kwargs=dict(name=str(name) if name is not None else None, visible=False),
),
kwargs,
)
default_size = fig is None and "height" not in _kwargs
fig = Heatmap(
data=reshaping.to_2d_array(df),
x_labels=x_labels,
y_labels=y_labels,
fig=fig,
**_kwargs,
).fig
if default_size:
fig.layout["height"] += 100 # slider takes up space
fig.data[active].visible = True
steps = []
for i in range(len(fig.data)):
step = dict(
method="update",
args=[{"visible": [False] * len(fig.data)}, {}],
label=str(_slider_labels[i]) if _slider_labels[i] is not None else None,
)
step["args"][0]["visible"][i] = True
steps.append(step)
prefix = (
f"{_self.wrapper.index.names[slider_level]}: "
if _self.wrapper.index.names[slider_level] is not None
else None
)
sliders = [
dict(
active=active,
currentvalue={"prefix": prefix},
pad={"t": 50},
steps=steps,
)
]
fig.update_layout(sliders=sliders)
return fig
def ts_heatmap(
self,
column: tp.Optional[tp.Label] = None,
is_y_category: bool = True,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Heatmap of time-series data."""
if column is not None:
obj = self.select_col_from_obj(self.obj, column=column)
else:
obj = self.obj
if isinstance(obj, pd.Series):
obj = obj.to_frame()
return obj.transpose().iloc[::-1].vbt.heatmap(is_y_category=is_y_category, **kwargs)
def volume(
self,
column: tp.Optional[tp.Label] = None,
x_level: tp.Optional[tp.Level] = None,
y_level: tp.Optional[tp.Level] = None,
z_level: tp.Optional[tp.Level] = None,
x_labels: tp.Optional[tp.Labels] = None,
y_labels: tp.Optional[tp.Labels] = None,
z_labels: tp.Optional[tp.Labels] = None,
slider_level: tp.Optional[tp.Level] = None,
slider_labels: tp.Optional[tp.Labels] = None,
active: int = 0,
scene_name: str = "scene",
fillna: tp.Optional[tp.Number] = None,
fig: tp.Optional[tp.BaseFigure] = None,
return_fig: bool = True,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Create a 3D volume figure based on object's multi-index and values.
If multi-index contains more than three levels or you want them in specific order, pass
`x_level`, `y_level`, and `z_level`, each (`int` if index or `str` if name) corresponding
to an axis of the volume. Optionally, pass `slider_level` to use a level as a slider.
Creates `vectorbtpro.generic.plotting.Volume` and returns the figure.
Usage:
```pycon
>>> multi_index = pd.MultiIndex.from_tuples([
... (1, 1, 1),
... (2, 2, 2),
... (3, 3, 3)
... ])
>>> sr = pd.Series(np.arange(len(multi_index)), index=multi_index)
>>> sr
1 1 1 0
2 2 2 1
3 3 3 2
dtype: int64
>>> sr.vbt.volume().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.generic.plotting import Volume
self_col = self.select_col(column=column)
(x_level, y_level, z_level), (slider_level,) = indexes.pick_levels(
self_col.wrapper.index,
required_levels=(x_level, y_level, z_level),
optional_levels=(slider_level,),
)
x_level_vals = self_col.wrapper.index.get_level_values(x_level)
y_level_vals = self_col.wrapper.index.get_level_values(y_level)
z_level_vals = self_col.wrapper.index.get_level_values(z_level)
# Labels are just unique level values
if x_labels is None:
x_labels = np.unique(x_level_vals)
if y_labels is None:
y_labels = np.unique(y_level_vals)
if z_labels is None:
z_labels = np.unique(z_level_vals)
x_name = x_level_vals.name if x_level_vals.name is not None else "x"
y_name = y_level_vals.name if y_level_vals.name is not None else "y"
z_name = z_level_vals.name if z_level_vals.name is not None else "z"
def_kwargs = dict()
def_kwargs["trace_kwargs"] = dict(
hovertemplate=f"{x_name}: %{{x}} "
+ f"{y_name}: %{{y}} "
+ f"{z_name}: %{{z}} "
+ "value: %{value}"
)
def_kwargs[scene_name] = dict(
xaxis_title=x_level_vals.name,
yaxis_title=y_level_vals.name,
zaxis_title=z_level_vals.name,
)
def_kwargs["scene_name"] = scene_name
kwargs = merge_dicts(def_kwargs, kwargs)
contains_nan = False
if slider_level is None:
# No grouping
v = self_col.unstack_to_array(levels=(x_level, y_level, z_level))
if fillna is not None:
v = np.nan_to_num(v, nan=fillna)
if np.isnan(v).any():
contains_nan = True
volume = Volume(data=v, x_labels=x_labels, y_labels=y_labels, z_labels=z_labels, fig=fig, **kwargs)
if return_fig:
fig = volume.fig
else:
fig = volume
else:
# Requires grouping
# See https://plotly.com/python/sliders/
if not return_fig:
raise ValueError("Cannot use return_fig=False and slider_level simultaneously")
_slider_labels = []
for i, (name, group) in enumerate(self_col.obj.groupby(level=slider_level)):
if slider_labels is not None:
name = slider_labels[i]
_slider_labels.append(name)
v = group.vbt.unstack_to_array(levels=(x_level, y_level, z_level))
if fillna is not None:
v = np.nan_to_num(v, nan=fillna)
if np.isnan(v).any():
contains_nan = True
_kwargs = merge_dicts(
dict(trace_kwargs=dict(name=str(name) if name is not None else None, visible=False)),
kwargs,
)
default_size = fig is None and "height" not in _kwargs
fig = Volume(data=v, x_labels=x_labels, y_labels=y_labels, z_labels=z_labels, fig=fig, **_kwargs).fig
if default_size:
fig.layout["height"] += 100 # slider takes up space
fig.data[active].visible = True
steps = []
for i in range(len(fig.data)):
step = dict(
method="update",
args=[{"visible": [False] * len(fig.data)}, {}],
label=str(_slider_labels[i]) if _slider_labels[i] is not None else None,
)
step["args"][0]["visible"][i] = True
steps.append(step)
prefix = (
f"{self_col.wrapper.index.names[slider_level]}: "
if self_col.wrapper.index.names[slider_level] is not None
else None
)
sliders = [dict(active=active, currentvalue={"prefix": prefix}, pad={"t": 50}, steps=steps)]
fig.update_layout(sliders=sliders)
if contains_nan:
warn("Data contains NaNs. Use `fillna` argument or `show` method in case of visualization issues.")
return fig
def qqplot(
self,
column: tp.Optional[tp.Label] = None,
sparams: tp.Union[tp.Iterable, tuple, None] = (),
dist: str = "norm",
plot_line: bool = True,
line_shape_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot probability plot using `scipy.stats.probplot`.
`**kwargs` are passed to `GenericAccessor.scatterplot`.
Usage:
```pycon
>>> pd.Series(np.random.standard_normal(100)).vbt.qqplot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
import scipy.stats as st
obj = self.select_col_from_obj(self.obj, column=column)
qq = st.probplot(obj, sparams=sparams, dist=dist)
fig = pd.Series(qq[0][1], index=qq[0][0]).vbt.scatterplot(fig=fig, **kwargs)
if plot_line:
if line_shape_kwargs is None:
line_shape_kwargs = {}
x = np.array([qq[0][0][0], qq[0][0][-1]])
y = qq[1][1] + qq[1][0] * x
fig.add_shape(
**merge_dicts(
dict(type="line", xref=xref, yref=yref, x0=x[0], y0=y[0], x1=x[1], y1=y[1], line=dict(color="red")),
line_shape_kwargs,
)
)
return fig
def areaplot(
self,
line_shape: str = "spline",
line_visible: bool = False,
colorway: tp.Union[None, str, tp.Sequence[str]] = None,
trace_kwargs: tp.KwargsLikeSequence = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot stacked area.
Args:
line_shape (str): Line shape.
line_visible (bool): Whether to make line visible.
colorway (str or sequence): Name of the built-in, qualitative colorway, or a list with colors.
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> df.vbt.areaplot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from vectorbtpro.utils.figure import make_figure
import plotly.express as px
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if colorway is None:
if fig.layout.colorway is not None:
colorway = fig.layout.colorway
else:
colorway = fig.layout.template.layout.colorway
if len(self.wrapper.columns) > len(colorway):
colorway = px.colors.qualitative.Alphabet
elif isinstance(colorway, str):
colorway = getattr(px.colors.qualitative, colorway)
pos_mask = self.obj.values > 0
pos_mask_any = pos_mask.any()
neg_mask = self.obj.values < 0
neg_mask_any = neg_mask.any()
pos_showlegend = False
neg_showlegend = False
if pos_mask_any:
pos_showlegend = True
elif neg_mask_any:
neg_showlegend = True
line_width = None if line_visible else 0
line_opacity = 0.3 if line_visible else 0.8
if pos_mask_any:
pos_df = self.obj.copy()
pos_df[neg_mask] = 0.0
fig = pos_df.vbt.lineplot(
trace_kwargs=[
merge_dicts(
dict(
legendgroup="area_" + str(c),
stackgroup="one",
line=dict(width=line_width, color=colorway[c % len(colorway)], shape=line_shape),
fillcolor=adjust_opacity(colorway[c % len(colorway)], line_opacity),
showlegend=pos_showlegend,
),
resolve_dict(trace_kwargs, i=c),
)
for c in range(len(self.wrapper.columns))
],
add_trace_kwargs=add_trace_kwargs,
use_gl=False,
fig=fig,
**layout_kwargs,
)
if neg_mask_any:
neg_df = self.obj.copy()
neg_df[pos_mask] = 0.0
fig = neg_df.vbt.lineplot(
trace_kwargs=[
merge_dicts(
dict(
legendgroup="area_" + str(c),
stackgroup="two",
line=dict(width=line_width, color=colorway[c % len(colorway)], shape=line_shape),
fillcolor=adjust_opacity(colorway[c % len(colorway)], line_opacity),
showlegend=neg_showlegend,
),
resolve_dict(trace_kwargs, i=c),
)
for c in range(len(self.wrapper.columns))
],
add_trace_kwargs=add_trace_kwargs,
use_gl=False,
fig=fig,
**layout_kwargs,
)
return fig
def plot_pattern(
self,
pattern: tp.ArrayLike,
interp_mode: tp.Union[int, str] = "mixed",
rescale_mode: tp.Union[int, str] = "minmax",
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: tp.Union[int, str] = "absolute",
max_error: tp.ArrayLike = np.nan,
max_error_interp_mode: tp.Union[None, int, str] = None,
column: tp.Optional[tp.Label] = None,
plot_obj: bool = True,
fill_distance: bool = False,
obj_trace_kwargs: tp.KwargsLike = None,
pattern_trace_kwargs: tp.KwargsLike = None,
lower_max_error_trace_kwargs: tp.KwargsLike = None,
upper_max_error_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot pattern.
Mimics the same similarity calculation procedure as implemented in
`vectorbtpro.generic.nb.patterns.pattern_similarity_nb`.
Usage:
```pycon
>>> sr = pd.Series([10, 11, 12, 13, 12, 13, 14, 15, 13, 14, 11])
>>> sr.vbt.plot_pattern([1, 2, 3, 2, 1]).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if isinstance(interp_mode, str):
interp_mode = map_enum_fields(interp_mode, InterpMode)
if isinstance(rescale_mode, str):
rescale_mode = map_enum_fields(rescale_mode, RescaleMode)
if isinstance(error_type, str):
error_type = map_enum_fields(error_type, ErrorType)
if max_error_interp_mode is not None and isinstance(max_error_interp_mode, str):
max_error_interp_mode = map_enum_fields(max_error_interp_mode, InterpMode)
if max_error_interp_mode is None:
max_error_interp_mode = interp_mode
obj_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"])),
obj_trace_kwargs,
)
if pattern_trace_kwargs is None:
pattern_trace_kwargs = {}
if lower_max_error_trace_kwargs is None:
lower_max_error_trace_kwargs = {}
if upper_max_error_trace_kwargs is None:
upper_max_error_trace_kwargs = {}
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
self_col = self.select_col(column=column)
if plot_obj:
# Plot object
fig = self_col.lineplot(
trace_kwargs=obj_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
# Reconstruct pattern and max error bands
pattern_sr, max_error_sr = self_col.fit_pattern(
pattern,
interp_mode=interp_mode,
rescale_mode=rescale_mode,
vmin=vmin,
vmax=vmax,
pmin=pmin,
pmax=pmax,
invert=invert,
error_type=error_type,
max_error=max_error,
max_error_interp_mode=max_error_interp_mode,
)
# Plot pattern and max error bands
def_pattern_trace_kwargs = dict(
name=f"Pattern",
connectgaps=True,
)
if interp_mode == InterpMode.Discrete:
_pattern_trace_kwargs = merge_dicts(
def_pattern_trace_kwargs,
dict(
mode="lines+markers",
marker=dict(color=adjust_opacity(plotting_cfg["color_schema"]["cyan"], 0.75)),
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.75), dash="dot"),
),
pattern_trace_kwargs,
)
else:
if fill_distance:
_pattern_trace_kwargs = merge_dicts(
def_pattern_trace_kwargs,
dict(
mode="lines",
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["cyan"], 0.75)),
fill="tonexty",
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["cyan"], 0.25),
),
pattern_trace_kwargs,
)
else:
_pattern_trace_kwargs = merge_dicts(
def_pattern_trace_kwargs,
dict(
mode="lines",
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["cyan"], 0.75), dash="dot"),
),
pattern_trace_kwargs,
)
fig = pattern_sr.rename(None).vbt.plot(
trace_kwargs=_pattern_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
# Plot max error bounds
if not np.isnan(max_error).all():
def_max_error_trace_kwargs = dict(
name="Max error",
connectgaps=True,
)
if max_error_interp_mode == InterpMode.Discrete:
_lower_max_error_trace_kwargs = merge_dicts(
def_max_error_trace_kwargs,
dict(
mode="markers+lines",
marker=dict(color=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.5)),
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5), dash="dot"),
),
lower_max_error_trace_kwargs,
)
_upper_max_error_trace_kwargs = merge_dicts(
def_max_error_trace_kwargs,
dict(
mode="markers+lines",
marker=dict(color=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.5)),
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5), dash="dot"),
showlegend=False,
),
upper_max_error_trace_kwargs,
)
else:
_lower_max_error_trace_kwargs = merge_dicts(
def_max_error_trace_kwargs,
dict(
mode="lines",
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.5), dash="dot"),
),
lower_max_error_trace_kwargs,
)
_upper_max_error_trace_kwargs = merge_dicts(
def_max_error_trace_kwargs,
dict(
mode="lines",
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.5), dash="dot"),
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.1),
fill="tonexty",
showlegend=False,
),
upper_max_error_trace_kwargs,
)
fig = (
(pattern_sr - max_error_sr)
.rename(None)
.vbt.plot(
trace_kwargs=_lower_max_error_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
)
fig = (
(pattern_sr + max_error_sr)
.rename(None)
.vbt.plot(
trace_kwargs=_upper_max_error_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `GenericAccessor.plots`.
Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
`plots` from `vectorbtpro._settings.generic`."""
from vectorbtpro._settings import settings
generic_plots_cfg = settings["generic"]["plots"]
return merge_dicts(Analyzable.plots_defaults.__get__(self), generic_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot=dict(
check_is_not_grouped=True,
plot_func="plot",
pass_trace_names=False,
tags="generic",
)
),
)
@property
def subplots(self) -> Config:
return self._subplots
if settings["importing"]["sklearn"]:
from sklearn.exceptions import NotFittedError
from sklearn.preprocessing import (
Binarizer,
MinMaxScaler,
MaxAbsScaler,
Normalizer,
RobustScaler,
StandardScaler,
QuantileTransformer,
PowerTransformer,
)
from sklearn.utils.validation import check_is_fitted
transform_config = ReadonlyConfig(
{
"binarize": dict(transformer=Binarizer, docstring="See `sklearn.preprocessing.Binarizer`."),
"minmax_scale": dict(transformer=MinMaxScaler, docstring="See `sklearn.preprocessing.MinMaxScaler`."),
"maxabs_scale": dict(transformer=MaxAbsScaler, docstring="See `sklearn.preprocessing.MaxAbsScaler`."),
"normalize": dict(transformer=Normalizer, docstring="See `sklearn.preprocessing.Normalizer`."),
"robust_scale": dict(transformer=RobustScaler, docstring="See `sklearn.preprocessing.RobustScaler`."),
"scale": dict(transformer=StandardScaler, docstring="See `sklearn.preprocessing.StandardScaler`."),
"quantile_transform": dict(
transformer=QuantileTransformer,
docstring="See `sklearn.preprocessing.QuantileTransformer`.",
),
"power_transform": dict(
transformer=PowerTransformer,
docstring="See `sklearn.preprocessing.PowerTransformer`.",
),
}
)
"""_"""
__pdoc__[
"transform_config"
] = f"""Config of transform methods to be attached to `GenericAccessor`.
```python
{transform_config.prettify()}
```
"""
GenericAccessor = attach_transform_methods(transform_config)(GenericAccessor)
GenericAccessor.override_metrics_doc(__pdoc__)
GenericAccessor.override_subplots_doc(__pdoc__)
class GenericSRAccessor(GenericAccessor, BaseSRAccessor):
"""Accessor on top of data of any type. For Series only.
Accessible via `pd.Series.vbt`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
mapping: tp.Optional[tp.MappingLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
BaseSRAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs)
if _full_init:
GenericAccessor.__init__(self, wrapper, obj=obj, mapping=mapping, **kwargs)
def fit_pattern(
self,
pattern: tp.ArrayLike,
interp_mode: tp.Union[int, str] = "mixed",
rescale_mode: tp.Union[int, str] = "minmax",
vmin: float = np.nan,
vmax: float = np.nan,
pmin: float = np.nan,
pmax: float = np.nan,
invert: bool = False,
error_type: tp.Union[int, str] = "absolute",
max_error: tp.ArrayLike = np.nan,
max_error_interp_mode: tp.Union[None, int, str] = None,
jitted: tp.JittedOption = None,
) -> tp.Tuple[tp.Series, tp.Series]:
"""See `vectorbtpro.generic.nb.patterns.fit_pattern_nb`."""
if isinstance(interp_mode, str):
interp_mode = map_enum_fields(interp_mode, InterpMode)
if isinstance(rescale_mode, str):
rescale_mode = map_enum_fields(rescale_mode, RescaleMode)
if isinstance(error_type, str):
error_type = map_enum_fields(error_type, ErrorType)
if max_error_interp_mode is not None and isinstance(max_error_interp_mode, str):
max_error_interp_mode = map_enum_fields(max_error_interp_mode, InterpMode)
if max_error_interp_mode is None:
max_error_interp_mode = interp_mode
pattern = reshaping.to_1d_array(pattern)
max_error = reshaping.broadcast_array_to(max_error, len(pattern))
func = jit_reg.resolve_option(nb.fit_pattern_nb, jitted)
fit_pattern, fit_max_error = func(
self.to_1d_array(),
pattern,
interp_mode=interp_mode,
rescale_mode=rescale_mode,
vmin=vmin,
vmax=vmax,
pmin=pmin,
pmax=pmax,
invert=invert,
error_type=error_type,
max_error=max_error,
max_error_interp_mode=max_error_interp_mode,
)
pattern_sr = self.wrapper.wrap(fit_pattern)
max_error_sr = self.wrapper.wrap(fit_max_error)
return pattern_sr, max_error_sr
def to_renko(
self,
brick_size: tp.ArrayLike,
relative: tp.ArrayLike = False,
start_value: tp.Optional[float] = None,
max_out_len: tp.Optional[int] = None,
reset_index: bool = False,
return_uptrend: bool = False,
jitted: tp.JittedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Union[tp.Series, tp.Tuple[tp.Series, tp.Series]]:
"""See `vectorbtpro.generic.nb.base.to_renko_1d_nb`."""
func = jit_reg.resolve_option(nb.to_renko_1d_nb, jitted)
arr_out, idx_out, uptrend_out = func(
self.to_1d_array(),
reshaping.broadcast_array_to(brick_size, self.wrapper.shape[0]),
relative=reshaping.broadcast_array_to(relative, self.wrapper.shape[0]),
start_value=start_value,
max_out_len=max_out_len,
)
if reset_index:
new_index = pd.RangeIndex(stop=len(idx_out))
else:
new_index = self.wrapper.index[idx_out]
wrap_kwargs = merge_dicts(
dict(index=new_index),
wrap_kwargs,
)
sr_out = self.wrapper.wrap(arr_out, group_by=False, **wrap_kwargs)
if return_uptrend:
uptrend_out = self.wrapper.wrap(uptrend_out, group_by=False, **wrap_kwargs)
return sr_out, uptrend_out
return sr_out
def to_renko_ohlc(
self,
brick_size: tp.ArrayLike,
relative: tp.ArrayLike = False,
start_value: tp.Optional[float] = None,
max_out_len: tp.Optional[int] = None,
reset_index: bool = False,
jitted: tp.JittedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Frame:
"""See `vectorbtpro.generic.nb.base.to_renko_ohlc_1d_nb`."""
func = jit_reg.resolve_option(nb.to_renko_ohlc_1d_nb, jitted)
arr_out, idx_out = func(
self.to_1d_array(),
reshaping.broadcast_array_to(brick_size, self.wrapper.shape[0]),
relative=reshaping.broadcast_array_to(relative, self.wrapper.shape[0]),
start_value=start_value,
max_out_len=max_out_len,
)
if reset_index:
new_index = pd.RangeIndex(stop=len(idx_out))
else:
new_index = self.wrapper.index[idx_out]
wrap_kwargs = merge_dicts(
dict(index=new_index, columns=["Open", "High", "Low", "Close"]),
wrap_kwargs,
)
return self.wrapper.wrap(arr_out, group_by=False, **wrap_kwargs)
class GenericDFAccessor(GenericAccessor, BaseDFAccessor):
"""Accessor on top of data of any type. For DataFrames only.
Accessible via `pd.DataFrame.vbt`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
mapping: tp.Optional[tp.MappingLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
BaseDFAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs)
if _full_init:
GenericAccessor.__init__(self, wrapper, obj=obj, mapping=mapping, **kwargs)
def band(self, band_name: str, return_meta: bool = False) -> tp.Union[tp.Series, dict]:
"""Calculate the band by its name.
Examples for the band name:
* "50%": 50th quantile
* "Q=50%": 50th quantile
* "Q=0.5": 50th quantile
* "Z=1.96": Z-score of 1.96
* "P=95%": One-tailed significance level of 0.95 (translated into z-score)
* "P=0.95": One-tailed significance level of 0.95 (translated into z-score)
* "median": Median (50th quantile)
* "mean": Mean across all columns
* "min": Min across all columns
* "max": Max across all columns
* "lowest": Column with the lowest final value
* "highest": Column with the highest final value
"""
band_name = band_name.lower().replace(" ", "")
if band_name == "median":
band_name = "50%"
if "%" in band_name and not band_name.startswith("q=") and not band_name.startswith("p="):
band_name = f"q={band_name}"
if band_name.startswith("q="):
if "%" in band_name:
q = float(band_name.replace("q=", "").replace("%", "")) / 100
else:
q = float(band_name.replace("q=", ""))
q_readable = np.around(q * 100, decimals=2)
if q_readable.is_integer():
q_readable = int(q_readable)
band_title = f"Q={q_readable}% (proj)"
def band_func(df, _q=q):
return df.quantile(_q, axis=1)
elif band_name.startswith("z="):
z = float(band_name.replace("z=", ""))
z_readable = np.around(z, decimals=2)
if z_readable.is_integer():
z_readable = int(z_readable)
band_title = f"Z={z_readable} (proj)"
def band_func(df, _z=z):
return df.mean(axis=1) + _z * df.std(axis=1)
elif band_name.startswith("p="):
import scipy.stats as st
if "%" in band_name:
p = float(band_name.replace("p=", "").replace("%", "")) / 100
else:
p = float(band_name.replace("p=", ""))
p_readable = np.around(p * 100, decimals=2)
if p_readable.is_integer():
p_readable = int(p_readable)
band_title = f"P={p_readable}% (proj)"
z = st.norm.ppf(p)
def band_func(df, _z=z):
return df.mean(axis=1) + _z * df.std(axis=1)
elif band_name == "mean":
band_title = "Mean (proj)"
def band_func(df):
return df.mean(axis=1)
elif band_name == "min":
band_title = "Min (proj)"
def band_func(df):
return df.min(axis=1)
elif band_name == "max":
band_title = "Max (proj)"
def band_func(df):
return df.max(axis=1)
elif band_name == "lowest":
band_title = "Lowest (proj)"
def band_func(df):
return df[df.ffill().iloc[-1].idxmin()]
elif band_name == "highest":
band_title = "Highest (proj)"
def band_func(df):
return df[df.ffill().iloc[-1].idxmax()]
else:
raise ValueError(f"Invalid band_name: '{band_name}'")
if return_meta:
return dict(band_name=band_name, band_title=band_title, band_func=band_func)
return band_func(self.obj)
def plot_projections(
self,
plot_projections: bool = True,
plot_bands: bool = True,
plot_lower: tp.Union[bool, str, tp.Callable] = True,
plot_middle: tp.Union[bool, str, tp.Callable] = True,
plot_upper: tp.Union[bool, str, tp.Callable] = True,
plot_aux_middle: tp.Union[bool, str, tp.Callable] = True,
plot_fill: bool = True,
colorize: tp.Union[bool, str, tp.Callable] = True,
rename_levels: tp.Union[None, dict, tp.Sequence] = None,
projection_trace_kwargs: tp.KwargsLike = None,
upper_trace_kwargs: tp.KwargsLike = None,
middle_trace_kwargs: tp.KwargsLike = None,
lower_trace_kwargs: tp.KwargsLike = None,
aux_middle_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot a DataFrame where each column is a projection.
If `plot_projections` is True, will plot each projection as a semi-transparent line.
The arguments `plot_lower`, `plot_middle`, `plot_aux_middle`, and `plot_upper` represent
bands and accept the following:
* True: Plot the band using the default quantile (20/50/80)
* False: Do not plot the band
* callable: Custom function that accepts DataFrame and reduces it across columns
* For other options see `GenericDFAccessor.band`
!!! note
When providing z-scores, the upper should be positive, the middle should be "mean", and
the lower should be negative. When providing significance levels, the middle should be "mean", while
the lower should be positive and lower than the upper, for example, 25% and 75%.
Argument `colorize` allows the following values:
* False: Do not colorize
* True or "median": Colorize by median
* "mean": Colorize by mean
* "last": Colorize by last value
* callable: Custom function that accepts (rebased to 0) Series/DataFrame with
nans already dropped and reduces it across rows
Colorization is performed by mapping the metric value of the band to the range between the
minimum and maximum value across all projections where 0 is always the middle point.
If none of the bands is plotted, projections got colorized. Otherwise, projections stay gray.
Usage:
```pycon
>>> df = pd.DataFrame({
... 0: [10, 11, 12, 11, 10],
... 1: [10, 12, 14, np.nan, np.nan],
... 2: [10, 12, 11, 12, np.nan],
... 3: [10, 9, 8, 9, 8],
... 4: [10, 11, np.nan, np.nan, np.nan],
... })
>>> df.vbt.plot_projections().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if projection_trace_kwargs is None:
projection_trace_kwargs = {}
if lower_trace_kwargs is None:
lower_trace_kwargs = {}
if upper_trace_kwargs is None:
upper_trace_kwargs = {}
if middle_trace_kwargs is None:
middle_trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
# Resolve band functions and names
if len(self.obj.columns) == 1:
plot_bands = False
if not plot_bands:
plot_lower = False
plot_middle = False
plot_upper = False
plot_aux_middle = False
if isinstance(plot_lower, bool):
if plot_lower:
plot_lower = "20%"
else:
plot_lower = None
if isinstance(plot_middle, bool):
if plot_middle:
plot_middle = "50%"
else:
plot_middle = None
if isinstance(plot_upper, bool):
if plot_upper:
plot_upper = "80%"
else:
plot_upper = None
if isinstance(plot_aux_middle, bool):
if plot_aux_middle:
plot_aux_middle = "mean"
else:
plot_aux_middle = None
def _resolve_band_and_name(band_func, arg_name):
band_title = None
if isinstance(band_func, str):
band_func_meta = self.band(band_func, return_meta=True)
band_title = band_func_meta["band_title"]
band_func = band_func_meta["band_func"]
if band_func is not None and not callable(band_func):
raise TypeError(f"Argument {arg_name} has wrong type '{type(band_func)}'")
return band_func, band_title
plot_lower, lower_name = _resolve_band_and_name(plot_lower, "plot_lower")
if lower_name is None:
lower_name = "Lower (proj)"
plot_middle, middle_name = _resolve_band_and_name(plot_middle, "plot_middle")
if middle_name is None:
middle_name = "Middle (proj)"
plot_upper, upper_name = _resolve_band_and_name(plot_upper, "plot_upper")
if upper_name is None:
upper_name = "Upper (proj)"
plot_aux_middle, aux_middle_name = _resolve_band_and_name(plot_aux_middle, "plot_aux_middle")
if aux_middle_name is None:
aux_middle_name = "Aux middle (proj)"
if isinstance(colorize, bool):
if colorize:
colorize = "median"
else:
colorize = None
if colorize is not None:
if isinstance(colorize, str):
colorize = colorize.lower().replace(" ", "")
if colorize == "median":
colorize = lambda x: x.median()
elif colorize == "mean":
colorize = lambda x: x.mean()
elif colorize == "last":
colorize = lambda x: x.ffill().iloc[-1]
else:
raise ValueError(f"Argument colorize has wrong value '{colorize}'")
if colorize is not None and not callable(colorize):
raise TypeError(f"Argument colorize has wrong type '{type(colorize)}'")
if colorize is not None:
proj_min = colorize(self.obj - self.obj.iloc[0]).min()
proj_max = colorize(self.obj - self.obj.iloc[0]).max()
else:
proj_min = None
proj_max = None
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if len(self.obj.columns) > 0:
if plot_projections:
# Plot projections
for col in range(self.wrapper.shape[1]):
proj_sr = self.obj.iloc[:, col].dropna()
hovertemplate = f"(%{{x}}, %{{y}})"
if not checks.is_default_index(self.wrapper.columns):
level_names = []
level_values = []
if isinstance(self.wrapper.columns, pd.MultiIndex):
for l in range(self.wrapper.columns.nlevels):
level_names.append(self.wrapper.columns.names[l])
level_values.append(self.wrapper.columns.get_level_values(l)[col])
else:
level_names.append(self.wrapper.columns.name)
level_values.append(self.wrapper.columns[col])
for l in range(len(level_names)):
level_name = level_names[l]
level_value = level_values[l]
if rename_levels is not None:
if isinstance(rename_levels, dict):
if level_name in rename_levels:
level_name = rename_levels[level_name]
elif l in rename_levels:
level_name = rename_levels[l]
else:
level_name = rename_levels[l]
if level_name is None:
level_name = f"Level {l}"
hovertemplate += f" {level_name}: {level_value}"
if colorize is not None:
proj_color = map_value_to_cmap(
colorize(proj_sr - proj_sr.iloc[0]),
[
plotting_cfg["color_schema"]["red"],
plotting_cfg["color_schema"]["yellow"],
plotting_cfg["color_schema"]["green"],
],
vmin=proj_min,
vcenter=0,
vmax=proj_max,
)
else:
proj_color = plotting_cfg["color_schema"]["gray"]
if not plot_bands:
proj_opacity = 0.5
else:
proj_opacity = 0.1
_projection_trace_kwargs = merge_dicts(
dict(
name=f"proj ({self.obj.shape[1]})",
line=dict(color=proj_color),
opacity=proj_opacity,
legendgroup="proj",
showlegend=col == 0,
hovertemplate=hovertemplate,
),
projection_trace_kwargs,
)
proj_sr.rename(None).vbt.lineplot(
trace_kwargs=_projection_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if plot_bands and len(self.obj.columns) > 1:
# Calculate bands
if plot_lower is not None:
lower_band = plot_lower(self.obj).dropna()
else:
lower_band = None
if plot_middle is not None:
middle_band = plot_middle(self.obj).dropna()
else:
middle_band = None
if plot_upper is not None:
upper_band = plot_upper(self.obj).dropna()
else:
upper_band = None
if plot_aux_middle is not None:
aux_middle_band = plot_aux_middle(self.obj).dropna()
else:
aux_middle_band = None
if lower_band is not None:
# Plot lower band
def_lower_trace_kwargs = dict(name=lower_name)
if colorize is not None:
lower_color = map_value_to_cmap(
colorize(lower_band - lower_band.iloc[0]),
[
plotting_cfg["color_schema"]["red"],
plotting_cfg["color_schema"]["yellow"],
plotting_cfg["color_schema"]["green"],
],
vmin=proj_min,
vcenter=0,
vmax=proj_max,
)
def_lower_trace_kwargs["line"] = dict(color=adjust_opacity(lower_color, 0.75))
else:
lower_color = plotting_cfg["color_schema"]["gray"]
def_lower_trace_kwargs["line"] = dict(color=adjust_opacity(lower_color, 0.5))
lower_band.rename(None).vbt.lineplot(
trace_kwargs=merge_dicts(def_lower_trace_kwargs, lower_trace_kwargs),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if middle_band is not None:
# Plot middle band
def_middle_trace_kwargs = dict(name=middle_name)
if colorize is not None:
middle_color = map_value_to_cmap(
colorize(middle_band - middle_band.iloc[0]),
[
plotting_cfg["color_schema"]["red"],
plotting_cfg["color_schema"]["yellow"],
plotting_cfg["color_schema"]["green"],
],
vmin=proj_min,
vcenter=0,
vmax=proj_max,
)
else:
middle_color = plotting_cfg["color_schema"]["gray"]
def_middle_trace_kwargs["line"] = dict(color=middle_color)
if plot_fill and lower_band is not None:
def_middle_trace_kwargs["fill"] = "tonexty"
def_middle_trace_kwargs["fillcolor"] = adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.25)
middle_band.rename(None).vbt.lineplot(
trace_kwargs=merge_dicts(def_middle_trace_kwargs, middle_trace_kwargs),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if upper_band is not None:
# Plot upper band
def_upper_trace_kwargs = dict(name=upper_name)
if colorize is not None:
upper_color = map_value_to_cmap(
colorize(upper_band - upper_band.iloc[0]),
[
plotting_cfg["color_schema"]["red"],
plotting_cfg["color_schema"]["yellow"],
plotting_cfg["color_schema"]["green"],
],
vmin=proj_min,
vcenter=0,
vmax=proj_max,
)
def_upper_trace_kwargs["line"] = dict(color=adjust_opacity(upper_color, 0.75))
else:
upper_color = plotting_cfg["color_schema"]["gray"]
def_upper_trace_kwargs["line"] = dict(color=adjust_opacity(upper_color, 0.5))
if plot_fill and (lower_band is not None or middle_band is not None):
def_upper_trace_kwargs["fill"] = "tonexty"
def_upper_trace_kwargs["fillcolor"] = adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.25)
upper_band.rename(None).vbt.lineplot(
trace_kwargs=merge_dicts(def_upper_trace_kwargs, upper_trace_kwargs),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if aux_middle_band is not None:
# Plot auxiliary band
def_aux_middle_trace_kwargs = dict(name=aux_middle_name)
if colorize is not None:
aux_middle_color = map_value_to_cmap(
colorize(aux_middle_band - aux_middle_band.iloc[0]),
[
plotting_cfg["color_schema"]["red"],
plotting_cfg["color_schema"]["yellow"],
plotting_cfg["color_schema"]["green"],
],
vmin=proj_min,
vcenter=0,
vmax=proj_max,
)
else:
aux_middle_color = plotting_cfg["color_schema"]["gray"]
def_aux_middle_trace_kwargs["line"] = dict(dash="dot", color=aux_middle_color)
aux_middle_band.rename(None).vbt.lineplot(
trace_kwargs=merge_dicts(def_aux_middle_trace_kwargs, aux_middle_trace_kwargs),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class for analyzing data."""
from vectorbtpro import _typing as tp
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.generic.plots_builder import PlotsBuilderMixin
from vectorbtpro.generic.stats_builder import StatsBuilderMixin
__all__ = [
"Analyzable",
]
class MetaAnalyzable(type(Wrapping), type(StatsBuilderMixin), type(PlotsBuilderMixin)):
"""Metaclass for `Analyzable`."""
pass
AnalyzableT = tp.TypeVar("AnalyzableT", bound="Analyzable")
class Analyzable(Wrapping, StatsBuilderMixin, PlotsBuilderMixin, metaclass=MetaAnalyzable):
"""Class that can be analyzed by computing and plotting attributes of any kind."""
def __init__(self, wrapper: ArrayWrapper, **kwargs) -> None:
Wrapping.__init__(self, wrapper, **kwargs)
StatsBuilderMixin.__init__(self)
PlotsBuilderMixin.__init__(self)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class decorators for generic accessors."""
import inspect
from vectorbtpro import _typing as tp
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts, Config
from vectorbtpro.utils.parsing import get_func_arg_names
__all__ = []
def attach_nb_methods(config: Config) -> tp.ClassWrapper:
"""Class decorator to attach Numba methods.
`config` must contain target method names (keys) and dictionaries (values) with the following keys:
* `func`: Function that must be wrapped. The first argument must expect a 2-dim array.
* `is_reducing`: Whether the function is reducing. Defaults to False.
* `disable_jitted`: Whether to disable the `jitted` option.
* `disable_chunked`: Whether to disable the `chunked` option.
* `replace_signature`: Whether to replace the target signature with the source signature. Defaults to True.
* `wrap_kwargs`: Default keyword arguments for wrapping. Will be merged with the dict supplied by the user.
Defaults to `dict(name_or_index=target_name)` for reducing functions.
The class must be a subclass of `vectorbtpro.base.wrapping.Wrapping`.
"""
def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
from vectorbtpro.base.wrapping import Wrapping
checks.assert_subclass_of(cls, Wrapping)
for target_name, settings in config.items():
func = settings["func"]
is_reducing = settings.get("is_reducing", False)
disable_jitted = settings.get("disable_jitted", False)
disable_chunked = settings.get("disable_chunked", False)
replace_signature = settings.get("replace_signature", True)
default_wrap_kwargs = settings.get("wrap_kwargs", dict(name_or_index=target_name) if is_reducing else None)
def new_method(
self,
*args,
_target_name: str = target_name,
_func: tp.Callable = func,
_is_reducing: bool = is_reducing,
_disable_jitted: bool = disable_jitted,
_disable_chunked: bool = disable_chunked,
_default_wrap_kwargs: tp.KwargsLike = default_wrap_kwargs,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
args = (self.to_2d_array(),) + args
inspect.signature(_func).bind(*args, **kwargs)
if not _disable_jitted:
_func = jit_reg.resolve_option(_func, jitted)
elif jitted is not None:
raise ValueError("This method doesn't support jitting")
if not _disable_chunked:
_func = ch_reg.resolve_option(_func, chunked)
elif chunked is not None:
raise ValueError("This method doesn't support chunking")
a = _func(*args, **kwargs)
wrap_kwargs = merge_dicts(_default_wrap_kwargs, wrap_kwargs)
if _is_reducing:
return self.wrapper.wrap_reduced(a, **wrap_kwargs)
return self.wrapper.wrap(a, **wrap_kwargs)
if replace_signature:
# Replace the function's signature with the original one
source_sig = inspect.signature(func)
new_method_params = tuple(inspect.signature(new_method).parameters.values())
self_arg = new_method_params[0]
jitted_arg = new_method_params[-4]
chunked_arg = new_method_params[-3]
wrap_kwargs_arg = new_method_params[-2]
new_parameters = (self_arg,) + tuple(source_sig.parameters.values())[1:]
if not disable_jitted:
new_parameters += (jitted_arg,)
if not disable_chunked:
new_parameters += (chunked_arg,)
new_parameters += (wrap_kwargs_arg,)
new_method.__signature__ = source_sig.replace(parameters=new_parameters)
new_method.__name__ = target_name
new_method.__module__ = cls.__module__
new_method.__qualname__ = f"{cls.__name__}.{new_method.__name__}"
new_method.__doc__ = f"See `{func.__module__ + '.' + func.__name__}`."
setattr(cls, target_name, new_method)
return cls
return wrapper
def attach_transform_methods(config: Config) -> tp.ClassWrapper:
"""Class decorator to add transformation methods.
`config` must contain target method names (keys) and dictionaries (values) with the following keys:
* `transformer`: Transformer class/object.
* `docstring`: Method docstring.
* `replace_signature`: Whether to replace the target signature. Defaults to True.
The class must be a subclass of `vectorbtpro.generic.accessors.GenericAccessor`.
"""
def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
from vectorbtpro.generic.accessors import TransformerT
checks.assert_subclass_of(cls, "GenericAccessor")
for target_name, settings in config.items():
transformer = settings["transformer"]
docstring = settings.get("docstring", f"See `{transformer.__name__}`.")
replace_signature = settings.get("replace_signature", True)
def new_method(
self,
_target_name: str = target_name,
_transformer: tp.Union[tp.Type[TransformerT], TransformerT] = transformer,
**kwargs,
) -> tp.SeriesFrame:
if inspect.isclass(_transformer):
arg_names = get_func_arg_names(_transformer.__init__)
transformer_kwargs = dict()
for arg_name in arg_names:
if arg_name in kwargs:
transformer_kwargs[arg_name] = kwargs.pop(arg_name)
return self.transform(_transformer(**transformer_kwargs), **kwargs)
return self.transform(_transformer, **kwargs)
if replace_signature:
source_sig = inspect.signature(transformer.__init__)
new_method_params = tuple(inspect.signature(new_method).parameters.values())
if inspect.isclass(transformer):
transformer_params = tuple(source_sig.parameters.values())
source_sig = inspect.Signature(
(new_method_params[0],) + transformer_params[1:] + (new_method_params[-1],),
)
new_method.__signature__ = source_sig
else:
source_sig = inspect.Signature((new_method_params[0],) + (new_method_params[-1],))
new_method.__signature__ = source_sig
new_method.__name__ = target_name
new_method.__module__ = cls.__module__
new_method.__qualname__ = f"{cls.__name__}.{new_method.__name__}"
new_method.__doc__ = docstring
setattr(cls, target_name, new_method)
return cls
return wrapper
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with drawdown records.
Drawdown records capture information on drawdowns. Since drawdowns are ranges,
they subclass `vectorbtpro.generic.ranges.Ranges`.
!!! warning
`Drawdowns` return both recovered AND active drawdowns, which may skew your performance results.
To only consider recovered drawdowns, you should explicitly query `status_recovered` attribute.
Using `Drawdowns.from_price`, you can generate drawdown records for any time series and analyze them right away.
```pycon
>>> from vectorbtpro import *
>>> price = vbt.YFData.pull(
... "BTC-USD",
... start="2019-10 UTC",
... end="2020-01 UTC"
... ).get('Close')
```
[=100% "100%"]{: .candystripe .candystripe-animate }
```pycon
>>> price = price.rename(None)
>>> drawdowns = vbt.Drawdowns.from_price(price, wrapper_kwargs=dict(freq='d'))
>>> drawdowns.readable
Drawdown Id Column Start Index Valley Index \\
0 0 0 2019-10-02 00:00:00+00:00 2019-10-06 00:00:00+00:00
1 1 0 2019-10-09 00:00:00+00:00 2019-10-24 00:00:00+00:00
2 2 0 2019-10-27 00:00:00+00:00 2019-12-17 00:00:00+00:00
End Index Peak Value Valley Value End Value Status
0 2019-10-09 00:00:00+00:00 8393.041992 7988.155762 8595.740234 Recovered
1 2019-10-25 00:00:00+00:00 8595.740234 7493.488770 8660.700195 Recovered
2 2019-12-31 00:00:00+00:00 9551.714844 6640.515137 7193.599121 Active
>>> drawdowns.duration.max(wrap_kwargs=dict(to_timedelta=True))
Timedelta('66 days 00:00:00')
```
## From accessors
Moreover, all generic accessors have a property `drawdowns` and a method `get_drawdowns`:
```pycon
>>> # vectorbtpro.generic.accessors.GenericAccessor.drawdowns.coverage
>>> price.vbt.drawdowns.coverage
0.967391304347826
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Drawdowns.metrics`.
```pycon
>>> df = pd.DataFrame({
... 'a': [1, 2, 1, 3, 2],
... 'b': [2, 3, 1, 2, 1]
... })
>>> drawdowns = df.vbt(freq='d').drawdowns
>>> drawdowns['a'].stats()
Start 0
End 4
Period 5 days 00:00:00
Coverage [%] 80.0
Total Records 2
Total Recovered Drawdowns 1
Total Active Drawdowns 1
Active Drawdown [%] 33.333333
Active Duration 2 days 00:00:00
Active Recovery [%] 0.0
Active Recovery Return [%] 0.0
Active Recovery Duration 0 days 00:00:00
Max Drawdown [%] 50.0
Avg Drawdown [%] 50.0
Max Drawdown Duration 2 days 00:00:00
Avg Drawdown Duration 2 days 00:00:00
Max Recovery Return [%] 200.0
Avg Recovery Return [%] 200.0
Max Recovery Duration 1 days 00:00:00
Avg Recovery Duration 1 days 00:00:00
Avg Recovery Duration Ratio 1.0
Name: a, dtype: object
```
By default, the metrics `max_dd`, `avg_dd`, `max_dd_duration`, and `avg_dd_duration` do
not include active drawdowns. To change that, pass `incl_active=True`:
```pycon
>>> drawdowns['a'].stats(settings=dict(incl_active=True))
Start 0
End 4
Period 5 days 00:00:00
Coverage [%] 80.0
Total Records 2
Total Recovered Drawdowns 1
Total Active Drawdowns 1
Active Drawdown [%] 33.333333
Active Duration 2 days 00:00:00
Active Recovery [%] 0.0
Active Recovery Return [%] 0.0
Active Recovery Duration 0 days 00:00:00
Max Drawdown [%] 50.0
Avg Drawdown [%] 41.666667
Max Drawdown Duration 2 days 00:00:00
Avg Drawdown Duration 2 days 00:00:00
Max Recovery Return [%] 200.0
Avg Recovery Return [%] 200.0
Max Recovery Duration 1 days 00:00:00
Avg Recovery Duration 1 days 00:00:00
Avg Recovery Duration Ratio 1.0
Name: a, dtype: object
```
`Drawdowns.stats` also supports (re-)grouping:
```pycon
>>> drawdowns['a'].stats(group_by=True)
UserWarning: Metric 'active_dd' does not support grouped data
UserWarning: Metric 'active_duration' does not support grouped data
UserWarning: Metric 'active_recovery' does not support grouped data
UserWarning: Metric 'active_recovery_return' does not support grouped data
UserWarning: Metric 'active_recovery_duration' does not support grouped data
Start 0
End 4
Period 5 days 00:00:00
Coverage [%] 80.0
Total Records 2
Total Recovered Drawdowns 1
Total Active Drawdowns 1
Max Drawdown [%] 50.0
Avg Drawdown [%] 50.0
Max Drawdown Duration 2 days 00:00:00
Avg Drawdown Duration 2 days 00:00:00
Max Recovery Return [%] 200.0
Avg Recovery Return [%] 200.0
Max Recovery Duration 1 days 00:00:00
Avg Recovery Duration 1 days 00:00:00
Avg Recovery Duration Ratio 1.0
Name: group, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Drawdowns.subplots`.
`Drawdowns` class has a single subplot based on `Drawdowns.plot`:
```pycon
>>> drawdowns['a'].plots().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_1d_array, to_2d_array
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic import nb
from vectorbtpro.generic.enums import DrawdownStatus, drawdown_dt, range_dt
from vectorbtpro.generic.ranges import Ranges
from vectorbtpro.records.decorators import override_field_config, attach_fields, attach_shortcut_properties
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils.colors import adjust_lightness
from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, ReadonlyConfig, HybridConfig
from vectorbtpro.utils.template import RepEval, RepFunc
__all__ = [
"Drawdowns",
]
__pdoc__ = {}
dd_field_config = ReadonlyConfig(
dict(
dtype=drawdown_dt,
settings=dict(
id=dict(title="Drawdown Id"),
valley_idx=dict(title="Valley Index", mapping="index"),
start_val=dict(
title="Start Value",
),
valley_val=dict(
title="Valley Value",
),
end_val=dict(
title="End Value",
),
status=dict(mapping=DrawdownStatus),
),
)
)
"""_"""
__pdoc__[
"dd_field_config"
] = f"""Field config for `Drawdowns`.
```python
{dd_field_config.prettify()}
```
"""
dd_attach_field_config = ReadonlyConfig(dict(status=dict(attach_filters=True)))
"""_"""
__pdoc__[
"dd_attach_field_config"
] = f"""Config of fields to be attached to `Drawdowns`.
```python
{dd_attach_field_config.prettify()}
```
"""
dd_shortcut_config = ReadonlyConfig(
dict(
ranges=dict(),
decline_ranges=dict(),
recovery_ranges=dict(),
drawdown=dict(obj_type="mapped_array"),
avg_drawdown=dict(obj_type="red_array"),
max_drawdown=dict(obj_type="red_array"),
recovery_return=dict(obj_type="mapped_array"),
avg_recovery_return=dict(obj_type="red_array"),
max_recovery_return=dict(obj_type="red_array"),
decline_duration=dict(obj_type="mapped_array"),
recovery_duration=dict(obj_type="mapped_array"),
recovery_duration_ratio=dict(obj_type="mapped_array"),
active_drawdown=dict(obj_type="red_array"),
active_duration=dict(obj_type="red_array"),
active_recovery=dict(obj_type="red_array"),
active_recovery_return=dict(obj_type="red_array"),
active_recovery_duration=dict(obj_type="red_array"),
)
)
"""_"""
__pdoc__[
"dd_shortcut_config"
] = f"""Config of shortcut properties to be attached to `Drawdowns`.
```python
{dd_shortcut_config.prettify()}
```
"""
DrawdownsT = tp.TypeVar("DrawdownsT", bound="Drawdowns")
@attach_shortcut_properties(dd_shortcut_config)
@attach_fields(dd_attach_field_config)
@override_field_config(dd_field_config)
class Drawdowns(Ranges):
"""Extends `vectorbtpro.generic.ranges.Ranges` for working with drawdown records.
Requires `records_arr` to have all fields defined in `vectorbtpro.generic.enums.drawdown_dt`."""
@property
def field_config(self) -> Config:
return self._field_config
@classmethod
def from_price(
cls: tp.Type[DrawdownsT],
close: tp.ArrayLike,
*,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
attach_data: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> DrawdownsT:
"""Build `Drawdowns` from price.
`**kwargs` will be passed to `Drawdowns.__init__`."""
if wrapper_kwargs is None:
wrapper_kwargs = {}
close_arr = to_2d_array(close)
open_arr = to_2d_array(open) if open is not None else None
high_arr = to_2d_array(high) if high is not None else None
low_arr = to_2d_array(low) if low is not None else None
func = jit_reg.resolve_option(nb.get_drawdowns_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
records_arr = func(
open=open_arr,
high=high_arr,
low=low_arr,
close=close_arr,
sim_start=sim_start,
sim_end=sim_end,
)
if wrapper is None:
wrapper = ArrayWrapper.from_obj(close, **resolve_dict(wrapper_kwargs))
elif len(wrapper_kwargs) > 0:
wrapper = wrapper.replace(**wrapper_kwargs)
return cls(
wrapper,
records_arr,
open=open if attach_data else None,
high=high if attach_data else None,
low=low if attach_data else None,
close=close if attach_data else None,
**kwargs,
)
def get_ranges(self, **kwargs) -> Ranges:
"""Get records of type `vectorbtpro.generic.ranges.Ranges` for peak-to-end ranges."""
new_records_arr = np.empty(self.values.shape, dtype=range_dt)
new_records_arr["id"][:] = self.get_field_arr("id").copy()
new_records_arr["col"][:] = self.get_field_arr("col").copy()
new_records_arr["start_idx"][:] = self.get_field_arr("start_idx").copy()
new_records_arr["end_idx"][:] = self.get_field_arr("end_idx").copy()
new_records_arr["status"][:] = self.get_field_arr("status").copy()
return Ranges.from_records(
self.wrapper,
new_records_arr,
open=self._open,
high=self._high,
low=self._low,
close=self._close,
**kwargs,
)
def get_decline_ranges(self, **kwargs) -> Ranges:
"""Get records of type `vectorbtpro.generic.ranges.Ranges` for peak-to-valley ranges."""
new_records_arr = np.empty(self.values.shape, dtype=range_dt)
new_records_arr["id"][:] = self.get_field_arr("id").copy()
new_records_arr["col"][:] = self.get_field_arr("col").copy()
new_records_arr["start_idx"][:] = self.get_field_arr("start_idx").copy()
new_records_arr["end_idx"][:] = self.get_field_arr("valley_idx").copy()
new_records_arr["status"][:] = self.get_field_arr("status").copy()
return Ranges.from_records(
self.wrapper,
new_records_arr,
open=self._open,
high=self._high,
low=self._low,
close=self._close,
**kwargs,
)
def get_recovery_ranges(self, **kwargs) -> Ranges:
"""Get records of type `vectorbtpro.generic.ranges.Ranges` for valley-to-end ranges."""
new_records_arr = np.empty(self.values.shape, dtype=range_dt)
new_records_arr["id"][:] = self.get_field_arr("id").copy()
new_records_arr["col"][:] = self.get_field_arr("col").copy()
new_records_arr["start_idx"][:] = self.get_field_arr("valley_idx").copy()
new_records_arr["end_idx"][:] = self.get_field_arr("end_idx").copy()
new_records_arr["status"][:] = self.get_field_arr("status").copy()
return Ranges.from_records(
self.wrapper,
new_records_arr,
open=self._open,
high=self._high,
low=self._low,
close=self._close,
**kwargs,
)
# ############# Drawdown ############# #
def get_drawdown(self, jitted: tp.JittedOption = None, chunked: tp.ChunkedOption = None, **kwargs) -> MappedArray:
"""See `vectorbtpro.generic.nb.records.dd_drawdown_nb`.
Takes into account both recovered and active drawdowns."""
func = jit_reg.resolve_option(nb.dd_drawdown_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
drawdown = func(self.get_field_arr("start_val"), self.get_field_arr("valley_val"))
return self.map_array(drawdown, **kwargs)
def get_avg_drawdown(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get average drawdown (ADD).
Based on `Drawdowns.drawdown`."""
wrap_kwargs = merge_dicts(dict(name_or_index="avg_drawdown"), wrap_kwargs)
return self.drawdown.mean(group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs)
def get_max_drawdown(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get maximum drawdown (MDD).
Based on `Drawdowns.drawdown`."""
wrap_kwargs = merge_dicts(dict(name_or_index="max_drawdown"), wrap_kwargs)
return self.drawdown.min(group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs)
# ############# Recovery ############# #
def get_recovery_return(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArray:
"""See `vectorbtpro.generic.nb.records.dd_recovery_return_nb`.
Takes into account both recovered and active drawdowns."""
func = jit_reg.resolve_option(nb.dd_recovery_return_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
recovery_return = func(self.get_field_arr("valley_val"), self.get_field_arr("end_val"))
return self.map_array(recovery_return, **kwargs)
def get_avg_recovery_return(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get average recovery return.
Based on `Drawdowns.recovery_return`."""
wrap_kwargs = merge_dicts(dict(name_or_index="avg_recovery_return"), wrap_kwargs)
return self.recovery_return.mean(
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
def get_max_recovery_return(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get maximum recovery return.
Based on `Drawdowns.recovery_return`."""
wrap_kwargs = merge_dicts(dict(name_or_index="max_recovery_return"), wrap_kwargs)
return self.recovery_return.max(
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
# ############# Duration ############# #
def get_decline_duration(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArray:
"""See `vectorbtpro.generic.nb.records.dd_decline_duration_nb`.
Takes into account both recovered and active drawdowns."""
func = jit_reg.resolve_option(nb.dd_decline_duration_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
decline_duration = func(self.get_field_arr("start_idx"), self.get_field_arr("valley_idx"))
return self.map_array(decline_duration, **kwargs)
def get_recovery_duration(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArray:
"""See `vectorbtpro.generic.nb.records.dd_recovery_duration_nb`.
A value higher than 1 means the recovery was slower than the decline.
Takes into account both recovered and active drawdowns."""
func = jit_reg.resolve_option(nb.dd_recovery_duration_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
recovery_duration = func(self.get_field_arr("valley_idx"), self.get_field_arr("end_idx"))
return self.map_array(recovery_duration, **kwargs)
def get_recovery_duration_ratio(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArray:
"""See `vectorbtpro.generic.nb.records.dd_recovery_duration_ratio_nb`.
Takes into account both recovered and active drawdowns."""
func = jit_reg.resolve_option(nb.dd_recovery_duration_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
recovery_duration_ratio = func(
self.get_field_arr("start_idx"),
self.get_field_arr("valley_idx"),
self.get_field_arr("end_idx"),
)
return self.map_array(recovery_duration_ratio, **kwargs)
# ############# Status: Active ############# #
def get_active_drawdown(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get drawdown of the last active drawdown only.
Does not support grouping."""
if self.wrapper.grouper.is_grouped(group_by=group_by):
raise ValueError("Grouping is not supported by this method")
wrap_kwargs = merge_dicts(dict(name_or_index="active_drawdown"), wrap_kwargs)
active = self.status_active
curr_end_val = active.end_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked)
curr_start_val = active.start_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked)
curr_drawdown = (curr_end_val - curr_start_val) / curr_start_val
return self.wrapper.wrap_reduced(curr_drawdown, group_by=group_by, **wrap_kwargs)
def get_active_duration(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get duration of the last active drawdown only.
Does not support grouping."""
if self.wrapper.grouper.is_grouped(group_by=group_by):
raise ValueError("Grouping is not supported by this method")
wrap_kwargs = merge_dicts(dict(to_timedelta=True, name_or_index="active_duration"), wrap_kwargs)
return self.status_active.duration.nth(
-1,
jitted=jitted,
chunked=chunked,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
def get_active_recovery(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get recovery of the last active drawdown only.
Does not support grouping."""
if self.wrapper.grouper.is_grouped(group_by=group_by):
raise ValueError("Grouping is not supported by this method")
wrap_kwargs = merge_dicts(dict(name_or_index="active_recovery"), wrap_kwargs)
active = self.status_active
curr_start_val = active.start_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked)
curr_end_val = active.end_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked)
curr_valley_val = active.valley_val.nth(-1, group_by=group_by, jitted=jitted, chunked=chunked)
curr_recovery = (curr_end_val - curr_valley_val) / (curr_start_val - curr_valley_val)
return self.wrapper.wrap_reduced(curr_recovery, group_by=group_by, **wrap_kwargs)
def get_active_recovery_return(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get recovery return of the last active drawdown only.
Does not support grouping."""
if self.wrapper.grouper.is_grouped(group_by=group_by):
raise ValueError("Grouping is not supported by this method")
wrap_kwargs = merge_dicts(dict(name_or_index="active_recovery_return"), wrap_kwargs)
return self.status_active.recovery_return.nth(
-1,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
def get_active_recovery_duration(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get recovery duration of the last active drawdown only.
Does not support grouping."""
if self.wrapper.grouper.is_grouped(group_by=group_by):
raise ValueError("Grouping is not supported by this method")
wrap_kwargs = merge_dicts(dict(to_timedelta=True, name_or_index="active_recovery_duration"), wrap_kwargs)
return self.status_active.recovery_duration.nth(
-1,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Drawdowns.stats`.
Merges `vectorbtpro.generic.ranges.Ranges.stats_defaults` and
`stats` from `vectorbtpro._settings.drawdowns`."""
from vectorbtpro._settings import settings
drawdowns_stats_cfg = settings["drawdowns"]["stats"]
return merge_dicts(Ranges.stats_defaults.__get__(self), drawdowns_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
coverage=dict(
title="Coverage [%]",
calc_func="coverage",
post_calc_func=lambda self, out, settings: out * 100,
tags=["ranges", "duration"],
),
total_records=dict(title="Total Records", calc_func="count", tags="records"),
total_recovered=dict(
title="Total Recovered Drawdowns",
calc_func="status_recovered.count",
tags="drawdowns",
),
total_active=dict(title="Total Active Drawdowns", calc_func="status_active.count", tags="drawdowns"),
active_dd=dict(
title="Active Drawdown [%]",
calc_func="active_drawdown",
post_calc_func=lambda self, out, settings: -out * 100,
check_is_not_grouped=True,
tags=["drawdowns", "active"],
),
active_duration=dict(
title="Active Duration",
calc_func="active_duration",
fill_wrap_kwargs=True,
check_is_not_grouped=True,
tags=["drawdowns", "active", "duration"],
),
active_recovery=dict(
title="Active Recovery [%]",
calc_func="active_recovery",
post_calc_func=lambda self, out, settings: out * 100,
check_is_not_grouped=True,
tags=["drawdowns", "active"],
),
active_recovery_return=dict(
title="Active Recovery Return [%]",
calc_func="active_recovery_return",
post_calc_func=lambda self, out, settings: out * 100,
check_is_not_grouped=True,
tags=["drawdowns", "active"],
),
active_recovery_duration=dict(
title="Active Recovery Duration",
calc_func="active_recovery_duration",
fill_wrap_kwargs=True,
check_is_not_grouped=True,
tags=["drawdowns", "active", "duration"],
),
max_dd=dict(
title="Max Drawdown [%]",
calc_func=RepEval("'max_drawdown' if incl_active else 'status_recovered.get_max_drawdown'"),
post_calc_func=lambda self, out, settings: -out * 100,
tags=RepEval("['drawdowns'] if incl_active else ['drawdowns', 'recovered']"),
),
avg_dd=dict(
title="Avg Drawdown [%]",
calc_func=RepEval("'avg_drawdown' if incl_active else 'status_recovered.get_avg_drawdown'"),
post_calc_func=lambda self, out, settings: -out * 100,
tags=RepEval("['drawdowns'] if incl_active else ['drawdowns', 'recovered']"),
),
max_dd_duration=dict(
title="Max Drawdown Duration",
calc_func=RepEval("'max_duration' if incl_active else 'status_recovered.get_max_duration'"),
fill_wrap_kwargs=True,
tags=RepEval("['drawdowns', 'duration'] if incl_active else ['drawdowns', 'recovered', 'duration']"),
),
avg_dd_duration=dict(
title="Avg Drawdown Duration",
calc_func=RepEval("'avg_duration' if incl_active else 'status_recovered.get_avg_duration'"),
fill_wrap_kwargs=True,
tags=RepEval("['drawdowns', 'duration'] if incl_active else ['drawdowns', 'recovered', 'duration']"),
),
max_return=dict(
title="Max Recovery Return [%]",
calc_func="status_recovered.recovery_return.max",
post_calc_func=lambda self, out, settings: out * 100,
tags=["drawdowns", "recovered"],
),
avg_return=dict(
title="Avg Recovery Return [%]",
calc_func="status_recovered.recovery_return.mean",
post_calc_func=lambda self, out, settings: out * 100,
tags=["drawdowns", "recovered"],
),
max_recovery_duration=dict(
title="Max Recovery Duration",
calc_func="status_recovered.recovery_duration.max",
apply_to_timedelta=True,
tags=["drawdowns", "recovered", "duration"],
),
avg_recovery_duration=dict(
title="Avg Recovery Duration",
calc_func="status_recovered.recovery_duration.mean",
apply_to_timedelta=True,
tags=["drawdowns", "recovered", "duration"],
),
recovery_duration_ratio=dict(
title="Avg Recovery Duration Ratio",
calc_func="status_recovered.recovery_duration_ratio.mean",
tags=["drawdowns", "recovered"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot(
self,
column: tp.Optional[tp.Label] = None,
top_n: tp.Optional[int] = 5,
plot_ohlc: bool = True,
plot_close: bool = True,
plot_markers: bool = True,
plot_zones: bool = True,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
close_trace_kwargs: tp.KwargsLike = None,
peak_trace_kwargs: tp.KwargsLike = None,
valley_trace_kwargs: tp.KwargsLike = None,
recovery_trace_kwargs: tp.KwargsLike = None,
active_trace_kwargs: tp.KwargsLike = None,
decline_shape_kwargs: tp.KwargsLike = None,
recovery_shape_kwargs: tp.KwargsLike = None,
active_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot drawdowns.
Args:
column (str): Name of the column to plot.
top_n (int): Filter top N drawdown records by maximum drawdown.
plot_ohlc (bool): Whether to plot OHLC.
plot_close (bool): Whether to plot close.
plot_markers (bool): Whether to plot markers.
plot_zones (bool): Whether to plot zones.
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Drawdowns.close`.
peak_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for peak values.
valley_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for valley values.
recovery_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for recovery values.
active_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for active recovery values.
decline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for decline zones.
recovery_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for recovery zones.
active_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for active recovery zones.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
xref (str): X coordinate axis.
yref (str): Y coordinate axis.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> index = pd.date_range("2020", periods=8)
>>> price = pd.Series([1, 2, 1, 2, 3, 2, 1, 2], index=index)
>>> vbt.Drawdowns.from_price(price, wrapper_kwargs=dict(freq='1 day')).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure, get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
if top_n is not None:
# Drawdowns is negative, thus top_n becomes bottom_n
self_col = self_col.apply_mask(self_col.drawdown.bottom_n_mask(top_n))
if ohlc_trace_kwargs is None:
ohlc_trace_kwargs = {}
if close_trace_kwargs is None:
close_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"),
close_trace_kwargs,
)
if peak_trace_kwargs is None:
peak_trace_kwargs = {}
if valley_trace_kwargs is None:
valley_trace_kwargs = {}
if recovery_trace_kwargs is None:
recovery_trace_kwargs = {}
if active_trace_kwargs is None:
active_trace_kwargs = {}
if decline_shape_kwargs is None:
decline_shape_kwargs = {}
if recovery_shape_kwargs is None:
recovery_shape_kwargs = {}
if active_shape_kwargs is None:
active_shape_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
y_domain = get_domain(yref, fig)
plotting_ohlc = False
if (
plot_ohlc
and self_col._open is not None
and self_col._high is not None
and self_col._low is not None
and self_col._close is not None
):
plotting_ohlc = True
ohlc_df = pd.DataFrame(
{
"open": self_col.open,
"high": self_col.high,
"low": self_col.low,
"close": self_col.close,
}
)
if "opacity" not in ohlc_trace_kwargs:
ohlc_trace_kwargs["opacity"] = 0.5
fig = ohlc_df.vbt.ohlcv.plot(
ohlc_type=ohlc_type,
plot_volume=False,
ohlc_trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
elif plot_close and self_col._close is not None:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if self_col.count() > 0:
# Extract information
id_ = self_col.get_field_arr("id")
start_idx = self_col.get_map_field_to_index("start_idx")
if not plotting_ohlc and self_col._close is not None:
start_val = self_col.close.loc[start_idx]
else:
start_val = self_col.get_field_arr("start_val")
valley_idx = self_col.get_map_field_to_index("valley_idx")
if not plotting_ohlc and self_col._close is not None:
valley_val = self_col.close.loc[valley_idx]
else:
valley_val = self_col.get_field_arr("valley_val")
end_idx = self_col.get_map_field_to_index("end_idx")
if not plotting_ohlc and self_col._close is not None:
end_val = self_col.close.loc[end_idx]
else:
end_val = self_col.get_field_arr("end_val")
drawdown = self_col.drawdown.values
recovery_return = self_col.recovery_return.values
status = self_col.get_field_arr("status")
decline_duration = to_1d_array(
self_col.wrapper.arr_to_timedelta(
self_col.decline_duration.values, to_pd=True, silence_warnings=True
).astype(str)
)
recovery_duration = to_1d_array(
self_col.wrapper.arr_to_timedelta(
self_col.recovery_duration.values, to_pd=True, silence_warnings=True
).astype(str)
)
duration = to_1d_array(
self_col.wrapper.arr_to_timedelta(self_col.duration.values, to_pd=True, silence_warnings=True).astype(
str
)
)
# Peak and recovery at same time -> recovery wins
peak_mask = (start_val != np.roll(end_val, 1)) | (start_idx != np.roll(end_idx, 1))
if peak_mask.any():
if plot_markers:
# Plot peak markers
peak_customdata, peak_hovertemplate = self_col.prepare_customdata(
incl_fields=["id", "start_idx", "start_val"], mask=peak_mask
)
_peak_trace_kwargs = merge_dicts(
dict(
x=start_idx[peak_mask],
y=start_val[peak_mask],
mode="markers",
marker=dict(
symbol="diamond",
color=plotting_cfg["contrast_color_schema"]["blue"],
size=7,
line=dict(
width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["blue"])
),
),
name="Peak",
customdata=peak_customdata,
hovertemplate=peak_hovertemplate,
),
peak_trace_kwargs,
)
peak_scatter = go.Scatter(**_peak_trace_kwargs)
fig.add_trace(peak_scatter, **add_trace_kwargs)
recovered_mask = status == DrawdownStatus.Recovered
if recovered_mask.any():
if plot_markers:
# Plot valley markers
valley_customdata, valley_hovertemplate = self_col.prepare_customdata(
incl_fields=["id", "valley_idx", "valley_val"],
append_info=[
(drawdown, "Drawdown", "$title: %{customdata[$index]:,%}"),
(decline_duration, "Decline duration"),
],
mask=recovered_mask,
)
_valley_trace_kwargs = merge_dicts(
dict(
x=valley_idx[recovered_mask],
y=valley_val[recovered_mask],
mode="markers",
marker=dict(
symbol="diamond",
color=plotting_cfg["contrast_color_schema"]["red"],
size=7,
line=dict(
width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["red"])
),
),
name="Valley",
customdata=valley_customdata,
hovertemplate=valley_hovertemplate,
),
valley_trace_kwargs,
)
valley_scatter = go.Scatter(**_valley_trace_kwargs)
fig.add_trace(valley_scatter, **add_trace_kwargs)
if plot_markers:
# Plot recovery markers
recovery_customdata, recovery_hovertemplate = self_col.prepare_customdata(
incl_fields=["id", "end_idx", "end_val"],
append_info=[
(drawdown, "Drawdown", "$title: %{customdata[$index]:,%}"),
(duration, "Drawdown duration"),
(recovery_return, "Recovery return", "$title: %{customdata[$index]:,%}"),
(recovery_duration, "Recovery duration"),
],
mask=recovered_mask,
)
_recovery_trace_kwargs = merge_dicts(
dict(
x=end_idx[recovered_mask],
y=end_val[recovered_mask],
mode="markers",
marker=dict(
symbol="diamond",
color=plotting_cfg["contrast_color_schema"]["green"],
size=7,
line=dict(
width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["green"])
),
),
name="Recovery/Peak",
customdata=recovery_customdata,
hovertemplate=recovery_hovertemplate,
),
recovery_trace_kwargs,
)
recovery_scatter = go.Scatter(**_recovery_trace_kwargs)
fig.add_trace(recovery_scatter, **add_trace_kwargs)
active_mask = status == DrawdownStatus.Active
if active_mask.any():
if plot_markers:
# Plot active markers
active_customdata, active_hovertemplate = self_col.prepare_customdata(
incl_fields=["id"],
append_info=[
(drawdown, "Drawdown", "$title: %{customdata[$index]:,%}"),
(duration, "Drawdown duration"),
],
mask=active_mask,
)
_active_trace_kwargs = merge_dicts(
dict(
x=end_idx[active_mask],
y=end_val[active_mask],
mode="markers",
marker=dict(
symbol="diamond",
color=plotting_cfg["contrast_color_schema"]["orange"],
size=7,
line=dict(
width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["orange"])
),
),
name="Active",
customdata=active_customdata,
hovertemplate=active_hovertemplate,
),
active_trace_kwargs,
)
active_scatter = go.Scatter(**_active_trace_kwargs)
fig.add_trace(active_scatter, **add_trace_kwargs)
if plot_zones:
# Plot drawdown zones
self_col.status_recovered.plot_shapes(
plot_ohlc=False,
plot_close=False,
shape_kwargs=merge_dicts(
dict(
x0=RepFunc(lambda i: start_idx[recovered_mask][i]),
x1=RepFunc(lambda i: valley_idx[recovered_mask][i]),
fillcolor=plotting_cfg["contrast_color_schema"]["red"],
),
decline_shape_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
# Plot recovery zones
self_col.status_recovered.plot_shapes(
plot_ohlc=False,
plot_close=False,
shape_kwargs=merge_dicts(
dict(
x0=RepFunc(lambda i: valley_idx[recovered_mask][i]),
x1=RepFunc(lambda i: end_idx[recovered_mask][i]),
fillcolor=plotting_cfg["contrast_color_schema"]["green"],
),
recovery_shape_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
# Plot active drawdown zones
self_col.status_active.plot_shapes(
plot_ohlc=False,
plot_close=False,
shape_kwargs=merge_dicts(
dict(
x0=RepFunc(lambda i: start_idx[active_mask][i]),
x1=RepFunc(lambda i: end_idx[active_mask][i]),
fillcolor=plotting_cfg["contrast_color_schema"]["orange"],
),
active_shape_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Drawdowns.plots`.
Merges `vectorbtpro.generic.ranges.Ranges.plots_defaults` and
`plots` from `vectorbtpro._settings.drawdowns`."""
from vectorbtpro._settings import settings
drawdowns_plots_cfg = settings["drawdowns"]["plots"]
return merge_dicts(Ranges.plots_defaults.__get__(self), drawdowns_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot=dict(
title="Drawdowns",
check_is_not_grouped=True,
plot_func="plot",
tags="drawdowns",
)
),
)
@property
def subplots(self) -> Config:
return self._subplots
Drawdowns.override_field_config_doc(__pdoc__)
Drawdowns.override_metrics_doc(__pdoc__)
Drawdowns.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Named tuples and enumerated types for generic data.
Defines enums and other schemas for `vectorbtpro.generic`."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.utils.formatting import prettify
__pdoc__all__ = __all__ = [
"BarZone",
"WType",
"RangeStatus",
"InterpMode",
"RescaleMode",
"ErrorType",
"DistanceMeasure",
"OverlapMode",
"DrawdownStatus",
"range_dt",
"pattern_range_dt",
"drawdown_dt",
"RollSumAIS",
"RollSumAOS",
"RollProdAIS",
"RollProdAOS",
"RollMeanAIS",
"RollMeanAOS",
"RollStdAIS",
"RollStdAOS",
"RollZScoreAIS",
"RollZScoreAOS",
"WMMeanAIS",
"WMMeanAOS",
"EWMMeanAIS",
"EWMMeanAOS",
"EWMStdAIS",
"EWMStdAOS",
"VidyaAIS",
"VidyaAOS",
"RollCovAIS",
"RollCovAOS",
"RollCorrAIS",
"RollCorrAOS",
"RollOLSAIS",
"RollOLSAOS",
]
__pdoc__ = {}
# ############# Enums ############# #
class BarZoneT(tp.NamedTuple):
Open: int = 0
Middle: int = 1
Close: int = 2
BarZone = BarZoneT()
"""_"""
__pdoc__[
"BarZone"
] = f"""Bar zone.
```python
{prettify(BarZone)}
```
"""
class WTypeT(tp.NamedTuple):
Simple: int = 0
Weighted: int = 1
Exp: int = 2
Wilder: int = 3
Vidya: int = 4
WType = WTypeT()
"""_"""
__pdoc__[
"WType"
] = f"""Rolling window type.
```python
{prettify(WType)}
```
"""
class RangeStatusT(tp.NamedTuple):
Open: int = 0
Closed: int = 1
RangeStatus = RangeStatusT()
"""_"""
__pdoc__[
"RangeStatus"
] = f"""Range status.
```python
{prettify(RangeStatus)}
```
"""
class InterpModeT(tp.NamedTuple):
Linear: int = 0
Nearest: int = 1
Discrete: int = 2
Mixed: int = 3
InterpMode = InterpModeT()
"""_"""
__pdoc__[
"InterpMode"
] = f"""Interpolation mode.
```python
{prettify(InterpMode)}
```
Attributes:
Line: Linear interpolation.
For example: `[1.0, 2.0, 3.0]` -> `[1.0, 1.5, 2.0, 2.5, 3.0]`
Nearest: Nearest-neighbor interpolation.
For example: `[1.0, 2.0, 3.0]` -> `[1.0, 1.0, 2.0, 3.0, 3.0]`
Discrete: Discrete interpolation.
For example: `[1.0, 2.0, 3.0]` -> `[1.0, np.nan, 2.0, np.nan, 3.0]`
Mixed: Mixed interpolation.
For example: `[1.0, 2.0, 3.0]` -> `[1.0, 1.5, 2.0, 2.5, 3.0]`
"""
class RescaleModeT(tp.NamedTuple):
MinMax: int = 0
Rebase: int = 1
Disable: int = 2
RescaleMode = RescaleModeT()
"""_"""
__pdoc__[
"RescaleMode"
] = f"""Rescaling mode.
```python
{prettify(RescaleMode)}
```
Attributes:
MinMax: Array is rescaled from its min-max range to the min-max range of another array.
For example: `[3.0, 2.0, 1.0]` to `[10, 11, 12]` -> `[12.0, 11.0, 10.0]`
Use this to search for patterns irrespective of their vertical scale.
Rebase: Array is rebased to the first value in another array.
For example: `[3.0, 2.0, 1.0]` to `[10, 11, 12]` -> `[10.0, 6.6, 3.3]`
Use this to search for percentage changes.
Disable: Disable any rescaling.
For example: `[3.0, 2.0, 1.0]` to `[10, 11, 12]` -> `[3.0, 2.0, 1.0]`
Use this to search for particular numbers.
"""
class ErrorTypeT(tp.NamedTuple):
Absolute: int = 0
Relative: int = 1
ErrorType = ErrorTypeT()
"""_"""
__pdoc__[
"ErrorType"
] = f"""Error type.
```python
{prettify(ErrorType)}
```
Attributes:
Absolute: Absolute error, that is, `x1 - x0`.
Relative: Relative error, that is, `(x1 - x0) / x0`.
"""
class DistanceMeasureT(tp.NamedTuple):
MAE: int = 0
MSE: int = 1
RMSE: int = 2
DistanceMeasure = DistanceMeasureT()
"""_"""
__pdoc__[
"DistanceMeasure"
] = f"""Distance measure.
```python
{prettify(DistanceMeasure)}
```
Attributes:
MAE: Mean absolute error.
MSE: Mean squared error.
RMSE: Root mean squared error.
"""
class OverlapModeT(tp.NamedTuple):
AllowAll: int = -2
Allow: int = -1
Disallow: int = 0
OverlapMode = OverlapModeT()
"""_"""
__pdoc__[
"OverlapMode"
] = f"""Overlapping mode.
```python
{prettify(OverlapMode)}
```
Attributes:
AllowAll: Allow any overlapping ranges, even if they start at the same row.
Allow: Allow overlapping ranges, but only if they do not start at the same row.
Disallow: Disallow any overlapping ranges.
Any other positive number will check whether the intersection of each two consecutive ranges is
bigger than that number of rows, and if so, the range with the highest similarity will be selected.
"""
class DrawdownStatusT(tp.NamedTuple):
Active: int = 0
Recovered: int = 1
DrawdownStatus = DrawdownStatusT()
"""_"""
__pdoc__[
"DrawdownStatus"
] = f"""Drawdown status.
```python
{prettify(DrawdownStatus)}
```
"""
# ############# Records ############# #
range_dt = np.dtype(
[
("id", int_),
("col", int_),
("start_idx", int_),
("end_idx", int_),
("status", int_),
],
align=True,
)
"""_"""
__pdoc__[
"range_dt"
] = f"""`np.dtype` of range records.
```python
{prettify(range_dt)}
```
"""
pattern_range_dt = np.dtype(
[
("id", int_),
("col", int_),
("start_idx", int_),
("end_idx", int_),
("status", int_),
("similarity", float_),
],
align=True,
)
"""_"""
__pdoc__[
"pattern_range_dt"
] = f"""`np.dtype` of pattern range records.
```python
{prettify(pattern_range_dt)}
```
"""
drawdown_dt = np.dtype(
[
("id", int_),
("col", int_),
("start_idx", int_),
("valley_idx", int_),
("end_idx", int_),
("start_val", float_),
("valley_val", float_),
("end_val", float_),
("status", int_),
],
align=True,
)
"""_"""
__pdoc__[
"drawdown_dt"
] = f"""`np.dtype` of drawdown records.
```python
{prettify(drawdown_dt)}
```
"""
# ############# States ############# #
class RollSumAIS(tp.NamedTuple):
i: int
value: float
pre_window_value: float
cumsum: float
nancnt: int
window: int
minp: tp.Optional[int]
__pdoc__[
"RollSumAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.rolling_sum_acc_nb`."""
class RollSumAOS(tp.NamedTuple):
cumsum: float
nancnt: int
window_len: int
value: float
__pdoc__[
"RollSumAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.rolling_sum_acc_nb`."""
class RollProdAIS(tp.NamedTuple):
i: int
value: float
pre_window_value: float
cumprod: float
nancnt: int
window: int
minp: tp.Optional[int]
__pdoc__[
"RollProdAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.rolling_prod_acc_nb`."""
class RollProdAOS(tp.NamedTuple):
cumprod: float
nancnt: int
window_len: int
value: float
__pdoc__[
"RollProdAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.rolling_prod_acc_nb`."""
class RollMeanAIS(tp.NamedTuple):
i: int
value: float
pre_window_value: float
cumsum: float
nancnt: int
window: int
minp: tp.Optional[int]
__pdoc__[
"RollMeanAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.rolling_mean_acc_nb`."""
class RollMeanAOS(tp.NamedTuple):
cumsum: float
nancnt: int
window_len: int
value: float
__pdoc__[
"RollMeanAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.rolling_mean_acc_nb`."""
class RollStdAIS(tp.NamedTuple):
i: int
value: float
pre_window_value: float
cumsum: float
cumsum_sq: float
nancnt: int
window: int
minp: tp.Optional[int]
ddof: int
__pdoc__[
"RollStdAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.rolling_std_acc_nb`."""
class RollStdAOS(tp.NamedTuple):
cumsum: float
cumsum_sq: float
nancnt: int
window_len: int
value: float
__pdoc__[
"RollStdAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.rolling_std_acc_nb`."""
class RollZScoreAIS(tp.NamedTuple):
i: int
value: float
pre_window_value: float
cumsum: float
cumsum_sq: float
nancnt: int
window: int
minp: tp.Optional[int]
ddof: int
__pdoc__[
"RollZScoreAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.rolling_zscore_acc_nb`."""
class RollZScoreAOS(tp.NamedTuple):
cumsum: float
cumsum_sq: float
nancnt: int
window_len: int
value: float
__pdoc__[
"RollZScoreAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.rolling_zscore_acc_nb`."""
class WMMeanAIS(tp.NamedTuple):
i: int
value: float
pre_window_value: float
cumsum: float
wcumsum: float
nancnt: int
window: int
minp: tp.Optional[int]
__pdoc__[
"WMMeanAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.wm_mean_acc_nb`."""
class WMMeanAOS(tp.NamedTuple):
cumsum: float
wcumsum: float
nancnt: int
window_len: int
value: float
__pdoc__[
"WMMeanAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.wm_mean_acc_nb`."""
class EWMMeanAIS(tp.NamedTuple):
i: int
value: float
old_wt: float
weighted_avg: float
nobs: int
alpha: float
minp: tp.Optional[int]
adjust: bool
__pdoc__[
"EWMMeanAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.ewm_mean_acc_nb`.
To get `alpha`, use one of the following:
* `vectorbtpro.generic.nb.rolling.alpha_from_com_nb`
* `vectorbtpro.generic.nb.rolling.alpha_from_span_nb`
* `vectorbtpro.generic.nb.rolling.alpha_from_halflife_nb`
* `vectorbtpro.generic.nb.rolling.alpha_from_wilder_nb`"""
class EWMMeanAOS(tp.NamedTuple):
old_wt: float
weighted_avg: float
nobs: int
value: float
__pdoc__[
"EWMMeanAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.ewm_mean_acc_nb`."""
class EWMStdAIS(tp.NamedTuple):
i: int
value: float
mean_x: float
mean_y: float
cov: float
sum_wt: float
sum_wt2: float
old_wt: float
nobs: int
alpha: float
minp: tp.Optional[int]
adjust: bool
__pdoc__[
"EWMStdAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.ewm_std_acc_nb`.
For tips on `alpha`, see `EWMMeanAIS`."""
class EWMStdAOS(tp.NamedTuple):
mean_x: float
mean_y: float
cov: float
sum_wt: float
sum_wt2: float
old_wt: float
nobs: int
value: float
__pdoc__[
"EWMStdAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.ewm_std_acc_nb`."""
class VidyaAIS(tp.NamedTuple):
i: int
prev_value: float
value: float
pre_window_prev_value: float
pre_window_value: float
pos_cumsum: float
neg_cumsum: float
prev_vidya: float
nancnt: int
window: int
minp: tp.Optional[int]
__pdoc__[
"VidyaAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.vidya_acc_nb`."""
class VidyaAOS(tp.NamedTuple):
pos_cumsum: float
neg_cumsum: float
nancnt: int
window_len: int
cmo: float
vidya: float
__pdoc__[
"VidyaAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.vidya_acc_nb`."""
class RollCovAIS(tp.NamedTuple):
i: int
value1: float
value2: float
pre_window_value1: float
pre_window_value2: float
cumsum1: float
cumsum2: float
cumsum_prod: float
nancnt: int
window: int
minp: tp.Optional[int]
ddof: int
__pdoc__[
"RollCovAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.rolling_cov_acc_nb`."""
class RollCovAOS(tp.NamedTuple):
cumsum1: float
cumsum2: float
cumsum_prod: float
nancnt: int
window_len: int
value: float
__pdoc__[
"RollCovAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.rolling_cov_acc_nb`."""
class RollCorrAIS(tp.NamedTuple):
i: int
value1: float
value2: float
pre_window_value1: float
pre_window_value2: float
cumsum1: float
cumsum2: float
cumsum_sq1: float
cumsum_sq2: float
cumsum_prod: float
nancnt: int
window: int
minp: tp.Optional[int]
__pdoc__[
"RollCorrAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.rolling_corr_acc_nb`."""
class RollCorrAOS(tp.NamedTuple):
cumsum1: float
cumsum2: float
cumsum_sq1: float
cumsum_sq2: float
cumsum_prod: float
nancnt: int
window_len: int
value: float
__pdoc__[
"RollCorrAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.rolling_corr_acc_nb`."""
class RollOLSAIS(tp.NamedTuple):
i: int
value1: float
value2: float
pre_window_value1: float
pre_window_value2: float
validcnt: int
cumsum1: float
cumsum2: float
cumsum_sq1: float
cumsum_prod: float
nancnt: int
window: int
minp: tp.Optional[int]
__pdoc__[
"RollOLSAIS"
] = """A named tuple representing the input state of
`vectorbtpro.generic.nb.rolling.rolling_ols_acc_nb`."""
class RollOLSAOS(tp.NamedTuple):
validcnt: int
cumsum1: float
cumsum2: float
cumsum_sq1: float
cumsum_prod: float
nancnt: int
window_len: int
slope_value: float
intercept_value: float
__pdoc__[
"RollOLSAOS"
] = """A named tuple representing the output state of
`vectorbtpro.generic.nb.rolling.rolling_ols_acc_nb`."""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Mixin for building plots out of subplots."""
import inspect
import string
from collections import Counter
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexing import ParamLoc
from vectorbtpro.base.wrapping import Wrapping
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import get_dict_attr, AttrResolverMixin
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import Config, HybridConfig, merge_dicts
from vectorbtpro.utils.parsing import get_func_arg_names, get_forward_args
from vectorbtpro.utils.tagging import match_tags
from vectorbtpro.utils.template import substitute_templates, CustomTemplate
from vectorbtpro.utils.warnings_ import warn
__all__ = []
class MetaPlotsBuilderMixin(type):
"""Metaclass for `PlotsBuilderMixin`."""
@property
def subplots(cls) -> Config:
"""Subplots supported by `PlotsBuilderMixin.plots`."""
return cls._subplots
class PlotsBuilderMixin(Base, metaclass=MetaPlotsBuilderMixin):
"""Mixin that implements `PlotsBuilderMixin.plots`.
Required to be a subclass of `vectorbtpro.base.wrapping.Wrapping`."""
_writeable_attrs: tp.WriteableAttrs = {"_subplots"}
def __init__(self) -> None:
checks.assert_instance_of(self, Wrapping)
# Copy writeable attrs
self._subplots = type(self)._subplots.copy()
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `PlotsBuilderMixin.plots`."""
return dict(settings=dict(freq=self.wrapper.freq))
def resolve_plots_setting(
self,
value: tp.Optional[tp.Any],
key: str,
merge: bool = False,
) -> tp.Any:
"""Resolve a setting for `PlotsBuilderMixin.plots`."""
from vectorbtpro._settings import settings as _settings
plots_builder_cfg = _settings["plots_builder"]
if merge:
return merge_dicts(
plots_builder_cfg[key],
self.plots_defaults.get(key, {}),
value,
)
if value is not None:
return value
return self.plots_defaults.get(key, plots_builder_cfg[key])
_subplots: tp.ClassVar[Config] = HybridConfig(dict())
@property
def subplots(self) -> Config:
"""Subplots supported by `${cls_name}`.
```python
${subplots}
```
Returns `${cls_name}._subplots`, which gets (hybrid-) copied upon creation of each instance.
Thus, changing this config won't affect the class.
To change subplots, you can either change the config in-place, override this property,
or overwrite the instance variable `${cls_name}._subplots`."""
return self._subplots
def plots(
self,
subplots: tp.Optional[tp.MaybeIterable[tp.Union[str, tp.Tuple[str, tp.Kwargs]]]] = None,
tags: tp.Optional[tp.MaybeIterable[str]] = None,
column: tp.Optional[tp.Label] = None,
group_by: tp.GroupByLike = None,
per_column: tp.Optional[bool] = None,
split_columns: tp.Optional[bool] = None,
silence_warnings: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
settings: tp.KwargsLike = None,
filters: tp.KwargsLike = None,
subplot_settings: tp.KwargsLike = None,
show_titles: bool = None,
show_legend: tp.Optional[bool] = None,
show_column_label: tp.Optional[bool] = None,
hide_id_labels: bool = None,
group_id_labels: bool = None,
make_subplots_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.Optional[tp.BaseFigure]:
"""Plot various parts of this object.
Args:
subplots (str, tuple, iterable, or dict): Subplots to plot.
Each element can be either:
* Subplot name (see keys in `PlotsBuilderMixin.subplots`)
* Tuple of a subplot name and a settings dict as in `PlotsBuilderMixin.subplots`
* Tuple of a subplot name and a template of instance `vectorbtpro.utils.template.CustomTemplate`
* Tuple of a subplot name and a list of settings dicts to be expanded into multiple subplots
The settings dict can contain the following keys:
* `title`: Title of the subplot. Defaults to the name.
* `plot_func` (required): Plotting function for custom subplots.
Must write the supplied figure `fig` in-place and can return anything (it won't be used).
* `xaxis_kwargs`: Layout keyword arguments for the x-axis. Defaults to `dict(title='Index')`.
* `yaxis_kwargs`: Layout keyword arguments for the y-axis. Defaults to empty dict.
* `tags`, `check_{filter}`, `inv_check_{filter}`, `resolve_plot_func`, `pass_{arg}`,
`resolve_path_{arg}`, `resolve_{arg}` and `template_context`:
The same as in `vectorbtpro.generic.stats_builder.StatsBuilderMixin` for `calc_func`.
* Any other keyword argument that overrides the settings or is passed directly to `plot_func`.
If `resolve_plot_func` is True, the plotting function may "request" any of the
following arguments by accepting them or if `pass_{arg}` was found in the settings dict:
* Each of `vectorbtpro.utils.attr_.AttrResolverMixin.self_aliases`: original object
(ungrouped, with no column selected)
* `group_by`: won't be passed if it was used in resolving the first attribute of `plot_func`
specified as a path, use `pass_group_by=True` to pass anyway
* `column`
* `subplot_name`
* `trace_names`: list with the subplot name, can't be used in templates
* `add_trace_kwargs`: dict with subplot row and column index
* `xref`
* `yref`
* `xaxis`
* `yaxis`
* `x_domain`
* `y_domain`
* `fig`
* `silence_warnings`
* Any argument from `settings`
* Any attribute of this object if it meant to be resolved
(see `vectorbtpro.utils.attr_.AttrResolverMixin.resolve_attr`)
!!! note
Layout-related resolution arguments such as `add_trace_kwargs` are unavailable
before filtering and thus cannot be used in any templates but can still be overridden.
Pass `subplots='all'` to plot all supported subplots.
tags (str or iterable): See `tags` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
column (str): See `column` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
group_by (any): See `group_by` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
per_column (bool): See `per_column` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
split_columns (bool): See `split_columns` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
silence_warnings (bool): See `silence_warnings` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
template_context (mapping): See `template_context` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
Applied on `settings`, `make_subplots_kwargs`, and `layout_kwargs`, and then on each subplot settings.
filters (dict): See `filters` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
settings (dict): See `settings` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
subplot_settings (dict): See `metric_settings` in `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
show_titles (bool): Whether to show the title of each subplot.
show_legend (bool): Whether to show legend.
If None and plotting per column, becomes False, otherwise True.
show_column_label (bool): Whether to show the column label next to each legend label.
If None and plotting per column, becomes True, otherwise False.
hide_id_labels (bool): Whether to hide identical legend labels.
Two labels are identical if their name, marker style and line style match.
group_id_labels (bool): Whether to group identical legend labels.
make_subplots_kwargs (dict): Keyword arguments passed to `plotly.subplots.make_subplots`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments used to update the layout of the figure.
!!! note
`PlotsBuilderMixin` and `vectorbtpro.generic.stats_builder.StatsBuilderMixin` are very similar.
Some artifacts follow the same concept, just named differently:
* `plots_defaults` vs `stats_defaults`
* `subplots` vs `metrics`
* `subplot_settings` vs `metric_settings`
See further notes under `vectorbtpro.generic.stats_builder.StatsBuilderMixin`.
"""
# Plot per column
if column is None:
if per_column is None:
per_column = self.resolve_plots_setting(per_column, "per_column")
if per_column:
columns = self.get_item_keys(group_by=group_by)
if len(columns) > 1:
if split_columns is None:
split_columns = self.resolve_plots_setting(split_columns, "split_columns")
if show_legend is None:
show_legend = self.resolve_plots_setting(show_legend, "show_legend")
if show_legend is None:
show_legend = False
if show_column_label is None:
show_column_label = self.resolve_plots_setting(show_column_label, "show_column_label")
if show_column_label is None:
show_column_label = True
fig = None
if split_columns:
for _, column_self in self.items(group_by=group_by, wrap=True):
_args, _kwargs = get_forward_args(column_self.plots, locals())
fig = column_self.plots(*_args, **_kwargs)
else:
for column in columns:
_args, _kwargs = get_forward_args(self.plots, locals())
fig = self.plots(*_args, **_kwargs)
return fig
from vectorbtpro.utils.figure import make_subplots, get_domain
from vectorbtpro._settings import settings as _settings
plotting_cfg = _settings["plotting"]
# Resolve defaults
silence_warnings = self.resolve_plots_setting(silence_warnings, "silence_warnings")
show_titles = self.resolve_plots_setting(show_titles, "show_titles")
show_legend = self.resolve_plots_setting(show_legend, "show_legend")
if show_legend is None:
show_legend = True
show_column_label = self.resolve_plots_setting(show_column_label, "show_column_label")
if show_column_label is None:
show_column_label = False
hide_id_labels = self.resolve_plots_setting(hide_id_labels, "hide_id_labels")
group_id_labels = self.resolve_plots_setting(group_id_labels, "group_id_labels")
template_context = self.resolve_plots_setting(template_context, "template_context", merge=True)
filters = self.resolve_plots_setting(filters, "filters", merge=True)
settings = self.resolve_plots_setting(settings, "settings", merge=True)
subplot_settings = self.resolve_plots_setting(subplot_settings, "subplot_settings", merge=True)
make_subplots_kwargs = self.resolve_plots_setting(make_subplots_kwargs, "make_subplots_kwargs", merge=True)
layout_kwargs = self.resolve_plots_setting(layout_kwargs, "layout_kwargs", merge=True)
# Replace templates globally (not used at subplot level)
if len(template_context) > 0:
sub_settings = substitute_templates(
settings,
context=template_context,
eval_id="sub_settings",
strict=False,
)
sub_make_subplots_kwargs = substitute_templates(
make_subplots_kwargs,
context=template_context,
eval_id="sub_make_subplots_kwargs",
)
sub_layout_kwargs = substitute_templates(
layout_kwargs,
context=template_context,
eval_id="sub_layout_kwargs",
)
else:
sub_settings = settings
sub_make_subplots_kwargs = make_subplots_kwargs
sub_layout_kwargs = layout_kwargs
# Resolve self
reself = self.resolve_self(
cond_kwargs=sub_settings,
impacts_caching=False,
silence_warnings=silence_warnings,
)
# Prepare subplots
if subplots is None:
subplots = reself.resolve_plots_setting(subplots, "subplots")
if subplots == "all":
subplots = reself.subplots
if isinstance(subplots, dict):
subplots = list(subplots.items())
if isinstance(subplots, (str, tuple)):
subplots = [subplots]
# Prepare tags
if tags is None:
tags = reself.resolve_plots_setting(tags, "tags")
if isinstance(tags, str) and tags == "all":
tags = None
if isinstance(tags, (str, tuple)):
tags = [tags]
# Bring to the same shape
new_subplots = []
for i, subplot in enumerate(subplots):
if isinstance(subplot, str):
subplot = (subplot, reself.subplots[subplot])
if not isinstance(subplot, tuple):
raise TypeError(f"Subplot at index {i} must be either a string or a tuple")
new_subplots.append(subplot)
subplots = new_subplots
# Expand subplots
new_subplots = []
for i, (subplot_name, _subplot_settings) in enumerate(subplots):
if isinstance(_subplot_settings, CustomTemplate):
subplot_context = merge_dicts(
template_context,
{name: reself for name in reself.self_aliases},
dict(
column=column,
group_by=group_by,
subplot_name=subplot_name,
silence_warnings=silence_warnings,
),
settings,
)
subplot_context = substitute_templates(
subplot_context,
context=subplot_context,
eval_id="subplot_context",
)
_subplot_settings = _subplot_settings.substitute(
context=subplot_context,
strict=True,
eval_id="subplot",
)
if isinstance(_subplot_settings, list):
for __subplot_settings in _subplot_settings:
new_subplots.append((subplot_name, __subplot_settings))
else:
new_subplots.append((subplot_name, _subplot_settings))
subplots = new_subplots
# Handle duplicate names
subplot_counts = Counter(list(map(lambda x: x[0], subplots)))
subplot_i = {k: -1 for k in subplot_counts.keys()}
subplots_dct = {}
for i, (subplot_name, _subplot_settings) in enumerate(subplots):
if subplot_counts[subplot_name] > 1:
subplot_i[subplot_name] += 1
subplot_name = subplot_name + "_" + str(subplot_i[subplot_name])
subplots_dct[subplot_name] = _subplot_settings
# Check subplot_settings
missed_keys = set(subplot_settings.keys()).difference(set(subplots_dct.keys()))
if len(missed_keys) > 0:
raise ValueError(f"Keys {missed_keys} in subplot_settings could not be matched with any subplot")
# Merge settings
opt_arg_names_dct = {}
custom_arg_names_dct = {}
resolved_self_dct = {}
context_dct = {}
for subplot_name, _subplot_settings in list(subplots_dct.items()):
opt_settings = merge_dicts(
{name: reself for name in reself.self_aliases},
dict(
column=column,
group_by=group_by,
subplot_name=subplot_name,
trace_names=[subplot_name],
silence_warnings=silence_warnings,
),
settings,
)
_subplot_settings = _subplot_settings.copy()
passed_subplot_settings = subplot_settings.get(subplot_name, {})
merged_settings = merge_dicts(opt_settings, _subplot_settings, passed_subplot_settings)
subplot_template_context = merged_settings.pop("template_context", {})
template_context_merged = merge_dicts(template_context, subplot_template_context)
template_context_merged = substitute_templates(
template_context_merged,
context=merged_settings,
eval_id="template_context_merged",
)
context = merge_dicts(template_context_merged, merged_settings)
# safe because we will use substitute_templates again once layout params are known
merged_settings = substitute_templates(
merged_settings,
context=context,
eval_id="merged_settings",
)
# Filter by tag
if tags is not None:
in_tags = merged_settings.get("tags", None)
if in_tags is None or not match_tags(tags, in_tags):
subplots_dct.pop(subplot_name, None)
continue
custom_arg_names = set(_subplot_settings.keys()).union(set(passed_subplot_settings.keys()))
opt_arg_names = set(opt_settings.keys())
custom_reself = reself.resolve_self(
cond_kwargs=merged_settings,
custom_arg_names=custom_arg_names,
impacts_caching=True,
silence_warnings=merged_settings["silence_warnings"],
)
subplots_dct[subplot_name] = merged_settings
custom_arg_names_dct[subplot_name] = custom_arg_names
opt_arg_names_dct[subplot_name] = opt_arg_names
resolved_self_dct[subplot_name] = custom_reself
context_dct[subplot_name] = context
# Filter subplots
for subplot_name, _subplot_settings in list(subplots_dct.items()):
custom_reself = resolved_self_dct[subplot_name]
context = context_dct[subplot_name]
_silence_warnings = _subplot_settings.get("silence_warnings")
subplot_filters = set()
for k in _subplot_settings.keys():
filter_name = None
if k.startswith("check_"):
filter_name = k[len("check_") :]
elif k.startswith("inv_check_"):
filter_name = k[len("inv_check_") :]
if filter_name is not None:
if filter_name not in filters:
raise ValueError(f"Metric '{subplot_name}' requires filter '{filter_name}'")
subplot_filters.add(filter_name)
for filter_name in subplot_filters:
filter_settings = filters[filter_name]
_filter_settings = substitute_templates(
filter_settings,
context=context,
eval_id="filter_settings",
)
filter_func = _filter_settings["filter_func"]
warning_message = _filter_settings.get("warning_message", None)
inv_warning_message = _filter_settings.get("inv_warning_message", None)
to_check = _subplot_settings.get("check_" + filter_name, False)
inv_to_check = _subplot_settings.get("inv_check_" + filter_name, False)
if to_check or inv_to_check:
whether_true = filter_func(custom_reself, _subplot_settings)
to_remove = (to_check and not whether_true) or (inv_to_check and whether_true)
if to_remove:
if to_check and warning_message is not None and not _silence_warnings:
warn(warning_message)
if inv_to_check and inv_warning_message is not None and not _silence_warnings:
warn(inv_warning_message)
subplots_dct.pop(subplot_name, None)
custom_arg_names_dct.pop(subplot_name, None)
opt_arg_names_dct.pop(subplot_name, None)
resolved_self_dct.pop(subplot_name, None)
context_dct.pop(subplot_name, None)
break
# Any subplots left?
if len(subplots_dct) == 0:
if not silence_warnings:
warn("No subplots to plot")
return None
# Set up figure
rows = sub_make_subplots_kwargs.pop("rows", len(subplots_dct))
cols = sub_make_subplots_kwargs.pop("cols", 1)
specs = sub_make_subplots_kwargs.pop(
"specs",
[[{} for _ in range(cols)] for _ in range(rows)],
)
row_col_tuples = []
for row, row_spec in enumerate(specs):
for col, col_spec in enumerate(row_spec):
if col_spec is not None:
row_col_tuples.append((row + 1, col + 1))
shared_xaxes = sub_make_subplots_kwargs.pop("shared_xaxes", True)
shared_yaxes = sub_make_subplots_kwargs.pop("shared_yaxes", False)
default_height = plotting_cfg["layout"]["height"]
default_width = plotting_cfg["layout"]["width"] + 50
min_space = 10 # space between subplots with no axis sharing
max_title_spacing = 30
max_xaxis_spacing = 50
max_yaxis_spacing = 100
legend_height = 50
if show_titles:
title_spacing = max_title_spacing
else:
title_spacing = 0
if not shared_xaxes and rows > 1:
xaxis_spacing = max_xaxis_spacing
else:
xaxis_spacing = 0
if not shared_yaxes and cols > 1:
yaxis_spacing = max_yaxis_spacing
else:
yaxis_spacing = 0
if "height" in sub_layout_kwargs:
height = sub_layout_kwargs.pop("height")
else:
height = default_height + title_spacing
if rows > 1:
height *= rows
height += min_space * rows - min_space
height += legend_height - legend_height * rows
if shared_xaxes:
height += max_xaxis_spacing - max_xaxis_spacing * rows
if "width" in sub_layout_kwargs:
width = sub_layout_kwargs.pop("width")
else:
width = default_width
if cols > 1:
width *= cols
width += min_space * cols - min_space
if shared_yaxes:
width += max_yaxis_spacing - max_yaxis_spacing * cols
if height is not None:
if "vertical_spacing" in sub_make_subplots_kwargs:
vertical_spacing = sub_make_subplots_kwargs.pop("vertical_spacing")
else:
vertical_spacing = min_space + title_spacing + xaxis_spacing
if vertical_spacing is not None and vertical_spacing > 1:
vertical_spacing /= height
legend_y = 1 + (min_space + title_spacing) / height
else:
vertical_spacing = sub_make_subplots_kwargs.pop("vertical_spacing", None)
legend_y = 1.02
if width is not None:
if "horizontal_spacing" in sub_make_subplots_kwargs:
horizontal_spacing = sub_make_subplots_kwargs.pop("horizontal_spacing")
else:
horizontal_spacing = min_space + yaxis_spacing
if horizontal_spacing is not None and horizontal_spacing > 1:
horizontal_spacing /= width
else:
horizontal_spacing = sub_make_subplots_kwargs.pop("horizontal_spacing", None)
if show_titles:
_subplot_titles = []
for i in range(len(subplots_dct)):
_subplot_titles.append("$title_" + str(i))
else:
_subplot_titles = None
if fig is None:
fig = make_subplots(
rows=rows,
cols=cols,
specs=specs,
shared_xaxes=shared_xaxes,
shared_yaxes=shared_yaxes,
subplot_titles=_subplot_titles,
vertical_spacing=vertical_spacing,
horizontal_spacing=horizontal_spacing,
**sub_make_subplots_kwargs,
)
sub_layout_kwargs = merge_dicts(
dict(
showlegend=True,
width=width,
height=height,
legend=dict(
orientation="h",
yanchor="bottom",
y=legend_y,
xanchor="right",
x=1,
traceorder="normal",
),
),
sub_layout_kwargs,
)
trace_start_idx = 0
else:
trace_start_idx = len(fig.data)
fig.update_layout(**sub_layout_kwargs)
# Plot subplots
arg_cache_dct = {}
for i, (subplot_name, _subplot_settings) in enumerate(subplots_dct.items()):
try:
final_kwargs = _subplot_settings.copy()
opt_arg_names = opt_arg_names_dct[subplot_name]
custom_arg_names = custom_arg_names_dct[subplot_name]
custom_reself = resolved_self_dct[subplot_name]
context = context_dct[subplot_name]
# Compute figure artifacts
row, col = row_col_tuples[i]
xref = "x" if i == 0 else "x" + str(i + 1)
yref = "y" if i == 0 else "y" + str(i + 1)
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
x_domain = get_domain(xref, fig)
y_domain = get_domain(yref, fig)
subplot_layout_kwargs = dict(
add_trace_kwargs=dict(row=row, col=col),
xref=xref,
yref=yref,
xaxis=xaxis,
yaxis=yaxis,
x_domain=x_domain,
y_domain=y_domain,
fig=fig,
pass_fig=True, # force passing fig
)
for k in subplot_layout_kwargs:
opt_arg_names.add(k)
if k in final_kwargs:
custom_arg_names.add(k)
final_kwargs = merge_dicts(subplot_layout_kwargs, final_kwargs)
context = merge_dicts(subplot_layout_kwargs, context)
final_kwargs = substitute_templates(final_kwargs, context=context, eval_id="final_kwargs")
# Clean up keys
for k, v in list(final_kwargs.items()):
if k.startswith("check_") or k.startswith("inv_check_") or k in ("tags",):
final_kwargs.pop(k, None)
# Get subplot-specific values
_column = final_kwargs.get("column")
_group_by = final_kwargs.get("group_by")
_silence_warnings = final_kwargs.get("silence_warnings")
title = final_kwargs.pop("title", subplot_name)
plot_func = final_kwargs.pop("plot_func", None)
xaxis_kwargs = final_kwargs.pop("xaxis_kwargs", None)
yaxis_kwargs = final_kwargs.pop("yaxis_kwargs", None)
resolve_plot_func = final_kwargs.pop("resolve_plot_func", True)
use_shortcuts = final_kwargs.pop("use_shortcuts", True)
use_caching = final_kwargs.pop("use_caching", True)
if plot_func is not None:
# Resolve plot_func
if resolve_plot_func:
if not callable(plot_func):
passed_kwargs_out = {}
def _getattr_func(
obj: tp.Any,
attr: str,
args: tp.ArgsLike = None,
kwargs: tp.KwargsLike = None,
call_attr: bool = True,
_final_kwargs: tp.Kwargs = final_kwargs,
_opt_arg_names: tp.Set[str] = opt_arg_names,
_custom_arg_names: tp.Set[str] = custom_arg_names,
_arg_cache_dct: tp.Kwargs = arg_cache_dct,
_use_shortcuts: bool = use_shortcuts,
_use_caching: bool = use_caching,
) -> tp.Any:
if attr in _final_kwargs:
return _final_kwargs[attr]
if args is None:
args = ()
if kwargs is None:
kwargs = {}
if obj is custom_reself:
resolve_path_arg = _final_kwargs.pop(
"resolve_path_" + attr,
True,
)
if resolve_path_arg:
if call_attr:
cond_kwargs = {
k: v for k, v in _final_kwargs.items() if k in _opt_arg_names
}
out = custom_reself.resolve_attr(
attr, # do not pass _attr, important for caching
args=args,
cond_kwargs=cond_kwargs,
kwargs=kwargs,
custom_arg_names=_custom_arg_names,
cache_dct=_arg_cache_dct,
use_caching=_use_caching,
passed_kwargs_out=passed_kwargs_out,
use_shortcuts=_use_shortcuts,
)
else:
if isinstance(obj, AttrResolverMixin):
cls_dir = obj.cls_dir
else:
cls_dir = dir(type(obj))
if "get_" + attr in cls_dir:
_attr = "get_" + attr
else:
_attr = attr
out = getattr(obj, _attr)
_select_col_arg = _final_kwargs.pop(
"select_col_" + attr,
False,
)
if _select_col_arg and _column is not None:
out = custom_reself.select_col_from_obj(
out,
_column,
wrapper=custom_reself.wrapper.regroup(_group_by),
)
passed_kwargs_out["group_by"] = _group_by
passed_kwargs_out["column"] = _column
return out
out = getattr(obj, attr)
if callable(out) and call_attr:
return out(*args, **kwargs)
return out
plot_func = custom_reself.deep_getattr(
plot_func,
getattr_func=_getattr_func,
call_last_attr=False,
)
if "group_by" in passed_kwargs_out:
if "pass_group_by" not in final_kwargs:
final_kwargs.pop("group_by", None)
if "column" in passed_kwargs_out:
if "pass_column" not in final_kwargs:
final_kwargs.pop("column", None)
if not callable(plot_func):
raise TypeError("plot_func must be callable")
# Resolve arguments
func_arg_names = get_func_arg_names(plot_func)
for k in func_arg_names:
if k not in final_kwargs:
resolve_arg = final_kwargs.pop("resolve_" + k, False)
use_shortcuts_arg = final_kwargs.pop("use_shortcuts_" + k, True)
select_col_arg = final_kwargs.pop("select_col_" + k, False)
if resolve_arg:
try:
arg_out = custom_reself.resolve_attr(
k,
cond_kwargs=final_kwargs,
custom_arg_names=custom_arg_names,
cache_dct=arg_cache_dct,
use_caching=use_caching,
use_shortcuts=use_shortcuts_arg,
)
except AttributeError:
continue
if select_col_arg and _column is not None:
arg_out = custom_reself.select_col_from_obj(
arg_out,
_column,
wrapper=custom_reself.wrapper.regroup(_group_by),
)
final_kwargs[k] = arg_out
for k in list(final_kwargs.keys()):
if k in opt_arg_names:
if "pass_" + k in final_kwargs:
if not final_kwargs.get("pass_" + k): # first priority
final_kwargs.pop(k, None)
elif k not in func_arg_names: # second priority
final_kwargs.pop(k, None)
for k in list(final_kwargs.keys()):
if k.startswith("pass_") or k.startswith("resolve_"):
final_kwargs.pop(k, None) # cleanup
# Call plot_func
plot_func(**final_kwargs)
else:
# Do not resolve plot_func
plot_func(custom_reself, _subplot_settings)
# Update global layout
for annotation in fig.layout.annotations:
if "text" in annotation and annotation["text"] == "$title_" + str(i):
annotation.update(text=title)
subplot_layout = dict()
subplot_layout[xaxis] = merge_dicts(dict(title="Index"), xaxis_kwargs)
subplot_layout[yaxis] = merge_dicts(dict(), yaxis_kwargs)
fig.update_layout(**subplot_layout)
except Exception as e:
warn(f"Subplot '{subplot_name}' raised an exception")
raise e
# Hide legend labels
if not show_legend:
for i in range(trace_start_idx, len(fig.data)):
fig.data[i].update(showlegend=False)
# Show column label
if show_column_label:
if column is not None:
_column = column
else:
_column = reself.wrapper.get_columns(group_by=group_by)[0]
for i in range(trace_start_idx, len(fig.data)):
trace = fig.data[i]
if trace["name"] is not None:
trace.update(name=trace["name"] + f" [{ParamLoc.encode_key(_column)}]")
# Remove duplicate legend labels
found_ids = dict()
unique_idx = trace_start_idx
for i in range(trace_start_idx, len(fig.data)):
trace = fig.data[i]
if trace["showlegend"] is not False and trace["legendgroup"] is None:
if "name" in trace:
name = trace["name"]
else:
name = None
if "marker" in trace:
marker = trace["marker"]
else:
marker = {}
if "symbol" in marker:
marker_symbol = marker["symbol"]
else:
marker_symbol = None
if "color" in marker:
marker_color = marker["color"]
else:
marker_color = None
if "line" in trace:
line = trace["line"]
else:
line = {}
if "dash" in line:
line_dash = line["dash"]
else:
line_dash = None
if "color" in line:
line_color = line["color"]
else:
line_color = None
id = (name, marker_symbol, marker_color, line_dash, line_color)
if id in found_ids:
if hide_id_labels:
trace.update(showlegend=False)
if group_id_labels:
trace.update(legendgroup=found_ids[id])
else:
if group_id_labels:
trace.update(legendgroup=unique_idx)
found_ids[id] = unique_idx
unique_idx += 1
# Hide identical legend labels
if hide_id_labels:
legendgroups = set()
for i in range(trace_start_idx, len(fig.data)):
trace = fig.data[i]
if trace["legendgroup"] is not None:
if trace["showlegend"]:
if trace["legendgroup"] in legendgroups:
trace.update(showlegend=False)
else:
legendgroups.add(trace["legendgroup"])
# Remove all except the last title if sharing the same axis
if shared_xaxes:
i = 0
for row in range(rows):
for col in range(cols):
if specs[row][col] is not None:
xaxis = "xaxis" if i == 0 else "xaxis" + str(i + 1)
if row < rows - 1:
fig.layout[xaxis].update(title=None)
i += 1
if shared_yaxes:
i = 0
for row in range(rows):
for col in range(cols):
if specs[row][col] is not None:
yaxis = "yaxis" if i == 0 else "yaxis" + str(i + 1)
if col > 0:
fig.layout[yaxis].update(title=None)
i += 1
# Return the figure
return fig
# ############# Docs ############# #
@classmethod
def build_subplots_doc(cls, source_cls: tp.Optional[type] = None) -> str:
"""Build subplots documentation."""
if source_cls is None:
source_cls = PlotsBuilderMixin
return string.Template(
inspect.cleandoc(get_dict_attr(source_cls, "subplots").__doc__),
).substitute(
{"subplots": cls.subplots.prettify(), "cls_name": cls.__name__},
)
@classmethod
def override_subplots_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None:
"""Call this method on each subclass that overrides `PlotsBuilderMixin.subplots`."""
__pdoc__[cls.__name__ + ".subplots"] = cls.build_subplots_doc(source_cls=source_cls)
__pdoc__ = dict()
PlotsBuilderMixin.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base plotting functions.
Provides functions for visualizing data in an efficient and convenient way.
Each creates a figure widget that is compatible with ipywidgets and enables interactive
data visualization in Jupyter Notebook and JupyterLab environments. For more details
on using Plotly, see [Getting Started with Plotly in Python](https://plotly.com/python/getting-started/).
!!! warning
Errors related to plotting in Jupyter environment usually appear in the logs, not under the cell."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import math
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.basedatatypes import BaseTraceType
from vectorbtpro import _typing as tp
from vectorbtpro.base import reshaping
from vectorbtpro.utils import checks
from vectorbtpro.utils.array_ import rescale
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.colors import map_value_to_cmap
from vectorbtpro.utils.config import Configured, resolve_dict, merge_dicts
from vectorbtpro.utils.figure import make_figure
__all__ = [
"TraceUpdater",
"Gauge",
"Bar",
"Scatter",
"Histogram",
"Box",
"Heatmap",
"Volume",
]
def clean_labels(labels: tp.Labels) -> tp.Labels:
"""Clean labels.
Plotly doesn't support multi-indexes."""
if isinstance(labels, pd.MultiIndex):
labels = labels.to_flat_index()
if isinstance(labels, pd.PeriodIndex):
labels = labels.map(str)
if len(labels) > 0 and isinstance(labels[0], tuple):
labels = list(map(str, labels))
return labels
class TraceType(Configured):
"""Class representing a trace type."""
_expected_keys_mode: tp.ExpectedKeysMode = "disable"
class TraceUpdater(Base):
"""Class for updating traces."""
def __init__(self, fig: tp.BaseFigure, traces: tp.Tuple[BaseTraceType, ...]) -> None:
self._fig = fig
self._traces = traces
@property
def fig(self) -> tp.BaseFigure:
"""Figure."""
return self._fig
@property
def traces(self) -> tp.Tuple[BaseTraceType, ...]:
"""Traces to update."""
return self._traces
@classmethod
def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, *args, **kwargs) -> None:
"""Update one trace."""
raise NotImplementedError
def update(self, *args, **kwargs) -> None:
"""Update all traces using new data."""
raise NotImplementedError
class Gauge(TraceType, TraceUpdater):
"""Gauge plot.
Args:
value (float): The value to be displayed.
label (str): The label to be displayed.
value_range (tuple of float): The value range of the gauge.
cmap_name (str): A matplotlib-compatible colormap name.
See the [list of available colormaps](https://matplotlib.org/tutorials/colors/colormaps.html).
trace_kwargs (dict): Keyword arguments passed to the `plotly.graph_objects.Indicator`.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> from vectorbtpro import *
>>> gauge = vbt.Gauge(
... value=2,
... value_range=(1, 3),
... label='My Gauge'
... )
>>> gauge.fig.show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
def __init__(
self,
value: tp.Optional[float] = None,
label: tp.Optional[str] = None,
value_range: tp.Optional[tp.Tuple[float, float]] = None,
cmap_name: str = "Spectral",
trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
make_figure_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> None:
TraceType.__init__(
self,
value=value,
label=label,
value_range=value_range,
cmap_name=cmap_name,
trace_kwargs=trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
make_figure_kwargs=make_figure_kwargs,
fig=fig,
**layout_kwargs,
)
from vectorbtpro._settings import settings
layout_cfg = settings["plotting"]["layout"]
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if fig is None:
fig = make_figure(**resolve_dict(make_figure_kwargs))
if "width" in layout_cfg:
# Calculate nice width and height
fig.update_layout(width=layout_cfg["width"] * 0.7, height=layout_cfg["width"] * 0.5, margin=dict(t=80))
fig.update_layout(**layout_kwargs)
_trace_kwargs = merge_dicts(
dict(
domain=dict(x=[0, 1], y=[0, 1]),
mode="gauge+number+delta",
title=dict(text=label),
),
trace_kwargs,
)
trace = go.Indicator(**_trace_kwargs)
if value is not None:
self.update_trace(trace, value, value_range=value_range, cmap_name=cmap_name)
fig.add_trace(trace, **add_trace_kwargs)
TraceUpdater.__init__(self, fig, (fig.data[-1],))
self._value_range = value_range
self._cmap_name = cmap_name
@property
def value_range(self) -> tp.Tuple[float, float]:
"""The value range of the gauge."""
return self._value_range
@property
def cmap_name(self) -> str:
"""A matplotlib-compatible colormap name."""
return self._cmap_name
@classmethod
def update_trace(
cls,
trace: BaseTraceType,
value: float,
value_range: tp.Optional[tp.Tuple[float, float]] = None,
cmap_name: str = "Spectral",
) -> None:
if value_range is not None:
trace.gauge.axis.range = value_range
if cmap_name is not None:
trace.gauge.bar.color = map_value_to_cmap(value, cmap_name, vmin=value_range[0], vmax=value_range[1])
trace.delta.reference = trace.value
trace.value = value
def update(self, value: float) -> None:
if self.value_range is None:
self._value_range = value, value
else:
self._value_range = min(self.value_range[0], value), max(self.value_range[1], value)
with self.fig.batch_update():
self.update_trace(
self.traces[0],
value=value,
value_range=self.value_range,
cmap_name=self.cmap_name,
)
class Bar(TraceType, TraceUpdater):
"""Bar plot.
Args:
data (array_like): Data in any format that can be converted to NumPy.
Must be of shape (`x_labels`, `trace_names`).
trace_names (str or list of str): Trace names, corresponding to columns in pandas.
x_labels (array_like): X-axis labels, corresponding to index in pandas.
trace_kwargs (dict or list of dict): Keyword arguments passed to `plotly.graph_objects.Bar`.
Can be specified per trace as a sequence of dicts.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> from vectorbtpro import *
>>> bar = vbt.Bar(
... data=[[1, 2], [3, 4]],
... trace_names=['a', 'b'],
... x_labels=['x', 'y']
... )
>>> bar.fig.show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
def __init__(
self,
data: tp.Optional[tp.ArrayLike] = None,
trace_names: tp.TraceNames = None,
x_labels: tp.Optional[tp.Labels] = None,
trace_kwargs: tp.KwargsLikeSequence = None,
add_trace_kwargs: tp.KwargsLike = None,
make_figure_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> None:
TraceType.__init__(
self,
data=data,
trace_names=trace_names,
x_labels=x_labels,
trace_kwargs=trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
make_figure_kwargs=make_figure_kwargs,
fig=fig,
**layout_kwargs,
)
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if data is not None:
data = reshaping.to_2d_array(data)
if trace_names is not None:
checks.assert_shape_equal(data, trace_names, (1, 0))
else:
if trace_names is None:
raise ValueError("At least data or trace_names must be passed")
if trace_names is None:
trace_names = [None] * data.shape[1]
if isinstance(trace_names, str):
trace_names = [trace_names]
if x_labels is not None:
x_labels = clean_labels(x_labels)
if fig is None:
fig = make_figure(**resolve_dict(make_figure_kwargs))
fig.update_layout(**layout_kwargs)
for i, trace_name in enumerate(trace_names):
_trace_kwargs = resolve_dict(trace_kwargs, i=i)
trace_name = _trace_kwargs.pop("name", trace_name)
if trace_name is not None:
trace_name = str(trace_name)
_trace_kwargs = merge_dicts(
dict(x=x_labels, name=trace_name, showlegend=trace_name is not None),
_trace_kwargs,
)
trace = go.Bar(**_trace_kwargs)
if data is not None:
self.update_trace(trace, data, i)
fig.add_trace(trace, **add_trace_kwargs)
TraceUpdater.__init__(self, fig, fig.data[-len(trace_names) :])
@classmethod
def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, i: int) -> None:
data = reshaping.to_2d_array(data)
trace.y = data[:, i]
if trace.marker.colorscale is not None:
trace.marker.color = data[:, i]
def update(self, data: tp.ArrayLike) -> None:
data = reshaping.to_2d_array(data)
with self.fig.batch_update():
for i, trace in enumerate(self.traces):
self.update_trace(trace, data, i)
class Scatter(TraceType, TraceUpdater):
"""Scatter plot.
Args:
data (array_like): Data in any format that can be converted to NumPy.
Must be of shape (`x_labels`, `trace_names`).
trace_names (str or list of str): Trace names, corresponding to columns in pandas.
x_labels (array_like): X-axis labels, corresponding to index in pandas.
trace_kwargs (dict or list of dict): Keyword arguments passed to `plotly.graph_objects.Scatter`.
Can be specified per trace as a sequence of dicts.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`.
fig (Figure or FigureWidget): Figure to add traces to.
use_gl (bool): Whether to use `plotly.graph_objects.Scattergl`.
Defaults to the global setting. If the global setting is None, becomes True
if there are more than 10,000 data points.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> from vectorbtpro import *
>>> scatter = vbt.Scatter(
... data=[[1, 2], [3, 4]],
... trace_names=['a', 'b'],
... x_labels=['x', 'y']
... )
>>> scatter.fig.show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
def __init__(
self,
data: tp.Optional[tp.ArrayLike] = None,
trace_names: tp.TraceNames = None,
x_labels: tp.Optional[tp.Labels] = None,
trace_kwargs: tp.KwargsLikeSequence = None,
add_trace_kwargs: tp.KwargsLike = None,
make_figure_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
use_gl: tp.Optional[bool] = None,
**layout_kwargs,
) -> None:
TraceType.__init__(
self,
data=data,
trace_names=trace_names,
x_labels=x_labels,
trace_kwargs=trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
make_figure_kwargs=make_figure_kwargs,
fig=fig,
**layout_kwargs,
)
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if data is not None:
data = reshaping.to_2d_array(data)
if trace_names is not None:
checks.assert_shape_equal(data, trace_names, (1, 0))
else:
if trace_names is None:
raise ValueError("At least data or trace_names must be passed")
if trace_names is None:
trace_names = [None] * data.shape[1]
if isinstance(trace_names, str):
trace_names = [trace_names]
if x_labels is not None:
x_labels = clean_labels(x_labels)
if fig is None:
fig = make_figure(**resolve_dict(make_figure_kwargs))
fig.update_layout(**layout_kwargs)
for i, trace_name in enumerate(trace_names):
_trace_kwargs = resolve_dict(trace_kwargs, i=i)
_use_gl = _trace_kwargs.pop("use_gl", use_gl)
if _use_gl is None:
_use_gl = plotting_cfg["use_gl"]
if _use_gl is None:
_use_gl = _use_gl is None and data is not None and data.size >= 10000
trace_name = _trace_kwargs.pop("name", trace_name)
if trace_name is not None:
trace_name = str(trace_name)
if _use_gl:
scatter_obj = go.Scattergl
else:
scatter_obj = go.Scatter
try:
from plotly_resampler.aggregation import AbstractFigureAggregator
if isinstance(fig, AbstractFigureAggregator):
use_resampler = True
else:
use_resampler = False
except ImportError:
use_resampler = False
if use_resampler:
if data is None:
raise ValueError("Cannot create empty scatter traces when using plotly-resampler")
_trace_kwargs = merge_dicts(
dict(name=trace_name, showlegend=trace_name is not None),
_trace_kwargs,
)
trace = scatter_obj(**_trace_kwargs)
fig.add_trace(trace, hf_x=x_labels, hf_y=data[:, i], **add_trace_kwargs)
else:
_trace_kwargs = merge_dicts(
dict(x=x_labels, name=trace_name, showlegend=trace_name is not None),
_trace_kwargs,
)
trace = scatter_obj(**_trace_kwargs)
if data is not None:
self.update_trace(trace, data, i)
fig.add_trace(trace, **add_trace_kwargs)
TraceUpdater.__init__(self, fig, fig.data[-len(trace_names) :])
@classmethod
def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, i: int) -> None:
data = reshaping.to_2d_array(data)
trace.y = data[:, i]
def update(self, data: tp.ArrayLike) -> None:
data = reshaping.to_2d_array(data)
with self.fig.batch_update():
for i, trace in enumerate(self.traces):
self.update_trace(trace, data, i)
class Histogram(TraceType, TraceUpdater):
"""Histogram plot.
Args:
data (array_like): Data in any format that can be converted to NumPy.
Must be of shape (any, `trace_names`).
trace_names (str or list of str): Trace names, corresponding to columns in pandas.
horizontal (bool): Whether to plot horizontally.
remove_nan (bool): Whether to remove NaN values.
from_quantile (float): Filter out data points before this quantile.
Must be in range `[0, 1]`.
to_quantile (float): Filter out data points after this quantile.
Must be in range `[0, 1]`.
trace_kwargs (dict or list of dict): Keyword arguments passed to `plotly.graph_objects.Histogram`.
Can be specified per trace as a sequence of dicts.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> from vectorbtpro import *
>>> hist = vbt.Histogram(
... data=[[1, 2], [3, 4], [2, 1]],
... trace_names=['a', 'b']
... )
>>> hist.fig.show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
def __init__(
self,
data: tp.Optional[tp.ArrayLike] = None,
trace_names: tp.TraceNames = None,
horizontal: bool = False,
remove_nan: bool = True,
from_quantile: tp.Optional[float] = None,
to_quantile: tp.Optional[float] = None,
trace_kwargs: tp.KwargsLikeSequence = None,
add_trace_kwargs: tp.KwargsLike = None,
make_figure_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> None:
TraceType.__init__(
self,
data=data,
trace_names=trace_names,
horizontal=horizontal,
remove_nan=remove_nan,
from_quantile=from_quantile,
to_quantile=to_quantile,
trace_kwargs=trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
make_figure_kwargs=make_figure_kwargs,
fig=fig,
**layout_kwargs,
)
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if data is not None:
data = reshaping.to_2d_array(data)
if trace_names is not None:
checks.assert_shape_equal(data, trace_names, (1, 0))
else:
if trace_names is None:
raise ValueError("At least data or trace_names must be passed")
if trace_names is None:
trace_names = [None] * data.shape[1]
if isinstance(trace_names, str):
trace_names = [trace_names]
if fig is None:
fig = make_figure(**resolve_dict(make_figure_kwargs))
fig.update_layout(barmode="overlay")
fig.update_layout(**layout_kwargs)
for i, trace_name in enumerate(trace_names):
_trace_kwargs = resolve_dict(trace_kwargs, i=i)
trace_name = _trace_kwargs.pop("name", trace_name)
if trace_name is not None:
trace_name = str(trace_name)
_trace_kwargs = merge_dicts(
dict(
opacity=0.75 if len(trace_names) > 1 else 1,
name=trace_name,
showlegend=trace_name is not None,
),
_trace_kwargs,
)
trace = go.Histogram(**_trace_kwargs)
if data is not None:
self.update_trace(
trace,
data,
i,
horizontal=horizontal,
remove_nan=remove_nan,
from_quantile=from_quantile,
to_quantile=to_quantile,
)
fig.add_trace(trace, **add_trace_kwargs)
TraceUpdater.__init__(self, fig, fig.data[-len(trace_names) :])
self._horizontal = horizontal
self._remove_nan = remove_nan
self._from_quantile = from_quantile
self._to_quantile = to_quantile
@property
def horizontal(self) -> bool:
"""Whether to plot horizontally."""
return self._horizontal
@property
def remove_nan(self) -> bool:
"""Whether to remove NaN values."""
return self._remove_nan
@property
def from_quantile(self) -> float:
"""Filter out data points before this quantile."""
return self._from_quantile
@property
def to_quantile(self) -> float:
"""Filter out data points after this quantile."""
return self._to_quantile
@classmethod
def update_trace(
cls,
trace: BaseTraceType,
data: tp.ArrayLike,
i: int,
horizontal: bool = False,
remove_nan: bool = True,
from_quantile: tp.Optional[float] = None,
to_quantile: tp.Optional[float] = None,
) -> None:
data = reshaping.to_2d_array(data)
d = data[:, i]
if remove_nan:
d = d[~np.isnan(d)]
mask = np.full(d.shape, True)
if from_quantile is not None:
mask &= d >= np.quantile(d, from_quantile)
if to_quantile is not None:
mask &= d <= np.quantile(d, to_quantile)
d = d[mask]
if horizontal:
trace.x = None
trace.y = d
else:
trace.x = d
trace.y = None
def update(self, data: tp.ArrayLike) -> None:
data = reshaping.to_2d_array(data)
with self.fig.batch_update():
for i, trace in enumerate(self.traces):
self.update_trace(
trace,
data,
i,
horizontal=self.horizontal,
remove_nan=self.remove_nan,
from_quantile=self.from_quantile,
to_quantile=self.to_quantile,
)
class Box(TraceType, TraceUpdater):
"""Box plot.
For keyword arguments, see `Histogram`.
Usage:
```pycon
>>> from vectorbtpro import *
>>> box = vbt.Box(
... data=[[1, 2], [3, 4], [2, 1]],
... trace_names=['a', 'b']
... )
>>> box.fig.show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
def __init__(
self,
data: tp.Optional[tp.ArrayLike] = None,
trace_names: tp.TraceNames = None,
horizontal: bool = False,
remove_nan: bool = True,
from_quantile: tp.Optional[float] = None,
to_quantile: tp.Optional[float] = None,
trace_kwargs: tp.KwargsLikeSequence = None,
add_trace_kwargs: tp.KwargsLike = None,
make_figure_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> None:
TraceType.__init__(
self,
data=data,
trace_names=trace_names,
horizontal=horizontal,
remove_nan=remove_nan,
from_quantile=from_quantile,
to_quantile=to_quantile,
trace_kwargs=trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
make_figure_kwargs=make_figure_kwargs,
fig=fig,
**layout_kwargs,
)
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if data is not None:
data = reshaping.to_2d_array(data)
if trace_names is not None:
checks.assert_shape_equal(data, trace_names, (1, 0))
else:
if trace_names is None:
raise ValueError("At least data or trace_names must be passed")
if trace_names is None:
trace_names = [None] * data.shape[1]
if isinstance(trace_names, str):
trace_names = [trace_names]
if fig is None:
fig = make_figure(**resolve_dict(make_figure_kwargs))
fig.update_layout(**layout_kwargs)
for i, trace_name in enumerate(trace_names):
_trace_kwargs = resolve_dict(trace_kwargs, i=i)
trace_name = _trace_kwargs.pop("name", trace_name)
if trace_name is not None:
trace_name = str(trace_name)
_trace_kwargs = merge_dicts(
dict(name=trace_name, showlegend=trace_name is not None, boxmean="sd"),
_trace_kwargs,
)
trace = go.Box(**_trace_kwargs)
if data is not None:
self.update_trace(
trace,
data,
i,
horizontal=horizontal,
remove_nan=remove_nan,
from_quantile=from_quantile,
to_quantile=to_quantile,
)
fig.add_trace(trace, **add_trace_kwargs)
TraceUpdater.__init__(self, fig, fig.data[-len(trace_names) :])
self._horizontal = horizontal
self._remove_nan = remove_nan
self._from_quantile = from_quantile
self._to_quantile = to_quantile
@property
def horizontal(self) -> bool:
"""Whether to plot horizontally."""
return self._horizontal
@property
def remove_nan(self) -> bool:
"""Whether to remove NaN values."""
return self._remove_nan
@property
def from_quantile(self) -> float:
"""Filter out data points before this quantile."""
return self._from_quantile
@property
def to_quantile(self) -> float:
"""Filter out data points after this quantile."""
return self._to_quantile
@classmethod
def update_trace(
cls,
trace: BaseTraceType,
data: tp.ArrayLike,
i: int,
horizontal: bool = False,
remove_nan: bool = True,
from_quantile: tp.Optional[float] = None,
to_quantile: tp.Optional[float] = None,
) -> None:
data = reshaping.to_2d_array(data)
d = data[:, i]
if remove_nan:
d = d[~np.isnan(d)]
mask = np.full(d.shape, True)
if from_quantile is not None:
mask &= d >= np.quantile(d, from_quantile)
if to_quantile is not None:
mask &= d <= np.quantile(d, to_quantile)
d = d[mask]
if horizontal:
trace.x = d
trace.y = None
else:
trace.x = None
trace.y = d
def update(self, data: tp.ArrayLike) -> None:
data = reshaping.to_2d_array(data)
with self.fig.batch_update():
for i, trace in enumerate(self.traces):
self.update_trace(
trace,
data,
i,
horizontal=self.horizontal,
remove_nan=self.remove_nan,
from_quantile=self.from_quantile,
to_quantile=self.to_quantile,
)
class Heatmap(TraceType, TraceUpdater):
"""Heatmap plot.
Args:
data (array_like): Data in any format that can be converted to NumPy.
Must be of shape (`y_labels`, `x_labels`).
x_labels (array_like): X-axis labels, corresponding to columns in pandas.
y_labels (array_like): Y-axis labels, corresponding to index in pandas.
is_x_category (bool): Whether X-axis is a categorical axis.
is_y_category (bool): Whether Y-axis is a categorical axis.
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap`.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> from vectorbtpro import *
>>> heatmap = vbt.Heatmap(
... data=[[1, 2], [3, 4]],
... x_labels=['a', 'b'],
... y_labels=['x', 'y']
... )
>>> heatmap.fig.show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
def __init__(
self,
data: tp.Optional[tp.ArrayLike] = None,
x_labels: tp.Optional[tp.Labels] = None,
y_labels: tp.Optional[tp.Labels] = None,
is_x_category: bool = False,
is_y_category: bool = False,
trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
make_figure_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> None:
TraceType.__init__(
self,
data=data,
x_labels=x_labels,
y_labels=y_labels,
trace_kwargs=trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
make_figure_kwargs=make_figure_kwargs,
fig=fig,
**layout_kwargs,
)
from vectorbtpro._settings import settings
layout_cfg = settings["plotting"]["layout"]
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if data is not None:
data = reshaping.to_2d_array(data)
if x_labels is not None:
checks.assert_shape_equal(data, x_labels, (1, 0))
if y_labels is not None:
checks.assert_shape_equal(data, y_labels, (0, 0))
else:
if x_labels is None or y_labels is None:
raise ValueError("At least data, or x_labels and y_labels must be passed")
if x_labels is not None:
x_labels = clean_labels(x_labels)
if y_labels is not None:
y_labels = clean_labels(y_labels)
if fig is None:
fig = make_figure(**resolve_dict(make_figure_kwargs))
if "width" in layout_cfg:
# Calculate nice width and height
max_width = layout_cfg["width"]
if data is not None:
x_len = data.shape[1]
y_len = data.shape[0]
else:
x_len = len(x_labels)
y_len = len(y_labels)
width = math.ceil(rescale(x_len / (x_len + y_len), (0, 1), (0.3 * max_width, max_width)))
width = min(width + 150, max_width) # account for colorbar
height = math.ceil(rescale(y_len / (x_len + y_len), (0, 1), (0.3 * max_width, max_width)))
height = min(height, max_width * 0.7) # limit height
fig.update_layout(width=width, height=height)
_trace_kwargs = merge_dicts(
dict(hoverongaps=False, colorscale="Plasma", x=x_labels, y=y_labels),
trace_kwargs,
)
trace = go.Heatmap(**_trace_kwargs)
if data is not None:
self.update_trace(trace, data)
fig.add_trace(trace, **add_trace_kwargs)
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
axis_kwargs = dict()
if is_x_category:
axis_kwargs[xaxis] = dict(type="category")
if is_y_category:
axis_kwargs[yaxis] = dict(type="category")
fig.update_layout(**axis_kwargs)
fig.update_layout(**layout_kwargs)
TraceUpdater.__init__(self, fig, (fig.data[-1],))
@classmethod
def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, *args, **kwargs) -> None:
trace.z = reshaping.to_2d_array(data)
def update(self, data: tp.ArrayLike) -> None:
with self.fig.batch_update():
self.update_trace(self.traces[0], data)
class Volume(TraceType, TraceUpdater):
"""Volume plot.
Args:
data (array_like): Data in any format that can be converted to NumPy.
Must be a 3-dim array.
x_labels (array_like): X-axis labels.
y_labels (array_like): Y-axis labels.
z_labels (array_like): Z-axis labels.
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Volume`.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
scene_name (str): Reference to the 3D scene.
make_figure_kwargs (dict): Keyword arguments passed to `vectorbtpro.utils.figure.make_figure`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
!!! note
Figure widgets have currently problems displaying NaNs.
Use `.show()` method for rendering.
Usage:
```pycon
>>> from vectorbtpro import *
>>> volume = vbt.Volume(
... data=np.random.randint(1, 10, size=(3, 3, 3)),
... x_labels=['a', 'b', 'c'],
... y_labels=['d', 'e', 'f'],
... z_labels=['g', 'h', 'i']
... )
>>> volume.fig.show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
def __init__(
self,
data: tp.Optional[tp.ArrayLike] = None,
x_labels: tp.Optional[tp.Labels] = None,
y_labels: tp.Optional[tp.Labels] = None,
z_labels: tp.Optional[tp.Labels] = None,
trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
scene_name: str = "scene",
make_figure_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> None:
TraceType.__init__(
self,
data=data,
x_labels=x_labels,
y_labels=y_labels,
z_labels=z_labels,
trace_kwargs=trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
scene_name=scene_name,
make_figure_kwargs=make_figure_kwargs,
fig=fig,
**layout_kwargs,
)
from vectorbtpro._settings import settings
layout_cfg = settings["plotting"]["layout"]
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if data is not None:
checks.assert_ndim(data, 3)
data = np.asarray(data)
x_len, y_len, z_len = data.shape
if x_labels is not None:
checks.assert_shape_equal(data, x_labels, (0, 0))
if y_labels is not None:
checks.assert_shape_equal(data, y_labels, (1, 0))
if z_labels is not None:
checks.assert_shape_equal(data, z_labels, (2, 0))
else:
if x_labels is None or y_labels is None or z_labels is None:
raise ValueError("At least data, or x_labels, y_labels and z_labels must be passed")
x_len = len(x_labels)
y_len = len(y_labels)
z_len = len(z_labels)
if x_labels is None:
x_labels = np.arange(x_len)
else:
x_labels = clean_labels(x_labels)
if y_labels is None:
y_labels = np.arange(y_len)
else:
y_labels = clean_labels(y_labels)
if z_labels is None:
z_labels = np.arange(z_len)
else:
z_labels = clean_labels(z_labels)
x_labels = np.asarray(x_labels)
y_labels = np.asarray(y_labels)
z_labels = np.asarray(z_labels)
if fig is None:
fig = make_figure(**resolve_dict(make_figure_kwargs))
if "width" in layout_cfg:
# Calculate nice width and height
fig.update_layout(width=layout_cfg["width"], height=0.7 * layout_cfg["width"])
# Non-numeric data types are not supported by go.Volume, so use ticktext
# Note: Currently plotly displays the entire tick array, in future versions it will be more sensible
more_layout = dict()
more_layout[scene_name] = dict()
if not np.issubdtype(x_labels.dtype, np.number):
x_ticktext = x_labels
x_labels = np.arange(x_len)
more_layout[scene_name]["xaxis"] = dict(ticktext=x_ticktext, tickvals=x_labels, tickmode="array")
if not np.issubdtype(y_labels.dtype, np.number):
y_ticktext = y_labels
y_labels = np.arange(y_len)
more_layout[scene_name]["yaxis"] = dict(ticktext=y_ticktext, tickvals=y_labels, tickmode="array")
if not np.issubdtype(z_labels.dtype, np.number):
z_ticktext = z_labels
z_labels = np.arange(z_len)
more_layout[scene_name]["zaxis"] = dict(ticktext=z_ticktext, tickvals=z_labels, tickmode="array")
fig.update_layout(**more_layout)
fig.update_layout(**layout_kwargs)
# Arrays must have the same length as the flattened data array
x = np.repeat(x_labels, len(y_labels) * len(z_labels))
y = np.tile(np.repeat(y_labels, len(z_labels)), len(x_labels))
z = np.tile(z_labels, len(x_labels) * len(y_labels))
_trace_kwargs = merge_dicts(
dict(x=x, y=y, z=z, opacity=0.2, surface_count=15, colorscale="Plasma"),
trace_kwargs,
)
trace = go.Volume(**_trace_kwargs)
if data is not None:
self.update_trace(trace, data)
fig.add_trace(trace, **add_trace_kwargs)
TraceUpdater.__init__(self, fig, (fig.data[-1],))
@classmethod
def update_trace(cls, trace: BaseTraceType, data: tp.ArrayLike, *args, **kwargs) -> None:
trace.value = np.asarray(data).flatten()
def update(self, data: tp.ArrayLike) -> None:
with self.fig.batch_update():
self.update_trace(self.traces[0], data)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with records that can make use of OHLC data."""
from vectorbtpro import _typing as tp
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.reshaping import to_2d_array
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic import nb
from vectorbtpro.records.base import Records
from vectorbtpro.records.decorators import attach_shortcut_properties
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import ReadonlyConfig
__all__ = [
"PriceRecords",
]
__pdoc__ = {}
price_records_shortcut_config = ReadonlyConfig(
dict(
bar_open_time=dict(obj_type="mapped"),
bar_close_time=dict(obj_type="mapped"),
bar_open=dict(obj_type="mapped"),
bar_high=dict(obj_type="mapped"),
bar_low=dict(obj_type="mapped"),
bar_close=dict(obj_type="mapped"),
)
)
"""_"""
__pdoc__[
"price_records_shortcut_config"
] = f"""Config of shortcut properties to be attached to `PriceRecords`.
```python
{price_records_shortcut_config.prettify()}
```
"""
PriceRecordsT = tp.TypeVar("PriceRecordsT", bound="PriceRecords")
@attach_shortcut_properties(price_records_shortcut_config)
class PriceRecords(Records):
"""Extends `vectorbtpro.records.base.Records` for records that can make use of OHLC data."""
@classmethod
def from_records(
cls: tp.Type[PriceRecordsT],
wrapper: ArrayWrapper,
records: tp.RecordArray,
data: tp.Optional["Data"] = None,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
close: tp.Optional[tp.ArrayLike] = None,
attach_data: bool = True,
**kwargs,
) -> PriceRecordsT:
"""Build `PriceRecords` from records."""
if open is None and data is not None:
open = data.open
if high is None and data is not None:
high = data.high
if low is None and data is not None:
low = data.low
if close is None and data is not None:
close = data.close
return cls(
wrapper,
records,
open=open if attach_data else None,
high=high if attach_data else None,
low=low if attach_data else None,
close=close if attach_data else None,
**kwargs,
)
@classmethod
def resolve_row_stack_kwargs(
cls: tp.Type[PriceRecordsT],
*objs: tp.MaybeTuple[PriceRecordsT],
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `PriceRecords` after stacking along columns."""
kwargs = Records.resolve_row_stack_kwargs(*objs, **kwargs)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, PriceRecords):
raise TypeError("Each object to be merged must be an instance of PriceRecords")
for price_name in ("open", "high", "low", "close"):
if price_name not in kwargs:
price_objs = []
stack_price_objs = True
for obj in objs:
if getattr(obj, price_name) is not None:
price_objs.append(getattr(obj, price_name))
else:
stack_price_objs = False
break
if stack_price_objs:
kwargs[price_name] = kwargs["wrapper"].row_stack_arrs(
*price_objs,
group_by=False,
wrap=False,
)
return kwargs
@classmethod
def resolve_column_stack_kwargs(
cls: tp.Type[PriceRecordsT],
*objs: tp.MaybeTuple[PriceRecordsT],
reindex_kwargs: tp.KwargsLike = None,
ffill_close: bool = False,
fbfill_close: bool = False,
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `PriceRecords` after stacking along columns."""
kwargs = Records.resolve_column_stack_kwargs(*objs, reindex_kwargs=reindex_kwargs, **kwargs)
kwargs.pop("reindex_kwargs", None)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, PriceRecords):
raise TypeError("Each object to be merged must be an instance of PriceRecords")
for price_name in ("open", "high", "low", "close"):
if price_name not in kwargs:
price_objs = []
stack_price_objs = True
for obj in objs:
if getattr(obj, "_" + price_name) is not None:
price_objs.append(getattr(obj, price_name))
else:
stack_price_objs = False
break
if stack_price_objs:
new_price = kwargs["wrapper"].column_stack_arrs(
*price_objs,
reindex_kwargs=reindex_kwargs,
group_by=False,
wrap=True,
)
if price_name == "close":
if fbfill_close:
new_price = new_price.vbt.fbfill()
elif ffill_close:
new_price = new_price.vbt.ffill()
kwargs[price_name] = new_price.values
return kwargs
def __init__(
self,
wrapper: ArrayWrapper,
records_arr: tp.RecordArray,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
close: tp.Optional[tp.ArrayLike] = None,
**kwargs,
) -> None:
Records.__init__(
self,
wrapper,
records_arr,
open=open,
high=high,
low=low,
close=close,
**kwargs,
)
if open is not None:
open = to_2d_array(open)
if high is not None:
high = to_2d_array(high)
if low is not None:
low = to_2d_array(low)
if close is not None:
close = to_2d_array(close)
self._open = open
self._high = high
self._low = low
self._close = close
def indexing_func_meta(self, *args, records_meta: tp.DictLike = None, **kwargs) -> dict:
"""Perform indexing on `PriceRecords` and return metadata."""
if records_meta is None:
records_meta = Records.indexing_func_meta(self, *args, **kwargs)
prices = {}
for price_name in ("open", "high", "low", "close"):
if getattr(self, "_" + price_name) is not None:
new_price = ArrayWrapper.select_from_flex_array(
getattr(self, "_" + price_name),
row_idxs=records_meta["wrapper_meta"]["row_idxs"],
col_idxs=records_meta["wrapper_meta"]["col_idxs"],
rows_changed=records_meta["wrapper_meta"]["rows_changed"],
columns_changed=records_meta["wrapper_meta"]["columns_changed"],
)
else:
new_price = None
prices[price_name] = new_price
return {**records_meta, **prices}
def indexing_func(self: PriceRecordsT, *args, price_records_meta: tp.DictLike = None, **kwargs) -> PriceRecordsT:
"""Perform indexing on `PriceRecords`."""
if price_records_meta is None:
price_records_meta = self.indexing_func_meta(*args, **kwargs)
return self.replace(
wrapper=price_records_meta["wrapper_meta"]["new_wrapper"],
records_arr=price_records_meta["new_records_arr"],
open=price_records_meta["open"],
high=price_records_meta["high"],
low=price_records_meta["low"],
close=price_records_meta["close"],
)
def resample(
self: PriceRecordsT,
*args,
ffill_close: bool = False,
fbfill_close: bool = False,
records_meta: tp.DictLike = None,
**kwargs,
) -> PriceRecordsT:
"""Perform resampling on `PriceRecords`."""
if records_meta is None:
records_meta = self.resample_meta(*args, **kwargs)
if self._open is None:
new_open = None
else:
new_open = self.open.vbt.resample_apply(
records_meta["wrapper_meta"]["resampler"],
nb.first_reduce_nb,
)
if self._high is None:
new_high = None
else:
new_high = self.high.vbt.resample_apply(
records_meta["wrapper_meta"]["resampler"],
nb.max_reduce_nb,
)
if self._low is None:
new_low = None
else:
new_low = self.low.vbt.resample_apply(
records_meta["wrapper_meta"]["resampler"],
nb.min_reduce_nb,
)
if self._close is None:
new_close = None
else:
new_close = self.close.vbt.resample_apply(
records_meta["wrapper_meta"]["resampler"],
nb.last_reduce_nb,
)
if fbfill_close:
new_close = new_close.vbt.fbfill()
elif ffill_close:
new_close = new_close.vbt.ffill()
return self.replace(
wrapper=records_meta["wrapper_meta"]["new_wrapper"],
records_arr=records_meta["new_records_arr"],
open=new_open,
high=new_high,
low=new_low,
close=new_close,
)
@property
def open(self) -> tp.Optional[tp.SeriesFrame]:
"""Open price."""
if self._open is None:
return None
return self.wrapper.wrap(self._open, group_by=False)
@property
def high(self) -> tp.Optional[tp.SeriesFrame]:
"""High price."""
if self._high is None:
return None
return self.wrapper.wrap(self._high, group_by=False)
@property
def low(self) -> tp.Optional[tp.SeriesFrame]:
"""Low price."""
if self._low is None:
return None
return self.wrapper.wrap(self._low, group_by=False)
@property
def close(self) -> tp.Optional[tp.SeriesFrame]:
"""Close price."""
if self._close is None:
return None
return self.wrapper.wrap(self._close, group_by=False)
def get_bar_open_time(self, **kwargs) -> MappedArray:
"""Get a mapped array with the opening time of the bar."""
return self.map_array(self.wrapper.index[self.idx_arr], **kwargs)
def get_bar_close_time(self, **kwargs) -> MappedArray:
"""Get a mapped array with the closing time of the bar."""
if self.wrapper.freq is None:
raise ValueError("Must provide frequency")
return self.map_array(
Resampler.get_rbound_index(index=self.wrapper.index[self.idx_arr], freq=self.wrapper.freq), **kwargs
)
def get_bar_open(self, **kwargs) -> MappedArray:
"""Get a mapped array with the opening price of the bar."""
return self.apply(nb.bar_price_nb, self._open, **kwargs)
def get_bar_high(self, **kwargs) -> MappedArray:
"""Get a mapped array with the high price of the bar."""
return self.apply(nb.bar_price_nb, self._high, **kwargs)
def get_bar_low(self, **kwargs) -> MappedArray:
"""Get a mapped array with the low price of the bar."""
return self.apply(nb.bar_price_nb, self._low, **kwargs)
def get_bar_close(self, **kwargs) -> MappedArray:
"""Get a mapped array with the closing price of the bar."""
return self.apply(nb.bar_price_nb, self._close, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with range records.
Range records capture information on ranges. They are useful for analyzing duration of processes,
such as drawdowns, trades, and positions. They also come in handy when analyzing distance between events,
such as entry and exit signals.
Each range has a starting point and an ending point. For example, the points for `range(20)`
are 0 and 20 (not 19!) respectively.
!!! note
Be aware that if a range hasn't ended in a column, its `end_idx` will point at the latest index.
Make sure to account for this when computing custom metrics involving duration.
```pycon
>>> from vectorbtpro import *
>>> start = '2019-01-01 UTC' # crypto is in UTC
>>> end = '2020-01-01 UTC'
>>> price = vbt.YFData.pull('BTC-USD', start=start, end=end).get('Close')
```
[=100% "100%"]{: .candystripe .candystripe-animate }
```pycon
>>> fast_ma = vbt.MA.run(price, 10)
>>> slow_ma = vbt.MA.run(price, 50)
>>> fast_below_slow = fast_ma.ma_above(slow_ma)
>>> ranges = vbt.Ranges.from_array(fast_below_slow, wrapper_kwargs=dict(freq='d'))
>>> ranges.readable
Range Id Column Start Timestamp End Timestamp \\
0 0 0 2019-02-19 00:00:00+00:00 2019-07-25 00:00:00+00:00
1 1 0 2019-08-08 00:00:00+00:00 2019-08-19 00:00:00+00:00
2 2 0 2019-11-01 00:00:00+00:00 2019-11-20 00:00:00+00:00
Status
0 Closed
1 Closed
2 Closed
>>> ranges.duration.max(wrap_kwargs=dict(to_timedelta=True))
Timedelta('156 days 00:00:00')
```
## From accessors
Moreover, all generic accessors have a property `ranges` and a method `get_ranges`:
```pycon
>>> # vectorbtpro.generic.accessors.GenericAccessor.ranges.coverage
>>> fast_below_slow.vbt.ranges.coverage
0.5081967213114754
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Ranges.metrics`.
```pycon
>>> df = pd.DataFrame({
... 'a': [1, 2, np.nan, np.nan, 5, 6],
... 'b': [np.nan, 2, np.nan, 4, np.nan, 6]
... })
>>> ranges = df.vbt(freq='d').ranges
>>> ranges['a'].stats()
Start 0
End 5
Period 6 days 00:00:00
Total Records 2
Coverage 0.666667
Overlap Coverage 0.0
Duration: Min 2 days 00:00:00
Duration: Median 2 days 00:00:00
Duration: Max 2 days 00:00:00
Name: a, dtype: object
```
`Ranges.stats` also supports (re-)grouping:
```pycon
>>> ranges.stats(group_by=True)
Start 0
End 5
Period 6 days 00:00:00
Total Records 5
Coverage 0.416667
Overlap Coverage 0.4
Duration: Min 1 days 00:00:00
Duration: Median 1 days 00:00:00
Duration: Max 2 days 00:00:00
Name: group, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Ranges.subplots`.
`Ranges` class has a single subplot based on `Ranges.plot`:
```pycon
>>> ranges['a'].plots().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexes import stack_indexes, combine_indexes, tile_index
from vectorbtpro.base.reshaping import to_pd_array, to_1d_array, to_2d_array, tile
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic import nb, enums
from vectorbtpro.generic.price_records import PriceRecords
from vectorbtpro.records.base import Records
from vectorbtpro.records.decorators import override_field_config, attach_fields, attach_shortcut_properties
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.attr_ import DefineMixin, define, MISSING
from vectorbtpro.utils.colors import adjust_lightness
from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, ReadonlyConfig, HybridConfig
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.execution import Task, execute
from vectorbtpro.utils.params import combine_params, Param
from vectorbtpro.utils.parsing import get_func_kwargs
from vectorbtpro.utils.random_ import set_seed
from vectorbtpro.utils.template import substitute_templates
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"Ranges",
"PatternRanges",
"PSC",
]
__pdoc__ = {}
# ############# Ranges ############# #
ranges_field_config = ReadonlyConfig(
dict(
dtype=enums.range_dt,
settings=dict(
id=dict(title="Range Id"),
idx=dict(name="end_idx"), # remap field of Records
start_idx=dict(title="Start Index", mapping="index"),
end_idx=dict(title="End Index", mapping="index"),
status=dict(title="Status", mapping=enums.RangeStatus),
),
)
)
"""_"""
__pdoc__[
"ranges_field_config"
] = f"""Field config for `Ranges`.
```python
{ranges_field_config.prettify()}
```
"""
ranges_attach_field_config = ReadonlyConfig(dict(status=dict(attach_filters=True)))
"""_"""
__pdoc__[
"ranges_attach_field_config"
] = f"""Config of fields to be attached to `Ranges`.
```python
{ranges_attach_field_config.prettify()}
```
"""
ranges_shortcut_config = ReadonlyConfig(
dict(
valid=dict(),
invalid=dict(),
first_pd_mask=dict(obj_type="array"),
last_pd_mask=dict(obj_type="array"),
ranges_pd_mask=dict(obj_type="array"),
first_idx=dict(obj_type="mapped_array"),
last_idx=dict(obj_type="mapped_array"),
duration=dict(obj_type="mapped_array"),
real_duration=dict(obj_type="mapped_array"),
avg_duration=dict(obj_type="red_array"),
max_duration=dict(obj_type="red_array"),
coverage=dict(obj_type="red_array"),
overlap_coverage=dict(method_name="get_coverage", obj_type="red_array", method_kwargs=dict(overlapping=True)),
projections=dict(obj_type="array"),
)
)
"""_"""
__pdoc__[
"ranges_shortcut_config"
] = f"""Config of shortcut properties to be attached to `Ranges`.
```python
{ranges_shortcut_config.prettify()}
```
"""
RangesT = tp.TypeVar("RangesT", bound="Ranges")
@attach_shortcut_properties(ranges_shortcut_config)
@attach_fields(ranges_attach_field_config)
@override_field_config(ranges_field_config)
class Ranges(PriceRecords):
"""Extends `vectorbtpro.generic.price_records.PriceRecords` for working with range records.
Requires `records_arr` to have all fields defined in `vectorbtpro.generic.enums.range_dt`."""
@property
def field_config(self) -> Config:
return self._field_config
@classmethod
def from_array(
cls: tp.Type[RangesT],
arr: tp.ArrayLike,
gap_value: tp.Optional[tp.Scalar] = None,
attach_as_close: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> RangesT:
"""Build `Ranges` from an array.
Searches for sequences of
* True values in boolean data (False acts as a gap),
* positive values in integer data (-1 acts as a gap), and
* non-NaN values in any other data (NaN acts as a gap).
If `attach_as_close` is True, will attach `arr` as `close`.
`**kwargs` will be passed to `Ranges.__init__`."""
if wrapper_kwargs is None:
wrapper_kwargs = {}
wrapper = ArrayWrapper.from_obj(arr, **wrapper_kwargs)
arr = to_2d_array(arr)
if gap_value is None:
if np.issubdtype(arr.dtype, np.bool_):
gap_value = False
elif np.issubdtype(arr.dtype, np.integer):
gap_value = -1
else:
gap_value = np.nan
func = jit_reg.resolve_option(nb.get_ranges_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
records_arr = func(arr, gap_value)
if attach_as_close and "close" not in kwargs:
kwargs["close"] = arr
return cls(wrapper, records_arr, **kwargs)
@classmethod
def from_delta(
cls: tp.Type[RangesT],
records_or_mapped: tp.Union[Records, MappedArray],
delta: tp.Union[str, int, tp.FrequencyLike],
shift: tp.Optional[int] = None,
idx_field_or_arr: tp.Union[None, str, tp.Array1d] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> RangesT:
"""Build `Ranges` from a record/mapped array with a timedelta applied on its index field.
See `vectorbtpro.generic.nb.records.get_ranges_from_delta_nb`.
Set `delta` to an integer to wait a certain amount of rows. Set it to anything else to
wait a timedelta. The conversion is done using `vectorbtpro.utils.datetime_.to_timedelta64`.
The second option requires the index to be datetime-like, or at least the frequency to be set.
`**kwargs` will be passed to `Ranges.__init__`."""
if idx_field_or_arr is None:
if isinstance(records_or_mapped, Records):
idx_field_or_arr = records_or_mapped.get_field_arr("idx")
else:
idx_field_or_arr = records_or_mapped.idx_arr
if isinstance(idx_field_or_arr, str):
if isinstance(records_or_mapped, Records):
idx_field_or_arr = records_or_mapped.get_field_arr(idx_field_or_arr)
else:
raise ValueError("Providing an index field is allowed for records only")
if isinstance(records_or_mapped, Records):
id_arr = records_or_mapped.get_field_arr("id")
else:
id_arr = records_or_mapped.id_arr
if isinstance(delta, int):
delta_use_index = False
index = None
else:
delta = dt.to_ns(dt.to_timedelta64(delta))
if isinstance(records_or_mapped.wrapper.index, pd.DatetimeIndex):
index = dt.to_ns(records_or_mapped.wrapper.index)
else:
freq = dt.to_ns(dt.to_timedelta64(records_or_mapped.wrapper.freq))
index = np.arange(records_or_mapped.wrapper.shape[0]) * freq
delta_use_index = True
if shift is None:
shift = 0
col_map = records_or_mapped.col_mapper.get_col_map(group_by=False)
func = jit_reg.resolve_option(nb.get_ranges_from_delta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
new_records_arr = func(
records_or_mapped.wrapper.shape[0],
idx_field_or_arr,
id_arr,
col_map,
index=index,
delta=delta,
delta_use_index=delta_use_index,
shift=shift,
)
if isinstance(records_or_mapped, PriceRecords):
kwargs = merge_dicts(
dict(
open=records_or_mapped._open,
high=records_or_mapped._high,
low=records_or_mapped._low,
close=records_or_mapped._close,
),
kwargs,
)
return Ranges.from_records(records_or_mapped.wrapper, new_records_arr, **kwargs)
def with_delta(self, *args, **kwargs):
"""Pass self to `Ranges.from_delta`."""
return Ranges.from_delta(self, *args, **kwargs)
def crop(self) -> RangesT:
"""Remove any data outside the minimum start index and the maximum end index."""
min_start_idx = np.min(self.get_field_arr("start_idx"))
max_start_idx = np.max(self.get_field_arr("end_idx")) + 1
return self.iloc[min_start_idx:max_start_idx]
# ############# Filtering ############# #
def filter_min_duration(
self: RangesT,
min_duration: tp.Union[str, int, tp.FrequencyLike],
real: bool = False,
**kwargs,
) -> RangesT:
"""Filter out ranges that last less than a minimum duration."""
if isinstance(min_duration, int):
return self.apply_mask(self.duration.values >= min_duration, **kwargs)
min_duration = dt.to_timedelta64(min_duration)
if real:
return self.apply_mask(self.real_duration.values >= min_duration, **kwargs)
return self.apply_mask(self.duration.values * self.wrapper.freq >= min_duration, **kwargs)
def filter_max_duration(
self: RangesT,
max_duration: tp.Union[str, int, tp.FrequencyLike],
real: bool = False,
**kwargs,
) -> RangesT:
"""Filter out ranges that last more than a maximum duration."""
if isinstance(max_duration, int):
return self.apply_mask(self.duration.values <= max_duration, **kwargs)
max_duration = dt.to_timedelta64(max_duration)
if real:
return self.apply_mask(self.real_duration.values <= max_duration, **kwargs)
return self.apply_mask(self.duration.values * self.wrapper.freq <= max_duration, **kwargs)
# ############# Masking ############# #
def get_first_pd_mask(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame:
"""Get mask from `Ranges.get_first_idx`."""
return self.get_pd_mask(idx_arr=self.first_idx.values, group_by=group_by, **kwargs)
def get_last_pd_mask(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.SeriesFrame:
"""Get mask from `Ranges.get_last_idx`."""
out = self.get_pd_mask(idx_arr=self.last_idx.values, group_by=group_by, **kwargs)
return out
def get_ranges_pd_mask(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get mask from ranges.
See `vectorbtpro.generic.nb.records.ranges_to_mask_nb`."""
col_map = self.col_mapper.get_col_map(group_by=group_by)
func = jit_reg.resolve_option(nb.ranges_to_mask_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mask = func(
self.get_field_arr("start_idx"),
self.get_field_arr("end_idx"),
self.get_field_arr("status"),
col_map,
len(self.wrapper.index),
)
return self.wrapper.wrap(mask, group_by=group_by, **resolve_dict(wrap_kwargs))
# ############# Stats ############# #
def get_valid(self: RangesT, **kwargs) -> RangesT:
"""Get valid ranges.
A valid range doesn't have the start and end index set to -1."""
filter_mask = (self.get_field_arr("start_idx") != -1) & (self.get_field_arr("end_idx") != -1)
return self.apply_mask(filter_mask, **kwargs)
def get_invalid(self: RangesT, **kwargs) -> RangesT:
"""Get invalid ranges.
An invalid range has the start and/or end index set to -1."""
filter_mask = (self.get_field_arr("start_idx") == -1) | (self.get_field_arr("end_idx") == -1)
return self.apply_mask(filter_mask, **kwargs)
def get_first_idx(self, **kwargs):
"""Get the first index in each range."""
return self.map_field("start_idx", **kwargs)
def get_last_idx(self, **kwargs):
"""Get the last index in each range."""
last_idx = self.get_field_arr("end_idx", copy=True)
status = self.get_field_arr("status")
last_idx[status == enums.RangeStatus.Closed] -= 1
return self.map_array(last_idx, **kwargs)
def get_duration(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArray:
"""Get the effective duration of each range in integer format."""
func = jit_reg.resolve_option(nb.range_duration_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
duration = func(
self.get_field_arr("start_idx"),
self.get_field_arr("end_idx"),
self.get_field_arr("status"),
freq=1,
)
return self.map_array(duration, **kwargs)
def get_real_duration(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArray:
"""Get the real duration of each range in timedelta format."""
func = jit_reg.resolve_option(nb.range_duration_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
duration = func(
dt.to_ns(self.get_map_field_to_index("start_idx")),
dt.to_ns(self.get_map_field_to_index("end_idx")),
self.get_field_arr("status"),
freq=dt.to_ns(dt.to_timedelta64(self.wrapper.freq)),
).astype("timedelta64[ns]")
return self.map_array(duration, **kwargs)
def get_avg_duration(
self,
real: bool = False,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get average range duration (as timedelta)."""
if real:
duration = self.real_duration
duration = duration.replace(mapped_arr=dt.to_ns(duration.mapped_arr))
wrap_kwargs = merge_dicts(dict(name_or_index="avg_real_duration", dtype="timedelta64[ns]"), wrap_kwargs)
else:
duration = self.duration
wrap_kwargs = merge_dicts(dict(to_timedelta=True, name_or_index="avg_duration"), wrap_kwargs)
return duration.mean(group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs)
def get_max_duration(
self,
real: bool = False,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get maximum range duration (as timedelta)."""
if real:
duration = self.real_duration
duration = duration.replace(mapped_arr=dt.to_ns(duration.mapped_arr))
wrap_kwargs = merge_dicts(dict(name_or_index="max_real_duration", dtype="timedelta64[ns]"), wrap_kwargs)
else:
duration = self.duration
wrap_kwargs = merge_dicts(dict(to_timedelta=True, name_or_index="max_duration"), wrap_kwargs)
return duration.max(group_by=group_by, jitted=jitted, chunked=chunked, wrap_kwargs=wrap_kwargs, **kwargs)
def get_coverage(
self,
overlapping: bool = False,
normalize: bool = True,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get coverage, that is, the number of steps that are covered by all ranges.
See `vectorbtpro.generic.nb.records.range_coverage_nb`."""
col_map = self.col_mapper.get_col_map(group_by=group_by)
index_lens = self.wrapper.grouper.get_group_lens(group_by=group_by) * self.wrapper.shape[0]
func = jit_reg.resolve_option(nb.range_coverage_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
coverage = func(
self.get_field_arr("start_idx"),
self.get_field_arr("end_idx"),
self.get_field_arr("status"),
col_map,
index_lens,
overlapping=overlapping,
normalize=normalize,
)
wrap_kwargs = merge_dicts(dict(name_or_index="coverage"), wrap_kwargs)
return self.wrapper.wrap_reduced(coverage, group_by=group_by, **wrap_kwargs)
def get_projections(
self,
close: tp.Optional[tp.ArrayLike] = None,
proj_start: tp.Union[None, str, int, tp.FrequencyLike] = None,
proj_period: tp.Union[None, str, int, tp.FrequencyLike] = None,
incl_end_idx: bool = True,
extend: bool = False,
rebase: bool = True,
start_value: tp.ArrayLike = 1.0,
ffill: bool = False,
remove_empty: bool = True,
return_raw: bool = False,
start_index: tp.Optional[pd.Timestamp] = None,
id_level: tp.Union[None, str, tp.IndexLike] = None,
jitted: tp.JittedOption = None,
wrap_kwargs: tp.KwargsLike = None,
clean_index_kwargs: tp.KwargsLike = None,
) -> tp.Union[tp.Tuple[tp.Array1d, tp.Array2d], tp.Frame]:
"""Generate a projection for each range record.
See `vectorbtpro.generic.nb.records.map_ranges_to_projections_nb`.
Set `proj_start` to an integer to generate a projection after a certain row
after the start row. Set it to anything else to wait a timedelta.
The conversion is done using `vectorbtpro.utils.datetime_.to_timedelta64`.
The second option requires the index to be datetime-like, or at least the frequency to be set.
Set `proj_period` the same way as `proj_start` to generate a projection of a certain length.
Unless `extend` is True, it still respects the duration of the range.
Set `extend` to True to extend the projection even after the end of the range.
The extending period is taken from the longest range duration if `proj_period` is None,
and from the longest `proj_period` if it's not None.
Set `rebase` to True to make each projection start with 1, otherwise, each projection
will consist of original `close` values during the projected period. Use `start_value`
to replace 1 with another start value. It can also be a flexible array with elements per column.
If `start_value` is -1, will set it to the latest row in `close`.
Set `ffill` to True to forward fill NaN values, even if they are NaN in `close` itself.
Set `remove_empty` to True to remove projections that are either NaN or with only one element.
The index of each projection is still being tracked and will appear in the multi-index of the
returned DataFrame.
!!! note
As opposed to the Numba-compiled function, the returned DataFrame will have
projections stacked along columns rather than rows. Set `return_raw` to True
to return them in the original format.
"""
if close is None:
close = self.close
checks.assert_not_none(close, arg_name="close")
else:
close = self.wrapper.wrap(close, group_by=False)
if proj_start is None:
proj_start = 0
if isinstance(proj_start, int):
proj_start_use_index = False
index = None
else:
proj_start = dt.to_ns(dt.to_timedelta64(proj_start))
if isinstance(self.wrapper.index, pd.DatetimeIndex):
index = dt.to_ns(self.wrapper.index)
else:
freq = dt.to_ns(dt.to_timedelta64(self.wrapper.freq))
index = np.arange(self.wrapper.shape[0]) * freq
proj_start_use_index = True
if proj_period is not None:
if isinstance(proj_period, int):
proj_period_use_index = False
else:
proj_period = dt.to_ns(dt.to_timedelta64(proj_period))
if index is None:
if isinstance(self.wrapper.index, pd.DatetimeIndex):
index = dt.to_ns(self.wrapper.index)
else:
freq = dt.to_ns(dt.to_timedelta64(self.wrapper.freq))
index = np.arange(self.wrapper.shape[0]) * freq
proj_period_use_index = True
else:
proj_period_use_index = False
func = jit_reg.resolve_option(nb.map_ranges_to_projections_nb, jitted)
ridxs, projections = func(
to_2d_array(close),
self.get_field_arr("col"),
self.get_field_arr("start_idx"),
self.get_field_arr("end_idx"),
self.get_field_arr("status"),
index=index,
proj_start=proj_start,
proj_start_use_index=proj_start_use_index,
proj_period=proj_period,
proj_period_use_index=proj_period_use_index,
incl_end_idx=incl_end_idx,
extend=extend,
rebase=rebase,
start_value=to_1d_array(start_value),
ffill=ffill,
remove_empty=remove_empty,
)
if return_raw:
return ridxs, projections
projections = projections.T
wrapper = ArrayWrapper.from_obj(projections, freq=self.wrapper.freq)
if id_level is None:
id_level = pd.Index(self.id_arr, name="range_id")
elif isinstance(id_level, str):
mapping = self.get_field_mapping(id_level)
if isinstance(mapping, str) and mapping == "index":
id_level = self.get_map_field_to_index(id_level).rename(id_level)
else:
id_level = pd.Index(self.get_apply_mapping_arr(id_level), name=id_level)
else:
if not isinstance(id_level, pd.Index):
id_level = pd.Index(id_level, name="range_id")
if start_index is None:
start_index = close.index[-1]
wrap_kwargs = merge_dicts(
dict(
index=pd.date_range(
start=start_index,
periods=projections.shape[0],
freq=self.wrapper.freq,
),
columns=stack_indexes(
self.wrapper.columns[self.col_arr[ridxs]],
id_level[ridxs],
**resolve_dict(clean_index_kwargs),
),
),
wrap_kwargs,
)
return wrapper.wrap(projections, **wrap_kwargs)
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Ranges.stats`.
Merges `vectorbtpro.records.base.Records.stats_defaults` and
`stats` from `vectorbtpro._settings.ranges`."""
from vectorbtpro._settings import settings
ranges_stats_cfg = settings["ranges"]["stats"]
return merge_dicts(Records.stats_defaults.__get__(self), ranges_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
total_records=dict(title="Total Records", calc_func="count", tags="records"),
coverage=dict(
title="Coverage",
calc_func="coverage",
overlapping=False,
tags=["ranges", "coverage"],
),
overlap_coverage=dict(
title="Overlap Coverage",
calc_func="coverage",
overlapping=True,
tags=["ranges", "coverage"],
),
duration=dict(
title="Duration",
calc_func="duration.describe",
post_calc_func=lambda self, out, settings: {
"Min": out.loc["min"],
"Median": out.loc["50%"],
"Max": out.loc["max"],
},
apply_to_timedelta=True,
tags=["ranges", "duration"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot_projections(
self,
column: tp.Optional[tp.Label] = None,
min_duration: tp.Union[str, int, tp.FrequencyLike] = None,
max_duration: tp.Union[str, int, tp.FrequencyLike] = None,
last_n: tp.Optional[int] = None,
top_n: tp.Optional[int] = None,
random_n: tp.Optional[int] = None,
seed: tp.Optional[int] = None,
proj_start: tp.Union[None, str, int, tp.FrequencyLike] = "current_or_0",
proj_period: tp.Union[None, str, int, tp.FrequencyLike] = "max",
incl_end_idx: bool = True,
extend: bool = False,
ffill: bool = False,
plot_past_period: tp.Union[None, str, int, tp.FrequencyLike] = "current_or_proj_period",
plot_ohlc: tp.Union[bool, tp.Frame] = True,
plot_close: tp.Union[bool, tp.Series] = True,
plot_projections: bool = True,
plot_bands: bool = True,
plot_lower: tp.Union[bool, str, tp.Callable] = True,
plot_middle: tp.Union[bool, str, tp.Callable] = True,
plot_upper: tp.Union[bool, str, tp.Callable] = True,
plot_aux_middle: tp.Union[bool, str, tp.Callable] = True,
plot_fill: bool = True,
colorize: bool = True,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
close_trace_kwargs: tp.KwargsLike = None,
projection_trace_kwargs: tp.KwargsLike = None,
lower_trace_kwargs: tp.KwargsLike = None,
middle_trace_kwargs: tp.KwargsLike = None,
upper_trace_kwargs: tp.KwargsLike = None,
aux_middle_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot projections.
Combines generation of projections using `Ranges.get_projections` and
their plotting using `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
Args:
column (str): Name of the column to plot.
min_duration (str, int, or frequency_like): Filter range records by minimum duration.
max_duration (str, int, or frequency_like): Filter range records by maximum duration.
last_n (int): Select last N range records.
top_n (int): Select top N range records by maximum duration.
random_n (int): Select N range records randomly.
seed (int): Seed to make output deterministic.
proj_start (str, int, or frequency_like): See `Ranges.get_projections`.
Allows an additional option "current_or_{value}", which sets `proj_start` to
the duration of the current open range, and to the specified value if there is no open range.
proj_period (str, int, or frequency_like): See `Ranges.get_projections`.
Allows additional options "current_or_{option}", "mean", "min", "max", "median", or
a percentage such as "50%" representing a quantile. All of those options are based
on the duration of all the closed ranges filtered by the arguments above.
incl_end_idx (bool): See `Ranges.get_projections`.
extend (bool): See `Ranges.get_projections`.
ffill (bool): See `Ranges.get_projections`.
plot_past_period (str, int, or frequency_like): Past period to plot.
Allows the same options as `proj_period` plus "proj_period" and "current_or_proj_period".
plot_ohlc (bool or DataFrame): Whether to plot OHLC.
plot_close (bool or Series): Whether to plot close.
plot_projections (bool): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
plot_bands (bool): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
plot_lower (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
plot_middle (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
plot_upper (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
plot_aux_middle (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
plot_fill (bool): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
colorize (bool, str, or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Ranges.close`.
projection_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for projections.
lower_trace_kwargs (dict): Keyword arguments passed to `plotly.plotly.graph_objects.Scatter` for lower band.
middle_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for middle band.
upper_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for upper band.
aux_middle_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for auxiliary middle band.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> price = pd.Series(
... [11, 12, 13, 14, 11, 12, 13, 12, 11, 12],
... index=pd.date_range("2020", periods=10),
... )
>>> vbt.Ranges.from_array(
... price >= 12,
... attach_as_close=False,
... close=price,
... ).plot_projections(
... proj_start=0,
... proj_period=4,
... extend=True,
... plot_past_period=None
... ).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
self_col_open = self_col.status_open
self_col = self_col.status_closed
if proj_start is not None:
if isinstance(proj_start, str) and proj_start.startswith("current_or_"):
if self_col_open.count() > 0:
if self_col_open.count() > 1:
raise ValueError("Only one open range is allowed")
proj_start = int(self_col_open.duration.values[0])
else:
proj_start = proj_start.replace("current_or_", "")
if proj_start.isnumeric():
proj_start = int(proj_start)
if proj_start != 0:
self_col = self_col.filter_min_duration(proj_start, real=True)
if min_duration is not None:
self_col = self_col.filter_min_duration(min_duration, real=True)
if max_duration is not None:
self_col = self_col.filter_max_duration(max_duration, real=True)
if last_n is not None:
self_col = self_col.last_n(last_n)
if top_n is not None:
self_col = self_col.apply_mask(self_col.duration.top_n_mask(top_n))
if random_n is not None:
self_col = self_col.random_n(random_n, seed=seed)
if self_col.count() == 0:
warn("No ranges to plot. Relax the requirements.")
if ohlc_trace_kwargs is None:
ohlc_trace_kwargs = {}
if close_trace_kwargs is None:
close_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"),
close_trace_kwargs,
)
if isinstance(plot_ohlc, bool):
if (
self_col._open is not None
and self_col._high is not None
and self_col._low is not None
and self_col._close is not None
):
ohlc = pd.DataFrame(
{
"open": self_col.open,
"high": self_col.high,
"low": self_col.low,
"close": self_col.close,
}
)
else:
ohlc = None
else:
ohlc = plot_ohlc
plot_ohlc = True
if isinstance(plot_close, bool):
if ohlc is not None:
close = ohlc.vbt.ohlcv.close
else:
close = self_col.close
else:
close = plot_close
plot_close = True
if close is None:
raise ValueError("Close cannot be None")
# Resolve windows
def _resolve_period(period):
if self_col.count() == 0:
period = None
if period is not None:
if isinstance(period, str):
period = period.lower().replace(" ", "")
if period == "median":
period = "50%"
if "%" in period:
period = int(
np.quantile(
self_col.duration.values,
float(period.replace("%", "")) / 100,
)
)
elif period.startswith("current_or_"):
if self_col_open.count() > 0:
if self_col_open.count() > 1:
raise ValueError("Only one open range is allowed")
period = int(self_col_open.duration.values[0])
else:
period = period.replace("current_or_", "")
return _resolve_period(period)
elif period == "mean":
period = int(np.mean(self_col.duration.values))
elif period == "min":
period = int(np.min(self_col.duration.values))
elif period == "max":
period = int(np.max(self_col.duration.values))
return period
proj_period = _resolve_period(proj_period)
if isinstance(proj_period, int) and proj_period == 0:
warn("Projection period is zero. Setting to maximum.")
proj_period = int(np.max(self_col.duration.values))
if plot_past_period is not None and isinstance(plot_past_period, str):
plot_past_period = plot_past_period.lower().replace(" ", "")
if plot_past_period == "proj_period":
plot_past_period = proj_period
elif plot_past_period == "current_or_proj_period":
if self_col_open.count() > 0:
if self_col_open.count() > 1:
raise ValueError("Only one open range is allowed")
plot_past_period = int(self_col_open.duration.values[0])
else:
plot_past_period = proj_period
plot_past_period = _resolve_period(plot_past_period)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
# Plot OHLC/close
if plot_ohlc and ohlc is not None:
if plot_past_period is not None:
if isinstance(plot_past_period, int):
_ohlc = ohlc.iloc[-plot_past_period:]
else:
plot_past_period = dt.to_timedelta(plot_past_period)
_ohlc = ohlc[ohlc.index > ohlc.index[-1] - plot_past_period]
else:
_ohlc = ohlc
if _ohlc.size > 0:
if "opacity" not in ohlc_trace_kwargs:
ohlc_trace_kwargs["opacity"] = 0.5
fig = _ohlc.vbt.ohlcv.plot(
ohlc_type=ohlc_type,
plot_volume=False,
ohlc_trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
elif plot_close:
if plot_past_period is not None:
if isinstance(plot_past_period, int):
_close = close.iloc[-plot_past_period:]
else:
plot_past_period = dt.to_timedelta(plot_past_period)
_close = close[close.index > close.index[-1] - plot_past_period]
else:
_close = close
if _close.size > 0:
fig = _close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if self_col.count() > 0:
# Get projections
projections = self_col.get_projections(
close=close,
proj_start=proj_start,
proj_period=proj_period,
incl_end_idx=incl_end_idx,
extend=extend,
rebase=True,
start_value=close.iloc[-1],
ffill=ffill,
remove_empty=True,
return_raw=False,
)
if len(projections.columns) > 0:
# Plot projections
rename_levels = dict(range_id=self_col.get_field_title("id"))
fig = projections.vbt.plot_projections(
plot_projections=plot_projections,
plot_bands=plot_bands,
plot_lower=plot_lower,
plot_middle=plot_middle,
plot_upper=plot_upper,
plot_aux_middle=plot_aux_middle,
plot_fill=plot_fill,
colorize=colorize,
rename_levels=rename_levels,
projection_trace_kwargs=projection_trace_kwargs,
upper_trace_kwargs=upper_trace_kwargs,
middle_trace_kwargs=middle_trace_kwargs,
lower_trace_kwargs=lower_trace_kwargs,
aux_middle_trace_kwargs=aux_middle_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
def plot_shapes(
self,
column: tp.Optional[tp.Label] = None,
plot_ohlc: tp.Union[bool, tp.Frame] = True,
plot_close: tp.Union[bool, tp.Series] = True,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
close_trace_kwargs: tp.KwargsLike = None,
shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot range shapes.
Args:
column (str): Name of the column to plot.
plot_ohlc (bool or DataFrame): Whether to plot OHLC.
plot_close (bool or Series): Whether to plot close.
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Ranges.close`.
shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for shapes.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
xref (str): X coordinate axis.
yref (str): Y coordinate axis.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
* Plot zones colored by duration:
```pycon
>>> price = pd.Series(
... [1, 2, 1, 2, 3, 2, 1, 2, 3],
... index=pd.date_range("2020", periods=9),
... )
>>> def get_opacity(self_col, i):
... real_duration = self_col.get_real_duration().values
... return real_duration[i] / real_duration.max() * 0.5
>>> vbt.Ranges.from_array(price >= 2).plot_shapes(
... shape_kwargs=dict(fillcolor="teal", opacity=vbt.RepFunc(get_opacity))
... ).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from vectorbtpro.utils.figure import make_figure, get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
if ohlc_trace_kwargs is None:
ohlc_trace_kwargs = {}
if close_trace_kwargs is None:
close_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"),
close_trace_kwargs,
)
if shape_kwargs is None:
shape_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if isinstance(plot_ohlc, bool):
if (
self_col._open is not None
and self_col._high is not None
and self_col._low is not None
and self_col._close is not None
):
ohlc = pd.DataFrame(
{
"open": self_col.open,
"high": self_col.high,
"low": self_col.low,
"close": self_col.close,
}
)
else:
ohlc = None
else:
ohlc = plot_ohlc
plot_ohlc = True
if isinstance(plot_close, bool):
if ohlc is not None:
close = ohlc.vbt.ohlcv.close
else:
close = self_col.close
else:
close = plot_close
plot_close = True
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
x_domain = get_domain(yref, fig)
y_domain = get_domain(yref, fig)
# Plot OHLC/close
if plot_ohlc and ohlc is not None:
if "opacity" not in ohlc_trace_kwargs:
ohlc_trace_kwargs["opacity"] = 0.5
fig = ohlc.vbt.ohlcv.plot(
ohlc_type=ohlc_type,
plot_volume=False,
ohlc_trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
elif plot_close and close is not None:
fig = close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if self_col.count() > 0:
start_idx = self_col.get_map_field_to_index("start_idx", minus_one_to_zero=True)
end_idx = self_col.get_map_field_to_index("end_idx")
for i in range(len(self_col.values)):
start_index = start_idx[i]
end_index = end_idx[i]
_shape_kwargs = substitute_templates(
shape_kwargs,
context=dict(
self_col=self_col,
i=i,
record=self_col.values[i],
start_index=start_index,
end_index=end_index,
xref=xref,
yref=yref,
x_domain=x_domain,
y_domain=y_domain,
close=close,
ohlc=ohlc,
),
eval_id="shape_kwargs",
)
_shape_kwargs = merge_dicts(
dict(
type="rect",
xref=xref,
yref="paper",
x0=start_index,
y0=y_domain[0],
x1=end_index,
y1=y_domain[1],
fillcolor="gray",
opacity=0.15,
layer="below",
line_width=0,
),
_shape_kwargs,
)
fig.add_shape(**_shape_kwargs)
return fig
def plot(
self,
column: tp.Optional[tp.Label] = None,
top_n: tp.Optional[int] = None,
plot_ohlc: tp.Union[bool, tp.Frame] = True,
plot_close: tp.Union[bool, tp.Series] = True,
plot_markers: bool = True,
plot_zones: bool = True,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
close_trace_kwargs: tp.KwargsLike = None,
start_trace_kwargs: tp.KwargsLike = None,
end_trace_kwargs: tp.KwargsLike = None,
open_shape_kwargs: tp.KwargsLike = None,
closed_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
return_close: bool = False,
**layout_kwargs,
) -> tp.Union[tp.BaseFigure, tp.Tuple[tp.BaseFigure, tp.Series]]:
"""Plot ranges.
Args:
column (str): Name of the column to plot.
top_n (int): Filter top N range records by maximum duration.
plot_ohlc (bool or DataFrame): Whether to plot OHLC.
plot_close (bool or Series): Whether to plot close.
plot_markers (bool): Whether to plot markers.
plot_zones (bool): Whether to plot zones.
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Ranges.close`.
start_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for start values.
end_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for end values.
open_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for open zones.
closed_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for closed zones.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
xref (str): X coordinate axis.
yref (str): Y coordinate axis.
fig (Figure or FigureWidget): Figure to add traces to.
return_close (bool): Whether to return the close series along with the figure.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> price = pd.Series(
... [1, 2, 1, 2, 3, 2, 1, 2, 3],
... index=pd.date_range("2020", periods=9),
... )
>>> vbt.Ranges.from_array(price >= 2).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure, get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
if top_n is not None:
self_col = self_col.apply_mask(self_col.duration.top_n_mask(top_n))
if ohlc_trace_kwargs is None:
ohlc_trace_kwargs = {}
if close_trace_kwargs is None:
close_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"),
close_trace_kwargs,
)
if start_trace_kwargs is None:
start_trace_kwargs = {}
if end_trace_kwargs is None:
end_trace_kwargs = {}
if open_shape_kwargs is None:
open_shape_kwargs = {}
if closed_shape_kwargs is None:
closed_shape_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if isinstance(plot_ohlc, bool):
if (
self_col._open is not None
and self_col._high is not None
and self_col._low is not None
and self_col._close is not None
):
ohlc = pd.DataFrame(
{
"open": self_col.open,
"high": self_col.high,
"low": self_col.low,
"close": self_col.close,
}
)
else:
ohlc = None
else:
ohlc = plot_ohlc
plot_ohlc = True
if isinstance(plot_close, bool):
if ohlc is not None:
close = ohlc.vbt.ohlcv.close
else:
close = self_col.close
else:
close = plot_close
plot_close = True
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
y_domain = get_domain(yref, fig)
# Plot OHLC/close
plotting_ohlc = False
if plot_ohlc and ohlc is not None:
if "opacity" not in ohlc_trace_kwargs:
ohlc_trace_kwargs["opacity"] = 0.5
fig = ohlc.vbt.ohlcv.plot(
ohlc_type=ohlc_type,
plot_volume=False,
ohlc_trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
plotting_ohlc = True
elif plot_close and close is not None:
fig = close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if self_col.count() > 0:
# Extract information
start_idx = self_col.get_map_field_to_index("start_idx", minus_one_to_zero=True)
if plotting_ohlc and self_col.open is not None:
start_val = self_col.open.loc[start_idx]
elif close is not None:
start_val = close.loc[start_idx]
else:
start_val = np.full(len(start_idx), 0)
end_idx = self_col.get_map_field_to_index("end_idx")
if close is not None:
end_val = close.loc[end_idx]
else:
end_val = np.full(len(end_idx), 0)
status = self_col.get_field_arr("status")
if plot_markers:
# Plot start markers
start_customdata, start_hovertemplate = self_col.prepare_customdata(incl_fields=["id", "start_idx"])
_start_trace_kwargs = merge_dicts(
dict(
x=start_idx,
y=start_val,
mode="markers",
marker=dict(
symbol="diamond",
color=plotting_cfg["contrast_color_schema"]["blue"],
size=7,
line=dict(width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["blue"])),
),
name="Start",
customdata=start_customdata,
hovertemplate=start_hovertemplate,
),
start_trace_kwargs,
)
start_scatter = go.Scatter(**_start_trace_kwargs)
fig.add_trace(start_scatter, **add_trace_kwargs)
closed_mask = status == enums.RangeStatus.Closed
if closed_mask.any():
if plot_markers:
# Plot end markers
closed_end_customdata, closed_end_hovertemplate = self_col.prepare_customdata(mask=closed_mask)
_end_trace_kwargs = merge_dicts(
dict(
x=end_idx[closed_mask],
y=end_val[closed_mask],
mode="markers",
marker=dict(
symbol="diamond",
color=plotting_cfg["contrast_color_schema"]["green"],
size=7,
line=dict(
width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["green"])
),
),
name="Closed",
customdata=closed_end_customdata,
hovertemplate=closed_end_hovertemplate,
),
end_trace_kwargs,
)
closed_end_scatter = go.Scatter(**_end_trace_kwargs)
fig.add_trace(closed_end_scatter, **add_trace_kwargs)
open_mask = status == enums.RangeStatus.Open
if open_mask.any():
if plot_markers:
# Plot end markers
open_end_customdata, open_end_hovertemplate = self_col.prepare_customdata(
excl_fields=["end_idx"], mask=open_mask
)
_end_trace_kwargs = merge_dicts(
dict(
x=end_idx[open_mask],
y=end_val[open_mask],
mode="markers",
marker=dict(
symbol="diamond",
color=plotting_cfg["contrast_color_schema"]["orange"],
size=7,
line=dict(
width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["orange"])
),
),
name="Open",
customdata=open_end_customdata,
hovertemplate=open_end_hovertemplate,
),
end_trace_kwargs,
)
open_end_scatter = go.Scatter(**_end_trace_kwargs)
fig.add_trace(open_end_scatter, **add_trace_kwargs)
if plot_zones:
# Plot closed range zones
self_col.status_closed.plot_shapes(
plot_ohlc=False,
plot_close=False,
shape_kwargs=merge_dicts(
dict(fillcolor=plotting_cfg["contrast_color_schema"]["green"]),
closed_shape_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
# Plot open range zones
self_col.status_open.plot_shapes(
plot_ohlc=False,
plot_close=False,
shape_kwargs=merge_dicts(
dict(fillcolor=plotting_cfg["contrast_color_schema"]["orange"]),
open_shape_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
if return_close:
return fig, close
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Ranges.plots`.
Merges `vectorbtpro.records.base.Records.plots_defaults` and
`plots` from `vectorbtpro._settings.ranges`."""
from vectorbtpro._settings import settings
ranges_plots_cfg = settings["ranges"]["plots"]
return merge_dicts(Records.plots_defaults.__get__(self), ranges_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot=dict(
title="Ranges",
check_is_not_grouped=True,
plot_func="plot",
tags="ranges",
)
),
)
@property
def subplots(self) -> Config:
return self._subplots
Ranges.override_field_config_doc(__pdoc__)
Ranges.override_metrics_doc(__pdoc__)
Ranges.override_subplots_doc(__pdoc__)
# ############# Pattern ranges ############# #
PatternRangesT = tp.TypeVar("PatternRangesT", bound="PatternRanges")
@define
class PSC(DefineMixin):
"""Class that represents a pattern search config.
Every field will be resolved into the format suitable for Numba."""
pattern: tp.Union[tp.ArrayLike] = define.required_field()
"""Flexible pattern array.
Can be smaller or bigger than the source array; in such a case, the values of the smaller array
will be "stretched" by interpolation of the type in `PSC.interp_mode`."""
window: tp.Optional[int] = define.optional_field()
"""Minimum window.
Defaults to the length of `PSC.pattern`."""
max_window: tp.Optional[int] = define.optional_field()
"""Maximum window (including)."""
row_select_prob: tp.Union[float] = define.optional_field()
"""Row selection probability."""
window_select_prob: tp.Union[float] = define.optional_field()
"""Window selection probability."""
roll_forward: tp.Union[bool] = define.optional_field()
"""Whether to roll windows to the left of the current row, otherwise to the right."""
interp_mode: tp.Union[int, str] = define.optional_field()
"""Interpolation mode. See `vectorbtpro.generic.enums.InterpMode`."""
rescale_mode: tp.Union[int, str] = define.optional_field()
"""Rescaling mode. See `vectorbtpro.generic.enums.RescaleMode`."""
vmin: tp.Union[float] = define.optional_field()
"""Minimum value of any window. Should only be used when the array has fixed bounds.
Used in rescaling using `RescaleMode.MinMax` and checking against `PSC.min_pct_change` and `PSC.max_pct_change`.
If `np.nan`, gets calculated dynamically."""
vmax: tp.Union[float] = define.optional_field()
"""Maximum value of any window. Should only be used when the array has fixed bounds.
Used in rescaling using `RescaleMode.MinMax` and checking against `PSC.min_pct_change` and `PSC.max_pct_change`.
If `np.nan`, gets calculated dynamically."""
pmin: tp.Union[float] = define.optional_field()
"""Value to be considered as the minimum of `PSC.pattern`.
Used in rescaling using `RescaleMode.MinMax` and calculating the maximum distance at each point
if `PSC.max_error_as_maxdist` is disabled.
If `np.nan`, gets calculated dynamically."""
pmax: tp.Union[float] = define.optional_field()
"""Value to be considered as the maximum of `PSC.pattern`.
Used in rescaling using `RescaleMode.MinMax` and calculating the maximum distance at each point
if `PSC.max_error_as_maxdist` is disabled.
If `np.nan`, gets calculated dynamically."""
invert: tp.Union[bool] = define.optional_field()
"""Whether to invert the pattern vertically."""
error_type: tp.Union[int, str] = define.optional_field()
"""Error type. See `vectorbtpro.generic.enums.ErrorType`."""
distance_measure: tp.Union[int, str] = define.optional_field()
"""Distance measure. See `vectorbtpro.generic.enums.DistanceMeasure`."""
max_error: tp.Union[tp.ArrayLike] = define.optional_field()
"""Maximum error at each point. Can be provided as a flexible array.
If `max_error` is an array, it must be of the same size as the pattern array.
It also should be provided within the same scale as the pattern."""
max_error_interp_mode: tp.Union[None, int, str] = define.optional_field()
"""Interpolation mode for `PSC.max_error`. See `vectorbtpro.generic.enums.InterpMode`.
If None, defaults to `PSC.interp_mode`."""
max_error_as_maxdist: tp.Union[bool] = define.optional_field()
"""Whether `PSC.max_error` should be used as the maximum distance at each point.
If False, crossing `PSC.max_error` will set the distance to the maximum distance
based on `PSC.pmin`, `PSC.pmax`, and the pattern value at that point.
If True and any of the points in a window is `np.nan`, the point will be skipped."""
max_error_strict: tp.Union[bool] = define.optional_field()
"""Whether crossing `PSC.max_error` even once should yield the similarity of `np.nan`."""
min_pct_change: tp.Union[float] = define.optional_field()
"""Minimum percentage change of the window to stay a candidate for search.
If any window doesn't cross this mark, its similarity becomes `np.nan`."""
max_pct_change: tp.Union[float] = define.optional_field()
"""Maximum percentage change of the window to stay a candidate for search.
If any window crosses this mark, its similarity becomes `np.nan`."""
min_similarity: tp.Union[float] = define.optional_field()
"""Minimum similarity.
If any window doesn't cross this mark, its similarity becomes `np.nan`."""
minp: tp.Optional[int] = define.optional_field()
"""Minimum number of observations in price window required to have a value."""
overlap_mode: tp.Union[int, str] = define.optional_field()
"""Overlapping mode. See `vectorbtpro.generic.enums.OverlapMode`."""
max_records: tp.Optional[int] = define.optional_field()
"""Maximum number of records expected to be filled.
Set to avoid creating empty arrays larger than needed."""
name: tp.Optional[str] = define.field(default=None)
"""Name of the config."""
def __eq__(self, other):
return checks.is_deep_equal(self, other)
def __hash__(self):
dct = self.asdict()
if isinstance(dct["pattern"], np.ndarray):
dct["pattern"] = tuple(dct["pattern"])
else:
dct["pattern"] = (dct["pattern"],)
if isinstance(dct["max_error"], np.ndarray):
dct["max_error"] = tuple(dct["max_error"])
else:
dct["max_error"] = (dct["max_error"],)
return hash(tuple(dct.items()))
pattern_ranges_field_config = ReadonlyConfig(
dict(
dtype=enums.pattern_range_dt,
settings=dict(
id=dict(title="Pattern Range Id"),
similarity=dict(title="Similarity"),
),
)
)
"""_"""
__pdoc__[
"pattern_ranges_field_config"
] = f"""Field config for `PatternRanges`.
```python
{pattern_ranges_field_config.prettify()}
```
"""
@attach_fields
@override_field_config(pattern_ranges_field_config)
class PatternRanges(Ranges):
"""Extends `Ranges` for working with range records generated from pattern search."""
@property
def field_config(self) -> Config:
return self._field_config
@classmethod
def resolve_search_config(cls, search_config: tp.Union[None, dict, PSC] = None, **kwargs) -> PSC:
"""Resolve search config for `PatternRanges.from_pattern_search`.
Converts array-like objects into arrays and enums into integers."""
if search_config is None:
search_config = dict()
if isinstance(search_config, dict):
search_config = PSC(**search_config)
search_config = search_config.asdict()
defaults = {}
for k, v in get_func_kwargs(cls.from_pattern_search).items():
if k in search_config:
defaults[k] = v
defaults = merge_dicts(defaults, kwargs)
for k, v in search_config.items():
if v is MISSING:
v = defaults[k]
if k == "pattern":
if v is None:
raise ValueError("Must provide pattern")
v = to_1d_array(v)
elif k == "max_error":
v = to_1d_array(v)
elif k == "interp_mode":
v = map_enum_fields(v, enums.InterpMode)
elif k == "rescale_mode":
v = map_enum_fields(v, enums.RescaleMode)
elif k == "error_type":
v = map_enum_fields(v, enums.ErrorType)
elif k == "distance_measure":
v = map_enum_fields(v, enums.DistanceMeasure)
elif k == "max_error_interp_mode":
if v is None:
v = search_config["interp_mode"]
else:
v = map_enum_fields(v, enums.InterpMode)
elif k == "overlap_mode":
v = map_enum_fields(v, enums.OverlapMode)
search_config[k] = v
return PSC(**search_config)
@classmethod
def from_pattern_search(
cls: tp.Type[PatternRangesT],
arr: tp.ArrayLike,
pattern: tp.Union[Param, tp.ArrayLike] = None,
window: tp.Union[Param, None, int] = None,
max_window: tp.Union[Param, None, int] = None,
row_select_prob: tp.Union[Param, float] = 1.0,
window_select_prob: tp.Union[Param, float] = 1.0,
roll_forward: tp.Union[Param, bool] = False,
interp_mode: tp.Union[Param, int, str] = "mixed",
rescale_mode: tp.Union[Param, int, str] = "minmax",
vmin: tp.Union[Param, float] = np.nan,
vmax: tp.Union[Param, float] = np.nan,
pmin: tp.Union[Param, float] = np.nan,
pmax: tp.Union[Param, float] = np.nan,
invert: bool = False,
error_type: tp.Union[Param, int, str] = "absolute",
distance_measure: tp.Union[Param, int, str] = "mae",
max_error: tp.Union[Param, tp.ArrayLike] = np.nan,
max_error_interp_mode: tp.Union[Param, None, int, str] = None,
max_error_as_maxdist: tp.Union[Param, bool] = False,
max_error_strict: tp.Union[Param, bool] = False,
min_pct_change: tp.Union[Param, float] = np.nan,
max_pct_change: tp.Union[Param, float] = np.nan,
min_similarity: tp.Union[Param, float] = 0.85,
minp: tp.Union[Param, None, int] = None,
overlap_mode: tp.Union[Param, int, str] = "disallow",
max_records: tp.Union[Param, None, int] = None,
random_subset: tp.Optional[int] = None,
seed: tp.Optional[int] = None,
search_configs: tp.Optional[tp.Sequence[tp.MaybeSequence[PSC]]] = None,
jitted: tp.JittedOption = None,
execute_kwargs: tp.KwargsLike = None,
attach_as_close: bool = True,
clean_index_kwargs: tp.KwargsLike = None,
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> PatternRangesT:
"""Build `PatternRanges` from all occurrences of a pattern in an array.
Searches for parameters of the type `vectorbtpro.utils.params.Param`, and if found, broadcasts
and combines them using `vectorbtpro.utils.params.combine_params`. Then, converts them
into a list of search configurations. If none of such parameters was found among the passed arguments,
builds one search configuration using the passed arguments. If `search_configs` is not None, uses it
instead. In all cases, it uses the defaults defined in the signature of this method to augment
search configurations. For example, passing `min_similarity` of 95% will use it in all search
configurations except where it was explicitly overridden.
Argument `search_configs` must be provided as a sequence of `PSC` instances.
If any element is a list of `PSC` instances itself, it will be used per column in `arr`,
otherwise per entire `arr`. Each configuration will be resolved using `PatternRanges.resolve_search_config`
to prepare arguments for the use in Numba.
After all the search configurations have been resolved, uses `vectorbtpro.utils.execution.execute`
to loop over each configuration and execute it using `vectorbtpro.generic.nb.records.find_pattern_1d_nb`.
The results are then concatenated into a single records array and wrapped with `PatternRanges`.
If `attach_as_close` is True, will attach `arr` as `close`.
`**kwargs` will be passed to `PatternRanges.__init__`."""
if seed is not None:
set_seed(seed)
if clean_index_kwargs is None:
clean_index_kwargs = {}
arr = to_pd_array(arr)
arr_2d = to_2d_array(arr)
arr_wrapper = ArrayWrapper.from_obj(arr)
psc_keys = [a.name for a in PSC.fields if a.name != "name"]
method_locals = locals()
method_locals = {k: v for k, v in method_locals.items() if k in psc_keys}
# Flatten search configs
flat_search_configs = []
psc_names = []
psc_names_none = True
n_configs = 0
if search_configs is not None:
for maybe_search_config in search_configs:
if isinstance(maybe_search_config, dict):
maybe_search_config = PSC(**maybe_search_config)
if isinstance(maybe_search_config, PSC):
for col in range(arr_2d.shape[1]):
flat_search_configs.append(maybe_search_config)
if maybe_search_config.name is not None:
psc_names.append(maybe_search_config.name)
psc_names_none = False
else:
psc_names.append(n_configs)
n_configs += 1
else:
if len(maybe_search_config) != arr_2d.shape[1]:
raise ValueError("Sub-list with PSC instances must match the number of columns")
for col, search_config in enumerate(maybe_search_config):
if isinstance(search_config, dict):
search_config = PSC(**search_config)
flat_search_configs.append(search_config)
if search_config.name is not None:
psc_names.append(search_config.name)
psc_names_none = False
else:
psc_names.append(n_configs)
n_configs += 1
# Combine parameters
param_dct = {}
for k, v in method_locals.items():
if k in psc_keys and isinstance(v, Param):
param_dct[k] = v
param_columns = None
if len(param_dct) > 0:
param_product, param_columns = combine_params(
param_dct,
random_subset=random_subset,
clean_index_kwargs=clean_index_kwargs,
)
if len(flat_search_configs) == 0:
flat_search_configs = []
for i in range(len(param_columns)):
search_config = dict()
for k, v in param_product.items():
search_config[k] = v[i]
for col in range(arr_2d.shape[1]):
flat_search_configs.append(PSC(**search_config))
else:
new_flat_search_configs = []
for i in range(len(param_columns)):
for search_config in flat_search_configs:
new_search_config = dict()
for k, v in search_config.asdict().items():
if v is not MISSING:
if k in param_product:
raise ValueError(f"Parameter '{k}' is re-defined in a search configuration")
new_search_config[k] = v
if k in param_product:
new_search_config[k] = param_product[k][i]
new_flat_search_configs.append(PSC(**new_search_config))
flat_search_configs = new_flat_search_configs
# Create config from arguments if empty
if len(flat_search_configs) == 0:
single_group = True
for col in range(arr_2d.shape[1]):
flat_search_configs.append(PSC())
else:
single_group = False
# Prepare function and arguments
tasks = []
func = jit_reg.resolve_option(nb.find_pattern_1d_nb, jitted)
def_func_kwargs = get_func_kwargs(func)
new_search_configs = []
for c in range(len(flat_search_configs)):
func_kwargs = {
"col": c,
"arr": arr_2d[:, c % arr_2d.shape[1]],
}
new_search_config = cls.resolve_search_config(flat_search_configs[c], **method_locals)
for k, v in new_search_config.asdict().items():
if k == "name":
continue
if isinstance(v, Param):
raise TypeError(f"Cannot use Param inside search configs")
if k in def_func_kwargs:
if v is not def_func_kwargs[k]:
func_kwargs[k] = v
else:
func_kwargs[k] = v
tasks.append(Task(func, **func_kwargs))
new_search_configs.append(new_search_config)
# Build column hierarchy
n_config_params = len(psc_names) // arr_2d.shape[1]
if param_columns is not None:
if n_config_params == 0 or (n_config_params == 1 and psc_names_none):
new_columns = combine_indexes((param_columns, arr_wrapper.columns), **clean_index_kwargs)
else:
search_config_index = pd.Index(psc_names, name="search_config")
base_columns = stack_indexes(
(search_config_index, tile_index(arr_wrapper.columns, n_config_params)),
**clean_index_kwargs,
)
new_columns = combine_indexes((param_columns, base_columns), **clean_index_kwargs)
else:
if n_config_params == 0 or (n_config_params == 1 and psc_names_none):
new_columns = arr_wrapper.columns
else:
search_config_index = pd.Index(psc_names, name="search_config")
new_columns = stack_indexes(
(search_config_index, tile_index(arr_wrapper.columns, n_config_params)),
**clean_index_kwargs,
)
# Execute each configuration
execute_kwargs = merge_dicts(dict(show_progress=False if single_group else None), execute_kwargs)
result_list = execute(tasks, keys=new_columns, **execute_kwargs)
records_arr = np.concatenate(result_list)
# Wrap with class
wrapper = ArrayWrapper(
**merge_dicts(
dict(
index=arr_wrapper.index,
columns=new_columns,
),
wrapper_kwargs,
)
)
if attach_as_close and "close" not in kwargs:
kwargs["close"] = arr
if "open" in kwargs and kwargs["open"] is not None:
kwargs["open"] = to_2d_array(kwargs["open"])
kwargs["open"] = tile(kwargs["open"], len(wrapper.columns) // kwargs["open"].shape[1])
if "high" in kwargs and kwargs["high"] is not None:
kwargs["high"] = to_2d_array(kwargs["high"])
kwargs["high"] = tile(kwargs["high"], len(wrapper.columns) // kwargs["high"].shape[1])
if "low" in kwargs and kwargs["low"] is not None:
kwargs["low"] = to_2d_array(kwargs["low"])
kwargs["low"] = tile(kwargs["low"], len(wrapper.columns) // kwargs["low"].shape[1])
if "close" in kwargs and kwargs["close"] is not None:
kwargs["close"] = to_2d_array(kwargs["close"])
kwargs["close"] = tile(kwargs["close"], len(wrapper.columns) // kwargs["close"].shape[1])
return cls(wrapper, records_arr, new_search_configs, **kwargs)
def with_delta(self, *args, **kwargs):
"""Pass self to `Ranges.from_delta` but with the index set to the last index."""
if "idx_field_or_arr" not in kwargs:
kwargs["idx_field_or_arr"] = self.last_idx.values
return Ranges.from_delta(self, *args, **kwargs)
@classmethod
def resolve_row_stack_kwargs(
cls: tp.Type[PatternRangesT],
*objs: tp.MaybeTuple[PatternRangesT],
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `PatternRanges` after stacking along columns."""
kwargs = Ranges.resolve_row_stack_kwargs(*objs, **kwargs)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, PatternRanges):
raise TypeError("Each object to be merged must be an instance of PatternRanges")
new_search_configs = []
for obj in objs:
if len(obj.search_configs) == 1:
new_search_configs.append(obj.search_configs * len(kwargs["wrapper"].columns))
else:
new_search_configs.append(obj.search_configs)
if len(new_search_configs) >= 2:
if new_search_configs[-1] != new_search_configs[0]:
raise ValueError(f"Objects to be merged must have compatible PSC instances. Pass to override.")
kwargs["search_configs"] = new_search_configs[0]
return kwargs
@classmethod
def resolve_column_stack_kwargs(
cls: tp.Type[PatternRangesT],
*objs: tp.MaybeTuple[PatternRangesT],
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `PatternRanges` after stacking along columns."""
kwargs = Ranges.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs.pop("reindex_kwargs", None)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, PatternRanges):
raise TypeError("Each object to be merged must be an instance of PatternRanges")
kwargs["search_configs"] = [search_config for obj in objs for search_config in obj.search_configs]
return kwargs
def __init__(
self,
wrapper: ArrayWrapper,
records_arr: tp.RecordArray,
search_configs: tp.List[PSC],
**kwargs,
) -> None:
Ranges.__init__(
self,
wrapper,
records_arr,
search_configs=search_configs,
**kwargs,
)
self._search_configs = search_configs
def indexing_func(self: PatternRangesT, *args, ranges_meta: tp.DictLike = None, **kwargs) -> PatternRangesT:
"""Perform indexing on `PatternRanges`."""
if ranges_meta is None:
ranges_meta = Ranges.indexing_func_meta(self, *args, **kwargs)
col_idxs = ranges_meta["wrapper_meta"]["col_idxs"]
if not isinstance(col_idxs, slice):
col_idxs = to_1d_array(col_idxs)
col_idxs = np.arange(self.wrapper.shape_2d[1])[col_idxs]
new_search_configs = []
for i in col_idxs:
new_search_configs.append(self.search_configs[i])
return self.replace(
wrapper=ranges_meta["wrapper_meta"]["new_wrapper"],
records_arr=ranges_meta["new_records_arr"],
search_configs=new_search_configs,
open=ranges_meta["open"],
high=ranges_meta["high"],
low=ranges_meta["low"],
close=ranges_meta["close"],
)
@property
def search_configs(self) -> tp.List[PSC]:
"""List of `PSC` instances, one per column."""
return self._search_configs
# ############# Stats ############# #
_metrics: tp.ClassVar[Config] = HybridConfig(
{
**Ranges.metrics,
"similarity": dict(
title="Similarity",
calc_func="similarity.describe",
post_calc_func=lambda self, out, settings: {
"Min": out.loc["min"],
"Median": out.loc["50%"],
"Max": out.loc["max"],
},
tags=["pattern_ranges", "similarity"],
),
}
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plots ############# #
def plot(
self,
column: tp.Optional[tp.Label] = None,
top_n: tp.Optional[int] = None,
fit_ranges: tp.Union[bool, tp.MaybeSequence[int]] = False,
plot_patterns: bool = True,
plot_max_error: bool = False,
fill_distance: bool = True,
pattern_trace_kwargs: tp.KwargsLike = None,
lower_max_error_trace_kwargs: tp.KwargsLike = None,
upper_max_error_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot pattern ranges.
Based on `Ranges.plot` and `vectorbtpro.generic.accessors.GenericSRAccessor.plot_pattern`.
Args:
column (str): Name of the column to plot.
top_n (int): Filter top N range records by maximum duration.
fit_ranges (bool, int, or sequence of int): Whether or which range records to fit.
True to fit to all range records, integer or a sequence of such to fit to specific range records.
plot_patterns (bool or array_like): Whether to plot `PSC.pattern`.
plot_max_error (array_like): Whether to plot `PSC.max_error`.
fill_distance (bool): Whether to fill the space between close and pattern.
Visible for every interpolation mode except discrete.
pattern_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for pattern.
lower_max_error_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for lower max error.
upper_max_error_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for upper max error.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
xref (str): X coordinate axis.
yref (str): Y coordinate axis.
fig (Figure or FigureWidget): Figure to add traces to.
**kwargs: Keyword arguments passed to `Ranges.plot`.
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
if top_n is not None:
self_col = self_col.apply_mask(self_col.duration.top_n_mask(top_n))
search_config = self_col.search_configs[0]
if isinstance(fit_ranges, bool) and not fit_ranges:
fit_ranges = None
if fit_ranges is not None:
if fit_ranges is True:
self_col = self_col.iloc[self_col.values["start_idx"][0] : self_col.values["end_idx"][-1] + 1]
elif checks.is_int(fit_ranges):
self_col = self_col.apply_mask(self_col.id_arr == fit_ranges)
self_col = self_col.iloc[self_col.values["start_idx"][0] : self_col.values["end_idx"][0] + 1]
else:
self_col = self_col.apply_mask(np.isin(self_col.id_arr, fit_ranges))
self_col = self_col.iloc[self_col.values["start_idx"][0] : self_col.values["end_idx"][0] + 1]
if pattern_trace_kwargs is None:
pattern_trace_kwargs = {}
if lower_max_error_trace_kwargs is None:
lower_max_error_trace_kwargs = {}
if upper_max_error_trace_kwargs is None:
upper_max_error_trace_kwargs = {}
open_shape_kwargs = merge_dicts(
dict(fillcolor=plotting_cfg["contrast_color_schema"]["blue"]),
kwargs.pop("open_shape_kwargs", None),
)
closed_shape_kwargs = merge_dicts(
dict(fillcolor=plotting_cfg["contrast_color_schema"]["blue"]),
kwargs.pop("closed_shape_kwargs", None),
)
fig, close = Ranges.plot(
self_col,
return_close=True,
open_shape_kwargs=open_shape_kwargs,
closed_shape_kwargs=closed_shape_kwargs,
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
**kwargs,
)
if self_col.count() > 0:
# Extract information
start_idx = self_col.get_map_field_to_index("start_idx", minus_one_to_zero=True)
end_idx = self_col.get_map_field_to_index("end_idx")
status = self_col.get_field_arr("status")
if plot_patterns:
# Plot pattern
for r in range(len(start_idx)):
_start_idx = start_idx[r]
_end_idx = end_idx[r]
if close is None:
raise ValueError("Must provide close to overlay patterns")
arr_sr = close.loc[_start_idx:_end_idx]
if status[r] == enums.RangeStatus.Closed:
arr_sr = arr_sr.iloc[:-1]
if fill_distance:
obj_trace_kwargs = dict(
line=dict(color="rgba(0, 0, 0, 0)", width=0),
opacity=0,
hoverinfo="skip",
showlegend=False,
name=None,
)
else:
obj_trace_kwargs = None
_pattern_trace_kwargs = merge_dicts(
dict(
legendgroup="pattern",
showlegend=r == 0,
),
pattern_trace_kwargs,
)
_lower_max_error_trace_kwargs = merge_dicts(
dict(
legendgroup="max_error",
showlegend=r == 0,
),
lower_max_error_trace_kwargs,
)
_upper_max_error_trace_kwargs = merge_dicts(
dict(
legendgroup="max_error",
showlegend=False,
),
upper_max_error_trace_kwargs,
)
fig = arr_sr.vbt.plot_pattern(
pattern=search_config.pattern,
interp_mode=search_config.interp_mode,
rescale_mode=search_config.rescale_mode,
vmin=search_config.vmin,
vmax=search_config.vmax,
pmin=search_config.pmin,
pmax=search_config.pmax,
invert=search_config.invert,
error_type=search_config.error_type,
max_error=search_config.max_error if plot_max_error else np.nan,
max_error_interp_mode=search_config.max_error_interp_mode,
plot_obj=fill_distance,
fill_distance=fill_distance,
obj_trace_kwargs=obj_trace_kwargs,
pattern_trace_kwargs=_pattern_trace_kwargs,
lower_max_error_trace_kwargs=_lower_max_error_trace_kwargs,
upper_max_error_trace_kwargs=_upper_max_error_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
PatternRanges.override_field_config_doc(__pdoc__)
PatternRanges.override_metrics_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Mixin class for working with simulation ranges."""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.indexing import AutoIdxr
from vectorbtpro.base.reshaping import broadcast_array_to
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic import nb
from vectorbtpro.utils import checks
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.decorators import hybrid_method
SimRangeMixinT = tp.TypeVar("SimRangeMixinT", bound="SimRangeMixin")
class SimRangeMixin(Base):
"""Mixin class for working with simulation ranges.
Should be subclassed by a subclass of `vectorbtpro.base.wrapping.Wrapping`."""
@classmethod
def row_stack_sim_start(
cls,
new_wrapper: ArrayWrapper,
*objs: tp.MaybeTuple[SimRangeMixinT],
) -> tp.Optional[tp.ArrayLike]:
"""Row-stack simulation start."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if objs[0]._sim_start is not None:
new_sim_start = broadcast_array_to(objs[0]._sim_start, len(new_wrapper.columns))
else:
new_sim_start = None
for obj in objs[1:]:
if obj._sim_start is not None:
raise ValueError("Objects to be merged (except the first one) must have 'sim_start=None'")
return new_sim_start
@classmethod
def row_stack_sim_end(
cls,
new_wrapper: ArrayWrapper,
*objs: tp.MaybeTuple[SimRangeMixinT],
) -> tp.Optional[tp.ArrayLike]:
"""Row-stack simulation end."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if objs[-1]._sim_end is not None:
new_sim_end = len(new_wrapper.index) - len(objs[-1].wrapper.index) + objs[-1]._sim_end
new_sim_end = broadcast_array_to(new_sim_end, len(new_wrapper.columns))
else:
new_sim_end = None
for obj in objs[:-1]:
if obj._sim_end is not None:
raise ValueError("Objects to be merged (except the last one) must have 'sim_end=None'")
return new_sim_end
@classmethod
def column_stack_sim_start(
cls,
new_wrapper: ArrayWrapper,
*objs: tp.MaybeTuple[SimRangeMixinT],
) -> tp.Optional[tp.ArrayLike]:
"""Column-stack simulation start."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
stack_sim_start_objs = False
for obj in objs:
if obj._sim_start is not None:
stack_sim_start_objs = True
break
if stack_sim_start_objs:
obj_sim_starts = []
for obj in objs:
obj_sim_start = np.empty(len(obj._sim_start), dtype=int_)
for i in range(len(obj._sim_start)):
if obj._sim_start[i] == 0:
obj_sim_start[i] = 0
elif obj._sim_start[i] == len(obj.wrapper.index):
obj_sim_start[i] = len(new_wrapper.index)
else:
_obj_sim_start = new_wrapper.index.get_indexer([obj.wrapper.index[obj._sim_start[i]]])[0]
if _obj_sim_start == -1:
_obj_sim_start = 0
obj_sim_start[i] = _obj_sim_start
obj_sim_starts.append(obj_sim_start)
new_sim_start = new_wrapper.concat_arrs(*obj_sim_starts, wrap=False)
else:
new_sim_start = None
return new_sim_start
@classmethod
def column_stack_sim_end(
cls,
new_wrapper: ArrayWrapper,
*objs: tp.MaybeTuple[SimRangeMixinT],
) -> tp.Optional[tp.ArrayLike]:
"""Column-stack simulation end."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
stack_sim_end_objs = False
for obj in objs:
if obj._sim_end is not None:
stack_sim_end_objs = True
break
if stack_sim_end_objs:
obj_sim_ends = []
for obj in objs:
obj_sim_end = np.empty(len(obj._sim_end), dtype=int_)
for i in range(len(obj._sim_end)):
if obj._sim_end[i] == 0:
obj_sim_end[i] = 0
elif obj._sim_end[i] == len(obj.wrapper.index):
obj_sim_end[i] = len(new_wrapper.index)
else:
_obj_sim_end = new_wrapper.index.get_indexer([obj.wrapper.index[obj._sim_end[i]]])[0]
if _obj_sim_end == -1:
_obj_sim_end = 0
obj_sim_end[i] = _obj_sim_end
obj_sim_ends.append(obj_sim_end)
new_sim_end = new_wrapper.concat_arrs(*obj_sim_ends, wrap=False)
else:
new_sim_end = None
return new_sim_end
def __init__(
self,
sim_start: tp.Optional[tp.Array1d] = None,
sim_end: tp.Optional[tp.Array1d] = None,
) -> None:
sim_start = type(self).resolve_sim_start(sim_start=sim_start, wrapper=self.wrapper, group_by=False)
sim_end = type(self).resolve_sim_end(sim_end=sim_end, wrapper=self.wrapper, group_by=False)
self._sim_start = sim_start
self._sim_end = sim_end
def sim_start_indexing_func(self, wrapper_meta: dict) -> tp.Optional[tp.ArrayLike]:
"""Indexing function for simulation start."""
if self._sim_start is None:
new_sim_start = None
elif not wrapper_meta["rows_changed"]:
new_sim_start = self._sim_start
else:
if checks.is_int(wrapper_meta["row_idxs"]):
new_sim_start = self._sim_start - wrapper_meta["row_idxs"]
elif isinstance(wrapper_meta["row_idxs"], slice):
new_sim_start = self._sim_start - wrapper_meta["row_idxs"].start
else:
new_sim_start = self._sim_start - wrapper_meta["row_idxs"][0]
new_sim_start = np.clip(new_sim_start, 0, len(wrapper_meta["new_wrapper"].index))
return new_sim_start
def sim_end_indexing_func(self, wrapper_meta: dict) -> tp.Optional[tp.ArrayLike]:
"""Indexing function for simulation end."""
if self._sim_end is None:
new_sim_end = None
elif not wrapper_meta["rows_changed"]:
new_sim_end = self._sim_end
else:
if checks.is_int(wrapper_meta["row_idxs"]):
new_sim_end = self._sim_end - wrapper_meta["row_idxs"]
elif isinstance(wrapper_meta["row_idxs"], slice):
new_sim_end = self._sim_end - wrapper_meta["row_idxs"].start
else:
new_sim_end = self._sim_end - wrapper_meta["row_idxs"][0]
new_sim_end = np.clip(new_sim_end, 0, len(wrapper_meta["new_wrapper"].index))
return new_sim_end
def resample_sim_start(self, new_wrapper: ArrayWrapper) -> tp.Optional[tp.ArrayLike]:
"""Resample simulation start."""
if self._sim_start is not None:
new_sim_start = np.empty(len(self._sim_start), dtype=int_)
for i in range(len(self._sim_start)):
if self._sim_start[i] == 0:
new_sim_start[i] = 0
elif self._sim_start[i] == len(self.wrapper.index):
new_sim_start[i] = len(new_wrapper.index)
else:
_new_sim_start = new_wrapper.index.get_indexer(
[self.wrapper.index[self._sim_start[i]]],
method="ffill",
)[0]
if _new_sim_start == -1:
_new_sim_start = 0
new_sim_start[i] = _new_sim_start
else:
new_sim_start = None
return new_sim_start
def resample_sim_end(self, new_wrapper: ArrayWrapper) -> tp.Optional[tp.ArrayLike]:
"""Resample simulation end."""
if self._sim_end is not None:
new_sim_end = np.empty(len(self._sim_end), dtype=int_)
for i in range(len(self._sim_end)):
if self._sim_end[i] == 0:
new_sim_end[i] = 0
elif self._sim_end[i] == len(self.wrapper.index):
new_sim_end[i] = len(new_wrapper.index)
else:
_new_sim_end = new_wrapper.index.get_indexer(
[self.wrapper.index[self._sim_end[i]]],
method="bfill",
)[0]
if _new_sim_end == -1:
_new_sim_end = len(new_wrapper.index)
new_sim_end[i] = _new_sim_end
else:
new_sim_end = None
return new_sim_end
@hybrid_method
def resolve_sim_start_value(
cls_or_self,
value: tp.Scalar,
wrapper: tp.Optional[ArrayWrapper] = None,
) -> int:
"""Resolve a single value of simulation start."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
auto_idxr = AutoIdxr(value, indexer_method="bfill", below_to_zero=True)
return auto_idxr.get(wrapper.index, freq=wrapper.freq)
@hybrid_method
def resolve_sim_end_value(
cls_or_self,
value: tp.Scalar,
wrapper: tp.Optional[ArrayWrapper] = None,
) -> int:
"""Resolve a single value of simulation end."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
auto_idxr = AutoIdxr(value, indexer_method="bfill", above_to_len=True)
return auto_idxr.get(wrapper.index, freq=wrapper.freq)
@hybrid_method
def resolve_sim_start(
cls_or_self,
sim_start: tp.Optional[tp.ArrayLike] = None,
allow_none: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
) -> tp.Optional[tp.ArrayLike]:
"""Resolve simulation start."""
already_resolved = False
if not isinstance(cls_or_self, type):
if sim_start is None:
sim_start = cls_or_self._sim_start
already_resolved = True
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
if sim_start is False:
sim_start = None
if allow_none and sim_start is None:
return None
if not already_resolved and sim_start is not None:
sim_start_arr = np.asarray(sim_start)
if not np.issubdtype(sim_start_arr.dtype, np.integer):
if sim_start_arr.ndim == 0:
sim_start = cls_or_self.resolve_sim_start_value(sim_start, wrapper=wrapper)
else:
new_sim_start = np.empty(len(sim_start), dtype=int_)
for i in range(len(sim_start)):
new_sim_start[i] = cls_or_self.resolve_sim_start_value(sim_start[i], wrapper=wrapper)
sim_start = new_sim_start
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
sim_start = nb.resolve_grouped_sim_start_nb(
wrapper.shape_2d,
group_lens,
sim_start=sim_start,
allow_none=allow_none,
check_bounds=not already_resolved,
)
elif not already_resolved and wrapper.grouper.is_grouped():
group_lens = wrapper.grouper.get_group_lens()
sim_start = nb.resolve_ungrouped_sim_start_nb(
wrapper.shape_2d,
group_lens,
sim_start=sim_start,
allow_none=allow_none,
check_bounds=not already_resolved,
)
else:
sim_start = nb.resolve_sim_start_nb(
wrapper.shape_2d,
sim_start=sim_start,
allow_none=allow_none,
check_bounds=not already_resolved,
)
return sim_start
@hybrid_method
def resolve_sim_end(
cls_or_self,
sim_end: tp.Optional[tp.ArrayLike] = None,
allow_none: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
) -> tp.Optional[tp.ArrayLike]:
"""Resolve simulation end."""
already_resolved = False
if not isinstance(cls_or_self, type):
if sim_end is None:
sim_end = cls_or_self._sim_end
already_resolved = True
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
if sim_end is False:
sim_end = None
if allow_none and sim_end is None:
return None
if not already_resolved and sim_end is not None:
sim_end_arr = np.asarray(sim_end)
if not np.issubdtype(sim_end_arr.dtype, np.integer):
if sim_end_arr.ndim == 0:
sim_end = cls_or_self.resolve_sim_end_value(sim_end, wrapper=wrapper)
else:
new_sim_end = np.empty(len(sim_end), dtype=int_)
for i in range(len(sim_end)):
new_sim_end[i] = cls_or_self.resolve_sim_end_value(sim_end[i], wrapper=wrapper)
sim_end = new_sim_end
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
sim_end = nb.resolve_grouped_sim_end_nb(
wrapper.shape_2d,
group_lens,
sim_end=sim_end,
allow_none=allow_none,
check_bounds=not already_resolved,
)
elif not already_resolved and wrapper.grouper.is_grouped():
group_lens = wrapper.grouper.get_group_lens()
sim_end = nb.resolve_ungrouped_sim_end_nb(
wrapper.shape_2d,
group_lens,
sim_end=sim_end,
allow_none=allow_none,
check_bounds=not already_resolved,
)
else:
sim_end = nb.resolve_sim_end_nb(
wrapper.shape_2d,
sim_end=sim_end,
allow_none=allow_none,
check_bounds=not already_resolved,
)
return sim_end
@hybrid_method
def get_sim_start(
cls_or_self,
sim_start: tp.Optional[tp.ArrayLike] = None,
keep_flex: bool = False,
allow_none: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Union[None, tp.Array1d, tp.Series]:
"""Get simulation start."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(
sim_start=sim_start,
allow_none=allow_none,
wrapper=wrapper,
group_by=group_by,
)
if sim_start is None:
return None
if keep_flex:
return sim_start
wrap_kwargs = merge_dicts(dict(name_or_index="sim_end"), wrap_kwargs)
return wrapper.wrap_reduced(sim_start, group_by=group_by, **wrap_kwargs)
@property
def sim_start(self) -> tp.Series:
"""`SimRangeMixin.get_sim_start` with default arguments."""
return self.get_sim_start()
@hybrid_method
def get_sim_end(
cls_or_self,
sim_end: tp.Optional[tp.ArrayLike] = None,
keep_flex: bool = False,
allow_none: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Union[None, tp.Array1d, tp.Series]:
"""Get simulation end."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_end = cls_or_self.resolve_sim_end(
sim_end=sim_end,
allow_none=allow_none,
wrapper=wrapper,
group_by=group_by,
)
if sim_end is None:
return None
if keep_flex:
return sim_end
wrap_kwargs = merge_dicts(dict(name_or_index="sim_start"), wrap_kwargs)
return wrapper.wrap_reduced(sim_end, group_by=group_by, **wrap_kwargs)
@property
def sim_end(self) -> tp.Series:
"""`SimRangeMixin.get_sim_end` with default arguments."""
return self.get_sim_end()
@hybrid_method
def get_sim_start_index(
cls_or_self,
sim_start: tp.Optional[tp.ArrayLike] = None,
allow_none: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Optional[tp.Series]:
"""Get index of simulation start."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(
sim_start=sim_start,
allow_none=allow_none,
wrapper=wrapper,
group_by=group_by,
)
if sim_start is None:
return None
start_index = []
for i in range(len(sim_start)):
_sim_start = sim_start[i]
if _sim_start == 0:
start_index.append(wrapper.index[0])
elif _sim_start == len(wrapper.index):
if isinstance(wrapper.index, pd.DatetimeIndex) and wrapper.freq is not None:
start_index.append(wrapper.index[-1] + wrapper.freq)
elif isinstance(wrapper.index, pd.RangeIndex):
start_index.append(wrapper.index[-1] + 1)
else:
start_index.append(None)
else:
start_index.append(wrapper.index[_sim_start])
wrap_kwargs = merge_dicts(dict(name_or_index="sim_start_index"), wrap_kwargs)
return wrapper.wrap_reduced(pd.Index(start_index), group_by=group_by, **wrap_kwargs)
@property
def sim_start_index(self) -> tp.Series:
"""`SimRangeMixin.get_sim_start_index` with default arguments."""
return self.get_sim_start_index()
@hybrid_method
def get_sim_end_index(
cls_or_self,
sim_end: tp.Optional[tp.ArrayLike] = None,
allow_none: bool = False,
inclusive: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Optional[tp.Series]:
"""Get index of simulation end."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_end = cls_or_self.resolve_sim_end(
sim_end=sim_end,
allow_none=allow_none,
wrapper=wrapper,
group_by=group_by,
)
if sim_end is None:
return None
end_index = []
for i in range(len(sim_end)):
_sim_end = sim_end[i]
if _sim_end == 0:
if inclusive:
end_index.append(None)
else:
end_index.append(wrapper.index[0])
elif _sim_end == len(wrapper.index):
if inclusive:
end_index.append(wrapper.index[-1])
else:
if isinstance(wrapper.index, pd.DatetimeIndex) and wrapper.freq is not None:
end_index.append(wrapper.index[-1] + wrapper.freq)
elif isinstance(wrapper.index, pd.RangeIndex):
end_index.append(wrapper.index[-1] + 1)
else:
end_index.append(None)
else:
if inclusive:
end_index.append(wrapper.index[_sim_end - 1])
else:
end_index.append(wrapper.index[_sim_end])
wrap_kwargs = merge_dicts(dict(name_or_index="sim_end_index"), wrap_kwargs)
return wrapper.wrap_reduced(pd.Index(end_index), group_by=group_by, **wrap_kwargs)
@property
def sim_end_index(self) -> tp.Series:
"""`SimRangeMixin.get_sim_end_index` with default arguments."""
return self.get_sim_end_index()
@hybrid_method
def get_sim_duration(
cls_or_self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Optional[tp.Series]:
"""Get duration of simulation range."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(
sim_start=sim_start,
allow_none=False,
wrapper=wrapper,
group_by=group_by,
)
sim_end = cls_or_self.resolve_sim_end(
sim_end=sim_end,
allow_none=False,
wrapper=wrapper,
group_by=group_by,
)
total_duration = sim_end - sim_start
wrap_kwargs = merge_dicts(dict(name_or_index="sim_duration"), wrap_kwargs)
return wrapper.wrap_reduced(total_duration, group_by=group_by, **wrap_kwargs)
@property
def sim_duration(self) -> tp.Series:
"""`SimRangeMixin.get_sim_duration` with default arguments."""
return self.get_sim_duration()
@hybrid_method
def fit_fig_to_sim_range(
cls_or_self,
fig: tp.BaseFigure,
column: tp.Optional[tp.Label] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
xref: tp.Optional[str] = None,
) -> tp.BaseFigure:
"""Fit figure to simulation range."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.get_sim_start(
sim_start=sim_start,
allow_none=True,
wrapper=wrapper,
group_by=group_by,
)
sim_end = cls_or_self.get_sim_end(
sim_end=sim_end,
allow_none=True,
wrapper=wrapper,
group_by=group_by,
)
if sim_start is not None:
sim_start = wrapper.select_col_from_obj(sim_start, column=column, group_by=group_by)
if sim_end is not None:
sim_end = wrapper.select_col_from_obj(sim_end, column=column, group_by=group_by)
if sim_start is not None or sim_end is not None:
if sim_start == len(wrapper.index) or sim_end == 0 or sim_start == sim_end:
return fig
if sim_start is None:
sim_start = 0
if sim_start > 0:
sim_start_index = wrapper.index[sim_start - 1]
else:
if isinstance(wrapper.index, pd.DatetimeIndex) and wrapper.freq is not None:
sim_start_index = wrapper.index[0] - wrapper.freq
elif isinstance(wrapper.index, pd.RangeIndex):
sim_start_index = wrapper.index[0] - 1
else:
sim_start_index = wrapper.index[0]
if sim_end is None:
sim_end = len(wrapper.index)
if sim_end < len(wrapper.index):
sim_end_index = wrapper.index[sim_end]
else:
if isinstance(wrapper.index, pd.DatetimeIndex) and wrapper.freq is not None:
sim_end_index = wrapper.index[-1] + wrapper.freq
elif isinstance(wrapper.index, pd.RangeIndex):
sim_end_index = wrapper.index[-1] + 1
else:
sim_end_index = wrapper.index[-1]
if xref is not None:
xaxis = "xaxis" + xref[1:]
fig.update_layout(**{xaxis: dict(range=[sim_start_index, sim_end_index])})
else:
fig.update_xaxes(range=[sim_start_index, sim_end_index])
return fig
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Mixin for building statistics out of performance metrics."""
import inspect
import string
from collections import Counter
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.wrapping import Wrapping
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import get_dict_attr, AttrResolverMixin
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig
from vectorbtpro.utils.parsing import get_func_arg_names, get_forward_args
from vectorbtpro.utils.tagging import match_tags
from vectorbtpro.utils.template import substitute_templates, CustomTemplate
from vectorbtpro.utils.warnings_ import warn
__all__ = []
class MetaStatsBuilderMixin(type):
"""Metaclass for `StatsBuilderMixin`."""
@property
def metrics(cls) -> Config:
"""Metrics supported by `StatsBuilderMixin.stats`."""
return cls._metrics
class StatsBuilderMixin(Base, metaclass=MetaStatsBuilderMixin):
"""Mixin that implements `StatsBuilderMixin.stats`.
Required to be a subclass of `vectorbtpro.base.wrapping.Wrapping`."""
_writeable_attrs: tp.WriteableAttrs = {"_metrics"}
def __init__(self) -> None:
checks.assert_instance_of(self, Wrapping)
# Copy writeable attrs
self._metrics = type(self)._metrics.copy()
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `StatsBuilderMixin.stats`."""
return dict(settings=dict(freq=self.wrapper.freq))
def resolve_stats_setting(
self,
value: tp.Optional[tp.Any],
key: str,
merge: bool = False,
) -> tp.Any:
"""Resolve a setting for `StatsBuilderMixin.stats`."""
from vectorbtpro._settings import settings as _settings
stats_builder_cfg = _settings["stats_builder"]
if merge:
return merge_dicts(
stats_builder_cfg[key],
self.stats_defaults.get(key, {}),
value,
)
if value is not None:
return value
return self.stats_defaults.get(key, stats_builder_cfg[key])
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
)
)
@property
def metrics(self) -> Config:
"""Metrics supported by `${cls_name}`.
```python
${metrics}
```
Returns `${cls_name}._metrics`, which gets (hybrid-) copied upon creation of each instance.
Thus, changing this config won't affect the class.
To change metrics, you can either change the config in-place, override this property,
or overwrite the instance variable `${cls_name}._metrics`."""
return self._metrics
def stats(
self,
metrics: tp.Optional[tp.MaybeIterable[tp.Union[str, tp.Tuple[str, tp.Kwargs]]]] = None,
tags: tp.Optional[tp.MaybeIterable[str]] = None,
column: tp.Optional[tp.Label] = None,
group_by: tp.GroupByLike = None,
per_column: tp.Optional[bool] = None,
split_columns: tp.Optional[bool] = None,
agg_func: tp.Optional[tp.Callable] = np.mean,
dropna: tp.Optional[bool] = None,
silence_warnings: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
settings: tp.KwargsLike = None,
filters: tp.KwargsLike = None,
metric_settings: tp.KwargsLike = None,
) -> tp.Optional[tp.SeriesFrame]:
"""Compute various metrics on this object.
Args:
metrics (str, tuple, iterable, or dict): Metrics to calculate.
Each element can be either:
* Metric name (see keys in `StatsBuilderMixin.metrics`)
* Tuple of a metric name and a settings dict as in `StatsBuilderMixin.metrics`
* Tuple of a metric name and a template of instance `vectorbtpro.utils.template.CustomTemplate`
* Tuple of a metric name and a list of settings dicts to be expanded into multiple metrics
The settings dict can contain the following keys:
* `title`: Title of the metric. Defaults to the name.
* `tags`: Single or multiple tags to associate this metric with.
If any of these tags is in `tags`, keeps this metric.
* `check_{filter}` and `inv_check_{filter}`: Whether to check this metric against a
filter defined in `filters`. True (or False for inverse) means to keep this metric.
* `calc_func` (required): Calculation function for custom metrics.
Must return either a scalar for one column/group, pd.Series for multiple columns/groups,
or a dict of such for multiple sub-metrics.
* `resolve_calc_func`: whether to resolve `calc_func`. If the function can be accessed
by traversing attributes of this object, you can specify the path to this function
as a string (see `vectorbtpro.utils.attr_.deep_getattr` for the path format).
If `calc_func` is a function, arguments from merged metric settings are matched with
arguments in the signature (see below). If `resolve_calc_func` is False, `calc_func`
must accept (resolved) self and dictionary of merged metric settings.
Defaults to True.
* `use_shortcuts`: Whether to use shortcut properties whenever possible when
resolving `calc_func`. Defaults to True.
* `post_calc_func`: Function to post-process the result of `calc_func`.
Must accept (resolved) self, output of `calc_func`, and dictionary of merged metric settings,
and return whatever is acceptable to be returned by `calc_func`. Defaults to None.
* `fill_wrap_kwargs`: Whether to fill `wrap_kwargs` with `to_timedelta` and `silence_warnings`.
Defaults to False.
* `apply_to_timedelta`: Whether to apply `vectorbtpro.base.wrapping.ArrayWrapper.arr_to_timedelta`
on the result. To disable this globally, pass `to_timedelta=False` in `settings`.
Defaults to False.
* `pass_{arg}`: Whether to pass any argument from the settings (see below). Defaults to True if
this argument was found in the function's signature. Set to False to not pass.
If argument to be passed was not found, `pass_{arg}` is removed.
* `resolve_path_{arg}`: Whether to resolve an argument that is meant to be an attribute of
this object and is the first part of the path of `calc_func`. Passes only optional arguments.
Defaults to True. See `vectorbtpro.utils.attr_.AttrResolverMixin.resolve_attr`.
* `resolve_{arg}`: Whether to resolve an argument that is meant to be an attribute of
this object and is present in the function's signature. Defaults to False.
See `vectorbtpro.utils.attr_.AttrResolverMixin.resolve_attr`.
* `use_shortcuts_{arg}`: Whether to use shortcut properties whenever possible when resolving
an argument. Defaults to True.
* `select_col_{arg}`: Whether to select the column from an argument that is meant to be
an attribute of this object. Defaults to False.
* `template_context`: Mapping to replace templates in metric settings. Used across all settings.
* Any other keyword argument that overrides the settings or is passed directly to `calc_func`.
If `resolve_calc_func` is True, the calculation function may "request" any of the
following arguments by accepting them or if `pass_{arg}` was found in the settings dict:
* Each of `vectorbtpro.utils.attr_.AttrResolverMixin.self_aliases`: original object
(ungrouped, with no column selected)
* `group_by`: won't be passed if it was used in resolving the first attribute of `calc_func`
specified as a path, use `pass_group_by=True` to pass anyway
* `column`
* `metric_name`
* `agg_func`
* `silence_warnings`
* `to_timedelta`: replaced by True if None and frequency is set
* Any argument from `settings`
* Any attribute of this object if it meant to be resolved
(see `vectorbtpro.utils.attr_.AttrResolverMixin.resolve_attr`)
Pass `metrics='all'` to calculate all supported metrics.
tags (str or iterable): Tags to select.
See `vectorbtpro.utils.tagging.match_tags`.
column (str): Name of the column/group.
!!! hint
There are two ways to select a column: `obj['a'].stats()` and `obj.stats(column='a')`.
They both accomplish the same thing but in different ways: `obj['a'].stats()` computes
statistics of the column 'a' only, while `obj.stats(column='a')` computes statistics of
all columns first and only then selects the column 'a'. The first method is preferred
when you have a lot of data or caching is disabled. The second method is preferred when
most attributes have already been cached.
group_by (any): Group or ungroup columns. See `vectorbtpro.base.grouping.base.Grouper`.
per_column (bool): Whether to compute per column and then stack along columns.
split_columns (bool): Whether to split this instance into multiple columns when `per_column` is True.
Otherwise, iterates over columns and passes `column` to the whole instance.
agg_func (callable): Aggregation function to aggregate statistics across all columns.
By default, takes the mean of all columns. If None, returns all columns as a DataFrame.
Must take `pd.Series` and return a const.
Takes effect if `column` was specified or this object contains only one column of data.
If `agg_func` has been overridden by a metric:
* Takes effect if global `agg_func` is not None
* Raises a warning if it's None but the result of calculation has multiple values
dropna (bool): Whether to hide metrics that are all NaN.
silence_warnings (bool): Whether to silence all warnings.
template_context (mapping): Context used to substitute templates.
Gets merged over `template_context` from `vectorbtpro._settings.stats_builder` and
`StatsBuilderMixin.stats_defaults`.
Applied on `settings` and then on each metric settings.
filters (dict): Filters to apply.
Each item consists of the filter name and settings dict.
The settings dict can contain the following keys:
* `filter_func`: Filter function that must accept resolved self and
merged settings for a metric, and return either True or False.
* `warning_message`: Warning message to be shown when skipping a metric.
Can be a template that will be substituted using merged metric settings as context.
Defaults to None.
* `inv_warning_message`: Same as `warning_message` but for inverse checks.
Gets merged over `filters` from `vectorbtpro._settings.stats_builder` and
`StatsBuilderMixin.stats_defaults`.
settings (dict): Global settings and resolution arguments.
Extends/overrides `settings` from `vectorbtpro._settings.stats_builder` and
`StatsBuilderMixin.stats_defaults`.
Gets extended/overridden by metric settings.
metric_settings (dict): Keyword arguments for each metric.
Extends/overrides all global and metric settings.
For template logic, see `vectorbtpro.utils.template`.
For defaults, see `vectorbtpro._settings.stats_builder` and `StatsBuilderMixin.stats_defaults`.
!!! hint
There are two types of arguments: optional (or resolution) and mandatory arguments.
Optional arguments are only passed if they are found in the function's signature.
Mandatory arguments are passed regardless of this. Optional arguments can only be defined
using `settings` (that is, globally), while mandatory arguments can be defined both using
default metric settings and `{metric_name}_kwargs`. Overriding optional arguments using default
metric settings or `{metric_name}_kwargs` won't turn them into mandatory. For this, pass `pass_{arg}=True`.
!!! hint
Make sure to resolve and then to re-use as many object attributes as possible to
utilize built-in caching (even if global caching is disabled).
"""
# Compute per column
if column is None:
if per_column is None:
per_column = self.resolve_stats_setting(per_column, "per_column")
if per_column:
columns = self.get_item_keys(group_by=group_by)
if len(columns) > 1:
results = []
if split_columns:
for _, column_self in self.items(group_by=group_by, wrap=True):
_args, _kwargs = get_forward_args(column_self.stats, locals())
results.append(column_self.stats(*_args, **_kwargs))
else:
for column in columns:
_args, _kwargs = get_forward_args(self.stats, locals())
results.append(self.stats(*_args, **_kwargs))
return pd.concat(results, keys=columns, axis=1)
# Resolve defaults
dropna = self.resolve_stats_setting(dropna, "dropna")
silence_warnings = self.resolve_stats_setting(silence_warnings, "silence_warnings")
template_context = self.resolve_stats_setting(template_context, "template_context", merge=True)
filters = self.resolve_stats_setting(filters, "filters", merge=True)
settings = self.resolve_stats_setting(settings, "settings", merge=True)
metric_settings = self.resolve_stats_setting(metric_settings, "metric_settings", merge=True)
# Replace templates globally (not used at metric level)
if len(template_context) > 0:
sub_settings = substitute_templates(
settings,
context=template_context,
eval_id="sub_settings",
strict=False,
)
else:
sub_settings = settings
# Resolve self
reself = self.resolve_self(
cond_kwargs=sub_settings,
impacts_caching=False,
silence_warnings=silence_warnings,
)
# Prepare metrics
metrics = reself.resolve_stats_setting(metrics, "metrics")
if metrics == "all":
metrics = reself.metrics
if isinstance(metrics, dict):
metrics = list(metrics.items())
if isinstance(metrics, (str, tuple)):
metrics = [metrics]
# Prepare tags
tags = reself.resolve_stats_setting(tags, "tags")
if isinstance(tags, str) and tags == "all":
tags = None
if isinstance(tags, (str, tuple)):
tags = [tags]
# Bring to the same shape
new_metrics = []
for i, metric in enumerate(metrics):
if isinstance(metric, str):
metric = (metric, reself.metrics[metric])
if not isinstance(metric, tuple):
raise TypeError(f"Metric at index {i} must be either a string or a tuple")
new_metrics.append(metric)
metrics = new_metrics
# Expand metrics
new_metrics = []
for i, (metric_name, _metric_settings) in enumerate(metrics):
if isinstance(_metric_settings, CustomTemplate):
metric_context = merge_dicts(
template_context,
{name: reself for name in reself.self_aliases},
dict(
column=column,
group_by=group_by,
metric_name=metric_name,
agg_func=agg_func,
silence_warnings=silence_warnings,
to_timedelta=None,
),
settings,
)
metric_context = substitute_templates(
metric_context,
context=metric_context,
eval_id="metric_context",
)
_metric_settings = _metric_settings.substitute(
context=metric_context,
strict=True,
eval_id="metric",
)
if isinstance(_metric_settings, list):
for __metric_settings in _metric_settings:
new_metrics.append((metric_name, __metric_settings))
else:
new_metrics.append((metric_name, _metric_settings))
metrics = new_metrics
# Handle duplicate names
metric_counts = Counter(list(map(lambda x: x[0], metrics)))
metric_i = {k: -1 for k in metric_counts.keys()}
metrics_dct = {}
for i, (metric_name, _metric_settings) in enumerate(metrics):
if metric_counts[metric_name] > 1:
metric_i[metric_name] += 1
metric_name = metric_name + "_" + str(metric_i[metric_name])
metrics_dct[metric_name] = _metric_settings
# Check metric_settings
missed_keys = set(metric_settings.keys()).difference(set(metrics_dct.keys()))
if len(missed_keys) > 0:
raise ValueError(f"Keys {missed_keys} in metric_settings could not be matched with any metric")
# Merge settings
opt_arg_names_dct = {}
custom_arg_names_dct = {}
resolved_self_dct = {}
context_dct = {}
for metric_name, _metric_settings in list(metrics_dct.items()):
opt_settings = merge_dicts(
{name: reself for name in reself.self_aliases},
dict(
column=column,
group_by=group_by,
metric_name=metric_name,
agg_func=agg_func,
silence_warnings=silence_warnings,
to_timedelta=None,
),
settings,
)
_metric_settings = _metric_settings.copy()
passed_metric_settings = metric_settings.get(metric_name, {})
merged_settings = merge_dicts(opt_settings, _metric_settings, passed_metric_settings)
metric_template_context = merged_settings.pop("template_context", {})
template_context_merged = merge_dicts(template_context, metric_template_context)
template_context_merged = substitute_templates(
template_context_merged,
context=merged_settings,
eval_id="template_context_merged",
)
context = merge_dicts(template_context_merged, merged_settings)
merged_settings = substitute_templates(
merged_settings,
context=context,
eval_id="merged_settings",
)
# Filter by tag
if tags is not None:
in_tags = merged_settings.get("tags", None)
if in_tags is None or not match_tags(tags, in_tags):
metrics_dct.pop(metric_name, None)
continue
custom_arg_names = set(_metric_settings.keys()).union(set(passed_metric_settings.keys()))
opt_arg_names = set(opt_settings.keys())
custom_reself = reself.resolve_self(
cond_kwargs=merged_settings,
custom_arg_names=custom_arg_names,
impacts_caching=True,
silence_warnings=merged_settings["silence_warnings"],
)
metrics_dct[metric_name] = merged_settings
custom_arg_names_dct[metric_name] = custom_arg_names
opt_arg_names_dct[metric_name] = opt_arg_names
resolved_self_dct[metric_name] = custom_reself
context_dct[metric_name] = context
# Filter metrics
for metric_name, _metric_settings in list(metrics_dct.items()):
custom_reself = resolved_self_dct[metric_name]
context = context_dct[metric_name]
_silence_warnings = _metric_settings.get("silence_warnings")
metric_filters = set()
for k in _metric_settings.keys():
filter_name = None
if k.startswith("check_"):
filter_name = k[len("check_") :]
elif k.startswith("inv_check_"):
filter_name = k[len("inv_check_") :]
if filter_name is not None:
if filter_name not in filters:
raise ValueError(f"Metric '{metric_name}' requires filter '{filter_name}'")
metric_filters.add(filter_name)
for filter_name in metric_filters:
filter_settings = filters[filter_name]
_filter_settings = substitute_templates(
filter_settings,
context=context,
eval_id="filter_settings",
)
filter_func = _filter_settings["filter_func"]
warning_message = _filter_settings.get("warning_message", None)
inv_warning_message = _filter_settings.get("inv_warning_message", None)
to_check = _metric_settings.get("check_" + filter_name, False)
inv_to_check = _metric_settings.get("inv_check_" + filter_name, False)
if to_check or inv_to_check:
whether_true = filter_func(custom_reself, _metric_settings)
to_remove = (to_check and not whether_true) or (inv_to_check and whether_true)
if to_remove:
if to_check and warning_message is not None and not _silence_warnings:
warn(warning_message)
if inv_to_check and inv_warning_message is not None and not _silence_warnings:
warn(inv_warning_message)
metrics_dct.pop(metric_name, None)
custom_arg_names_dct.pop(metric_name, None)
opt_arg_names_dct.pop(metric_name, None)
resolved_self_dct.pop(metric_name, None)
context_dct.pop(metric_name, None)
break
# Any metrics left?
if len(metrics_dct) == 0:
if not silence_warnings:
warn("No metrics to calculate")
return None
# Compute stats
arg_cache_dct = {}
stats_dct = {}
used_agg_func = False
for i, (metric_name, _metric_settings) in enumerate(metrics_dct.items()):
try:
final_kwargs = _metric_settings.copy()
opt_arg_names = opt_arg_names_dct[metric_name]
custom_arg_names = custom_arg_names_dct[metric_name]
custom_reself = resolved_self_dct[metric_name]
# Clean up keys
for k, v in list(final_kwargs.items()):
if k.startswith("check_") or k.startswith("inv_check_") or k in ("tags",):
final_kwargs.pop(k, None)
# Get metric-specific values
_column = final_kwargs.get("column")
_group_by = final_kwargs.get("group_by")
_agg_func = final_kwargs.get("agg_func")
_silence_warnings = final_kwargs.get("silence_warnings")
if final_kwargs["to_timedelta"] is None:
final_kwargs["to_timedelta"] = custom_reself.wrapper.freq is not None
to_timedelta = final_kwargs.get("to_timedelta")
title = final_kwargs.pop("title", metric_name)
calc_func = final_kwargs.pop("calc_func")
resolve_calc_func = final_kwargs.pop("resolve_calc_func", True)
post_calc_func = final_kwargs.pop("post_calc_func", None)
use_shortcuts = final_kwargs.pop("use_shortcuts", True)
use_caching = final_kwargs.pop("use_caching", True)
fill_wrap_kwargs = final_kwargs.pop("fill_wrap_kwargs", False)
if fill_wrap_kwargs:
final_kwargs["wrap_kwargs"] = merge_dicts(
dict(to_timedelta=to_timedelta, silence_warnings=_silence_warnings),
final_kwargs.get("wrap_kwargs", None),
)
apply_to_timedelta = final_kwargs.pop("apply_to_timedelta", False)
# Resolve calc_func
if resolve_calc_func:
if not callable(calc_func):
passed_kwargs_out = {}
def _getattr_func(
obj: tp.Any,
attr: str,
args: tp.ArgsLike = None,
kwargs: tp.KwargsLike = None,
call_attr: bool = True,
_final_kwargs: tp.Kwargs = final_kwargs,
_opt_arg_names: tp.Set[str] = opt_arg_names,
_custom_arg_names: tp.Set[str] = custom_arg_names,
_arg_cache_dct: tp.Kwargs = arg_cache_dct,
_use_shortcuts: bool = use_shortcuts,
_use_caching: bool = use_caching,
) -> tp.Any:
if attr in _final_kwargs:
return _final_kwargs[attr]
if args is None:
args = ()
if kwargs is None:
kwargs = {}
if obj is custom_reself:
resolve_path_arg = _final_kwargs.pop("resolve_path_" + attr, True)
if resolve_path_arg:
if call_attr:
cond_kwargs = {k: v for k, v in _final_kwargs.items() if k in _opt_arg_names}
out = custom_reself.resolve_attr(
attr, # do not pass _attr, important for caching
args=args,
cond_kwargs=cond_kwargs,
kwargs=kwargs,
custom_arg_names=_custom_arg_names,
cache_dct=_arg_cache_dct,
use_caching=_use_caching,
passed_kwargs_out=passed_kwargs_out,
use_shortcuts=_use_shortcuts,
)
else:
if isinstance(obj, AttrResolverMixin):
cls_dir = obj.cls_dir
else:
cls_dir = dir(type(obj))
if "get_" + attr in cls_dir:
_attr = "get_" + attr
else:
_attr = attr
out = getattr(obj, _attr)
_select_col_arg = _final_kwargs.pop("select_col_" + attr, False)
if _select_col_arg and _column is not None:
out = custom_reself.select_col_from_obj(
out,
_column,
wrapper=custom_reself.wrapper.regroup(_group_by),
)
passed_kwargs_out["group_by"] = _group_by
passed_kwargs_out["column"] = _column
return out
out = getattr(obj, attr)
if callable(out) and call_attr:
return out(*args, **kwargs)
return out
calc_func = custom_reself.deep_getattr(
calc_func,
getattr_func=_getattr_func,
call_last_attr=False,
)
if "group_by" in passed_kwargs_out:
if "pass_group_by" not in final_kwargs:
final_kwargs.pop("group_by", None)
if "column" in passed_kwargs_out:
if "pass_column" not in final_kwargs:
final_kwargs.pop("column", None)
# Resolve arguments
if callable(calc_func):
func_arg_names = get_func_arg_names(calc_func)
for k in func_arg_names:
if k not in final_kwargs:
resolve_arg = final_kwargs.pop("resolve_" + k, False)
use_shortcuts_arg = final_kwargs.pop("use_shortcuts_" + k, True)
select_col_arg = final_kwargs.pop("select_col_" + k, False)
if resolve_arg:
try:
arg_out = custom_reself.resolve_attr(
k,
cond_kwargs=final_kwargs,
custom_arg_names=custom_arg_names,
cache_dct=arg_cache_dct,
use_caching=use_caching,
use_shortcuts=use_shortcuts_arg,
)
except AttributeError:
continue
if select_col_arg and _column is not None:
arg_out = custom_reself.select_col_from_obj(
arg_out,
_column,
wrapper=custom_reself.wrapper.regroup(_group_by),
)
final_kwargs[k] = arg_out
for k in list(final_kwargs.keys()):
if k in opt_arg_names:
if "pass_" + k in final_kwargs:
if not final_kwargs.get("pass_" + k): # first priority
final_kwargs.pop(k, None)
elif k not in func_arg_names: # second priority
final_kwargs.pop(k, None)
for k in list(final_kwargs.keys()):
if k.startswith("pass_") or k.startswith("resolve_"):
final_kwargs.pop(k, None) # cleanup
# Call calc_func
out = calc_func(**final_kwargs)
else:
# calc_func is already a result
out = calc_func
else:
# Do not resolve calc_func
out = calc_func(custom_reself, _metric_settings)
# Call post_calc_func
if post_calc_func is not None:
out = post_calc_func(custom_reself, out, _metric_settings)
# Post-process and store the metric
multiple = True
if not isinstance(out, dict):
multiple = False
out = {None: out}
for k, v in out.items():
# Resolve title
if multiple:
if title is None:
t = str(k)
else:
t = title + ": " + str(k)
else:
t = title
# Check result type
if checks.is_any_array(v) and not checks.is_series(v):
raise TypeError(
"calc_func must return either a scalar for one column/group, "
"pd.Series for multiple columns/groups, or a dict of such. "
f"Not {type(v)}."
)
# Handle apply_to_timedelta
if apply_to_timedelta and to_timedelta:
v = custom_reself.wrapper.arr_to_timedelta(v, silence_warnings=_silence_warnings)
# Select column or aggregate
if checks.is_series(v):
if _column is None and v.shape[0] == 1:
v = v.iloc[0]
elif _column is not None:
v = custom_reself.select_col_from_obj(
v,
_column,
wrapper=custom_reself.wrapper.regroup(_group_by),
)
elif _agg_func is not None and agg_func is not None:
v = _agg_func(v)
if _agg_func is agg_func:
used_agg_func = True
elif _agg_func is None and agg_func is not None:
if not _silence_warnings:
warn(
f"Metric '{metric_name}' returned multiple values "
"despite having no aggregation function",
)
continue
# Store metric
if t in stats_dct:
if not _silence_warnings:
warn(f"Duplicate metric title '{t}'")
stats_dct[t] = v
except Exception as e:
warn(f"Metric '{metric_name}' raised an exception")
raise e
# Return the stats
if reself.wrapper.get_ndim(group_by=group_by) == 1:
sr = pd.Series(
stats_dct,
name=reself.wrapper.get_name(group_by=group_by),
dtype=object,
)
if dropna:
sr.replace([np.inf, -np.inf], np.nan, inplace=True)
return sr.dropna()
return sr
if column is not None:
sr = pd.Series(stats_dct, name=column, dtype=object)
if dropna:
sr.replace([np.inf, -np.inf], np.nan, inplace=True)
return sr.dropna()
return sr
if agg_func is not None:
if used_agg_func and not silence_warnings:
warn(
f"Object has multiple columns. Aggregated some metrics using {agg_func}. "
"Pass either agg_func=None or per_column=True to return statistics per column. "
"Pass column to select a single column or group.",
)
sr = pd.Series(stats_dct, name="agg_stats", dtype=object)
if dropna:
sr.replace([np.inf, -np.inf], np.nan, inplace=True)
return sr.dropna()
return sr
new_index = reself.wrapper.grouper.get_index(group_by=group_by)
df = pd.DataFrame(stats_dct, index=new_index)
if dropna:
df.replace([np.inf, -np.inf], np.nan, inplace=True)
return df.dropna(axis=1, how="all")
return df
# ############# Docs ############# #
@classmethod
def build_metrics_doc(cls, source_cls: tp.Optional[type] = None) -> str:
"""Build metrics documentation."""
if source_cls is None:
source_cls = StatsBuilderMixin
return string.Template(
inspect.cleandoc(get_dict_attr(source_cls, "metrics").__doc__),
).substitute(
{"metrics": cls.metrics.prettify(), "cls_name": cls.__name__},
)
@classmethod
def override_metrics_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None:
"""Call this method on each subclass that overrides `StatsBuilderMixin.metrics`."""
__pdoc__[cls.__name__ + ".metrics"] = cls.build_metrics_doc(source_cls=source_cls)
__pdoc__ = dict()
StatsBuilderMixin.override_metrics_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with custom indicators built with the indicator factory.
You can access all the indicators by `vbt.*`.
Run for the examples:
```pycon
>>> ohlcv = vbt.YFData.pull(
... "BTC-USD",
... start="2019-03-01",
... end="2019-09-01"
... ).get()
```
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.indicators.custom.adx import *
from vectorbtpro.indicators.custom.atr import *
from vectorbtpro.indicators.custom.bbands import *
from vectorbtpro.indicators.custom.hurst import *
from vectorbtpro.indicators.custom.ma import *
from vectorbtpro.indicators.custom.macd import *
from vectorbtpro.indicators.custom.msd import *
from vectorbtpro.indicators.custom.obv import *
from vectorbtpro.indicators.custom.ols import *
from vectorbtpro.indicators.custom.patsim import *
from vectorbtpro.indicators.custom.pivotinfo import *
from vectorbtpro.indicators.custom.rsi import *
from vectorbtpro.indicators.custom.sigdet import *
from vectorbtpro.indicators.custom.stoch import *
from vectorbtpro.indicators.custom.supertrend import *
from vectorbtpro.indicators.custom.vwap import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `ADX`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"ADX",
]
__pdoc__ = {}
ADX = IndicatorFactory(
class_name="ADX",
module_name=__name__,
input_names=["high", "low", "close"],
param_names=["window", "wtype"],
output_names=["plus_di", "minus_di", "dx", "adx"],
).with_apply_func(
nb.adx_nb,
kwargs_as_args=["minp", "adjust"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="wilder",
minp=None,
adjust=False,
)
class _ADX(ADX):
"""Average Directional Movement Index (ADX).
The indicator is used by some traders to determine the strength of a trend.
See [Average Directional Index (ADX)](https://www.investopedia.com/terms/a/adx.asp)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plus_di_trace_kwargs: tp.KwargsLike = None,
minus_di_trace_kwargs: tp.KwargsLike = None,
adx_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `ADX.plus_di`, `ADX.minus_di`, and `ADX.adx`.
Args:
column (str): Name of the column to plot.
plus_di_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ADX.plus_di`.
minus_di_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ADX.minus_di`.
adx_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ADX.adx`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.ADX.run(ohlcv['High'], ohlcv['Low'], ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if plus_di_trace_kwargs is None:
plus_di_trace_kwargs = {}
if minus_di_trace_kwargs is None:
minus_di_trace_kwargs = {}
if adx_trace_kwargs is None:
adx_trace_kwargs = {}
plus_di_trace_kwargs = merge_dicts(
dict(name="+DI", line=dict(color=plotting_cfg["color_schema"]["green"], dash="dot")),
plus_di_trace_kwargs,
)
minus_di_trace_kwargs = merge_dicts(
dict(name="-DI", line=dict(color=plotting_cfg["color_schema"]["red"], dash="dot")),
minus_di_trace_kwargs,
)
adx_trace_kwargs = merge_dicts(
dict(name="ADX", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
adx_trace_kwargs,
)
fig = self_col.plus_di.vbt.lineplot(
trace_kwargs=plus_di_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.minus_di.vbt.lineplot(
trace_kwargs=minus_di_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.adx.vbt.lineplot(
trace_kwargs=adx_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(ADX, "__doc__", _ADX.__doc__)
setattr(ADX, "plot", _ADX.plot)
ADX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `ATR`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"ATR",
]
__pdoc__ = {}
ATR = IndicatorFactory(
class_name="ATR",
module_name=__name__,
input_names=["high", "low", "close"],
param_names=["window", "wtype"],
output_names=["tr", "atr"],
).with_apply_func(
nb.atr_nb,
kwargs_as_args=["minp", "adjust"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="wilder",
minp=None,
adjust=False,
)
class _ATR(ATR):
"""Average True Range (ATR).
The indicator provide an indication of the degree of price volatility.
Strong moves, in either direction, are often accompanied by large ranges,
or large True Ranges.
See [Average True Range - ATR](https://www.investopedia.com/terms/a/atr.asp)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
tr_trace_kwargs: tp.KwargsLike = None,
atr_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `ATR.tr` and `ATR.atr`.
Args:
column (str): Name of the column to plot.
tr_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ATR.tr`.
atr_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ATR.atr`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.ATR.run(ohlcv['High'], ohlcv['Low'], ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if tr_trace_kwargs is None:
tr_trace_kwargs = {}
if atr_trace_kwargs is None:
atr_trace_kwargs = {}
tr_trace_kwargs = merge_dicts(
dict(name="TR", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
tr_trace_kwargs,
)
atr_trace_kwargs = merge_dicts(
dict(name="ATR", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])),
atr_trace_kwargs,
)
fig = self_col.tr.vbt.lineplot(
trace_kwargs=tr_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.atr.vbt.lineplot(
trace_kwargs=atr_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(ATR, "__doc__", _ATR.__doc__)
setattr(ATR, "plot", _ATR.plot)
ATR.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `BBANDS`."""
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_2d_array
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.colors import adjust_opacity
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"BBANDS",
]
__pdoc__ = {}
BBANDS = IndicatorFactory(
class_name="BBANDS",
module_name=__name__,
short_name="bb",
input_names=["close"],
param_names=["window", "wtype", "alpha"],
output_names=["upper", "middle", "lower"],
lazy_outputs=dict(
percent_b=lambda self: self.wrapper.wrap(
nb.bbands_percent_b_nb(
to_2d_array(self.close),
to_2d_array(self.upper),
to_2d_array(self.lower),
),
),
bandwidth=lambda self: self.wrapper.wrap(
nb.bbands_bandwidth_nb(
to_2d_array(self.upper),
to_2d_array(self.middle),
to_2d_array(self.lower),
),
),
),
).with_apply_func(
nb.bbands_nb,
kwargs_as_args=["minp", "adjust", "ddof"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="simple",
alpha=2,
minp=None,
adjust=False,
ddof=0,
)
class _BBANDS(BBANDS):
"""Bollinger Bands (BBANDS).
A Bollinger Band® is a technical analysis tool defined by a set of lines plotted two standard
deviations (positively and negatively) away from a simple moving average (SMA) of the security's
price, but can be adjusted to user preferences.
See [Bollinger Band®](https://www.investopedia.com/terms/b/bollingerbands.asp)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_close: bool = True,
close_trace_kwargs: tp.KwargsLike = None,
upper_trace_kwargs: tp.KwargsLike = None,
middle_trace_kwargs: tp.KwargsLike = None,
lower_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `BBANDS.upper`, `BBANDS.middle`, and `BBANDS.lower` against `BBANDS.close`.
Args:
column (str): Name of the column to plot.
plot_close (bool): Whether to plot `BBANDS.close`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `BBANDS.close`.
upper_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `BBANDS.upper`.
middle_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `BBANDS.middle`.
lower_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `BBANDS.lower`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.BBANDS.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if upper_trace_kwargs is None:
upper_trace_kwargs = {}
if middle_trace_kwargs is None:
middle_trace_kwargs = {}
if lower_trace_kwargs is None:
lower_trace_kwargs = {}
lower_trace_kwargs = merge_dicts(
dict(
name="Lower band",
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5)),
),
lower_trace_kwargs,
)
upper_trace_kwargs = merge_dicts(
dict(
name="Upper band",
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5)),
fill="tonexty",
fillcolor="rgba(128, 128, 128, 0.2)",
),
upper_trace_kwargs,
) # default kwargs
middle_trace_kwargs = merge_dicts(
dict(name="Middle band", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), middle_trace_kwargs
)
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
fig = self_col.lower.vbt.lineplot(
trace_kwargs=lower_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.upper.vbt.lineplot(
trace_kwargs=upper_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.middle.vbt.lineplot(
trace_kwargs=middle_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if plot_close:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(BBANDS, "__doc__", _BBANDS.__doc__)
setattr(BBANDS, "plot", _BBANDS.plot)
BBANDS.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `HURST`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.enums import HurstMethod
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"HURST",
]
__pdoc__ = {}
HURST = IndicatorFactory(
class_name="HURST",
module_name=__name__,
input_names=["close"],
param_names=[
"window",
"method",
"max_lag",
"min_log",
"max_log",
"log_step",
"min_chunk",
"max_chunk",
"num_chunks",
],
output_names=["hurst"],
).with_apply_func(
nb.rolling_hurst_nb,
kwargs_as_args=["minp", "stabilize"],
param_settings=dict(
method=dict(
dtype=HurstMethod,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
)
),
window=200,
method="standard",
max_lag=20,
min_log=1,
max_log=2,
log_step=0.25,
min_chunk=8,
max_chunk=100,
num_chunks=5,
minp=None,
stabilize=False,
)
class _HURST(HURST):
"""Moving Hurst exponent (HURST)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
hurst_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `HURST.hurst`.
Args:
column (str): Name of the column to plot.
hurst_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `HURST.hurst`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> ohlcv = vbt.YFData.pull(
... "BTC-USD",
... start="2020-01-01",
... end="2024-01-01"
... ).get()
>>> vbt.HURST.run(ohlcv["Close"]).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if hurst_trace_kwargs is None:
hurst_trace_kwargs = {}
hurst_trace_kwargs = merge_dicts(
dict(name="HURST", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
hurst_trace_kwargs,
)
fig = self_col.hurst.vbt.lineplot(
trace_kwargs=hurst_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
**layout_kwargs,
)
return fig
setattr(HURST, "__doc__", _HURST.__doc__)
setattr(HURST, "plot", _HURST.plot)
HURST.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `MA`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"MA",
]
__pdoc__ = {}
MA = IndicatorFactory(
class_name="MA",
module_name=__name__,
input_names=["close"],
param_names=["window", "wtype"],
output_names=["ma"],
).with_apply_func(
nb.ma_nb,
kwargs_as_args=["minp", "adjust"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="simple",
minp=None,
adjust=False,
)
class _MA(MA):
"""Moving Average (MA).
A moving average is a widely used indicator in technical analysis that helps smooth out
price action by filtering out the “noise” from random short-term price fluctuations.
See [Moving Average (MA)](https://www.investopedia.com/terms/m/movingaverage.asp)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_close: bool = True,
close_trace_kwargs: tp.KwargsLike = None,
ma_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `MA.ma` against `MA.close`.
Args:
column (str): Name of the column to plot.
plot_close (bool): Whether to plot `MA.close`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MA.close`.
ma_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MA.ma`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.MA.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if ma_trace_kwargs is None:
ma_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
ma_trace_kwargs = merge_dicts(
dict(name="MA", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
ma_trace_kwargs,
)
if plot_close:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.ma.vbt.lineplot(
trace_kwargs=ma_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(MA, "__doc__", _MA.__doc__)
setattr(MA, "plot", _MA.plot)
MA.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `MACD`."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_2d_array
from vectorbtpro.generic import nb as generic_nb, enums as generic_enums
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.colors import adjust_opacity
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"MACD",
]
__pdoc__ = {}
MACD = IndicatorFactory(
class_name="MACD",
module_name=__name__,
input_names=["close"],
param_names=["fast_window", "slow_window", "signal_window", "wtype", "macd_wtype", "signal_wtype"],
output_names=["macd", "signal"],
lazy_outputs=dict(
hist=lambda self: self.wrapper.wrap(
nb.macd_hist_nb(
to_2d_array(self.macd),
to_2d_array(self.signal),
),
),
),
).with_apply_func(
nb.macd_nb,
kwargs_as_args=["minp", "macd_minp", "signal_minp", "adjust", "macd_adjust", "signal_adjust"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
),
macd_wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
),
signal_wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
),
),
fast_window=12,
slow_window=26,
signal_window=9,
wtype="exp",
macd_wtype=None,
signal_wtype=None,
minp=None,
macd_minp=None,
signal_minp=None,
adjust=False,
macd_adjust=None,
signal_adjust=None,
)
class _MACD(MACD):
"""Moving Average Convergence Divergence (MACD).
Is a trend-following momentum indicator that shows the relationship between
two moving averages of prices.
See [Moving Average Convergence Divergence – MACD](https://www.investopedia.com/terms/m/macd.asp)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
macd_trace_kwargs: tp.KwargsLike = None,
signal_trace_kwargs: tp.KwargsLike = None,
hist_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `MACD.macd`, `MACD.signal` and `MACD.hist`.
Args:
column (str): Name of the column to plot.
macd_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MACD.macd`.
signal_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MACD.signal`.
hist_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Bar` for `MACD.hist`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.MACD.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(bargap=0)
fig.update_layout(**layout_kwargs)
if macd_trace_kwargs is None:
macd_trace_kwargs = {}
if signal_trace_kwargs is None:
signal_trace_kwargs = {}
if hist_trace_kwargs is None:
hist_trace_kwargs = {}
macd_trace_kwargs = merge_dicts(
dict(name="MACD", line=dict(color=plotting_cfg["color_schema"]["lightblue"])), macd_trace_kwargs
)
signal_trace_kwargs = merge_dicts(
dict(name="Signal", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])), signal_trace_kwargs
)
fig = self_col.macd.vbt.lineplot(
trace_kwargs=macd_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.signal.vbt.lineplot(
trace_kwargs=signal_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
# Plot hist
hist = self_col.hist.values
hist_diff = generic_nb.diff_1d_nb(hist)
marker_colors = np.full(hist.shape, adjust_opacity("silver", 0.75), dtype=object)
marker_colors[(hist > 0) & (hist_diff > 0)] = adjust_opacity("green", 0.75)
marker_colors[(hist > 0) & (hist_diff <= 0)] = adjust_opacity("lightgreen", 0.75)
marker_colors[(hist < 0) & (hist_diff < 0)] = adjust_opacity("red", 0.75)
marker_colors[(hist < 0) & (hist_diff >= 0)] = adjust_opacity("lightcoral", 0.75)
_hist_trace_kwargs = merge_dicts(
dict(
name="Histogram",
x=self_col.hist.index,
y=self_col.hist.values,
marker_color=marker_colors,
marker_line_width=0,
),
hist_trace_kwargs,
)
hist_bar = go.Bar(**_hist_trace_kwargs)
if add_trace_kwargs is None:
add_trace_kwargs = {}
fig.add_trace(hist_bar, **add_trace_kwargs)
return fig
setattr(MACD, "__doc__", _MACD.__doc__)
setattr(MACD, "plot", _MACD.plot)
MACD.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `MSD`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"MSD",
]
__pdoc__ = {}
MSD = IndicatorFactory(
class_name="MSD",
module_name=__name__,
input_names=["close"],
param_names=["window", "wtype"],
output_names=["msd"],
).with_apply_func(
nb.msd_nb,
kwargs_as_args=["minp", "adjust", "ddof"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="simple",
minp=None,
adjust=False,
ddof=0,
)
class _MSD(MSD):
"""Moving Standard Deviation (MSD).
Standard deviation is an indicator that measures the size of an assets recent price moves
in order to predict how volatile the price may be in the future."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
msd_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `MSD.msd`.
Args:
column (str): Name of the column to plot.
msd_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `MSD.msd`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.MSD.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if msd_trace_kwargs is None:
msd_trace_kwargs = {}
msd_trace_kwargs = merge_dicts(
dict(name="MSD", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
msd_trace_kwargs,
)
fig = self_col.msd.vbt.lineplot(
trace_kwargs=msd_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
**layout_kwargs,
)
return fig
setattr(MSD, "__doc__", _MSD.__doc__)
setattr(MSD, "plot", _MSD.plot)
MSD.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `OBV`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"OBV",
]
__pdoc__ = {}
OBV = IndicatorFactory(
class_name="OBV",
module_name=__name__,
short_name="obv",
input_names=["close", "volume"],
param_names=[],
output_names=["obv"],
).with_custom_func(nb.obv_nb)
class _OBV(OBV):
"""On-balance volume (OBV).
It relates price and volume in the stock market. OBV is based on a cumulative total volume.
See [On-Balance Volume (OBV)](https://www.investopedia.com/terms/o/onbalancevolume.asp)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
obv_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `OBV.obv`.
Args:
column (str): Name of the column to plot.
obv_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `OBV.obv`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```py
>>> vbt.OBV.run(ohlcv['Close'], ohlcv['Volume']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if obv_trace_kwargs is None:
obv_trace_kwargs = {}
obv_trace_kwargs = merge_dicts(
dict(name="OBV", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
obv_trace_kwargs,
)
fig = self_col.obv.vbt.lineplot(
trace_kwargs=obv_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(OBV, "__doc__", _OBV.__doc__)
setattr(OBV, "plot", _OBV.plot)
OBV.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `OLS`."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_2d_array
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"OLS",
]
__pdoc__ = {}
OLS = IndicatorFactory(
class_name="OLS",
module_name=__name__,
short_name="ols",
input_names=["x", "y"],
param_names=["window", "norm_window"],
output_names=["slope", "intercept", "zscore"],
lazy_outputs=dict(
pred=lambda self: self.wrapper.wrap(
nb.ols_pred_nb(
to_2d_array(self.x),
to_2d_array(self.slope),
to_2d_array(self.intercept),
),
),
error=lambda self: self.wrapper.wrap(
nb.ols_error_nb(
to_2d_array(self.y),
to_2d_array(self.pred),
),
),
angle=lambda self: self.wrapper.wrap(
nb.ols_angle_nb(
to_2d_array(self.slope),
),
),
),
).with_apply_func(
nb.ols_nb,
kwargs_as_args=["minp", "ddof", "with_zscore"],
window=14,
norm_window=None,
minp=None,
ddof=0,
with_zscore=True,
)
class _OLS(OLS):
"""Rolling Ordinary Least Squares (OLS).
The indicator can be used to detect changes in the behavior of the stocks against the market or each other.
See [The Linear Regression of Time and Price](https://www.investopedia.com/articles/trading/09/linear-regression-time-price.asp).
"""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_y: bool = True,
y_trace_kwargs: tp.KwargsLike = None,
pred_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `OLS.pred` against `OLS.y`.
Args:
column (str): Name of the column to plot.
plot_y (bool): Whether to plot `OLS.y`.
y_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `OLS.y`.
pred_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `OLS.pred`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.OLS.run(np.arange(len(ohlcv)), ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if y_trace_kwargs is None:
y_trace_kwargs = {}
if pred_trace_kwargs is None:
pred_trace_kwargs = {}
y_trace_kwargs = merge_dicts(
dict(name="Y", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
y_trace_kwargs,
)
pred_trace_kwargs = merge_dicts(
dict(name="Pred", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])),
pred_trace_kwargs,
)
if plot_y:
fig = self_col.y.vbt.lineplot(
trace_kwargs=y_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.pred.vbt.lineplot(
trace_kwargs=pred_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
def plot_zscore(
self,
column: tp.Optional[tp.Label] = None,
alpha: float = 0.05,
zscore_trace_kwargs: tp.KwargsLike = None,
add_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `OLS.zscore` with confidence intervals.
Args:
column (str): Name of the column to plot.
alpha (float): The alpha level for the confidence interval.
The default alpha = .05 returns a 95% confidence interval.
zscore_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `OLS.zscore`.
add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape`
when adding the range between both confidence intervals.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.OLS.run(np.arange(len(ohlcv)), ohlcv['Close']).plot_zscore().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
import scipy.stats as st
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
zscore_trace_kwargs = merge_dicts(
dict(name="Z-score", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
zscore_trace_kwargs,
)
fig = self_col.zscore.vbt.lineplot(
trace_kwargs=zscore_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
# Fill void between limits
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
add_shape_kwargs = merge_dicts(
dict(
type="rect",
xref=xref,
yref=yref,
x0=self_col.wrapper.index[0],
y0=st.norm.ppf(1 - alpha / 2),
x1=self_col.wrapper.index[-1],
y1=st.norm.ppf(alpha / 2),
fillcolor="mediumslateblue",
opacity=0.2,
layer="below",
line_width=0,
),
add_shape_kwargs,
)
fig.add_shape(**add_shape_kwargs)
return fig
setattr(OLS, "__doc__", _OLS.__doc__)
setattr(OLS, "plot", _OLS.plot)
setattr(OLS, "plot_zscore", _OLS.plot_zscore)
OLS.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `PATSIM`."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.generic import nb as generic_nb, enums as generic_enums
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"PATSIM",
]
__pdoc__ = {}
PATSIM = IndicatorFactory(
class_name="PATSIM",
module_name=__name__,
short_name="patsim",
input_names=["close"],
param_names=[
"pattern",
"window",
"max_window",
"row_select_prob",
"window_select_prob",
"interp_mode",
"rescale_mode",
"vmin",
"vmax",
"pmin",
"pmax",
"invert",
"error_type",
"distance_measure",
"max_error",
"max_error_interp_mode",
"max_error_as_maxdist",
"max_error_strict",
"min_pct_change",
"max_pct_change",
"min_similarity",
],
output_names=["similarity"],
).with_apply_func(
generic_nb.rolling_pattern_similarity_nb,
param_settings=dict(
pattern=dict(is_array_like=True, min_one_dim=True),
interp_mode=dict(
dtype=generic_enums.InterpMode,
post_index_func=lambda index: index.str.lower(),
),
rescale_mode=dict(
dtype=generic_enums.RescaleMode,
post_index_func=lambda index: index.str.lower(),
),
error_type=dict(
dtype=generic_enums.ErrorType,
post_index_func=lambda index: index.str.lower(),
),
distance_measure=dict(
dtype=generic_enums.DistanceMeasure,
post_index_func=lambda index: index.str.lower(),
),
max_error=dict(is_array_like=True, min_one_dim=True),
max_error_interp_mode=dict(
dtype=generic_enums.InterpMode,
post_index_func=lambda index: index.str.lower(),
),
),
window=None,
max_window=None,
row_select_prob=1.0,
window_select_prob=1.0,
interp_mode="mixed",
rescale_mode="minmax",
vmin=np.nan,
vmax=np.nan,
pmin=np.nan,
pmax=np.nan,
invert=False,
error_type="absolute",
distance_measure="mae",
max_error=np.nan,
max_error_interp_mode=None,
max_error_as_maxdist=False,
max_error_strict=False,
min_pct_change=np.nan,
max_pct_change=np.nan,
min_similarity=np.nan,
)
class _PATSIM(PATSIM):
"""Rolling pattern similarity.
Based on `vectorbtpro.generic.nb.rolling.rolling_pattern_similarity_nb`."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
similarity_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `PATSIM.similarity` against `PATSIM.close`.
Args:
column (str): Name of the column to plot.
similarity_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `PATSIM.similarity`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.PATSIM.run(ohlcv['Close'], np.array([1, 2, 3, 2, 1]), 30).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
similarity_trace_kwargs = merge_dicts(
dict(name="Similarity", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
similarity_trace_kwargs,
)
fig = self_col.similarity.vbt.lineplot(
trace_kwargs=similarity_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
default_layout = dict()
default_layout[yaxis] = dict(tickformat=",.0%")
fig.update_layout(**default_layout)
fig.update_layout(**layout_kwargs)
return fig
def overlay_with_heatmap(
self,
column: tp.Optional[tp.Label] = None,
close_trace_kwargs: tp.KwargsLike = None,
similarity_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Overlay `PATSIM.similarity` as a heatmap on top of `PATSIM.close`.
Args:
column (str): Name of the column to plot.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `PATSIM.close`.
similarity_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Heatmap` for `PATSIM.similarity`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.PATSIM.run(ohlcv['Close'], np.array([1, 2, 3, 2, 1]), 30).overlay_with_heatmap().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if similarity_trace_kwargs is None:
similarity_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
similarity_trace_kwargs = merge_dicts(
dict(
colorbar=dict(tickformat=",.0%"),
colorscale=[
[0.0, "rgba(0, 0, 0, 0)"],
[1.0, plotting_cfg["color_schema"]["lightpurple"]],
],
zmin=0,
zmax=1,
),
similarity_trace_kwargs,
)
fig = self_col.close.vbt.overlay_with_heatmap(
self_col.similarity,
trace_kwargs=close_trace_kwargs,
heatmap_kwargs=dict(y_labels=["Similarity"], trace_kwargs=similarity_trace_kwargs),
add_trace_kwargs=add_trace_kwargs,
fig=fig,
**layout_kwargs,
)
return fig
setattr(PATSIM, "__doc__", _PATSIM.__doc__)
setattr(PATSIM, "plot", _PATSIM.plot)
setattr(PATSIM, "overlay_with_heatmap", _PATSIM.overlay_with_heatmap)
PATSIM.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `PIVOTINFO`."""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_2d_array
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.indicators.enums import Pivot, TrendMode
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"PIVOTINFO",
]
__pdoc__ = {}
PIVOTINFO = IndicatorFactory(
class_name="PIVOTINFO",
module_name=__name__,
short_name="pivotinfo",
input_names=["high", "low"],
param_names=["up_th", "down_th"],
output_names=["conf_pivot", "conf_idx", "last_pivot", "last_idx"],
lazy_outputs=dict(
conf_value=lambda self: self.wrapper.wrap(
nb.pivot_value_nb(
to_2d_array(self.high),
to_2d_array(self.low),
to_2d_array(self.conf_pivot),
to_2d_array(self.conf_idx),
)
),
last_value=lambda self: self.wrapper.wrap(
nb.pivot_value_nb(
to_2d_array(self.high),
to_2d_array(self.low),
to_2d_array(self.last_pivot),
to_2d_array(self.last_idx),
)
),
pivots=lambda self: self.wrapper.wrap(
nb.pivots_nb(
to_2d_array(self.conf_pivot),
to_2d_array(self.conf_idx),
to_2d_array(self.last_pivot),
)
),
modes=lambda self: self.wrapper.wrap(
nb.modes_nb(
to_2d_array(self.pivots),
)
),
),
attr_settings=dict(
conf_pivot=dict(dtype=Pivot, enum_unkval=0),
last_pivot=dict(dtype=Pivot, enum_unkval=0),
pivots=dict(dtype=Pivot, enum_unkval=0),
modes=dict(dtype=TrendMode, enum_unkval=0),
),
).with_apply_func(
nb.pivot_info_nb,
param_settings=dict(
up_th=flex_elem_param_config,
down_th=flex_elem_param_config,
),
)
class _PIVOTINFO(PIVOTINFO):
"""Indicator that returns various information on pivots identified based on thresholds.
* `conf_pivot` (`vectorbtpro.indicators.enums.Pivot`): the type of the latest confirmed pivot (running)
* `conf_idx`: the index of the latest confirmed pivot (running)
* `conf_value`: the high/low value under the latest confirmed pivot (running)
* `last_pivot` (`vectorbtpro.indicators.enums.Pivot`): the type of the latest pivot (running)
* `last_idx`: the index of the latest pivot (running)
* `last_value`: the high/low value under the latest pivot (running)
* `pivots` (`vectorbtpro.indicators.enums.Pivot`): confirmed pivots stored under their indices
(looking ahead - use only for plotting!)
* `modes` (`vectorbtpro.indicators.enums.TrendMode`): modes between confirmed pivot points
(looking ahead - use only for plotting!)
"""
def plot(
self,
column: tp.Optional[tp.Label] = None,
conf_value_trace_kwargs: tp.KwargsLike = None,
last_value_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `PIVOTINFO.conf_value` and `PIVOTINFO.last_value`.
Args:
column (str): Name of the column to plot.
conf_value_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `PIVOTINFO.conf_value` line.
last_value_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `PIVOTINFO.last_value` line.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> fig = ohlcv.vbt.ohlcv.plot()
>>> vbt.PIVOTINFO.run(ohlcv['High'], ohlcv['Low'], 0.1, 0.1).plot(fig=fig).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if conf_value_trace_kwargs is None:
conf_value_trace_kwargs = {}
if last_value_trace_kwargs is None:
last_value_trace_kwargs = {}
conf_value_trace_kwargs = merge_dicts(
dict(name="Confirmed value", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
conf_value_trace_kwargs,
)
last_value_trace_kwargs = merge_dicts(
dict(name="Last value", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])),
last_value_trace_kwargs,
)
fig = self_col.conf_value.vbt.lineplot(
trace_kwargs=conf_value_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.last_value.vbt.lineplot(
trace_kwargs=last_value_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
def plot_zigzag(
self,
column: tp.Optional[tp.Label] = None,
zigzag_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot zig-zag line.
Args:
column (str): Name of the column to plot.
zigzag_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for zig-zag line.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> fig = ohlcv.vbt.ohlcv.plot()
>>> vbt.PIVOTINFO.run(ohlcv['High'], ohlcv['Low'], 0.1, 0.1).plot_zigzag(fig=fig).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if zigzag_trace_kwargs is None:
zigzag_trace_kwargs = {}
zigzag_trace_kwargs = merge_dicts(
dict(name="ZigZag", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
zigzag_trace_kwargs,
)
pivots = self_col.pivots
highs = self_col.high[pivots == Pivot.Peak]
lows = self_col.low[pivots == Pivot.Valley]
fig = (
pd.concat((highs, lows))
.sort_index()
.vbt.lineplot(
trace_kwargs=zigzag_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
)
return fig
setattr(PIVOTINFO, "__doc__", _PIVOTINFO.__doc__)
setattr(PIVOTINFO, "plot", _PIVOTINFO.plot)
setattr(PIVOTINFO, "plot_zigzag", _PIVOTINFO.plot_zigzag)
PIVOTINFO.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RSI`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"RSI",
]
__pdoc__ = {}
RSI = IndicatorFactory(
class_name="RSI",
module_name=__name__,
input_names=["close"],
param_names=["window", "wtype"],
output_names=["rsi"],
).with_apply_func(
nb.rsi_nb,
kwargs_as_args=["minp", "adjust"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="wilder",
minp=None,
adjust=False,
)
class _RSI(RSI):
"""Relative Strength Index (RSI).
Compares the magnitude of recent gains and losses over a specified time
period to measure speed and change of price movements of a security. It is
primarily used to attempt to identify overbought or oversold conditions in
the trading of an asset.
See [Relative Strength Index (RSI)](https://www.investopedia.com/terms/r/rsi.asp)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
limits: tp.Tuple[float, float] = (30, 70),
rsi_trace_kwargs: tp.KwargsLike = None,
add_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `RSI.rsi`.
Args:
column (str): Name of the column to plot.
limits (tuple of float): Tuple of the lower and upper limit.
rsi_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `RSI.rsi`.
add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` when adding the range between both limits.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.RSI.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if rsi_trace_kwargs is None:
rsi_trace_kwargs = {}
rsi_trace_kwargs = merge_dicts(
dict(name="RSI", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
rsi_trace_kwargs,
)
fig = self_col.rsi.vbt.lineplot(
trace_kwargs=rsi_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
default_layout = dict()
default_layout[yaxis] = dict(range=[-5, 105])
fig.update_layout(**default_layout)
fig.update_layout(**layout_kwargs)
# Fill void between limits
add_shape_kwargs = merge_dicts(
dict(
type="rect",
xref=xref,
yref=yref,
x0=self_col.wrapper.index[0],
y0=limits[0],
x1=self_col.wrapper.index[-1],
y1=limits[1],
fillcolor="mediumslateblue",
opacity=0.2,
layer="below",
line_width=0,
),
add_shape_kwargs,
)
fig.add_shape(**add_shape_kwargs)
return fig
setattr(RSI, "__doc__", _RSI.__doc__)
setattr(RSI, "plot", _RSI.plot)
RSI.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `SIGDET`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.colors import adjust_opacity
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"SIGDET",
]
__pdoc__ = {}
SIGDET = IndicatorFactory(
class_name="SIGDET",
module_name=__name__,
short_name="sigdet",
input_names=["close"],
param_names=["lag", "factor", "influence", "up_factor", "down_factor", "mean_influence", "std_influence"],
output_names=["signal", "upper_band", "lower_band"],
).with_apply_func(
nb.signal_detection_nb,
param_settings=dict(
factor=flex_elem_param_config,
influence=flex_elem_param_config,
up_factor=flex_elem_param_config,
down_factor=flex_elem_param_config,
mean_influence=flex_elem_param_config,
std_influence=flex_elem_param_config,
),
lag=14,
factor=1.0,
influence=1.0,
up_factor=None,
down_factor=None,
mean_influence=None,
std_influence=None,
)
class _SIGDET(SIGDET):
"""Robust peak detection algorithm (using z-scores).
See https://stackoverflow.com/a/22640362"""
def plot(
self,
column: tp.Optional[tp.Label] = None,
signal_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `SIGDET.signal` against `SIGDET.close`.
Args:
column (str): Name of the column to plot.
signal_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SIGDET.signal`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.SIGDET.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
signal_trace_kwargs = merge_dicts(
dict(name="Signal", line=dict(color=plotting_cfg["color_schema"]["lightblue"], shape="hv")),
signal_trace_kwargs,
)
fig = self_col.signal.vbt.lineplot(
trace_kwargs=signal_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
**layout_kwargs,
)
return fig
def plot_bands(
self,
column: tp.Optional[tp.Label] = None,
plot_close: bool = True,
close_trace_kwargs: tp.KwargsLike = None,
upper_band_trace_kwargs: tp.KwargsLike = None,
lower_band_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `SIGDET.upper_band` and `SIGDET.lower_band` against `SIGDET.close`.
Args:
column (str): Name of the column to plot.
plot_close (bool): Whether to plot `SIGDET.close`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SIGDET.close`.
upper_band_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SIGDET.upper_band`.
lower_band_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SIGDET.lower_band`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.SIGDET.run(ohlcv['Close']).plot_bands().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if upper_band_trace_kwargs is None:
upper_band_trace_kwargs = {}
if lower_band_trace_kwargs is None:
lower_band_trace_kwargs = {}
lower_band_trace_kwargs = merge_dicts(
dict(
name="Lower band",
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5)),
),
lower_band_trace_kwargs,
)
upper_band_trace_kwargs = merge_dicts(
dict(
name="Upper band",
line=dict(color=adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.5)),
fill="tonexty",
fillcolor="rgba(128, 128, 128, 0.2)",
),
upper_band_trace_kwargs,
) # default kwargs
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
fig = self_col.lower_band.vbt.lineplot(
trace_kwargs=lower_band_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.upper_band.vbt.lineplot(
trace_kwargs=upper_band_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if plot_close:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(SIGDET, "__doc__", _SIGDET.__doc__)
setattr(SIGDET, "plot", _SIGDET.plot)
setattr(SIGDET, "plot_bands", _SIGDET.plot_bands)
SIGDET.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `STOCH`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"STOCH",
]
__pdoc__ = {}
STOCH = IndicatorFactory(
class_name="STOCH",
module_name=__name__,
input_names=["high", "low", "close"],
param_names=["fast_k_window", "slow_k_window", "slow_d_window", "wtype", "slow_k_wtype", "slow_d_wtype"],
output_names=["fast_k", "slow_k", "slow_d"],
).with_apply_func(
nb.stoch_nb,
kwargs_as_args=["minp", "fast_k_minp", "slow_k_minp", "slow_d_minp", "adjust", "slow_k_adjust", "slow_d_adjust"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
),
slow_k_wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
),
slow_d_wtype=dict(
dtype=generic_enums.WType,
dtype_kwargs=dict(enum_unkval=None),
post_index_func=lambda index: index.str.lower(),
),
),
fast_k_window=14,
slow_k_window=3,
slow_d_window=3,
wtype="simple",
slow_k_wtype=None,
slow_d_wtype=None,
minp=None,
fast_k_minp=None,
slow_k_minp=None,
slow_d_minp=None,
adjust=False,
slow_k_adjust=None,
slow_d_adjust=None,
)
class _STOCH(STOCH):
"""Stochastic Oscillator (STOCH).
A stochastic oscillator is a momentum indicator comparing a particular closing price
of a security to a range of its prices over a certain period of time. It is used to
generate overbought and oversold trading signals, utilizing a 0-100 bounded range of values.
See [Stochastic Oscillator](https://www.investopedia.com/terms/s/stochasticoscillator.asp)."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
limits: tp.Tuple[float, float] = (20, 80),
fast_k_trace_kwargs: tp.KwargsLike = None,
slow_k_trace_kwargs: tp.KwargsLike = None,
slow_d_trace_kwargs: tp.KwargsLike = None,
add_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `STOCH.slow_k` and `STOCH.slow_d`.
Args:
column (str): Name of the column to plot.
limits (tuple of float): Tuple of the lower and upper limit.
fast_k_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `STOCH.fast_k`.
slow_k_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `STOCH.slow_k`.
slow_d_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `STOCH.slow_d`.
add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` when adding the range between both limits.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.STOCH.run(ohlcv['High'], ohlcv['Low'], ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fast_k_trace_kwargs is None:
fast_k_trace_kwargs = {}
if slow_k_trace_kwargs is None:
slow_k_trace_kwargs = {}
if slow_d_trace_kwargs is None:
slow_d_trace_kwargs = {}
fast_k_trace_kwargs = merge_dicts(
dict(name="Fast %K", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
fast_k_trace_kwargs,
)
slow_k_trace_kwargs = merge_dicts(
dict(name="Slow %K", line=dict(color=plotting_cfg["color_schema"]["lightpurple"])),
slow_k_trace_kwargs,
)
slow_d_trace_kwargs = merge_dicts(
dict(name="Slow %D", line=dict(color=plotting_cfg["color_schema"]["lightpink"])),
slow_d_trace_kwargs,
)
fig = self_col.fast_k.vbt.lineplot(
trace_kwargs=fast_k_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.slow_k.vbt.lineplot(
trace_kwargs=slow_k_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.slow_d.vbt.lineplot(
trace_kwargs=slow_d_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
default_layout = dict()
default_layout[yaxis] = dict(range=[-5, 105])
fig.update_layout(**default_layout)
fig.update_layout(**layout_kwargs)
# Fill void between limits
add_shape_kwargs = merge_dicts(
dict(
type="rect",
xref=xref,
yref=yref,
x0=self_col.wrapper.index[0],
y0=limits[0],
x1=self_col.wrapper.index[-1],
y1=limits[1],
fillcolor="mediumslateblue",
opacity=0.2,
layer="below",
line_width=0,
),
add_shape_kwargs,
)
fig.add_shape(**add_shape_kwargs)
return fig
setattr(STOCH, "__doc__", _STOCH.__doc__)
setattr(STOCH, "plot", _STOCH.plot)
STOCH.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `SUPERTREND`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"SUPERTREND",
]
__pdoc__ = {}
SUPERTREND = IndicatorFactory(
class_name="SUPERTREND",
module_name=__name__,
short_name="supertrend",
input_names=["high", "low", "close"],
param_names=["period", "multiplier"],
output_names=["trend", "direction", "long", "short"],
).with_apply_func(nb.supertrend_nb, period=7, multiplier=3)
class _SUPERTREND(SUPERTREND):
"""Supertrend indicator."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_close: bool = True,
close_trace_kwargs: tp.KwargsLike = None,
superl_trace_kwargs: tp.KwargsLike = None,
supers_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `SUPERTREND.long` and `SUPERTREND.short` against `SUPERTREND.close`.
Args:
column (str): Name of the column to plot.
plot_close (bool): Whether to plot `SUPERTREND.close`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SUPERTREND.close`.
superl_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SUPERTREND.long`.
supers_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `SUPERTREND.short`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.SUPERTREND.run(ohlcv['High'], ohlcv['Low'], ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if superl_trace_kwargs is None:
superl_trace_kwargs = {}
if supers_trace_kwargs is None:
supers_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
superl_trace_kwargs = merge_dicts(
dict(name="Long", line=dict(color=plotting_cfg["color_schema"]["green"])),
superl_trace_kwargs,
)
supers_trace_kwargs = merge_dicts(
dict(name="Short", line=dict(color=plotting_cfg["color_schema"]["red"])),
supers_trace_kwargs,
)
if plot_close:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.long.vbt.lineplot(
trace_kwargs=superl_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.short.vbt.lineplot(
trace_kwargs=supers_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(SUPERTREND, "__doc__", _SUPERTREND.__doc__)
setattr(SUPERTREND, "plot", _SUPERTREND.plot)
SUPERTREND.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `VWAP`."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.indicators import nb
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.template import RepFunc
__all__ = [
"VWAP",
]
__pdoc__ = {}
def substitute_anchor(wrapper: ArrayWrapper, anchor: tp.Optional[tp.FrequencyLike]) -> tp.Array1d:
"""Substitute reset frequency by group lens."""
if anchor is None:
return np.array([wrapper.shape[0]])
return wrapper.get_index_grouper(anchor).get_group_lens()
VWAP = IndicatorFactory(
class_name="VWAP",
module_name=__name__,
short_name="vwap",
input_names=["high", "low", "close", "volume"],
param_names=["anchor"],
output_names=["vwap"],
).with_apply_func(
nb.vwap_nb,
param_settings=dict(
anchor=dict(template=RepFunc(substitute_anchor)),
),
anchor="D",
)
class _VWAP(VWAP):
"""Volume-Weighted Average Price (VWAP).
VWAP is a technical analysis indicator used on intraday charts that resets at the start
of every new trading session.
See [Volume-Weighted Average Price (VWAP)](https://www.investopedia.com/terms/v/vwap.asp).
Anchor can be any index grouper."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_close: bool = True,
close_trace_kwargs: tp.KwargsLike = None,
vwap_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `VWAP.vwap` against `VWAP.close`.
Args:
column (str): Name of the column to plot.
plot_close (bool): Whether to plot `VWAP.close`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `VWAP.close`.
vwap_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `VWAP.vwap`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.VWAP.run(
... ohlcv['High'],
... ohlcv['Low'],
... ohlcv['Close'],
... ohlcv['Volume'],
... anchor="W"
... ).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if vwap_trace_kwargs is None:
vwap_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
vwap_trace_kwargs = merge_dicts(
dict(name="VWAP", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
vwap_trace_kwargs,
)
if plot_close:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.vwap.vbt.lineplot(
trace_kwargs=vwap_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(VWAP, "__doc__", _VWAP.__doc__)
setattr(VWAP, "plot", _VWAP.plot)
VWAP.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for building and running indicators.
Technical indicators are used to see past trends and anticipate future moves.
See [Using Technical Indicators to Develop Trading Strategies](https://www.investopedia.com/articles/trading/11/indicators-and-strategies-explained.asp)."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.indicators.configs import *
from vectorbtpro.indicators.custom import *
from vectorbtpro.indicators.expr import *
from vectorbtpro.indicators.factory import *
from vectorbtpro.indicators.nb import *
from vectorbtpro.indicators.talib_ import *
__exclude_from__all__ = [
"enums",
]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Configs for custom indicators."""
from vectorbtpro.utils.config import ReadonlyConfig
__all__ = [
"flex_col_param_config",
"flex_elem_param_config",
]
flex_elem_param_config = ReadonlyConfig(
dict(
is_array_like=True,
bc_to_input=True,
broadcast_kwargs=dict(keep_flex=True, min_ndim=2),
)
)
"""Config for flexible element-wise parameters."""
flex_row_param_config = ReadonlyConfig(
dict(
is_array_like=True,
bc_to_input=0,
broadcast_kwargs=dict(keep_flex=True, min_ndim=1),
)
)
"""Config for flexible row-wise parameters."""
flex_col_param_config = ReadonlyConfig(
dict(
is_array_like=True,
bc_to_input=1,
per_column=True,
broadcast_kwargs=dict(keep_flex=True, min_ndim=1),
)
)
"""Config for flexible column-wise parameters."""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Named tuples and enumerated types for indicators."""
from vectorbtpro import _typing as tp
from vectorbtpro.utils.formatting import prettify
__pdoc__all__ = __all__ = [
"Pivot",
"TrendMode",
"HurstMethod",
"SuperTrendAIS",
"SuperTrendAOS",
]
__pdoc__ = {}
# ############# Enums ############# #
class PivotT(tp.NamedTuple):
Valley: int = -1
Peak: int = 1
Pivot = PivotT()
"""_"""
__pdoc__[
"Pivot"
] = f"""Pivot.
```python
{prettify(Pivot)}
```
"""
class TrendModeT(tp.NamedTuple):
Downtrend: int = -1
Uptrend: int = 1
TrendMode = TrendModeT()
"""_"""
__pdoc__[
"TrendMode"
] = f"""Trend mode.
```python
{prettify(TrendMode)}
```
"""
class HurstMethodT(tp.NamedTuple):
Standard: int = 0
LogRS: int = 1
RS: int = 2
DMA: int = 3
DSOD: int = 4
HurstMethod = HurstMethodT()
"""_"""
__pdoc__[
"HurstMethod"
] = f"""Hurst method.
```python
{prettify(HurstMethod)}
```
"""
# ############# States ############# #
class SuperTrendAIS(tp.NamedTuple):
i: int
high: float
low: float
close: float
prev_close: float
prev_upper: float
prev_lower: float
prev_direction: int
nobs: int
weighted_avg: float
old_wt: float
period: int
multiplier: float
__pdoc__[
"SuperTrendAIS"
] = """A named tuple representing the input state of
`vectorbtpro.indicators.nb.supertrend_acc_nb`."""
class SuperTrendAOS(tp.NamedTuple):
nobs: int
weighted_avg: float
old_wt: float
upper: float
lower: float
trend: float
direction: int
long: float
short: float
__pdoc__[
"SuperTrendAOS"
] = """A named tuple representing the output state of
`vectorbtpro.indicators.nb.supertrend_acc_nb`."""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions and config for evaluating indicator expressions."""
import math
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.grouping.base import Grouper
from vectorbtpro.generic.nb import (
fshift_nb,
diff_nb,
rank_nb,
rolling_sum_nb,
rolling_mean_nb,
rolling_std_nb,
wm_mean_nb,
rolling_rank_nb,
rolling_prod_nb,
rolling_min_nb,
rolling_max_nb,
rolling_argmin_nb,
rolling_argmax_nb,
rolling_cov_nb,
rolling_corr_nb,
demean_nb,
)
from vectorbtpro.indicators.nb import vwap_nb
from vectorbtpro.returns.nb import returns_nb
from vectorbtpro.utils.config import HybridConfig
__all__ = []
# ############# Delay ############# #
def delay(x: tp.Array2d, d: float) -> tp.Array2d:
"""Value of `x` `d` days ago."""
return fshift_nb(x, math.floor(d))
def delta(x: tp.Array2d, d: float) -> tp.Array2d:
"""Today’s value of `x` minus the value of `x` `d` days ago."""
return diff_nb(x, math.floor(d))
# ############# Cross-section ############# #
def cs_rescale(x: tp.Array2d) -> tp.Array2d:
"""Rescale `x` such that `sum(abs(x)) = 1`."""
return (x.T / np.abs(x).sum(axis=1)).T
def cs_rank(x: tp.Array2d) -> tp.Array2d:
"""Rank cross-sectionally."""
return rank_nb(x.T, pct=True).T
def cs_demean(x: tp.Array2d, g: tp.GroupByLike, context: tp.KwargsLike = None) -> tp.Array2d:
"""Demean `x` against groups `g` cross-sectionally."""
group_map = Grouper(context["wrapper"].columns, g).get_group_map()
return demean_nb(x, group_map)
# ############# Rolling ############# #
def ts_min(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling min."""
return rolling_min_nb(x, math.floor(d))
def ts_max(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling max."""
return rolling_max_nb(x, math.floor(d))
def ts_argmin(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling argmin."""
argmin = rolling_argmin_nb(x, math.floor(d), local=True)
if -1 in argmin:
argmin = np.where(argmin != -1, argmin, np.nan)
return np.add(argmin, 1)
def ts_argmax(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling argmax."""
argmax = rolling_argmax_nb(x, math.floor(d), local=True)
if -1 in argmax:
argmax = np.where(argmax != -1, argmax, np.nan)
return np.add(argmax, 1)
def ts_rank(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling rank."""
return rolling_rank_nb(x, math.floor(d), pct=True)
def ts_sum(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling sum."""
return rolling_sum_nb(x, math.floor(d))
def ts_product(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling product."""
return rolling_prod_nb(x, math.floor(d))
def ts_mean(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling mean."""
return rolling_mean_nb(x, math.floor(d))
def ts_wmean(x: tp.Array2d, d: float) -> tp.Array2d:
"""Weighted moving average over the past `d` days with linearly decaying weight."""
return wm_mean_nb(x, math.floor(d))
def ts_std(x: tp.Array2d, d: float) -> tp.Array2d:
"""Return the rolling standard deviation."""
return rolling_std_nb(x, math.floor(d))
def ts_corr(x: tp.Array2d, y: tp.Array2d, d: float) -> tp.Array2d:
"""Time-serial correlation of `x` and `y` for the past `d` days."""
return rolling_corr_nb(x, y, math.floor(d))
def ts_cov(x: tp.Array2d, y: tp.Array2d, d: float) -> tp.Array2d:
"""Time-serial covariance of `x` and `y` for the past `d` days."""
return rolling_cov_nb(x, y, math.floor(d))
def adv(d: float, context: tp.KwargsLike = None) -> tp.Array2d:
"""Average daily dollar volume for the past `d` days."""
return ts_mean(context["volume"], math.floor(d))
# ############# Substitutions ############# #
def returns(context: tp.KwargsLike = None) -> tp.Array2d:
"""Daily close-to-close returns."""
return returns_nb(context["close"])
def vwap(context: tp.KwargsLike = None) -> tp.Array2d:
"""VWAP."""
if isinstance(context["wrapper"].index, pd.DatetimeIndex):
group_lens = context["wrapper"].get_index_grouper("D").get_group_lens()
else:
group_lens = np.array([context["wrapper"].shape[0]])
return vwap_nb(context["high"], context["low"], context["close"], context["volume"], group_lens)
def cap(context: tp.KwargsLike = None) -> tp.Array2d:
"""Market capitalization."""
return context["close"] * context["volume"]
# ############# Configs ############# #
__pdoc__ = {}
expr_func_config = HybridConfig(
dict(
delay=dict(func=delay),
delta=dict(func=delta),
cs_rescale=dict(func=cs_rescale),
cs_rank=dict(func=cs_rank),
cs_demean=dict(func=cs_demean),
ts_min=dict(func=ts_min),
ts_max=dict(func=ts_max),
ts_argmin=dict(func=ts_argmin),
ts_argmax=dict(func=ts_argmax),
ts_rank=dict(func=ts_rank),
ts_sum=dict(func=ts_sum),
ts_product=dict(func=ts_product),
ts_mean=dict(func=ts_mean),
ts_wmean=dict(func=ts_wmean),
ts_std=dict(func=ts_std),
ts_corr=dict(func=ts_corr),
ts_cov=dict(func=ts_cov),
adv=dict(func=adv, magnet_inputs=["volume"]),
)
)
"""_"""
__pdoc__[
"expr_func_config"
] = f"""Config for functions used in indicator expressions.
Can be modified.
```python
{expr_func_config.prettify()}
```
"""
expr_res_func_config = HybridConfig(
dict(
returns=dict(func=returns, magnet_inputs=["close"]),
vwap=dict(func=vwap, magnet_inputs=["high", "low", "close", "volume"]),
cap=dict(func=cap, magnet_inputs=["close", "volume"]),
)
)
"""_"""
__pdoc__[
"expr_res_func_config"
] = f"""Config for resolvable functions used in indicator expressions.
Can be modified.
```python
{expr_res_func_config.prettify()}
```
"""
wqa101_expr_config = HybridConfig(
{
1: "cs_rank(ts_argmax(power(where(returns < 0, ts_std(returns, 20), close), 2.), 5)) - 0.5",
2: "-ts_corr(cs_rank(delta(log(volume), 2)), cs_rank((close - open) / open), 6)",
3: "-ts_corr(cs_rank(open), cs_rank(volume), 10)",
4: "-ts_rank(cs_rank(low), 9)",
5: "cs_rank(open - (ts_sum(vwap, 10) / 10)) * (-abs(cs_rank(close - vwap)))",
6: "-ts_corr(open, volume, 10)",
7: "where(adv(20) < volume, (-ts_rank(abs(delta(close, 7)), 60)) * sign(delta(close, 7)), -1)",
8: "-cs_rank((ts_sum(open, 5) * ts_sum(returns, 5)) - delay(ts_sum(open, 5) * ts_sum(returns, 5), 10))",
9: (
"where(0 < ts_min(delta(close, 1), 5), delta(close, 1), where(ts_max(delta(close, 1), 5) < 0, delta(close,"
" 1), -delta(close, 1)))"
),
10: (
"cs_rank(where(0 < ts_min(delta(close, 1), 4), delta(close, 1), where(ts_max(delta(close, 1), 4) < 0,"
" delta(close, 1), -delta(close, 1))))"
),
11: "(cs_rank(ts_max(vwap - close, 3)) + cs_rank(ts_min(vwap - close, 3))) * cs_rank(delta(volume, 3))",
12: "sign(delta(volume, 1)) * (-delta(close, 1))",
13: "-cs_rank(ts_cov(cs_rank(close), cs_rank(volume), 5))",
14: "(-cs_rank(delta(returns, 3))) * ts_corr(open, volume, 10)",
15: "-ts_sum(cs_rank(ts_corr(cs_rank(high), cs_rank(volume), 3)), 3)",
16: "-cs_rank(ts_cov(cs_rank(high), cs_rank(volume), 5))",
17: (
"((-cs_rank(ts_rank(close, 10))) * cs_rank(delta(delta(close, 1), 1))) * cs_rank(ts_rank(volume /"
" adv(20), 5))"
),
18: "-cs_rank((ts_std(abs(close - open), 5) + (close - open)) + ts_corr(close, open, 10))",
19: "(-sign((close - delay(close, 7)) + delta(close, 7))) * (1 + cs_rank(1 + ts_sum(returns, 250)))",
20: "((-cs_rank(open - delay(high, 1))) * cs_rank(open - delay(close, 1))) * cs_rank(open - delay(low, 1))",
21: (
"where(((ts_sum(close, 8) / 8) + ts_std(close, 8)) < (ts_sum(close, 2) / 2), -1, where((ts_sum(close, 2) /"
" 2) < ((ts_sum(close, 8) / 8) - ts_std(close, 8)), 1, where(volume / adv(20) >= 1, 1, -1)))"
),
22: "-(delta(ts_corr(high, volume, 5), 5) * cs_rank(ts_std(close, 20)))",
23: "where((ts_sum(high, 20) / 20) < high, -delta(high, 2), 0)",
24: (
"where((delta(ts_sum(close, 100) / 100, 100) / delay(close, 100)) <= 0.05, (-(close - ts_min(close, 100))),"
" -delta(close, 3))"
),
25: "cs_rank((((-returns) * adv(20)) * vwap) * (high - close))",
26: "-ts_max(ts_corr(ts_rank(volume, 5), ts_rank(high, 5), 5), 3)",
27: "where(0.5 < cs_rank(ts_sum(ts_corr(cs_rank(volume), cs_rank(vwap), 6), 2) / 2.0), -1, 1)",
28: "cs_rescale((ts_corr(adv(20), low, 5) + ((high + low) / 2)) - close)",
29: (
"ts_min(ts_product(cs_rank(cs_rank(cs_rescale(log(ts_sum(ts_min(cs_rank(cs_rank(-cs_rank(delta(close - 1,"
" 5)))), 2), 1))))), 1), 5) + ts_rank(delay(-returns, 6), 5)"
),
30: (
"((1.0 - cs_rank((sign(close - delay(close, 1)) + sign(delay(close, 1) - delay(close, 2))) +"
" sign(delay(close, 2) - delay(close, 3)))) * ts_sum(volume, 5)) / ts_sum(volume, 20)"
),
31: (
"(cs_rank(cs_rank(cs_rank(ts_wmean(-cs_rank(cs_rank(delta(close, 10))), 10)))) + cs_rank(-delta(close, 3)))"
" + sign(cs_rescale(ts_corr(adv(20), low, 12)))"
),
32: "cs_rescale((ts_sum(close, 7) / 7) - close) + (20 * cs_rescale(ts_corr(vwap, delay(close, 5), 230)))",
33: "cs_rank(-(1 - (open / close)))",
34: "cs_rank((1 - cs_rank(ts_std(returns, 2) / ts_std(returns, 5))) + (1 - cs_rank(delta(close, 1))))",
35: "(ts_rank(volume, 32) * (1 - ts_rank((close + high) - low, 16))) * (1 - ts_rank(returns, 32))",
36: (
"((((2.21 * cs_rank(ts_corr(close - open, delay(volume, 1), 15))) + (0.7 * cs_rank(open - close))) + (0.73"
" * cs_rank(ts_rank(delay(-returns, 6), 5)))) + cs_rank(abs(ts_corr(vwap, adv(20), 6)))) + (0.6 *"
" cs_rank(((ts_sum(close, 200) / 200) - open) * (close - open)))"
),
37: "cs_rank(ts_corr(delay(open - close, 1), close, 200)) + cs_rank(open - close)",
38: "(-cs_rank(ts_rank(close, 10))) * cs_rank(close / open)",
39: (
"(-cs_rank(delta(close, 7) * (1 - cs_rank(ts_wmean(volume / adv(20), 9))))) * (1 + cs_rank(ts_sum(returns,"
" 250)))"
),
40: "(-cs_rank(ts_std(high, 10))) * ts_corr(high, volume, 10)",
41: "((high * low) ** 0.5) - vwap",
42: "cs_rank(vwap - close) / cs_rank(vwap + close)",
43: "ts_rank(volume / adv(20), 20) * ts_rank(-delta(close, 7), 8)",
44: "-ts_corr(high, cs_rank(volume), 5)",
45: (
"-((cs_rank(ts_sum(delay(close, 5), 20) / 20) * ts_corr(close, volume, 2)) * cs_rank(ts_corr(ts_sum(close,"
" 5), ts_sum(close, 20), 2)))"
),
46: (
"where(0.25 < (((delay(close, 20) - delay(close, 10)) / 10) - ((delay(close, 10) - close) / 10)), -1,"
" where((((delay(close, 20) - delay(close, 10)) / 10) - ((delay(close, 10) - close) / 10)) < 0, 1, -(close"
" - delay(close, 1))))"
),
47: (
"(((cs_rank(1 / close) * volume) / adv(20)) * ((high * cs_rank(high - close)) / (ts_sum(high, 5) / 5))) -"
" cs_rank(vwap - delay(vwap, 5))"
),
48: (
"cs_demean((ts_corr(delta(close, 1), delta(delay(close, 1), 1), 250) * delta(close, 1)) / close,"
" 'subindustry') / ts_sum((delta(close, 1) / delay(close, 1)) ** 2, 250)"
),
49: (
"where((((delay(close, 20) - delay(close, 10)) / 10) - ((delay(close, 10) - close) / 10)) < (-0.1), 1,"
" -(close - delay(close, 1)))"
),
50: "-ts_max(cs_rank(ts_corr(cs_rank(volume), cs_rank(vwap), 5)), 5)",
51: (
"where((((delay(close, 20) - delay(close, 10)) / 10) - ((delay(close, 10) - close) / 10)) < (-0.05), 1,"
" -(close - delay(close, 1)))"
),
52: (
"(((-ts_min(low, 5)) + delay(ts_min(low, 5), 5)) * cs_rank((ts_sum(returns, 240) - ts_sum(returns, 20)) /"
" 220)) * ts_rank(volume, 5)"
),
53: "-delta(((close - low) - (high - close)) / (close - low), 9)",
54: "(-((low - close) * (open ** 5))) / ((low - high) * (close ** 5))",
55: "-ts_corr(cs_rank((close - ts_min(low, 12)) / (ts_max(high, 12) - ts_min(low, 12))), cs_rank(volume), 6)",
56: "0 - (1 * (cs_rank(ts_sum(returns, 10) / ts_sum(ts_sum(returns, 2), 3)) * cs_rank(returns * cap)))",
57: "0 - (1 * ((close - vwap) / ts_wmean(cs_rank(ts_argmax(close, 30)), 2)))",
58: "-ts_rank(ts_wmean(ts_corr(cs_demean(vwap, 'sector'), volume, 3.92795), 7.89291), 5.50322)",
59: (
"-ts_rank(ts_wmean(ts_corr(cs_demean((vwap * 0.728317) + (vwap * (1 - 0.728317)), 'industry'), volume,"
" 4.25197), 16.2289), 8.19648)"
),
60: (
"0 - (1 * ((2 * cs_rescale(cs_rank((((close - low) - (high - close)) / (high - low)) * volume))) -"
" cs_rescale(cs_rank(ts_argmax(close, 10)))))"
),
61: "cs_rank(vwap - ts_min(vwap, 16.1219)) < cs_rank(ts_corr(vwap, adv(180), 17.9282))",
62: (
"(cs_rank(ts_corr(vwap, ts_sum(adv(20), 22.4101), 9.91009)) < cs_rank((cs_rank(open) + cs_rank(open)) <"
" (cs_rank((high + low) / 2) + cs_rank(high)))) * (-1)"
),
63: (
"(cs_rank(ts_wmean(delta(cs_demean(close, 'industry'), 2.25164), 8.22237)) - cs_rank(ts_wmean(ts_corr((vwap"
" * 0.318108) + (open * (1 - 0.318108)), ts_sum(adv(180), 37.2467), 13.557), 12.2883))) * (-1)"
),
64: (
"(cs_rank(ts_corr(ts_sum((open * 0.178404) + (low * (1 - 0.178404)), 12.7054), ts_sum(adv(120), 12.7054),"
" 16.6208)) < cs_rank(delta((((high + low) / 2) * 0.178404) + (vwap * (1 - 0.178404)), 3.69741))) * (-1)"
),
65: (
"(cs_rank(ts_corr((open * 0.00817205) + (vwap * (1 - 0.00817205)), ts_sum(adv(60), 8.6911), 6.40374)) <"
" cs_rank(open - ts_min(open, 13.635))) * (-1)"
),
66: (
"(cs_rank(ts_wmean(delta(vwap, 3.51013), 7.23052)) + ts_rank(ts_wmean((((low * 0.96633) + (low * (1 -"
" 0.96633))) - vwap) / (open - ((high + low) / 2)), 11.4157), 6.72611)) * (-1)"
),
67: (
"(cs_rank(high - ts_min(high, 2.14593)) ** cs_rank(ts_corr(cs_demean(vwap, 'sector'), cs_demean(adv(20),"
" 'subindustry'), 6.02936))) * (-1)"
),
68: (
"(ts_rank(ts_corr(cs_rank(high), cs_rank(adv(15)), 8.91644), 13.9333) < cs_rank(delta((close * 0.518371) +"
" (low * (1 - 0.518371)), 1.06157))) * (-1)"
),
69: (
"(cs_rank(ts_max(delta(cs_demean(vwap, 'industry'), 2.72412), 4.79344)) ** ts_rank(ts_corr((close *"
" 0.490655) + (vwap * (1 - 0.490655)), adv(20), 4.92416), 9.0615)) * (-1)"
),
70: (
"(cs_rank(delta(vwap, 1.29456)) ** ts_rank(ts_corr(cs_demean(close, 'industry'), adv(50), 17.8256),"
" 17.9171)) * (-1)"
),
71: (
"maximum(ts_rank(ts_wmean(ts_corr(ts_rank(close, 3.43976), ts_rank(adv(180), 12.0647), 18.0175), 4.20501),"
" 15.6948), ts_rank(ts_wmean(cs_rank((low + open) - (vwap + vwap)) ** 2, 16.4662), 4.4388))"
),
72: (
"cs_rank(ts_wmean(ts_corr((high + low) / 2, adv(40), 8.93345), 10.1519)) /"
" cs_rank(ts_wmean(ts_corr(ts_rank(vwap, 3.72469), ts_rank(volume, 18.5188), 6.86671), 2.95011))"
),
73: (
"maximum(cs_rank(ts_wmean(delta(vwap, 4.72775), 2.91864)), ts_rank(ts_wmean((delta((open * 0.147155) + (low"
" * (1 - 0.147155)), 2.03608) / ((open * 0.147155) + (low * (1 - 0.147155)))) * (-1), 3.33829), 16.7411)) *"
" (-1)"
),
74: (
"(cs_rank(ts_corr(close, ts_sum(adv(30), 37.4843), 15.1365)) < cs_rank(ts_corr(cs_rank((high * 0.0261661) +"
" (vwap * (1 - 0.0261661))), cs_rank(volume), 11.4791))) * (-1)"
),
75: "cs_rank(ts_corr(vwap, volume, 4.24304)) < cs_rank(ts_corr(cs_rank(low), cs_rank(adv(50)), 12.4413))",
76: (
"maximum(cs_rank(ts_wmean(delta(vwap, 1.24383), 11.8259)), ts_rank(ts_wmean(ts_rank(ts_corr(cs_demean(low,"
" 'sector'), adv(81), 8.14941), 19.569), 17.1543), 19.383)) * (-1)"
),
77: (
"minimum(cs_rank(ts_wmean((((high + low) / 2) + high) - (vwap + high), 20.0451)),"
" cs_rank(ts_wmean(ts_corr((high + low) / 2, adv(40), 3.1614), 5.64125)))"
),
78: (
"cs_rank(ts_corr(ts_sum((low * 0.352233) + (vwap * (1 - 0.352233)), 19.7428), ts_sum(adv(40), 19.7428),"
" 6.83313)) ** cs_rank(ts_corr(cs_rank(vwap), cs_rank(volume), 5.77492))"
),
79: (
"cs_rank(delta(cs_demean((close * 0.60733) + (open * (1 - 0.60733)), 'sector'), 1.23438)) <"
" cs_rank(ts_corr(ts_rank(vwap, 3.60973), ts_rank(adv(150), 9.18637), 14.6644))"
),
80: (
"(cs_rank(sign(delta(cs_demean((open * 0.868128) + (high * (1 - 0.868128)), 'industry'), 4.04545))) **"
" ts_rank(ts_corr(high, adv(10), 5.11456), 5.53756)) * (-1)"
),
81: (
"(cs_rank(log(ts_product(cs_rank(cs_rank(ts_corr(vwap, ts_sum(adv(10), 49.6054), 8.47743)) ** 4),"
" 14.9655))) < cs_rank(ts_corr(cs_rank(vwap), cs_rank(volume), 5.07914))) * (-1)"
),
82: (
"minimum(cs_rank(ts_wmean(delta(open, 1.46063), 14.8717)), ts_rank(ts_wmean(ts_corr(cs_demean(volume,"
" 'sector'), ((open * 0.634196) + (open * (1 - 0.634196))), 17.4842), 6.92131), 13.4283)) * (-1)"
),
83: (
"(cs_rank(delay((high - low) / (ts_sum(close, 5) / 5), 2)) * cs_rank(cs_rank(volume))) / (((high - low) /"
" (ts_sum(close, 5) / 5)) / (vwap - close))"
),
84: "power(ts_rank(vwap - ts_max(vwap, 15.3217), 20.7127), delta(close, 4.96796))",
85: (
"cs_rank(ts_corr((high * 0.876703) + (close * (1 - 0.876703)), adv(30), 9.61331)) **"
" cs_rank(ts_corr(ts_rank((high + low) / 2, 3.70596), ts_rank(volume, 10.1595), 7.11408))"
),
86: (
"(ts_rank(ts_corr(close, ts_sum(adv(20), 14.7444), 6.00049), 20.4195) < cs_rank((open + close) - (vwap +"
" open))) * (-1)"
),
87: (
"maximum(cs_rank(ts_wmean(delta((close * 0.369701) + (vwap * (1 - 0.369701)), 1.91233), 2.65461)),"
" ts_rank(ts_wmean(abs(ts_corr(cs_demean(adv(81), 'industry'), close, 13.4132)), 4.89768), 14.4535)) * (-1)"
),
88: (
"minimum(cs_rank(ts_wmean((cs_rank(open) + cs_rank(low)) - (cs_rank(high) + cs_rank(close)), 8.06882)),"
" ts_rank(ts_wmean(ts_corr(ts_rank(close, 8.44728), ts_rank(adv(60), 20.6966), 8.01266), 6.65053),"
" 2.61957))"
),
89: (
"ts_rank(ts_wmean(ts_corr((low * 0.967285) + (low * (1 - 0.967285)), adv(10), 6.94279), 5.51607), 3.79744)"
" - ts_rank(ts_wmean(delta(cs_demean(vwap, 'industry'), 3.48158), 10.1466), 15.3012)"
),
90: (
"(cs_rank(close - ts_max(close, 4.66719)) ** ts_rank(ts_corr(cs_demean(adv(40), 'subindustry'), low,"
" 5.38375), 3.21856)) * (-1)"
),
91: (
"(ts_rank(ts_wmean(ts_wmean(ts_corr(cs_demean(close, 'industry'), volume, 9.74928), 16.398), 3.83219),"
" 4.8667) - cs_rank(ts_wmean(ts_corr(vwap, adv(30), 4.01303), 2.6809))) * (-1)"
),
92: (
"minimum(ts_rank(ts_wmean((((high + low) / 2) + close) < (low + open), 14.7221), 18.8683),"
" ts_rank(ts_wmean(ts_corr(cs_rank(low), cs_rank(adv(30)), 7.58555), 6.94024), 6.80584))"
),
93: (
"ts_rank(ts_wmean(ts_corr(cs_demean(vwap, 'industry'), adv(81), 17.4193), 19.848), 7.54455) /"
" cs_rank(ts_wmean(delta((close * 0.524434) + (vwap * (1 - 0.524434)), 2.77377), 16.2664))"
),
94: (
"(cs_rank(vwap - ts_min(vwap, 11.5783)) ** ts_rank(ts_corr(ts_rank(vwap, 19.6462), ts_rank(adv(60),"
" 4.02992), 18.0926), 2.70756)) * (-1)"
),
95: (
"cs_rank(open - ts_min(open, 12.4105)) < ts_rank(cs_rank(ts_corr(ts_sum((high + low) / 2, 19.1351),"
" ts_sum(adv(40), 19.1351), 12.8742)) ** 5, 11.7584)"
),
96: (
"maximum(ts_rank(ts_wmean(ts_corr(cs_rank(vwap), cs_rank(volume), 3.83878), 4.16783), 8.38151),"
" ts_rank(ts_wmean(ts_argmax(ts_corr(ts_rank(close, 7.45404), ts_rank(adv(60), 4.13242), 3.65459),"
" 12.6556), 14.0365), 13.4143)) * (-1)"
),
97: (
"(cs_rank(ts_wmean(delta(cs_demean((low * 0.721001) + (vwap * (1 - 0.721001)), 'industry'), 3.3705),"
" 20.4523)) - ts_rank(ts_wmean(ts_rank(ts_corr(ts_rank(low, 7.87871), ts_rank(adv(60), 17.255), 4.97547),"
" 18.5925), 15.7152), 6.71659)) * (-1)"
),
98: (
"cs_rank(ts_wmean(ts_corr(vwap, ts_sum(adv(5), 26.4719), 4.58418), 7.18088)) -"
" cs_rank(ts_wmean(ts_rank(ts_argmin(ts_corr(cs_rank(open), cs_rank(adv(15)), 20.8187), 8.62571), 6.95668),"
" 8.07206))"
),
99: (
"(cs_rank(ts_corr(ts_sum((high + low) / 2, 19.8975), ts_sum(adv(60), 19.8975), 8.8136)) <"
" cs_rank(ts_corr(low, volume, 6.28259))) * (-1)"
),
100: (
"0 - (1 * (((1.5 * cs_rescale(cs_demean(cs_demean(cs_rank((((close - low) - (high - close)) / (high - low))"
" * volume), 'subindustry'), 'subindustry'))) - cs_rescale(cs_demean(ts_corr(close, cs_rank(adv(20)), 5) -"
" cs_rank(ts_argmin(close, 30)), 'subindustry'))) * (volume / adv(20))))"
),
101: "(close - open) / ((high - low) + .001)",
}
)
"""_"""
__pdoc__[
"wqa101_expr_config"
] = f"""Config with WorldQuant's 101 alpha expressions.
See [101 Formulaic Alphas](https://arxiv.org/abs/1601.00991).
Can be modified.
```python
{wqa101_expr_config.prettify()}
```
"""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Factory for building indicators.
Run for the examples below:
```pycon
>>> from vectorbtpro import *
>>> price = pd.DataFrame({
... 'a': [1, 2, 3, 4, 5],
... 'b': [5, 4, 3, 2, 1]
... }, index=pd.date_range("2020", periods=5)).astype(float)
>>> price
a b
2020-01-01 1.0 5.0
2020-01-02 2.0 4.0
2020-01-03 3.0 3.0
2020-01-04 4.0 2.0
2020-01-05 5.0 1.0
```"""
import fnmatch
import functools
import inspect
import itertools
import re
from collections import Counter, OrderedDict
from types import ModuleType, FunctionType
import numpy as np
import pandas as pd
from numba import njit
from numba.typed import List
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import indexes, reshaping, combining
from vectorbtpro.base.indexing import build_param_indexer
from vectorbtpro.base.merging import row_stack_arrays, column_stack_arrays
from vectorbtpro.base.reshaping import broadcast_array_to, Default, resolve_ref
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.accessors import BaseAccessor
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.indicators.expr import expr_func_config, expr_res_func_config, wqa101_expr_config
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils.array_ import build_nan_mask, squeeze_nan, unsqueeze_nan
from vectorbtpro.utils.config import merge_dicts, resolve_dict, Config, Configured, HybridConfig
from vectorbtpro.utils.decorators import class_property, cacheable_property, hybrid_method
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.eval_ import evaluate
from vectorbtpro.utils.execution import Task
from vectorbtpro.utils.formatting import camel_to_snake_case, prettify
from vectorbtpro.utils.magic_decorators import attach_binary_magic_methods, attach_unary_magic_methods
from vectorbtpro.utils.mapping import to_value_mapping, apply_mapping
from vectorbtpro.utils.module_ import search_package
from vectorbtpro.utils.params import (
to_typed_list,
broadcast_params,
create_param_product,
is_single_param_value,
params_to_list,
)
from vectorbtpro.utils.parsing import get_expr_var_names, get_func_arg_names, get_func_kwargs, suppress_stdout
from vectorbtpro.utils.random_ import set_seed
from vectorbtpro.utils.template import has_templates, substitute_templates, Rep
from vectorbtpro.utils.warnings_ import warn, WarningsFiltered
__all__ = [
"IndicatorBase",
"IndicatorFactory",
"IF",
"indicator",
"talib",
"pandas_ta",
"ta",
"wqa101",
"technical",
"techcon",
"smc",
]
__pdoc__ = {}
try:
if not tp.TYPE_CHECKING:
raise ImportError
from ta.utils import IndicatorMixin as IndicatorMixinT
except ImportError:
IndicatorMixinT = "IndicatorMixin"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from technical.consensus import Consensus as ConsensusT
except ImportError:
ConsensusT = "Consensus"
def prepare_params(
params: tp.MaybeParams,
param_names: tp.Sequence[str],
param_settings: tp.Sequence[tp.KwargsLike],
input_shape: tp.Optional[tp.Shape] = None,
to_2d: bool = False,
context: tp.KwargsLike = None,
) -> tp.Tuple[tp.Params, bool]:
"""Prepare parameters.
Resolves references and performs broadcasting to the input shape.
Returns prepared parameters as well as whether the user provided a single parameter combination."""
# Resolve references
if context is None:
context = {}
pool = dict(zip(param_names, params))
for k in pool:
pool[k] = resolve_ref(pool, k)
params = [pool[k] for k in param_names]
new_params = []
single_comb = True
for i, param_values in enumerate(params):
# Resolve settings
_param_settings = resolve_dict(param_settings[i])
is_tuple = _param_settings.get("is_tuple", False)
dtype = _param_settings.get("dtype", None)
if checks.is_mapping_like(dtype):
dtype_kwargs = _param_settings.get("dtype_kwargs", None)
if dtype_kwargs is None:
dtype_kwargs = {}
if checks.is_namedtuple(dtype):
param_values = map_enum_fields(param_values, dtype, **dtype_kwargs)
else:
param_values = apply_mapping(param_values, dtype, **dtype_kwargs)
is_array_like = _param_settings.get("is_array_like", False)
min_one_dim = _param_settings.get("min_one_dim", False)
bc_to_input = _param_settings.get("bc_to_input", False)
broadcast_kwargs = merge_dicts(
dict(require_kwargs=dict(requirements="W")),
_param_settings.get("broadcast_kwargs", None),
)
template = _param_settings.get("template", None)
if not is_single_param_value(param_values, is_tuple, is_array_like):
single_comb = False
new_param_values = params_to_list(param_values, is_tuple, is_array_like)
if template is not None:
new_param_values = [
template.substitute(context={param_names[i]: new_param_values[j], **context})
for j in range(len(new_param_values))
]
if not bc_to_input:
if is_array_like:
if min_one_dim:
new_param_values = list(map(reshaping.to_1d_array, new_param_values))
else:
new_param_values = list(map(np.asarray, new_param_values))
else:
# Broadcast to input or its axis
if is_tuple:
raise ValueError("Cannot broadcast to input if tuple")
if input_shape is None:
raise ValueError("Cannot broadcast to input if input shape is unknown. Pass input_shape.")
if bc_to_input is True:
to_shape = input_shape
else:
checks.assert_in(bc_to_input, (0, 1))
# Note that input_shape can be 1D
if bc_to_input == 0:
to_shape = (input_shape[0],)
else:
to_shape = (input_shape[1],) if len(input_shape) > 1 else (1,)
_new_param_values = reshaping.broadcast(*new_param_values, to_shape=to_shape, **broadcast_kwargs)
if len(new_param_values) == 1:
_new_param_values = [_new_param_values]
else:
_new_param_values = list(_new_param_values)
if to_2d and bc_to_input is True:
# If inputs are meant to reshape to 2D, do the same to parameters
# But only to those that fully resemble inputs (= not raw)
__new_param_values = _new_param_values.copy()
for j, param in enumerate(__new_param_values):
keep_flex = broadcast_kwargs.get("keep_flex", False)
if keep_flex is False or (isinstance(keep_flex, (tuple, list)) and not keep_flex[j]):
__new_param_values[j] = reshaping.to_2d(param)
new_param_values = __new_param_values
else:
new_param_values = _new_param_values
new_params.append(new_param_values)
return new_params, single_comb
def build_columns(
params: tp.Params,
input_columns: tp.IndexLike,
level_names: tp.Optional[tp.Sequence[str]] = None,
hide_levels: tp.Optional[tp.Sequence[tp.Union[str, int]]] = None,
single_value: tp.Optional[tp.Sequence[bool]] = None,
param_settings: tp.KwargsLikeSequence = None,
per_column: bool = False,
ignore_ranges: bool = False,
**kwargs,
) -> dict:
"""For each parameter in `params`, create a new column level with parameter values
and stack it on top of `input_columns`."""
if level_names is not None:
checks.assert_len_equal(params, level_names)
if hide_levels is None:
hide_levels = []
input_columns = indexes.to_any_index(input_columns)
param_indexes = []
rep_param_indexes = []
vis_param_indexes = []
vis_rep_param_indexes = []
has_per_column = False
for i in range(len(params)):
param_values = params[i]
level_name = None
if level_names is not None:
level_name = level_names[i]
_single_value = False
if single_value is not None:
_single_value = single_value[i]
_param_settings = resolve_dict(param_settings, i=i)
dtype = _param_settings.get("dtype", None)
if checks.is_mapping_like(dtype):
if checks.is_namedtuple(dtype):
dtype = to_value_mapping(dtype, reverse=False)
else:
dtype = to_value_mapping(dtype, reverse=True)
param_values = apply_mapping(param_values, dtype)
_per_column = _param_settings.get("per_column", False)
_post_index_func = _param_settings.get("post_index_func", None)
if per_column:
param_index = indexes.index_from_values(param_values, single_value=_single_value, name=level_name)
repeat_index = False
has_per_column = True
elif _per_column:
param_index = None
for p in param_values:
bc_param = broadcast_array_to(p, len(input_columns))
_param_index = indexes.index_from_values(bc_param, single_value=False, name=level_name)
if param_index is None:
param_index = _param_index
else:
param_index = param_index.append(_param_index)
if len(param_index) == 1 and len(input_columns) > 1:
param_index = indexes.repeat_index(param_index, len(input_columns), ignore_ranges=ignore_ranges)
repeat_index = False
has_per_column = True
else:
param_index = indexes.index_from_values(param_values, single_value=_single_value, name=level_name)
repeat_index = True
if _post_index_func is not None:
param_index = _post_index_func(param_index)
if repeat_index:
rep_param_index = indexes.repeat_index(param_index, len(input_columns), ignore_ranges=ignore_ranges)
else:
rep_param_index = param_index
param_indexes.append(param_index)
rep_param_indexes.append(rep_param_index)
if i not in hide_levels and (level_names is None or level_names[i] not in hide_levels):
vis_param_indexes.append(param_index)
vis_rep_param_indexes.append(rep_param_index)
if not per_column:
n_param_values = len(params[0]) if len(params) > 0 else 1
input_columns = indexes.tile_index(input_columns, n_param_values, ignore_ranges=ignore_ranges)
if len(vis_param_indexes) > 0:
if has_per_column:
param_index = None
else:
param_index = indexes.stack_indexes(vis_param_indexes, **kwargs)
final_index = indexes.stack_indexes([*vis_rep_param_indexes, input_columns], **kwargs)
else:
param_index = None
final_index = input_columns
return dict(
param_indexes=rep_param_indexes,
rep_param_indexes=rep_param_indexes,
vis_param_indexes=vis_param_indexes,
vis_rep_param_indexes=vis_rep_param_indexes,
param_index=param_index,
final_index=final_index,
)
def combine_objs(
obj: tp.SeriesFrame,
other: tp.MaybeTupleList[tp.Union[tp.ArrayLike, BaseAccessor]],
combine_func: tp.Callable,
*args,
level_name: tp.Optional[str] = None,
keys: tp.Optional[tp.IndexLike] = None,
allow_multiple: bool = True,
**kwargs,
) -> tp.SeriesFrame:
"""Combines/compares `obj` to `other`, for example, to generate signals.
Both will broadcast together.
Pass `other` as a tuple or a list to compare with multiple arguments.
In this case, a new column level will be created with the name `level_name`.
See `vectorbtpro.base.accessors.BaseAccessor.combine`."""
if allow_multiple and isinstance(other, (tuple, list)):
if keys is None:
keys = indexes.index_from_values(other, name=level_name)
return obj.vbt.combine(other, combine_func, *args, keys=keys, allow_multiple=allow_multiple, **kwargs)
IndicatorBaseT = tp.TypeVar("IndicatorBaseT", bound="IndicatorBase")
CacheOutputT = tp.Any
RawOutputT = tp.Tuple[
tp.List[tp.Array2d],
tp.List[tp.Tuple[tp.ParamValue, ...]],
int,
tp.List[tp.Any],
]
InputListT = tp.List[tp.Array2d]
InputMapperT = tp.Optional[tp.Array1d]
InOutputListT = tp.List[tp.Array2d]
OutputListT = tp.List[tp.Array2d]
ParamListT = tp.List[tp.List[tp.ParamValue]]
MapperListT = tp.List[tp.Index]
OtherListT = tp.List[tp.Any]
PipelineOutputT = tp.Tuple[
ArrayWrapper,
InputListT,
InputMapperT,
InOutputListT,
OutputListT,
ParamListT,
MapperListT,
OtherListT,
]
RunOutputT = tp.Union[IndicatorBaseT, tp.Tuple[tp.Any, ...], RawOutputT, CacheOutputT]
RunCombsOutputT = tp.Tuple[IndicatorBaseT, ...]
def combine_indicator_with_other(
self: IndicatorBaseT,
other: tp.Union["IndicatorBase", tp.ArrayLike],
np_func: tp.Callable[[tp.ArrayLike, tp.ArrayLike], tp.Array1d],
) -> tp.SeriesFrame:
"""Combine `IndicatorBase` with other compatible object."""
if isinstance(other, IndicatorBase):
other = other.main_output
return np_func(self.main_output, other)
@attach_binary_magic_methods(combine_indicator_with_other)
@attach_unary_magic_methods(lambda self, np_func: np_func(self.main_output))
class IndicatorBase(Analyzable):
"""Indicator base class.
Properties should be set before instantiation."""
_short_name: tp.ClassVar[str]
_input_names: tp.ClassVar[tp.Tuple[str, ...]]
_param_names: tp.ClassVar[tp.Tuple[str, ...]]
_in_output_names: tp.ClassVar[tp.Tuple[str, ...]]
_output_names: tp.ClassVar[tp.Tuple[str, ...]]
_lazy_output_names: tp.ClassVar[tp.Tuple[str, ...]]
_output_flags: tp.ClassVar[tp.Kwargs]
def __getattr__(self, k: str) -> tp.Any:
"""Redirect queries targeted at a generic output name by "output" or the short name of the indicator."""
try:
return object.__getattribute__(self, k)
except AttributeError:
pass
if k == "vbt":
return self.main_output.vbt
short_name = object.__getattribute__(self, "short_name")
output_names = object.__getattribute__(self, "output_names")
if len(output_names) == 1:
if k.startswith("output") and "output" not in output_names:
new_k = k[len("output") :]
if len(new_k) == 0 or not new_k[0].isalnum():
try:
return object.__getattribute__(self, output_names[0] + new_k)
except AttributeError:
pass
if k.startswith(short_name) and short_name not in output_names:
new_k = k[len(short_name) :]
if len(new_k) == 0 or not new_k[0].isalnum():
try:
return object.__getattribute__(self, output_names[0] + new_k)
except AttributeError:
pass
if k.lower().startswith(short_name.lower()) and short_name.lower() not in output_names:
new_k = k[len(short_name) :].lower()
if len(new_k) == 0 or not new_k[0].isalnum():
try:
return object.__getattribute__(self, output_names[0] + new_k)
except AttributeError:
pass
try:
return object.__getattribute__(self, output_names[0] + "_" + k)
except AttributeError:
pass
elif short_name in output_names:
try:
return object.__getattribute__(self, short_name + "_" + k)
except AttributeError:
pass
elif short_name.lower() in output_names:
try:
return object.__getattribute__(self, short_name.lower() + "_" + k)
except AttributeError:
pass
return object.__getattribute__(self, k)
@property
def main_output(self) -> tp.SeriesFrame:
"""Get main output.
It's either the only output or an output that matches the short name of the indicator."""
if len(self.output_names) == 1:
return getattr(self, self.output_names[0])
if self.short_name in self.output_names:
return getattr(self, self.short_name)
if self.short_name.lower() in self.output_names:
return getattr(self, self.short_name.lower())
raise ValueError(f"Indicator {self} has no main output")
def __array__(self, dtype: tp.Optional[tp.DTypeLike] = None) -> tp.Array:
"""Convert main output to NumPy array."""
return np.asarray(self.main_output, dtype=dtype)
@classmethod
def run_pipeline(
cls,
num_ret_outputs: int,
custom_func: tp.Callable,
*args,
require_input_shape: bool = False,
input_shape: tp.Optional[tp.ShapeLike] = None,
input_index: tp.Optional[tp.IndexLike] = None,
input_columns: tp.Optional[tp.IndexLike] = None,
inputs: tp.Optional[tp.MappingSequence[tp.ArrayLike]] = None,
in_outputs: tp.Optional[tp.MappingSequence[tp.ArrayLike]] = None,
in_output_settings: tp.Optional[tp.MappingSequence[tp.KwargsLike]] = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
params: tp.Optional[tp.MaybeParams] = None,
param_product: bool = False,
random_subset: tp.Optional[int] = None,
param_settings: tp.Optional[tp.MappingSequence[tp.KwargsLike]] = None,
run_unique: bool = False,
silence_warnings: bool = False,
per_column: tp.Optional[bool] = None,
keep_pd: bool = False,
to_2d: bool = True,
pass_packed: bool = False,
pass_input_shape: tp.Optional[bool] = None,
pass_wrapper: bool = False,
pass_param_index: bool = False,
pass_final_index: bool = False,
pass_single_comb: bool = False,
level_names: tp.Optional[tp.Sequence[str]] = None,
hide_levels: tp.Optional[tp.Sequence[tp.Union[str, int]]] = None,
build_col_kwargs: tp.KwargsLike = None,
return_raw: tp.Union[bool, str] = False,
use_raw: tp.Optional[RawOutputT] = None,
wrapper_kwargs: tp.KwargsLike = None,
seed: tp.Optional[int] = None,
**kwargs,
) -> tp.Union[CacheOutputT, RawOutputT, PipelineOutputT]:
"""A pipeline for running an indicator, used by `IndicatorFactory`.
Args:
num_ret_outputs (int): The number of output arrays returned by `custom_func`.
custom_func (callable): A custom calculation function.
See `IndicatorFactory.with_custom_func`.
*args: Arguments passed to the `custom_func`.
require_input_shape (bool): Whether to input shape is required.
Will set `pass_input_shape` to True and raise an error if `input_shape` is None.
input_shape (tuple): Shape to broadcast each input to.
Can be passed to `custom_func`. See `pass_input_shape`.
input_index (index_like): Sets index of each input.
Can be used to label index if no inputs passed.
input_columns (index_like): Sets columns of each input.
Can be used to label columns if no inputs passed.
inputs (mapping or sequence of array_like): A mapping or sequence of input arrays.
Use mapping to also supply names. If sequence, will convert to a mapping using `input_{i}` key.
in_outputs (mapping or sequence of array_like): A mapping or sequence of in-place output arrays.
Use mapping to also supply names. If sequence, will convert to a mapping using `in_output_{i}` key.
in_output_settings (dict or sequence of dict): Settings corresponding to each in-place output.
If mapping, should contain keys from `in_outputs`.
Following keys are accepted:
* `dtype`: Create this array using this data type and `np.empty`. Default is None.
broadcast_named_args (dict): Dictionary with named arguments to broadcast together with inputs.
You can then pass argument names wrapped with `vectorbtpro.utils.template.Rep`
and this method will substitute them by their corresponding broadcasted objects.
broadcast_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`
to broadcast inputs.
template_context (dict): Context used to substitute templates in `args` and `kwargs`.
params (mapping or sequence of any): A mapping or sequence of parameters.
Use mapping to also supply names. If sequence, will convert to a mapping using `param_{i}` key.
Each element is either an array-like object or a single value of any type.
param_product (bool): Whether to build a Cartesian product out of all parameters.
random_subset (int): Number of parameter combinations to pick randomly.
param_settings (dict or sequence of dict): Settings corresponding to each parameter.
If mapping, should contain keys from `params`.
Following keys are accepted:
* `dtype`: If data type is an enumerated type or other mapping, and a string as parameter
value was passed, will convert it first.
* `dtype_kwargs`: Keyword arguments passed to the function processing the data type.
If data type is enumerated, it will be `vectorbtpro.utils.enum_.map_enum_fields`.
* `is_tuple`: If tuple was passed, it will be considered as a single value.
To treat it as multiple values, pack it into a list.
* `is_array_like`: If array-like object was passed, it will be considered as a single value.
To treat it as multiple values, pack it into a list.
* `template`: Template to substitute each parameter value with, before broadcasting to input.
* `min_one_dim`: Whether to convert any scalar into a one-dimensional array.
Works only if `bc_to_input` is False.
* `bc_to_input`: Whether to broadcast parameter to input size. You can also broadcast
parameter to an axis by passing an integer.
* `broadcast_kwargs`: Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`.
* `per_column`: Whether each parameter value can be split by columns such that it can
be better reflected in a multi-index. Does not affect broadcasting.
* `post_index_func`: Function to convert the final index level of the parameter. Defaults to None.
run_unique (bool): Whether to run only on unique parameter combinations.
Disable if two identical parameter combinations can lead to different results
(e.g., due to randomness) or if inputs are large and `custom_func` is fast.
!!! note
Cache, raw output, and output objects outside of `num_ret_outputs` will be returned
for unique parameter combinations only.
silence_warnings (bool): Whether to hide warnings such as coming from `run_unique`.
per_column (bool): Whether the values of each parameter should be split by columns.
Defaults to False. Will pass `per_column` if it's not None.
Each list of parameter values will broadcast to the number of columns and
each parameter value will be applied per column rather than per whole input.
Input shape must be known beforehand.
Each from inputs, in-outputs, and parameters will be passed to `custom_func`
with the full shape. Expects the outputs be of the same shape as inputs.
keep_pd (bool): Whether to keep inputs as pandas objects, otherwise convert to NumPy arrays.
to_2d (bool): Whether to reshape inputs to 2-dim arrays, otherwise keep as-is.
pass_packed (bool): Whether to pass inputs and parameters to `custom_func` as lists.
If `custom_func` is Numba-compiled, passes tuples.
pass_input_shape (bool): Whether to pass `input_shape` to `custom_func` as keyword argument.
Defaults to True if `require_input_shape` is True, otherwise to False.
pass_wrapper (bool): Whether to pass the input wrapper to `custom_func` as keyword argument.
pass_param_index (bool): Whether to pass parameter index.
pass_final_index (bool): Whether to pass final index.
pass_single_comb (bool): Whether to pass whether there is only one parameter combination.
level_names (list of str): A list of column level names corresponding to each parameter.
Must have the same length as `params`.
hide_levels (list of int or str): A list of level names or indices of parameter levels to hide.
build_col_kwargs (dict): Keyword arguments passed to `build_columns`.
return_raw (bool or str): Whether to return raw outputs and hashed parameter tuples without
further post-processing.
Pass "outputs" to only return outputs.
use_raw (bool): Takes the raw results and uses them instead of running `custom_func`.
wrapper_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.wrapping.ArrayWrapper`.
seed (int): Seed to make output deterministic.
**kwargs: Keyword arguments passed to the `custom_func`.
Some common arguments include `return_cache` to return cache and `use_cache` to use cache.
If `use_cache` is False, disables caching completely. Those are only applicable to `custom_func`
that supports it (`custom_func` created using `IndicatorFactory.with_apply_func` are supported by default).
Returns:
Array wrapper, list of inputs (`np.ndarray`), input mapper (`np.ndarray`), list of outputs
(`np.ndarray`), list of parameter arrays (`np.ndarray`), list of parameter mappers (`np.ndarray`),
list of outputs that are outside of `num_ret_outputs`.
"""
pass_per_column = per_column is not None
if per_column is None:
per_column = False
if len(params) == 0 and per_column:
raise ValueError("per_column cannot be enabled without parameters")
if require_input_shape:
checks.assert_not_none(input_shape, arg_name="input_shape")
if pass_input_shape is None:
pass_input_shape = True
if pass_input_shape is None:
pass_input_shape = False
if input_index is not None:
input_index = indexes.to_any_index(input_index)
if input_columns is not None:
input_columns = indexes.to_any_index(input_columns)
if inputs is None:
inputs = {}
if not checks.is_mapping(inputs):
inputs = {"input_" + str(i): input for i, input in enumerate(inputs)}
input_names = list(inputs.keys())
input_list = list(inputs.values())
if in_outputs is None:
in_outputs = {}
if not checks.is_mapping(in_outputs):
in_outputs = {"in_output_" + str(i): in_output for i, in_output in enumerate(in_outputs)}
in_output_names = list(in_outputs.keys())
in_output_list = list(in_outputs.values())
if in_output_settings is None:
in_output_settings = {}
if checks.is_mapping(in_output_settings):
checks.assert_dict_valid(in_output_settings, [in_output_names, "dtype"])
in_output_settings = [in_output_settings.get(k, None) for k in in_output_names]
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if params is None:
params = {}
if not checks.is_mapping(params):
params = {"param_" + str(i): param for i, param in enumerate(params)}
param_names = list(params.keys())
param_list = list(params.values())
if param_settings is None:
param_settings = {}
if checks.is_mapping(param_settings):
checks.assert_dict_valid(
param_settings,
[
param_names,
[
"dtype",
"dtype_kwargs",
"is_tuple",
"is_array_like",
"template",
"min_one_dim",
"bc_to_input",
"broadcast_kwargs",
"per_column",
"post_index_func",
],
],
)
param_settings = [param_settings.get(k, None) for k in param_names]
if hide_levels is None:
hide_levels = []
if build_col_kwargs is None:
build_col_kwargs = {}
if wrapper_kwargs is None:
wrapper_kwargs = {}
if keep_pd and checks.is_numba_func(custom_func):
raise ValueError("Cannot pass pandas objects to a Numba-compiled custom_func. Set keep_pd to False.")
# Set seed
if seed is not None:
set_seed(seed)
if input_shape is not None:
input_shape = reshaping.to_tuple_shape(input_shape)
if len(inputs) > 0 or len(in_outputs) > 0 or len(broadcast_named_args) > 0:
# Broadcast inputs, in-outputs, and named args
# If input_shape is provided, will broadcast all inputs to this shape
broadcast_args = merge_dicts(inputs, in_outputs, broadcast_named_args)
broadcast_kwargs = merge_dicts(
dict(
to_shape=input_shape,
index_from=input_index,
columns_from=input_columns,
require_kwargs=dict(requirements="W"),
post_func=None if keep_pd else np.asarray,
to_pd=True,
),
broadcast_kwargs,
)
broadcast_args, wrapper = reshaping.broadcast(broadcast_args, return_wrapper=True, **broadcast_kwargs)
input_shape, input_index, input_columns = wrapper.shape, wrapper.index, wrapper.columns
if input_index is None:
input_index = pd.RangeIndex(start=0, step=1, stop=input_shape[0])
if input_columns is None:
input_columns = pd.RangeIndex(start=0, step=1, stop=input_shape[1] if len(input_shape) > 1 else 1)
input_list = [broadcast_args[input_name] for input_name in input_names]
in_output_list = [broadcast_args[in_output_name] for in_output_name in in_output_names]
broadcast_named_args = {arg_name: broadcast_args[arg_name] for arg_name in broadcast_named_args}
else:
wrapper = None
# Reshape input shape
input_shape_ready = input_shape
input_shape_2d = input_shape
if input_shape is not None:
input_shape_2d = input_shape if len(input_shape) > 1 else (input_shape[0], 1)
if to_2d:
if input_shape is not None:
input_shape_ready = input_shape_2d # ready for custom_func
if wrapper is not None:
wrapper_ready = wrapper
elif input_index is not None and input_columns is not None and input_shape_ready is not None:
wrapper_ready = ArrayWrapper(input_index, input_columns, len(input_shape_ready))
else:
wrapper_ready = None
# Prepare inputs
input_list_ready = []
for input in input_list:
new_input = input
if to_2d:
new_input = reshaping.to_2d(input)
if keep_pd and isinstance(new_input, np.ndarray):
# Keep as pandas object
new_input = ArrayWrapper(input_index, input_columns, new_input.ndim).wrap(new_input)
input_list_ready.append(new_input)
# Prepare parameters
# NOTE: input_shape instead of input_shape_ready since parameters should
# broadcast by the same rules as inputs
param_context = merge_dicts(
broadcast_named_args,
dict(
input_shape=input_shape_ready,
wrapper=wrapper_ready,
**dict(zip(input_names, input_list_ready)),
pre_sub_args=args,
pre_sub_kwargs=kwargs,
),
template_context,
)
param_list, single_comb = prepare_params(
param_list,
param_names,
param_settings,
input_shape=input_shape,
to_2d=to_2d,
context=param_context,
)
single_value = list(map(lambda x: len(x) == 1, param_list))
if len(param_list) > 1:
if level_names is not None:
# Check level names
checks.assert_len_equal(param_list, level_names)
# Columns should be free of the specified level names
if input_columns is not None:
for level_name in level_names:
if level_name is not None:
checks.assert_level_not_exists(input_columns, level_name)
if param_product:
# Make Cartesian product out of all params
param_list = create_param_product(param_list)
if len(param_list) > 0:
# Broadcast such that each array has the same length
if per_column:
# The number of parameters should match the number of columns before split
param_list = broadcast_params(param_list, to_n=input_shape_2d[1])
else:
param_list = broadcast_params(param_list)
if random_subset is not None:
# Pick combinations randomly
if per_column:
raise ValueError("Cannot select random subset when per_column=True")
random_indices = np.sort(np.random.permutation(np.arange(len(param_list[0])))[:random_subset])
param_list = [[params[i] for i in random_indices] for params in param_list]
n_param_values = len(param_list[0]) if len(param_list) > 0 else 1
use_run_unique = False
param_list_unique = param_list
if not per_column and run_unique:
try:
# Try to get all unique parameter combinations
param_tuples = list(zip(*param_list))
unique_param_tuples = list(OrderedDict.fromkeys(param_tuples).keys())
if len(unique_param_tuples) < len(param_tuples):
param_list_unique = list(map(list, zip(*unique_param_tuples)))
use_run_unique = True
except:
pass
if checks.is_numba_func(custom_func):
# Numba can't stand untyped lists
param_list_ready = [to_typed_list(params) for params in param_list_unique]
else:
param_list_ready = param_list_unique
n_unique_param_values = len(param_list_unique[0]) if len(param_list_unique) > 0 else 1
# Build column hierarchy for execution
if len(param_list) > 0 and input_columns is not None and (pass_param_index or pass_final_index):
# Build new column levels on top of input levels
build_columns_meta = build_columns(
param_list,
input_columns,
level_names=level_names,
hide_levels=hide_levels,
single_value=single_value,
param_settings=param_settings,
per_column=per_column,
**build_col_kwargs,
)
rep_param_indexes = build_columns_meta["rep_param_indexes"]
param_index = build_columns_meta["param_index"]
final_index = build_columns_meta["final_index"]
else:
# Some indicators don't have any params
rep_param_indexes = None
param_index = None
final_index = None
# Prepare in-place outputs
in_output_list_ready = []
for i in range(len(in_output_list)):
if input_shape_2d is None:
raise ValueError("input_shape is required when using in-place outputs")
if in_output_list[i] is not None:
# This in-place output has been already broadcast with inputs
in_output_wide = in_output_list[i]
if isinstance(in_output_list[i], np.ndarray):
in_output_wide = np.require(in_output_wide, requirements="W")
if not per_column:
# One per parameter combination
in_output_wide = reshaping.tile(in_output_wide, n_unique_param_values, axis=1)
else:
# This in-place output hasn't been provided, so create empty
_in_output_settings = resolve_dict(in_output_settings[i])
dtype = _in_output_settings.get("dtype", None)
if per_column:
in_output_shape = input_shape_ready
else:
in_output_shape = (input_shape_2d[0], input_shape_2d[1] * n_unique_param_values)
in_output_wide = np.empty(in_output_shape, dtype=dtype)
in_output_list[i] = in_output_wide
# Split each in-place output into chunks, each of input shape, and append to a list
in_outputs = []
if per_column:
in_outputs.append(in_output_wide)
else:
for p in range(n_unique_param_values):
if isinstance(in_output_wide, pd.DataFrame):
in_output = in_output_wide.iloc[:, p * input_shape_2d[1] : (p + 1) * input_shape_2d[1]]
if len(input_shape_ready) == 1:
in_output = in_output.iloc[:, 0]
else:
in_output = in_output_wide[:, p * input_shape_2d[1] : (p + 1) * input_shape_2d[1]]
if len(input_shape_ready) == 1:
in_output = in_output[:, 0]
if keep_pd and isinstance(in_output, np.ndarray):
in_output = ArrayWrapper(input_index, input_columns, in_output.ndim).wrap(in_output)
in_outputs.append(in_output)
in_output_list_ready.append(in_outputs)
if checks.is_numba_func(custom_func):
# Numba can't stand untyped lists
in_output_list_ready = [to_typed_list(in_outputs) for in_outputs in in_output_list_ready]
def _use_raw(_raw):
# Use raw results of previous run to build outputs
_output_list, _param_map, _n_input_cols, _other_list = _raw
idxs = np.array([_param_map.index(param_tuple) for param_tuple in zip(*param_list)])
_output_list = [
np.hstack([o[:, idx * _n_input_cols : (idx + 1) * _n_input_cols] for idx in idxs]) for o in _output_list
]
return _output_list, _param_map, _n_input_cols, _other_list
# Get raw results
if use_raw is not None:
# Use raw results of previous run to build outputs
output_list, param_map, n_input_cols, other_list = _use_raw(use_raw)
else:
# Prepare other arguments
func_args = args
func_kwargs = dict(kwargs)
if pass_input_shape:
func_kwargs["input_shape"] = input_shape_ready
if pass_wrapper:
func_kwargs["wrapper"] = wrapper_ready
if pass_param_index:
func_kwargs["param_index"] = param_index
if pass_final_index:
func_kwargs["final_index"] = final_index
if pass_single_comb:
func_kwargs["single_comb"] = single_comb
if pass_per_column:
func_kwargs["per_column"] = per_column
# Substitute templates
if has_templates(func_args) or has_templates(func_kwargs):
template_context = merge_dicts(
broadcast_named_args,
dict(
input_shape=input_shape_ready,
wrapper=wrapper_ready,
**dict(zip(input_names, input_list_ready)),
**dict(zip(in_output_names, in_output_list_ready)),
**dict(zip(param_names, param_list_ready)),
pre_sub_args=func_args,
pre_sub_kwargs=func_kwargs,
),
template_context,
)
func_args = substitute_templates(func_args, template_context, eval_id="custom_func_args")
func_kwargs = substitute_templates(func_kwargs, template_context, eval_id="custom_func_kwargs")
# Run the custom function
if checks.is_numba_func(custom_func):
func_args += tuple(func_kwargs.values())
func_kwargs = {}
if pass_packed:
outputs = custom_func(
tuple(input_list_ready),
tuple(in_output_list_ready),
tuple(param_list_ready),
*func_args,
**func_kwargs,
)
else:
outputs = custom_func(
*input_list_ready, *in_output_list_ready, *param_list_ready, *func_args, **func_kwargs
)
# Return outputs
if isinstance(return_raw, str):
if return_raw.lower() == "outputs":
if use_run_unique and not silence_warnings:
warn("Raw outputs are produced by unique parameter combinations when run_unique=True")
return outputs
else:
raise ValueError(f"Invalid return_raw: '{return_raw}'")
# Return cache
if kwargs.get("return_cache", False):
if use_run_unique and not silence_warnings:
warn("Cache is produced by unique parameter combinations when run_unique=True")
return outputs
# Post-process results
if outputs is None:
output_list = []
other_list = []
else:
if isinstance(outputs, (tuple, list, List)):
output_list = list(outputs)
else:
output_list = [outputs]
# Other outputs should be returned without post-processing (for example cache_dict)
if len(output_list) > num_ret_outputs:
other_list = output_list[num_ret_outputs:]
if use_run_unique and not silence_warnings:
warn(
"Additional output objects are produced by unique parameter combinations "
"when run_unique=True"
)
else:
other_list = []
# Process only the num_ret_outputs outputs
output_list = output_list[:num_ret_outputs]
if len(output_list) != num_ret_outputs:
raise ValueError("Number of returned outputs other than expected")
output_list = list(map(lambda x: reshaping.to_2d_array(x), output_list))
# In-place outputs are treated as outputs from here
output_list = in_output_list + output_list
# Prepare raw
param_map = list(zip(*param_list_unique)) # account for use_run_unique
output_shape = output_list[0].shape
for output in output_list:
if output.shape != output_shape:
raise ValueError("All outputs must have the same shape")
if per_column:
n_input_cols = output_shape[1]
else:
n_input_cols = output_shape[1] // n_unique_param_values
if input_shape_2d is not None:
if n_input_cols != input_shape_2d[1]:
if per_column:
raise ValueError(
"All outputs must have the same number of columns as inputs when per_column=True"
)
else:
raise ValueError(
"All outputs must have the same number of columns as there "
"are input columns times parameter combinations"
)
raw = output_list, param_map, n_input_cols, other_list
if return_raw:
if use_run_unique and not silence_warnings:
warn("Raw outputs are produced by unique parameter combinations when run_unique=True")
return raw
if use_run_unique:
output_list, param_map, n_input_cols, other_list = _use_raw(raw)
# Update shape and other meta if no inputs
if input_shape is None:
if n_input_cols == 1:
input_shape = (output_list[0].shape[0],)
else:
input_shape = (output_list[0].shape[0], n_input_cols)
if input_index is None:
input_index = pd.RangeIndex(start=0, step=1, stop=input_shape[0])
if input_columns is None:
input_columns = pd.RangeIndex(start=0, step=1, stop=input_shape[1] if len(input_shape) > 1 else 1)
# Build column hierarchy for indicator instance
if len(param_list) > 0:
if final_index is None:
# Build new column levels on top of input levels
build_columns_meta = build_columns(
param_list,
input_columns,
level_names=level_names,
hide_levels=hide_levels,
single_value=single_value,
param_settings=param_settings,
per_column=per_column,
**build_col_kwargs,
)
rep_param_indexes = build_columns_meta["rep_param_indexes"]
final_index = build_columns_meta["final_index"]
# Build a mapper that maps old columns in inputs to new columns
# Instead of tiling all inputs to the shape of outputs and wasting memory,
# we just keep a mapper and perform the tiling when needed
input_mapper = None
if len(input_list) > 0:
if per_column:
input_mapper = np.arange(len(input_columns))
else:
input_mapper = np.tile(np.arange(len(input_columns)), n_param_values)
# Build mappers to easily map between parameters and columns
mapper_list = [rep_param_indexes[i] for i in range(len(param_list))]
else:
# Some indicators don't have any params
final_index = input_columns
input_mapper = None
mapper_list = []
# Return artifacts: no pandas objects, just a wrapper and NumPy arrays
new_ndim = len(input_shape) if output_list[0].shape[1] == 1 else output_list[0].ndim
if new_ndim == 1 and not single_comb:
new_ndim = 2
wrapper = ArrayWrapper(input_index, final_index, new_ndim, **wrapper_kwargs)
return (
wrapper,
input_list,
input_mapper,
output_list[: len(in_output_list)],
output_list[len(in_output_list) :],
param_list,
mapper_list,
other_list,
)
@classmethod
def _run(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunOutputT:
"""Private run method."""
raise NotImplementedError
@classmethod
def run(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunOutputT:
"""Public run method."""
return cls._run(*args, **kwargs)
@classmethod
def _run_combs(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunCombsOutputT:
"""Private run combinations method."""
raise NotImplementedError
@classmethod
def run_combs(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunCombsOutputT:
"""Public run combinations method."""
return cls._run_combs(*args, **kwargs)
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[IndicatorBaseT],
*objs: tp.MaybeTuple[IndicatorBaseT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> IndicatorBaseT:
"""Stack multiple `IndicatorBase` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.
All objects to be merged must have the same columns x parameters."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, IndicatorBase):
raise TypeError("Each object to be merged must be an instance of IndicatorBase")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.row_stack(
*[obj.wrapper for obj in objs], stack_columns=False, **wrapper_kwargs
)
if "input_list" not in kwargs:
new_input_list = []
for input_name in cls.input_names:
new_input_list.append(row_stack_arrays([getattr(obj, f"_{input_name}") for obj in objs]))
kwargs["input_list"] = new_input_list
if "in_output_list" not in kwargs:
new_in_output_list = []
for in_output_name in cls.in_output_names:
new_in_output_list.append(row_stack_arrays([getattr(obj, f"_{in_output_name}") for obj in objs]))
kwargs["in_output_list"] = new_in_output_list
if "output_list" not in kwargs:
new_output_list = []
for output_name in cls.output_names:
new_output_list.append(row_stack_arrays([getattr(obj, f"_{output_name}") for obj in objs]))
kwargs["output_list"] = new_output_list
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[IndicatorBaseT],
*objs: tp.MaybeTuple[IndicatorBaseT],
wrapper_kwargs: tp.KwargsLike = None,
reindex_kwargs: tp.KwargsLike = None,
**kwargs,
) -> IndicatorBaseT:
"""Stack multiple `IndicatorBase` instances along columns x parameters.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.
All objects to be merged must have the same index."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, IndicatorBase):
raise TypeError("Each object to be merged must be an instance of IndicatorBase")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in objs],
**wrapper_kwargs,
)
if "input_mapper" not in kwargs:
stack_input_mapper_objs = True
for obj in objs:
if getattr(obj, "_input_mapper", None) is None:
stack_input_mapper_objs = False
break
if stack_input_mapper_objs:
kwargs["input_mapper"] = np.concatenate([getattr(obj, "_input_mapper") for obj in objs])
if "in_output_list" not in kwargs:
new_in_output_list = []
for in_output_name in cls.in_output_names:
new_in_output_list.append(column_stack_arrays([getattr(obj, f"_{in_output_name}") for obj in objs]))
kwargs["in_output_list"] = new_in_output_list
if "output_list" not in kwargs:
new_output_list = []
for output_name in cls.output_names:
new_output_list.append(column_stack_arrays([getattr(obj, f"_{output_name}") for obj in objs]))
kwargs["output_list"] = new_output_list
if "param_list" not in kwargs:
new_param_list = []
for param_name in cls.param_names:
param_objs = []
for obj in objs:
param_objs.extend(getattr(obj, f"_{param_name}_list"))
new_param_list.append(param_objs)
kwargs["param_list"] = new_param_list
if "mapper_list" not in kwargs:
new_mapper_list = []
for param_name in cls.param_names:
new_mapper = None
for obj in objs:
obj_mapper = getattr(obj, f"_{param_name}_mapper")
if new_mapper is None:
new_mapper = obj_mapper
else:
new_mapper = new_mapper.append(obj_mapper)
new_mapper_list.append(new_mapper)
kwargs["mapper_list"] = new_mapper_list
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
def __init__(
self,
wrapper: ArrayWrapper,
input_list: InputListT,
input_mapper: InputMapperT,
in_output_list: InOutputListT,
output_list: OutputListT,
param_list: ParamListT,
mapper_list: MapperListT,
short_name: str,
**kwargs,
) -> None:
if input_mapper is not None:
checks.assert_equal(input_mapper.shape[0], wrapper.shape_2d[1])
for ts in input_list:
checks.assert_equal(ts.shape[0], wrapper.shape_2d[0])
for ts in in_output_list + output_list:
checks.assert_equal(ts.shape, wrapper.shape_2d)
for params in param_list:
checks.assert_len_equal(param_list[0], params)
for mapper in mapper_list:
checks.assert_equal(len(mapper), wrapper.shape_2d[1])
checks.assert_instance_of(short_name, str)
if "level_names" in kwargs:
del kwargs["level_names"] # deprecated
Analyzable.__init__(
self,
wrapper,
input_list=input_list,
input_mapper=input_mapper,
in_output_list=in_output_list,
output_list=output_list,
param_list=param_list,
mapper_list=mapper_list,
short_name=short_name,
**kwargs,
)
setattr(self, "_short_name", short_name)
for i, ts_name in enumerate(self.input_names):
setattr(self, f"_{ts_name}", input_list[i])
setattr(self, "_input_mapper", input_mapper)
for i, in_output_name in enumerate(self.in_output_names):
setattr(self, f"_{in_output_name}", in_output_list[i])
for i, output_name in enumerate(self.output_names):
setattr(self, f"_{output_name}", output_list[i])
for i, param_name in enumerate(self.param_names):
setattr(self, f"_{param_name}_list", param_list[i])
setattr(self, f"_{param_name}_mapper", mapper_list[i])
# Initialize indexers
mapper_sr_list = []
for i, m in enumerate(mapper_list):
mapper_sr_list.append(pd.Series(m, index=wrapper.columns))
tuple_mapper = self._tuple_mapper
if tuple_mapper is not None:
level_names = tuple(tuple_mapper.names)
mapper_sr_list.append(pd.Series(tuple_mapper.tolist(), index=wrapper.columns))
else:
level_names = ()
self._level_names = level_names
for base_cls in type(self).__bases__:
if base_cls.__name__ == "ParamIndexer":
base_cls.__init__(self, mapper_sr_list, level_names=[*level_names, level_names])
@property
def _tuple_mapper(self) -> tp.Optional[pd.MultiIndex]:
"""Mapper of multiple parameters."""
if len(self.param_names) <= 1:
return None
return pd.MultiIndex.from_arrays([getattr(self, f"_{name}_mapper") for name in self.param_names])
@property
def _param_mapper(self) -> tp.Optional[pd.Index]:
"""Mapper of all parameters."""
if len(self.param_names) == 0:
return None
if len(self.param_names) == 1:
return getattr(self, f"_{self.param_names[0]}_mapper")
return self._tuple_mapper
@property
def _visible_param_mapper(self) -> tp.Optional[pd.Index]:
"""Mapper of visible parameters."""
if len(self.param_names) == 0:
return None
if len(self.param_names) == 1:
mapper = getattr(self, f"_{self.param_names[0]}_mapper")
if mapper.name is None:
return None
if mapper.name not in self.wrapper.columns.names:
return None
return mapper
mapper = self._tuple_mapper
visible_indexes = []
for i, name in enumerate(mapper.names):
if name is not None and name in self.wrapper.columns.names:
visible_indexes.append(mapper.get_level_values(i))
if len(visible_indexes) == 0:
return None
if len(visible_indexes) == 1:
return visible_indexes[0]
return pd.MultiIndex.from_arrays(visible_indexes)
def indexing_func(self: IndicatorBaseT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> IndicatorBaseT:
"""Perform indexing on `IndicatorBase`."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs)
row_idxs = wrapper_meta["row_idxs"]
col_idxs = wrapper_meta["col_idxs"]
rows_changed = wrapper_meta["rows_changed"]
columns_changed = wrapper_meta["columns_changed"]
if not isinstance(row_idxs, slice):
row_idxs = reshaping.to_1d_array(row_idxs)
if not isinstance(col_idxs, slice):
col_idxs = reshaping.to_1d_array(col_idxs)
input_mapper = getattr(self, "_input_mapper", None)
if input_mapper is not None:
if columns_changed:
input_mapper = input_mapper[col_idxs]
input_list = []
for input_name in self.input_names:
new_input = ArrayWrapper.select_from_flex_array(
getattr(self, f"_{input_name}"),
row_idxs=row_idxs,
col_idxs=col_idxs if input_mapper is None else None,
rows_changed=rows_changed,
columns_changed=columns_changed if input_mapper is None else False,
)
input_list.append(new_input)
in_output_list = []
for in_output_name in self.in_output_names:
new_in_output = ArrayWrapper.select_from_flex_array(
getattr(self, f"_{in_output_name}"),
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
in_output_list.append(new_in_output)
output_list = []
for output_name in self.output_names:
new_output = ArrayWrapper.select_from_flex_array(
getattr(self, f"_{output_name}"),
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
output_list.append(new_output)
param_list = []
for param_name in self.param_names:
param_list.append(getattr(self, f"_{param_name}_list"))
mapper_list = []
for param_name in self.param_names:
# Tuple mapper is a list because of its complex data type
mapper_list.append(getattr(self, f"_{param_name}_mapper")[col_idxs])
return self.replace(
wrapper=wrapper_meta["new_wrapper"],
input_list=input_list,
input_mapper=input_mapper,
in_output_list=in_output_list,
output_list=output_list,
param_list=param_list,
mapper_list=mapper_list,
)
@class_property
def short_name(cls_or_self) -> str:
"""Name of the indicator."""
return cls_or_self._short_name
@class_property
def input_names(cls_or_self) -> tp.Tuple[str, ...]:
"""Names of the input arrays."""
return cls_or_self._input_names
@class_property
def param_names(cls_or_self) -> tp.Tuple[str, ...]:
"""Names of the parameters."""
return cls_or_self._param_names
@class_property
def in_output_names(cls_or_self) -> tp.Tuple[str, ...]:
"""Names of the in-place output arrays."""
return cls_or_self._in_output_names
@class_property
def output_names(cls_or_self) -> tp.Tuple[str, ...]:
"""Names of the regular output arrays."""
return cls_or_self._output_names
@class_property
def lazy_output_names(cls_or_self) -> tp.Tuple[str, ...]:
"""Names of the lazy output arrays."""
return cls_or_self._lazy_output_names
@class_property
def output_flags(cls_or_self) -> tp.Kwargs:
"""Dictionary of output flags."""
return cls_or_self._output_flags
@class_property
def param_defaults(cls_or_self) -> tp.Dict[str, tp.Any]:
"""Parameter defaults extracted from the signature of `IndicatorBase.run`."""
func_kwargs = get_func_kwargs(cls_or_self.run)
out = {}
for k, v in func_kwargs.items():
if k in cls_or_self.param_names:
if isinstance(v, Default):
out[k] = v.value
else:
out[k] = v
return out
@property
def level_names(self) -> tp.Tuple[str]:
"""Column level names corresponding to each parameter."""
return self._level_names
def unpack(self) -> tp.MaybeTuple[tp.SeriesFrame]:
"""Return outputs, either one output or a tuple if there are multiple."""
out = tuple([getattr(self, name) for name in self.output_names])
if len(out) == 1:
out = out[0]
return out
def to_dict(self, include_all: bool = True) -> tp.Dict[str, tp.SeriesFrame]:
"""Return outputs as a dict."""
if include_all:
output_names = self.output_names + self.in_output_names + self.lazy_output_names
else:
output_names = self.output_names
return {name: getattr(self, name) for name in output_names}
def to_frame(self, include_all: bool = True) -> tp.Frame:
"""Return outputs as a DataFrame."""
out = self.to_dict(include_all=include_all)
return pd.concat(list(out.values()), axis=1, keys=pd.Index(list(out.keys()), name="output"))
def get(self, key: tp.Optional[tp.Hashable] = None) -> tp.Optional[tp.SeriesFrame]:
"""Get a time series."""
if key is None:
return self.main_output
return getattr(self, key)
def dropna(self: IndicatorBaseT, include_all: bool = True, **kwargs) -> IndicatorBaseT:
"""Drop missing values.
Keyword arguments are passed to `pd.Series.dropna` or `pd.DataFrame.dropna`."""
df = self.to_frame(include_all=include_all)
new_df = df.dropna(**kwargs)
if new_df.index.equals(df.index):
return self
return self.loc[new_df.index]
def rename(self: IndicatorBaseT, short_name: str) -> IndicatorBaseT:
"""Replace the short name of the indicator."""
new_level_names = ()
for level_name in self.level_names:
if level_name.startswith(self.short_name + "_"):
level_name = level_name.replace(self.short_name, short_name, 1)
new_level_names += (level_name,)
new_mapper_list = []
for i, param_name in enumerate(self.param_names):
mapper = getattr(self, f"_{param_name}_mapper")
new_mapper_list.append(mapper.rename(new_level_names[i]))
new_columns = self.wrapper.columns
for i, name in enumerate(self.wrapper.columns.names):
if name in self.level_names:
new_columns = new_columns.rename({name: new_level_names[self.level_names.index(name)]})
new_wrapper = self.wrapper.replace(columns=new_columns)
return self.replace(wrapper=new_wrapper, mapper_list=new_mapper_list, short_name=short_name)
def rename_levels(
self: IndicatorBaseT,
mapper: tp.MaybeMappingSequence[tp.Level],
**kwargs,
) -> IndicatorBaseT:
new_self = Analyzable.rename_levels(self, mapper, **kwargs)
old_column_names = self.wrapper.columns.names
new_column_names = new_self.wrapper.columns.names
new_level_names = ()
for level_name in new_self.level_names:
if level_name in old_column_names:
level_name = new_column_names[old_column_names.index(level_name)]
new_level_names += (level_name,)
new_mapper_list = []
for i, param_name in enumerate(new_self.param_names):
mapper = getattr(new_self, f"_{param_name}_mapper")
new_mapper_list.append(mapper.rename(new_level_names[i]))
return new_self.replace(mapper_list=new_mapper_list, level_names=new_level_names)
# ############# Iteration ############# #
def items(
self,
group_by: tp.GroupByLike = "params",
apply_group_by: bool = False,
keep_2d: bool = False,
key_as_index: bool = False,
) -> tp.Items:
"""Iterate over columns (or groups if grouped and `Wrapping.group_select` is True).
Allows the following additional options for `group_by`: "all_params", "params"
(only those that aren't hidden), and parameter names."""
if isinstance(group_by, str):
if group_by not in self.wrapper.columns.names:
if group_by.lower() == "all_params":
group_by = self._param_mapper
elif group_by.lower() == "params":
group_by = self._visible_param_mapper
elif group_by in self.param_names:
group_by = getattr(self, f"_{group_by}_mapper")
elif isinstance(group_by, (tuple, list)):
new_group_by = []
for g in group_by:
if isinstance(g, str):
if g not in self.wrapper.columns.names:
if g in self.param_names:
g = getattr(self, f"_{g}_mapper")
new_group_by.append(g)
group_by = type(group_by)(new_group_by)
for k, v in Analyzable.items(
self,
group_by=group_by,
apply_group_by=apply_group_by,
keep_2d=keep_2d,
key_as_index=key_as_index,
):
yield k, v
# ############# Documentation ############# #
@classmethod
def fix_docstrings(cls, __pdoc__: dict) -> None:
"""Fix docstrings."""
if hasattr(cls, "custom_func"):
if cls.__name__ + ".custom_func" not in __pdoc__:
__pdoc__[cls.__name__ + ".custom_func"] = "Custom function."
if hasattr(cls, "apply_func"):
if cls.__name__ + ".apply_func" not in __pdoc__:
__pdoc__[cls.__name__ + ".apply_func"] = "Apply function."
if hasattr(cls, "cache_func"):
if cls.__name__ + ".cache_func" not in __pdoc__:
__pdoc__[cls.__name__ + ".cache_func"] = "Cache function."
if hasattr(cls, "entry_place_func_nb"):
if cls.__name__ + ".entry_place_func_nb" not in __pdoc__:
__pdoc__[cls.__name__ + ".entry_place_func_nb"] = "Entry placement function."
if hasattr(cls, "exit_place_func_nb"):
if cls.__name__ + ".exit_place_func_nb" not in __pdoc__:
__pdoc__[cls.__name__ + ".exit_place_func_nb"] = "Exit placement function."
class IndicatorFactory(Configured):
"""A factory for creating new indicators.
Initialize `IndicatorFactory` to create a skeleton and then use a class method
such as `IndicatorFactory.with_custom_func` to bind a calculation function to the skeleton.
Args:
class_name (str): Name for the created indicator class.
class_docstring (str): Docstring for the created indicator class.
module_name (str): Name of the module the class originates from.
short_name (str): Short name of the indicator.
Defaults to lower-case `class_name`.
prepend_name (bool): Whether to prepend `short_name` to each parameter level.
input_names (list of str): List with input names.
param_names (list of str): List with parameter names.
in_output_names (list of str): List with in-output names.
An in-place output is an output that is not returned but modified in-place.
Some advantages of such outputs include:
1) they don't need to be returned,
2) they can be passed between functions as easily as inputs,
3) they can be provided with already allocated data to safe memory,
4) if data or default value are not provided, they are created empty to not occupy memory.
output_names (list of str): List with output names.
output_flags (dict): Dictionary of in-place and regular output flags.
lazy_outputs (dict): Dictionary with user-defined functions that will be
bound to the indicator class and wrapped with `property` if not already wrapped.
attr_settings (dict): Dictionary with attribute settings.
Attributes can be `input_names`, `in_output_names`, `output_names`, and `lazy_outputs`.
Following keys are accepted:
* `dtype`: Data type used to determine which methods to generate around this attribute.
Set to None to disable. Default is `float_`. Can be set to instance of
`collections.namedtuple` acting as enumerated type, or any other mapping;
It will then create a property with suffix `readable` that contains data in a string format.
* `enum_unkval`: Value to be considered as unknown. Applies to enumerated data types only.
* `make_cacheable`: Whether to make the property cacheable. Applies to inputs only.
metrics (dict): Metrics supported by `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats`.
If dict, will be converted to `vectorbtpro.utils.config.Config`.
stats_defaults (callable or dict): Defaults for `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats`.
If dict, will be converted into a property.
subplots (dict): Subplots supported by `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`.
If dict, will be converted to `vectorbtpro.utils.config.Config`.
plots_defaults (callable or dict): Defaults for `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`.
If dict, will be converted into a property.
**kwargs: Custom keyword arguments passed to the config.
!!! note
The `__init__` method is not used for running the indicator, for this use `run`.
The reason for this is indexing, which requires a clean `__init__` method for creating
a new indicator object with newly indexed attributes.
"""
def __init__(
self,
class_name: tp.Optional[str] = None,
class_docstring: tp.Optional[str] = None,
module_name: tp.Optional[str] = __name__,
short_name: tp.Optional[str] = None,
prepend_name: bool = True,
input_names: tp.Optional[tp.Sequence[str]] = None,
param_names: tp.Optional[tp.Sequence[str]] = None,
in_output_names: tp.Optional[tp.Sequence[str]] = None,
output_names: tp.Optional[tp.Sequence[str]] = None,
output_flags: tp.KwargsLike = None,
lazy_outputs: tp.KwargsLike = None,
attr_settings: tp.KwargsLike = None,
metrics: tp.Optional[tp.Kwargs] = None,
stats_defaults: tp.Union[None, tp.Callable, tp.Kwargs] = None,
subplots: tp.Optional[tp.Kwargs] = None,
plots_defaults: tp.Union[None, tp.Callable, tp.Kwargs] = None,
**kwargs,
) -> None:
Configured.__init__(
self,
class_name=class_name,
class_docstring=class_docstring,
module_name=module_name,
short_name=short_name,
prepend_name=prepend_name,
input_names=input_names,
param_names=param_names,
in_output_names=in_output_names,
output_names=output_names,
output_flags=output_flags,
lazy_outputs=lazy_outputs,
attr_settings=attr_settings,
metrics=metrics,
stats_defaults=stats_defaults,
subplots=subplots,
plots_defaults=plots_defaults,
**kwargs,
)
# Check parameters
if class_name is None:
class_name = "Indicator"
checks.assert_instance_of(class_name, str)
if class_docstring is None:
class_docstring = ""
checks.assert_instance_of(class_docstring, str)
if module_name is not None:
checks.assert_instance_of(module_name, str)
if short_name is None:
if class_name == "Indicator":
short_name = "custom"
else:
short_name = class_name.lower()
checks.assert_instance_of(short_name, str)
checks.assert_instance_of(prepend_name, bool)
if input_names is None:
input_names = []
else:
checks.assert_sequence(input_names)
input_names = list(input_names)
if param_names is None:
param_names = []
else:
checks.assert_sequence(param_names)
param_names = list(param_names)
if in_output_names is None:
in_output_names = []
else:
checks.assert_sequence(in_output_names)
in_output_names = list(in_output_names)
if output_names is None:
output_names = []
else:
checks.assert_sequence(output_names)
output_names = list(output_names)
all_output_names = in_output_names + output_names
if len(all_output_names) == 0:
raise ValueError("Must have at least one in-place or regular output")
if len(set.intersection(set(input_names), set(in_output_names), set(output_names))) > 0:
raise ValueError("Inputs, in-outputs, and parameters must all have unique names")
if output_flags is None:
output_flags = {}
checks.assert_instance_of(output_flags, dict)
if len(output_flags) > 0:
checks.assert_dict_valid(output_flags, all_output_names)
if lazy_outputs is None:
lazy_outputs = {}
checks.assert_instance_of(lazy_outputs, dict)
if attr_settings is None:
attr_settings = {}
checks.assert_instance_of(attr_settings, dict)
all_attr_names = input_names + all_output_names + list(lazy_outputs.keys())
if len(attr_settings) > 0:
checks.assert_dict_valid(
attr_settings,
[
all_attr_names,
["dtype", "enum_unkval", "make_cacheable"],
],
)
# Set up class
ParamIndexer = build_param_indexer(
param_names + (["tuple"] if len(param_names) > 1 else []),
module_name=module_name,
)
Indicator = type(class_name, (IndicatorBase, ParamIndexer), {})
Indicator.__doc__ = class_docstring
if module_name is not None:
Indicator.__module__ = module_name
# Create read-only properties
setattr(Indicator, "_short_name", short_name)
setattr(Indicator, "_input_names", tuple(input_names))
setattr(Indicator, "_param_names", tuple(param_names))
setattr(Indicator, "_in_output_names", tuple(in_output_names))
setattr(Indicator, "_output_names", tuple(output_names))
setattr(Indicator, "_lazy_output_names", tuple(lazy_outputs.keys()))
setattr(Indicator, "_output_flags", output_flags)
for param_name in param_names:
def param_list_prop(self, _param_name=param_name) -> tp.List[tp.ParamValue]:
return getattr(self, f"_{_param_name}_list")
param_list_prop.__doc__ = f"List of `{param_name}` values."
setattr(Indicator, f"{param_name}_list", property(param_list_prop))
for input_name in input_names:
_attr_settings = attr_settings.get(input_name, {})
make_cacheable = _attr_settings.get("make_cacheable", False)
def input_prop(self, _input_name: str = input_name) -> tp.SeriesFrame:
"""Input array."""
old_input = reshaping.to_2d_array(getattr(self, "_" + _input_name))
input_mapper = getattr(self, "_input_mapper")
if input_mapper is None:
return self.wrapper.wrap(old_input)
return self.wrapper.wrap(old_input[:, input_mapper])
input_prop.__name__ = input_name
input_prop.__module__ = Indicator.__module__
input_prop.__qualname__ = f"{Indicator.__name__}.{input_prop.__name__}"
if make_cacheable:
setattr(Indicator, input_name, cacheable_property(input_prop))
else:
setattr(Indicator, input_name, property(input_prop))
for output_name in all_output_names:
def output_prop(self, _output_name: str = output_name) -> tp.SeriesFrame:
return self.wrapper.wrap(getattr(self, "_" + _output_name))
if output_name in in_output_names:
output_prop.__doc__ = """In-place output array."""
else:
output_prop.__doc__ = """Output array."""
output_prop.__name__ = output_name
output_prop.__module__ = Indicator.__module__
output_prop.__qualname__ = f"{Indicator.__name__}.{output_prop.__name__}"
if output_name in output_flags:
_output_flags = output_flags[output_name]
if isinstance(_output_flags, (tuple, list)):
_output_flags = ", ".join(_output_flags)
output_prop.__doc__ += "\n\n" + _output_flags
setattr(Indicator, output_name, property(output_prop))
# Add user-defined outputs
for prop_name, prop in lazy_outputs.items():
prop.__name__ = prop_name
prop.__module__ = Indicator.__module__
prop.__qualname__ = f"{Indicator.__name__}.{prop.__name__}"
if prop.__doc__ is None:
prop.__doc__ = f"""Custom property."""
if not isinstance(prop, property):
prop = property(prop)
setattr(Indicator, prop_name, prop)
# Add comparison & combination methods for all inputs, outputs, and user-defined properties
def assign_combine_method(
func_name: str,
combine_func: tp.Callable,
def_kwargs: tp.Kwargs,
attr_name: str,
docstring: str,
) -> None:
def combine_method(
self: IndicatorBaseT,
other: tp.MaybeTupleList[tp.Union[IndicatorBaseT, tp.ArrayLike, BaseAccessor]],
level_name: tp.Optional[str] = None,
allow_multiple: bool = True,
_prepend_name: bool = prepend_name,
**kwargs,
) -> tp.SeriesFrame:
if allow_multiple and isinstance(other, (tuple, list)):
other = list(other)
for i in range(len(other)):
if isinstance(other[i], IndicatorBase):
other[i] = getattr(other[i], attr_name)
elif isinstance(other[i], str):
other[i] = getattr(self, other[i])
else:
if isinstance(other, IndicatorBase):
other = getattr(other, attr_name)
elif isinstance(other, str):
other = getattr(self, other)
if level_name is None:
if _prepend_name:
if attr_name == self.short_name:
level_name = f"{self.short_name}_{func_name}"
else:
level_name = f"{self.short_name}_{attr_name}_{func_name}"
else:
level_name = f"{attr_name}_{func_name}"
out = combine_objs(
getattr(self, attr_name),
other,
combine_func,
level_name=level_name,
allow_multiple=allow_multiple,
**merge_dicts(def_kwargs, kwargs),
)
return out
combine_method.__name__ = f"{attr_name}_{func_name}"
combine_method.__module__ = Indicator.__module__
combine_method.__qualname__ = f"{Indicator.__name__}.{combine_method.__name__}"
combine_method.__doc__ = docstring
setattr(Indicator, f"{attr_name}_{func_name}", combine_method)
for attr_name in all_attr_names:
_attr_settings = attr_settings.get(attr_name, {})
dtype = _attr_settings.get("dtype", float_)
enum_unkval = _attr_settings.get("enum_unkval", -1)
if checks.is_mapping_like(dtype):
def attr_readable(
self,
_attr_name: str = attr_name,
_mapping: tp.MappingLike = dtype,
_enum_unkval: tp.Any = enum_unkval,
) -> tp.SeriesFrame:
return getattr(self, _attr_name).vbt(mapping=_mapping).apply_mapping(enum_unkval=_enum_unkval)
attr_readable.__name__ = f"{attr_name}_readable"
attr_readable.__module__ = Indicator.__module__
attr_readable.__qualname__ = f"{Indicator.__name__}.{attr_readable.__name__}"
attr_readable.__doc__ = inspect.cleandoc(
"""`{attr_name}` in readable format based on the following mapping:
```python
{dtype}
```"""
).format(attr_name=attr_name, dtype=prettify(to_value_mapping(dtype, enum_unkval=enum_unkval)))
setattr(Indicator, f"{attr_name}_readable", property(attr_readable))
def attr_stats(
self,
*args,
_attr_name: str = attr_name,
_mapping: tp.MappingLike = dtype,
**kwargs,
) -> tp.SeriesFrame:
return getattr(self, _attr_name).vbt(mapping=_mapping).stats(*args, **kwargs)
attr_stats.__name__ = f"{attr_name}_stats"
attr_stats.__module__ = Indicator.__module__
attr_stats.__qualname__ = f"{Indicator.__name__}.{attr_stats.__name__}"
attr_stats.__doc__ = inspect.cleandoc(
"""Stats of `{attr_name}` based on the following mapping:
```python
{dtype}
```"""
).format(attr_name=attr_name, dtype=prettify(to_value_mapping(dtype)))
setattr(Indicator, f"{attr_name}_stats", attr_stats)
elif np.issubdtype(dtype, np.number):
func_info = [
("above", np.greater, dict()),
("below", np.less, dict()),
("equal", np.equal, dict()),
(
"crossed_above",
lambda x, y, wait=0, dropna=False: jit_reg.resolve(generic_nb.crossed_above_nb)(
x,
y,
wait=wait,
dropna=dropna,
),
dict(to_2d=True),
),
(
"crossed_below",
lambda x, y, wait=0, dropna=False: jit_reg.resolve(generic_nb.crossed_above_nb)(
y,
x,
wait=wait,
dropna=dropna,
),
dict(to_2d=True),
),
]
for func_name, np_func, def_kwargs in func_info:
method_docstring = f"""Return True for each element where `{attr_name}` is {func_name} `other`.
See `vectorbtpro.indicators.factory.combine_objs`."""
assign_combine_method(func_name, np_func, def_kwargs, attr_name, method_docstring)
def attr_stats(self, *args, _attr_name: str = attr_name, **kwargs) -> tp.SeriesFrame:
return getattr(self, _attr_name).vbt.stats(*args, **kwargs)
attr_stats.__name__ = f"{attr_name}_stats"
attr_stats.__module__ = Indicator.__module__
attr_stats.__qualname__ = f"{Indicator.__name__}.{attr_stats.__name__}"
attr_stats.__doc__ = f"""Stats of `{attr_name}` as generic."""
setattr(Indicator, f"{attr_name}_stats", attr_stats)
elif np.issubdtype(dtype, np.bool_):
func_info = [
("and", np.logical_and, dict()),
("or", np.logical_or, dict()),
("xor", np.logical_xor, dict()),
]
for func_name, np_func, def_kwargs in func_info:
method_docstring = f"""Return `{attr_name} {func_name.upper()} other`.
See `vectorbtpro.indicators.factory.combine_objs`."""
assign_combine_method(func_name, np_func, def_kwargs, attr_name, method_docstring)
def attr_stats(self, *args, _attr_name: str = attr_name, **kwargs) -> tp.SeriesFrame:
return getattr(self, _attr_name).vbt.signals.stats(*args, **kwargs)
attr_stats.__name__ = f"{attr_name}_stats"
attr_stats.__module__ = Indicator.__module__
attr_stats.__qualname__ = f"{Indicator.__name__}.{attr_stats.__name__}"
attr_stats.__doc__ = f"""Stats of `{attr_name}` as signals."""
setattr(Indicator, f"{attr_name}_stats", attr_stats)
# Prepare stats
if metrics is not None:
if not isinstance(metrics, Config):
metrics = Config(metrics, options_=dict(copy_kwargs=dict(copy_mode="deep")))
setattr(Indicator, "_metrics", metrics.copy())
if stats_defaults is not None:
if isinstance(stats_defaults, dict):
def stats_defaults_prop(self, _stats_defaults: tp.Kwargs = stats_defaults) -> tp.Kwargs:
return _stats_defaults
else:
def stats_defaults_prop(self, _stats_defaults: tp.Kwargs = stats_defaults) -> tp.Kwargs:
return stats_defaults(self)
stats_defaults_prop.__name__ = "stats_defaults"
stats_defaults_prop.__module__ = Indicator.__module__
stats_defaults_prop.__qualname__ = f"{Indicator.__name__}.{stats_defaults_prop.__name__}"
setattr(Indicator, "stats_defaults", property(stats_defaults_prop))
# Prepare plots
if subplots is not None:
if not isinstance(subplots, Config):
subplots = Config(subplots, options_=dict(copy_kwargs=dict(copy_mode="deep")))
setattr(Indicator, "_subplots", subplots.copy())
if plots_defaults is not None:
if isinstance(plots_defaults, dict):
def plots_defaults_prop(self, _plots_defaults: tp.Kwargs = plots_defaults) -> tp.Kwargs:
return _plots_defaults
else:
def plots_defaults_prop(self, _plots_defaults: tp.Kwargs = plots_defaults) -> tp.Kwargs:
return plots_defaults(self)
plots_defaults_prop.__name__ = "plots_defaults"
plots_defaults_prop.__module__ = Indicator.__module__
plots_defaults_prop.__qualname__ = f"{Indicator.__name__}.{plots_defaults_prop.__name__}"
setattr(Indicator, "plots_defaults", property(plots_defaults_prop))
# Store arguments
self._class_name = class_name
self._class_docstring = class_docstring
self._module_name = module_name
self._short_name = short_name
self._prepend_name = prepend_name
self._input_names = input_names
self._param_names = param_names
self._in_output_names = in_output_names
self._output_names = output_names
self._output_flags = output_flags
self._lazy_outputs = lazy_outputs
self._attr_settings = attr_settings
self._metrics = metrics
self._stats_defaults = stats_defaults
self._subplots = subplots
self._plots_defaults = plots_defaults
# Store indicator class
self._Indicator = Indicator
@property
def class_name(self) -> str:
"""Name for the created indicator class."""
return self._class_name
@property
def class_docstring(self) -> str:
"""Docstring for the created indicator class."""
return self._class_docstring
@property
def module_name(self) -> str:
"""Name of the module the class originates from."""
return self._module_name
@property
def short_name(self) -> str:
"""Short name of the indicator."""
return self._short_name
@property
def prepend_name(self) -> bool:
"""Whether to prepend `IndicatorFactory.short_name` to each parameter level."""
return self._prepend_name
@property
def input_names(self) -> tp.List[str]:
"""List with input names."""
return self._input_names
@property
def param_names(self) -> tp.List[str]:
"""List with parameter names."""
return self._param_names
@property
def in_output_names(self) -> tp.List[str]:
"""List with in-output names."""
return self._in_output_names
@property
def output_names(self) -> tp.List[str]:
"""List with output names."""
return self._output_names
@property
def output_flags(self) -> tp.Kwargs:
"""Dictionary of in-place and regular output flags."""
return self._output_flags
@property
def lazy_outputs(self) -> tp.Kwargs:
"""Dictionary with user-defined functions that will become properties."""
return self._lazy_outputs
@property
def attr_settings(self) -> tp.Kwargs:
"""Dictionary with attribute settings."""
return self._attr_settings
@property
def metrics(self) -> Config:
"""Metrics supported by `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats`."""
return self._metrics
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats`."""
return self._stats_defaults
@property
def subplots(self) -> Config:
"""Subplots supported by `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`."""
return self._subplots
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`."""
return self._plots_defaults
@property
def Indicator(self) -> tp.Type[IndicatorBase]:
"""Built indicator class."""
return self._Indicator
# ############# Construction ############# #
def with_custom_func(
self,
custom_func: tp.Callable,
require_input_shape: bool = False,
param_settings: tp.KwargsLike = None,
in_output_settings: tp.KwargsLike = None,
hide_params: tp.Union[None, bool, tp.Sequence[str]] = None,
hide_default: bool = True,
var_args: bool = False,
keyword_only_args: bool = False,
**pipeline_kwargs,
) -> tp.Type[IndicatorBase]:
"""Build indicator class around a custom calculation function.
In contrast to `IndicatorFactory.with_apply_func`, this method offers full flexibility.
It's up to the user to handle caching and concatenate columns for each parameter (for example,
by using `vectorbtpro.base.combining.apply_and_concat`). Also, you must ensure that
each output array has an appropriate number of columns, which is the number of columns in
input arrays multiplied by the number of parameter combinations.
Args:
custom_func (callable): A function that takes broadcast arrays corresponding
to `input_names`, broadcast in-place output arrays corresponding to `in_output_names`,
broadcast parameter arrays corresponding to `param_names`, and other arguments and
keyword arguments, and returns outputs corresponding to `output_names` and other objects
that are then returned with the indicator instance.
Can be Numba-compiled.
!!! note
Shape of each output must be the same and match the shape of each input stacked
n times (= the number of parameter values) along the column axis.
require_input_shape (bool): Whether to input shape is required.
param_settings (dict): A dictionary of parameter settings keyed by name.
See `IndicatorBase.run_pipeline` for keys.
Can be overwritten by any run method.
in_output_settings (dict): A dictionary of in-place output settings keyed by name.
See `IndicatorBase.run_pipeline` for keys.
Can be overwritten by any run method.
hide_params (bool or list of str): Parameter names to hide column levels for,
or whether to hide all parameters.
Can be overwritten by any run method.
hide_default (bool): Whether to hide column levels of parameters with default value.
Can be overwritten by any run method.
var_args (bool): Whether run methods should accept variable arguments (`*args`).
Set to True if `custom_func` accepts positional agruments that are not listed in the config.
keyword_only_args (bool): Whether run methods should accept keyword-only arguments (`*`).
Set to True to force the user to use keyword arguments (e.g., to avoid misplacing arguments).
**pipeline_kwargs: Keyword arguments passed to `IndicatorBase.run_pipeline`.
Can be overwritten by any run method.
Can contain default values and also references to other arguments wrapped
with `vectorbtpro.base.reshaping.Ref`.
Returns:
`Indicator`, and optionally other objects that are returned by `custom_func`
and exceed `output_names`.
Usage:
* The following example produces the same indicator as the `IndicatorFactory.with_apply_func` example.
```pycon
>>> @njit
>>> def apply_func_nb(i, ts1, ts2, p1, p2, arg1, arg2):
... return ts1 * p1[i] + arg1, ts2 * p2[i] + arg2
>>> @njit
... def custom_func(ts1, ts2, p1, p2, arg1, arg2):
... return vbt.base.combining.apply_and_concat_multiple_nb(
... len(p1), apply_func_nb, ts1, ts2, p1, p2, arg1, arg2)
>>> MyInd = vbt.IF(
... input_names=['ts1', 'ts2'],
... param_names=['p1', 'p2'],
... output_names=['o1', 'o2']
... ).with_custom_func(custom_func, var_args=True, arg2=200)
>>> myInd = MyInd.run(price, price * 2, [1, 2], [3, 4], 100)
>>> myInd.o1
custom_p1 1 2
custom_p2 3 4
a b a b
2020-01-01 101.0 105.0 102.0 110.0
2020-01-02 102.0 104.0 104.0 108.0
2020-01-03 103.0 103.0 106.0 106.0
2020-01-04 104.0 102.0 108.0 104.0
2020-01-05 105.0 101.0 110.0 102.0
>>> myInd.o2
custom_p1 1 2
custom_p2 3 4
a b a b
2020-01-01 206.0 230.0 208.0 240.0
2020-01-02 212.0 224.0 216.0 232.0
2020-01-03 218.0 218.0 224.0 224.0
2020-01-04 224.0 212.0 232.0 216.0
2020-01-05 230.0 206.0 240.0 208.0
```
The difference between `apply_func_nb` here and in `IndicatorFactory.with_apply_func` is that
here it takes the index of the current parameter combination that can be used for parameter selection.
* You can also remove the entire `apply_func_nb` and define your logic in `custom_func`
(which shouldn't necessarily be Numba-compiled):
```pycon
>>> @njit
... def custom_func(ts1, ts2, p1, p2, arg1, arg2):
... input_shape = ts1.shape
... n_params = len(p1)
... out1 = np.empty((input_shape[0], input_shape[1] * n_params), dtype=float_)
... out2 = np.empty((input_shape[0], input_shape[1] * n_params), dtype=float_)
... for k in range(n_params):
... for col in range(input_shape[1]):
... for i in range(input_shape[0]):
... out1[i, input_shape[1] * k + col] = ts1[i, col] * p1[k] + arg1
... out2[i, input_shape[1] * k + col] = ts2[i, col] * p2[k] + arg2
... return out1, out2
```
"""
Indicator = self.Indicator
short_name = self.short_name
prepend_name = self.prepend_name
input_names = self.input_names
param_names = self.param_names
in_output_names = self.in_output_names
output_names = self.output_names
all_input_names = input_names + param_names + in_output_names
setattr(Indicator, "custom_func", custom_func)
def _split_args(
args: tp.Sequence,
) -> tp.Tuple[tp.Dict[str, tp.ArrayLike], tp.Dict[str, tp.ArrayLike], tp.Dict[str, tp.ParamValues], tp.Args]:
inputs = dict(zip(input_names, args[: len(input_names)]))
checks.assert_len_equal(inputs, input_names)
args = args[len(input_names) :]
params = dict(zip(param_names, args[: len(param_names)]))
checks.assert_len_equal(params, param_names)
args = args[len(param_names) :]
in_outputs = dict(zip(in_output_names, args[: len(in_output_names)]))
checks.assert_len_equal(in_outputs, in_output_names)
args = args[len(in_output_names) :]
if not var_args and len(args) > 0:
raise TypeError(
"Variable length arguments are not supported by this function (var_args is set to False)"
)
return inputs, in_outputs, params, args
for k, v in pipeline_kwargs.items():
if k in param_names and not isinstance(v, Default):
pipeline_kwargs[k] = Default(v) # track default params
pipeline_kwargs = merge_dicts({k: None for k in in_output_names}, pipeline_kwargs)
# Display default parameters and in-place outputs in the signature
default_kwargs = {}
for k in list(pipeline_kwargs.keys()):
if k in input_names or k in param_names or k in in_output_names:
default_kwargs[k] = pipeline_kwargs.pop(k)
if var_args and keyword_only_args:
raise ValueError("var_args and keyword_only_args cannot be used together")
# Add private run method
def_run_kwargs = dict(
short_name=short_name,
hide_params=hide_params,
hide_default=hide_default,
**default_kwargs,
)
def _run(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunOutputT:
_short_name = kwargs.pop("short_name", def_run_kwargs["short_name"])
_hide_params = kwargs.pop("hide_params", def_run_kwargs["hide_params"])
_hide_default = kwargs.pop("hide_default", def_run_kwargs["hide_default"])
_param_settings = merge_dicts(param_settings, kwargs.pop("param_settings", {}))
_in_output_settings = merge_dicts(in_output_settings, kwargs.pop("in_output_settings", {}))
if isinstance(_hide_params, bool):
if not _hide_params:
_hide_params = None
else:
_hide_params = param_names
if _hide_params is None:
_hide_params = []
args = list(args)
# Split arguments
inputs, in_outputs, params, args = _split_args(args)
# Prepare column levels
level_names = []
hide_levels = []
for pname in param_names:
level_name = _short_name + "_" + pname if prepend_name else pname
level_names.append(level_name)
if pname in _hide_params or (_hide_default and isinstance(params[pname], Default)):
hide_levels.append(level_name)
for k, v in params.items():
if isinstance(v, Default):
params[k] = v.value
# Run the pipeline
results = Indicator.run_pipeline(
len(output_names), # number of returned outputs
custom_func,
*args,
require_input_shape=require_input_shape,
inputs=inputs,
in_outputs=in_outputs,
params=params,
level_names=level_names,
hide_levels=hide_levels,
param_settings=_param_settings,
in_output_settings=_in_output_settings,
**merge_dicts(pipeline_kwargs, kwargs),
)
# Return the raw result if any of the flags are set
if kwargs.get("return_raw", False) or kwargs.get("return_cache", False):
return results
# Unpack the result
(
wrapper,
new_input_list,
input_mapper,
in_output_list,
output_list,
new_param_list,
mapper_list,
other_list,
) = results
# Create a new instance
obj = cls(
wrapper,
new_input_list,
input_mapper,
in_output_list,
output_list,
new_param_list,
mapper_list,
short_name,
)
if len(other_list) > 0:
return (obj, *tuple(other_list))
return obj
setattr(Indicator, "_run", classmethod(_run))
# Add public run method
# Create function dynamically to provide user with a proper signature
def compile_run_function(func_name: str, docstring: str, _default_kwargs: tp.KwargsLike = None) -> tp.Callable:
pos_names = []
main_kw_names = []
other_kw_names = []
if _default_kwargs is None:
_default_kwargs = {}
for k in input_names + param_names:
if k in _default_kwargs:
main_kw_names.append(k)
else:
pos_names.append(k)
main_kw_names.extend(in_output_names) # in_output_names are keyword-only
for k, v in _default_kwargs.items():
if k not in pos_names and k not in main_kw_names:
other_kw_names.append(k)
_0 = func_name
_1 = "*, " if keyword_only_args else ""
_2 = []
if require_input_shape:
_2.append("input_shape")
_2.extend(pos_names)
_2 = ", ".join(_2) + ", " if len(_2) > 0 else ""
_3 = "*args, " if var_args else ""
_4 = ["{}={}".format(k, k) for k in main_kw_names + other_kw_names]
if require_input_shape:
_4 += ["input_index=None", "input_columns=None"]
_4 = ", ".join(_4) + ", " if len(_4) > 0 else ""
_5 = docstring
_6 = all_input_names
_6 = ", ".join(_6) + ", " if len(_6) > 0 else ""
_7 = []
if require_input_shape:
_7.append("input_shape")
_7.extend(other_kw_names)
_7 = ["{}={}".format(k, k) for k in _7]
if require_input_shape:
_7 += ["input_index=input_index", "input_columns=input_columns"]
_7 = ", ".join(_7) + ", " if len(_7) > 0 else ""
func_str = (
"@classmethod\n"
"def {0}(cls, {1}{2}{3}{4}**kwargs):\n"
' """{5}"""\n'
" return cls._{0}({6}{3}{7}**kwargs)".format(_0, _1, _2, _3, _4, _5, _6, _7)
)
scope = {**dict(Default=Default), **_default_kwargs}
filename = inspect.getfile(lambda: None)
code = compile(func_str, filename, "single")
exec(code, scope)
return scope[func_name]
_0 = self.class_name
_1 = ""
if len(self.input_names) > 0:
_1 += "\n* Inputs: " + ", ".join(map(lambda x: f"`{x}`", self.input_names))
if len(self.in_output_names) > 0:
_1 += "\n* In-place outputs: " + ", ".join(map(lambda x: f"`{x}`", self.in_output_names))
if len(self.param_names) > 0:
_1 += "\n* Parameters: " + ", ".join(map(lambda x: f"`{x}`", self.param_names))
if len(self.output_names) > 0:
_1 += "\n* Outputs: " + ", ".join(map(lambda x: f"`{x}`", self.output_names))
if len(self.lazy_outputs) > 0:
_1 += "\n* Lazy outputs: " + ", ".join(map(lambda x: f"`{x}`", list(self.lazy_outputs.keys())))
run_docstring = """Run `{0}` indicator.
{1}
Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all.
Set `hide_default` to False to show the column levels of the parameters with a default value.
Other keyword arguments are passed to `{0}.run_pipeline`.""".format(
_0,
_1,
)
run = compile_run_function("run", run_docstring, def_run_kwargs)
run.__name__ = "run"
run.__module__ = Indicator.__module__
run.__qualname__ = f"{Indicator.__name__}.{run.__name__}"
setattr(Indicator, "run", run)
if len(param_names) > 0:
# Add private run_combs method
def_run_combs_kwargs = dict(
r=2,
param_product=False,
comb_func=itertools.combinations,
run_unique=True,
short_names=None,
hide_params=hide_params,
hide_default=hide_default,
**default_kwargs,
)
def _run_combs(cls: tp.Type[IndicatorBaseT], *args, **kwargs) -> RunCombsOutputT:
_r = kwargs.pop("r", def_run_combs_kwargs["r"])
_param_product = kwargs.pop("param_product", def_run_combs_kwargs["param_product"])
_comb_func = kwargs.pop("comb_func", def_run_combs_kwargs["comb_func"])
_run_unique = kwargs.pop("run_unique", def_run_combs_kwargs["run_unique"])
_short_names = kwargs.pop("short_names", def_run_combs_kwargs["short_names"])
_hide_params = kwargs.pop("hide_params", def_run_kwargs["hide_params"])
_hide_default = kwargs.pop("hide_default", def_run_kwargs["hide_default"])
_param_settings = merge_dicts(param_settings, kwargs.get("param_settings", {}))
if isinstance(_hide_params, bool):
if not _hide_params:
_hide_params = None
else:
_hide_params = param_names
if _hide_params is None:
_hide_params = []
if _short_names is None:
_short_names = [f"{short_name}_{str(i + 1)}" for i in range(_r)]
args = list(args)
# Split arguments
inputs, in_outputs, params, args = _split_args(args)
# Hide params
for pname in param_names:
if _hide_default and isinstance(params[pname], Default):
params[pname] = params[pname].value
if pname not in _hide_params:
_hide_params.append(pname)
checks.assert_len_equal(params, param_names)
# Bring argument to list format
input_list = list(inputs.values())
in_output_list = list(in_outputs.values())
param_list = list(params.values())
# Prepare params
for i, pname in enumerate(param_names):
is_tuple = _param_settings.get(pname, {}).get("is_tuple", False)
is_array_like = _param_settings.get(pname, {}).get("is_array_like", False)
param_list[i] = params_to_list(params[pname], is_tuple, is_array_like)
if _param_product:
param_list = create_param_product(param_list)
else:
param_list = broadcast_params(param_list)
# Speed up by pre-calculating raw outputs
if _run_unique:
raw_results = cls._run(
*input_list,
*param_list,
*in_output_list,
*args,
return_raw=True,
run_unique=False,
**kwargs,
)
kwargs["use_raw"] = raw_results # use them next time
# Generate indicator instances
instances = []
if _comb_func == itertools.product:
param_lists = zip(*_comb_func(zip(*param_list), repeat=_r))
else:
param_lists = zip(*_comb_func(zip(*param_list), _r))
for i, param_list in enumerate(param_lists):
instances.append(
cls._run(
*input_list,
*zip(*param_list),
*in_output_list,
*args,
short_name=_short_names[i],
hide_params=_hide_params,
hide_default=_hide_default,
run_unique=False,
**kwargs,
)
)
return tuple(instances)
setattr(Indicator, "_run_combs", classmethod(_run_combs))
# Add public run_combs method
_0 = self.class_name
_1 = ""
if len(self.input_names) > 0:
_1 += "\n* Inputs: " + ", ".join(map(lambda x: f"`{x}`", self.input_names))
if len(self.in_output_names) > 0:
_1 += "\n* In-place outputs: " + ", ".join(map(lambda x: f"`{x}`", self.in_output_names))
if len(self.param_names) > 0:
_1 += "\n* Parameters: " + ", ".join(map(lambda x: f"`{x}`", self.param_names))
if len(self.output_names) > 0:
_1 += "\n* Outputs: " + ", ".join(map(lambda x: f"`{x}`", self.output_names))
if len(self.lazy_outputs) > 0:
_1 += "\n* Lazy outputs: " + ", ".join(map(lambda x: f"`{x}`", list(self.lazy_outputs.keys())))
run_combs_docstring = """Create a combination of multiple `{0}` indicators using function `comb_func`.
{1}
`comb_func` must accept an iterable of parameter tuples and `r`.
Also accepts all combinatoric iterators from itertools such as `itertools.combinations`.
Pass `r` to specify how many indicators to run.
Pass `short_names` to specify the short name for each indicator.
Set `run_unique` to True to first compute raw outputs for all parameters,
and then use them to build each indicator (faster).
Other keyword arguments are passed to `{0}.run`.
!!! note
This method should only be used when multiple indicators are needed.
To test multiple parameters, pass them as lists to `{0}.run`.
""".format(
_0,
_1,
)
run_combs = compile_run_function("run_combs", run_combs_docstring, def_run_combs_kwargs)
run_combs.__name__ = "run_combs"
run_combs.__module__ = Indicator.__module__
run_combs.__qualname__ = f"{Indicator.__name__}.{run_combs.__name__}"
setattr(Indicator, "run_combs", run_combs)
return Indicator
def with_apply_func(
self,
apply_func: tp.Callable,
cache_func: tp.Optional[tp.Callable] = None,
takes_1d: bool = False,
select_params: bool = True,
pass_packed: bool = False,
cache_pass_packed: tp.Optional[bool] = None,
pass_per_column: bool = False,
cache_pass_per_column: tp.Optional[bool] = None,
forward_skipna: bool = False,
kwargs_as_args: tp.Optional[tp.Iterable[str]] = None,
jit_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Union[str, tp.Type[IndicatorBase]]:
"""Build indicator class around a custom apply function.
In contrast to `IndicatorFactory.with_custom_func`, this method handles a lot of things for you,
such as caching, parameter selection, and concatenation. Your part is writing a function `apply_func`
that accepts a selection of parameters (single values as opposed to multiple values in
`IndicatorFactory.with_custom_func`) and does the calculation. It then automatically concatenates
the resulting arrays into a single array per output.
While this approach is simpler, it's also less flexible, since we can only work with
one parameter selection at a time and can't view all parameters.
The execution and concatenation is performed using `vectorbtpro.base.combining.apply_and_concat_each`.
!!! note
If `apply_func` is a Numba-compiled function:
* All inputs are automatically converted to NumPy arrays
* Each argument in `*args` must be of a Numba-compatible type
* You cannot pass keyword arguments
* Your outputs must be arrays of the same shape, data type and data order
!!! note
Reserved arguments such as `per_column` (in this order) get passed as positional
arguments if `jitted_loop` is True, otherwise as keyword arguments.
Args:
apply_func (callable): A function that takes inputs, selection of parameters, and
other arguments, and does calculations to produce outputs.
Arguments are passed to `apply_func` in the following order:
* `i` (index of the parameter combination) if `select_params` is set to False
* `input_shape` if `pass_input_shape` is set to True and `input_shape` not in `kwargs_as_args`
* Input arrays corresponding to `input_names`. Passed as a tuple if `pass_packed`, otherwise unpacked.
If `select_params` is True, each argument is a list composed of multiple arrays -
one per parameter combination. When `per_column` is True, each of those arrays
corresponds to a column. Otherwise, they all refer to the same array. If `takes_1d`,
each array gets additionally split into multiple column arrays. Still passed
as a single array to the caching function.
* In-output arrays corresponding to `in_output_names`. Passed as a tuple if `pass_packed`, otherwise unpacked.
If `select_params` is True, each argument is a list composed of multiple arrays -
one per parameter combination. When `per_column` is True, each of those arrays
corresponds to a column. If `takes_1d`, each array gets additionally split into
multiple column arrays. Still passed as a single array to the caching function.
* Parameters corresponding to `param_names`. Passed as a tuple if `pass_packed`, otherwise unpacked.
If `select_params` is True, each argument is a list composed of multiple values -
one per parameter combination. When `per_column` is True, each of those values
corresponds to a column. If `takes_1d`, each value gets additionally repeated by
the number of columns in the input arrays.
* Variable arguments if `var_args` is set to True
* `per_column` if `pass_per_column` is set to True and `per_column` not in
`kwargs_as_args` and `jitted_loop` is set to True
* Arguments listed in `kwargs_as_args` passed as positional. Can include `takes_1d` and `per_column`.
* Other keyword arguments if `jitted_loop` is False. Also includes `takes_1d` and `per_column`
if they must be passed and not in `kwargs_as_args`.
Can be Numba-compiled (but doesn't have to).
!!! note
Shape of each output must be the same and match the shape of each input.
cache_func (callable): A caching function to preprocess data beforehand.
Takes the same arguments as `apply_func`. Must return a single object or a tuple of objects.
All returned objects will be passed unpacked as last arguments to `apply_func`.
Can be Numba-compiled (but doesn't have to).
takes_1d (bool): Whether to split 2-dim arrays into multiple 1-dim arrays along the column axis.
Gets applied on inputs and in-outputs, while parameters get repeated by the number of columns.
select_params (bool): Whether to automatically select in-outputs and parameters.
If False, prepends the current iteration index to the arguments.
pass_packed (bool): Whether to pass packed tuples for inputs, in-place outputs, and parameters.
cache_pass_packed (bool): Overrides `pass_packed` for the caching function.
pass_per_column (bool): Whether to pass `per_column`.
cache_pass_per_column (bool): Overrides `pass_per_column` for the caching function.
forward_skipna (bool): Whether to forward `skipna` to the apply function.
kwargs_as_args (iterable of str): Keyword arguments from `kwargs` dict to pass as
positional arguments to the apply function.
Should be used together with `jitted_loop` set to True since Numba doesn't support
variable keyword arguments.
Defaults to []. Order matters.
jit_kwargs (dict): Keyword arguments passed to `@njit` decorator of the parameter selection function.
By default, has `nogil` set to True.
**kwargs: Keyword arguments passed to `IndicatorFactory.with_custom_func`, all the way down
to `vectorbtpro.base.combining.apply_and_concat_each`.
Returns:
Indicator
Usage:
* The following example produces the same indicator as the `IndicatorFactory.with_custom_func` example.
```pycon
>>> @njit
... def apply_func_nb(ts1, ts2, p1, p2, arg1, arg2):
... return ts1 * p1 + arg1, ts2 * p2 + arg2
>>> MyInd = vbt.IF(
... input_names=['ts1', 'ts2'],
... param_names=['p1', 'p2'],
... output_names=['out1', 'out2']
... ).with_apply_func(
... apply_func_nb, var_args=True,
... kwargs_as_args=['arg2'], arg2=200)
>>> myInd = MyInd.run(price, price * 2, [1, 2], [3, 4], 100)
>>> myInd.out1
custom_p1 1 2
custom_p2 3 4
a b a b
2020-01-01 101.0 105.0 102.0 110.0
2020-01-02 102.0 104.0 104.0 108.0
2020-01-03 103.0 103.0 106.0 106.0
2020-01-04 104.0 102.0 108.0 104.0
2020-01-05 105.0 101.0 110.0 102.0
>>> myInd.out2
custom_p1 1 2
custom_p2 3 4
a b a b
2020-01-01 206.0 230.0 208.0 240.0
2020-01-02 212.0 224.0 216.0 232.0
2020-01-03 218.0 218.0 224.0 224.0
2020-01-04 224.0 212.0 232.0 216.0
2020-01-05 230.0 206.0 240.0 208.0
```
* To change the execution engine or specify other engine-related arguments, use `execute_kwargs`:
```pycon
>>> import time
>>> def apply_func(ts, p):
... time.sleep(1)
... return ts * p
>>> MyInd = vbt.IF(
... input_names=['ts'],
... param_names=['p'],
... output_names=['out']
... ).with_apply_func(apply_func)
>>> %timeit MyInd.run(price, [1, 2, 3])
3.02 s ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit MyInd.run(price, [1, 2, 3], execute_kwargs=dict(engine='dask'))
1.02 s ± 2.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```
"""
Indicator = self.Indicator
setattr(Indicator, "apply_func", apply_func)
setattr(Indicator, "cache_func", cache_func)
module_name = self.module_name
input_names = self.input_names
output_names = self.output_names
in_output_names = self.in_output_names
param_names = self.param_names
num_ret_outputs = len(output_names)
if kwargs_as_args is None:
kwargs_as_args = []
if checks.is_numba_func(apply_func):
# Build a function that selects a parameter tuple
# Do it here to avoid compilation with Numba every time custom_func is run
_0 = "i"
_0 += ", args_before"
if len(input_names) > 0:
_0 += ", " + ", ".join(input_names)
if len(in_output_names) > 0:
_0 += ", " + ", ".join(in_output_names)
if len(param_names) > 0:
_0 += ", " + ", ".join(param_names)
_0 += ", *args"
if select_params:
_1 = "*args_before"
else:
_1 = "i, *args_before"
if pass_packed:
if len(input_names) > 0:
_1 += ", (" + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), input_names)) + ",)"
else:
_1 += ", ()"
if len(in_output_names) > 0:
_1 += ", (" + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), in_output_names)) + ",)"
else:
_1 += ", ()"
if len(param_names) > 0:
_1 += ", (" + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), param_names)) + ",)"
else:
_1 += ", ()"
else:
if len(input_names) > 0:
_1 += ", " + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), input_names))
if len(in_output_names) > 0:
_1 += ", " + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), in_output_names))
if len(param_names) > 0:
_1 += ", " + ", ".join(map(lambda x: x + ("[i]" if select_params else ""), param_names))
_1 += ", *args"
func_str = "def param_select_func_nb({0}):\n return apply_func({1})".format(_0, _1)
scope = {"apply_func": apply_func}
filename = inspect.getfile(lambda: None)
code = compile(func_str, filename, "single")
exec(code, scope)
param_select_func_nb = scope["param_select_func_nb"]
param_select_func_nb.__doc__ = "Parameter selection function."
if module_name is not None:
param_select_func_nb.__module__ = module_name
jit_kwargs = merge_dicts(dict(nogil=True), jit_kwargs)
param_select_func_nb = njit(param_select_func_nb, **jit_kwargs)
setattr(Indicator, "param_select_func_nb", param_select_func_nb)
def custom_func(
input_tuple: tp.Tuple[tp.AnyArray, ...],
in_output_tuple: tp.Tuple[tp.List[tp.AnyArray], ...],
param_tuple: tp.Tuple[tp.List[tp.ParamValue], ...],
*_args,
input_shape: tp.Optional[tp.Shape] = None,
per_column: bool = False,
split_columns: bool = False,
skipna: bool = False,
return_cache: bool = False,
use_cache: tp.Union[bool, CacheOutputT] = True,
jitted_loop: bool = False,
jitted_warmup: bool = False,
param_index: tp.Optional[tp.Index] = None,
final_index: tp.Optional[tp.Index] = None,
single_comb: bool = False,
execute_kwargs: tp.KwargsLike = None,
**_kwargs,
) -> tp.Union[None, CacheOutputT, tp.Array2d, tp.List[tp.Array2d]]:
"""Custom function that forwards inputs and parameters to `apply_func`."""
if jitted_loop and not checks.is_numba_func(apply_func):
raise ValueError("Apply function must be Numba-compiled for jitted_loop=True")
if skipna and len(in_output_tuple) > 1:
raise ValueError("NaNs cannot be skipped for in-outputs")
if skipna and jitted_loop:
raise ValueError("NaNs cannot be skipped when jitted_loop=True")
if forward_skipna:
_kwargs["skipna"] = skipna
skipna = False
if execute_kwargs is None:
execute_kwargs = {}
else:
execute_kwargs = dict(execute_kwargs)
_cache_pass_packed = cache_pass_packed
_cache_pass_per_column = cache_pass_per_column
# Prepend positional arguments
args_before = ()
if input_shape is not None and "input_shape" not in kwargs_as_args:
if per_column or takes_1d:
args_before += ((input_shape[0],),)
elif split_columns and len(input_shape) == 2:
args_before += ((input_shape[0], 1),)
else:
args_before += (input_shape,)
# Append positional arguments
more_args = ()
for k in kwargs_as_args:
if k == "per_column":
value = per_column
elif k == "takes_1d":
value = per_column
elif k == "split_columns":
value = per_column
else:
value = _kwargs.pop(k) # important: remove from kwargs
more_args += (value,)
# Resolve the number of parameters
if len(input_tuple) > 0:
if input_tuple[0].ndim == 1:
n_cols = 1
else:
n_cols = input_tuple[0].shape[1]
elif input_shape is not None:
if len(input_shape) == 1:
n_cols = 1
else:
n_cols = input_shape[1]
else:
n_cols = None
if per_column:
n_params = n_cols
else:
n_params = len(param_tuple[0]) if len(param_tuple) > 0 else 1
# Caching
cache = use_cache
if isinstance(cache, bool):
if cache and cache_func is not None:
_input_tuple = input_tuple
_in_output_tuple = ()
for in_outputs in in_output_tuple:
if checks.is_numba_func(cache_func):
_in_outputs = to_typed_list(in_outputs)
else:
_in_outputs = in_outputs
_in_output_tuple += (_in_outputs,)
_param_tuple = ()
for params in param_tuple:
if checks.is_numba_func(cache_func):
_params = to_typed_list(params)
else:
_params = params
_param_tuple += (_params,)
if _cache_pass_packed is None:
_cache_pass_packed = pass_packed
if _cache_pass_per_column is None and per_column:
_cache_pass_per_column = True
if _cache_pass_per_column is None:
_cache_pass_per_column = pass_per_column
cache_more_args = tuple(more_args)
cache_kwargs = dict(_kwargs)
if _cache_pass_per_column:
if "per_column" not in kwargs_as_args:
if jitted_loop:
cache_more_args += (per_column,)
else:
cache_kwargs["per_column"] = per_column
if _cache_pass_packed:
cache = cache_func(
*args_before,
_input_tuple,
_in_output_tuple,
_param_tuple,
*_args,
*cache_more_args,
**cache_kwargs,
)
else:
cache = cache_func(
*args_before,
*_input_tuple,
*_in_output_tuple,
*_param_tuple,
*_args,
*cache_more_args,
**cache_kwargs,
)
else:
cache = None
if return_cache:
return cache
if cache is None:
cache = ()
if not isinstance(cache, tuple):
cache = (cache,)
# Prepare inputs
def _expand_input(input: tp.MaybeList[tp.AnyArray], multiple: bool = False) -> tp.List[tp.AnyArray]:
if jitted_loop:
_inputs = List()
else:
_inputs = []
if per_column:
if multiple:
_input = input[0]
else:
_input = input
if _input.ndim == 2:
for i in range(_input.shape[1]):
if takes_1d:
if isinstance(_input, pd.DataFrame):
_inputs.append(_input.iloc[:, i])
else:
_inputs.append(_input[:, i])
else:
if isinstance(_input, pd.DataFrame):
_inputs.append(_input.iloc[:, i : i + 1])
else:
_inputs.append(_input[:, i : i + 1])
else:
_inputs.append(_input)
else:
for p in range(n_params):
if multiple:
_input = input[p]
else:
_input = input
if takes_1d or split_columns:
if isinstance(_input, pd.DataFrame):
for i in range(_input.shape[1]):
if takes_1d:
_inputs.append(_input.iloc[:, i])
else:
_inputs.append(_input.iloc[:, i : i + 1])
elif _input.ndim == 2:
for i in range(_input.shape[1]):
if takes_1d:
_inputs.append(_input[:, i])
else:
_inputs.append(_input[:, i : i + 1])
else:
_inputs.append(_input)
else:
_inputs.append(_input)
return _inputs
_input_tuple = ()
for input in input_tuple:
_inputs = _expand_input(input)
_input_tuple += (_inputs,)
if skipna and len(_input_tuple) > 0:
new_input_tuple = tuple([[] for _ in range(len(_input_tuple))])
nan_masks = []
any_nan = False
for i in range(len(_input_tuple[0])):
inputs = []
for k in range(len(_input_tuple)):
input = _input_tuple[k][i]
if input.ndim == 2 and input.shape[1] > 1:
raise ValueError(
"NaNs cannot be skipped for multi-columnar inputs. Use split_columns=True."
)
inputs.append(input)
nan_mask = build_nan_mask(*inputs)
nan_masks.append(nan_mask)
if not any_nan and nan_mask is not None and np.any(nan_mask):
any_nan = True
if any_nan:
inputs = squeeze_nan(*inputs, nan_mask=nan_mask)
for k in range(len(_input_tuple)):
new_input_tuple[k].append(inputs[k])
_input_tuple = new_input_tuple
if any_nan:
def post_execute_func(outputs, nan_masks):
new_outputs = []
for i, output in enumerate(outputs):
if isinstance(output, (tuple, list, List)):
output = unsqueeze_nan(*output, nan_mask=nan_masks[i])
else:
output = unsqueeze_nan(output, nan_mask=nan_masks[i])
new_outputs.append(output)
return new_outputs
if "post_execute_func" in execute_kwargs:
raise ValueError("Cannot use custom post_execute_func when skipna=True")
execute_kwargs["post_execute_func"] = post_execute_func
if "post_execute_kwargs" in execute_kwargs:
raise ValueError("Cannot use custom post_execute_kwargs when skipna=True")
execute_kwargs["post_execute_kwargs"] = dict(outputs=Rep("results"), nan_masks=nan_masks)
if "post_execute_on_sorted" in execute_kwargs:
raise ValueError("Cannot use custom post_execute_on_sorted when skipna=True")
execute_kwargs["post_execute_on_sorted"] = True
_in_output_tuple = ()
for in_outputs in in_output_tuple:
_in_outputs = _expand_input(in_outputs, multiple=True)
_in_output_tuple += (_in_outputs,)
_param_tuple = ()
for params in param_tuple:
if not per_column and (takes_1d or split_columns):
_params = [params[p] for p in range(len(params)) for _ in range(n_cols)]
else:
_params = params
if jitted_loop:
if len(_params) > 0 and np.isscalar(_params[0]):
_params = np.asarray(_params)
else:
_params = to_typed_list(_params)
_param_tuple += (_params,)
if not per_column and (takes_1d or split_columns):
_n_params = n_params * n_cols
keys = final_index
else:
_n_params = n_params
keys = param_index
execute_kwargs = merge_dicts(dict(show_progress=False if single_comb else None), execute_kwargs)
execute_kwargs["keys"] = keys
if pass_per_column:
if "per_column" not in kwargs_as_args:
if jitted_loop:
more_args += (per_column,)
else:
_kwargs["per_column"] = per_column
# Apply function and concatenate outputs
if jitted_loop:
return combining.apply_and_concat(
_n_params,
param_select_func_nb,
args_before,
*_input_tuple,
*_in_output_tuple,
*_param_tuple,
*_args,
*more_args,
*cache,
**_kwargs,
n_outputs=num_ret_outputs,
jitted_loop=True,
jitted_warmup=jitted_warmup,
execute_kwargs=execute_kwargs,
)
tasks = []
for i in range(_n_params):
if select_params:
_inputs = tuple(_inputs[i] for _inputs in _input_tuple)
_in_outputs = tuple(_in_outputs[i] for _in_outputs in _in_output_tuple)
_params = tuple(_params[i] for _params in _param_tuple)
else:
_inputs = _input_tuple
_in_outputs = _in_output_tuple
_params = _param_tuple
tasks.append(
Task(
apply_func,
*((i,) if not select_params else ()),
*args_before,
*((_inputs,) if pass_packed else _inputs),
*((_in_outputs,) if pass_packed else _in_outputs),
*((_params,) if pass_packed else _params),
*_args,
*more_args,
*cache,
**_kwargs,
)
)
return combining.apply_and_concat_each(
tasks,
n_outputs=num_ret_outputs,
execute_kwargs=execute_kwargs,
)
return self.with_custom_func(
custom_func,
pass_packed=True,
pass_param_index=True,
pass_final_index=True,
pass_single_comb=True,
**kwargs,
)
# ############# Exploration ############# #
_custom_indicators: tp.ClassVar[Config] = HybridConfig()
@class_property
def custom_indicators(cls) -> Config:
"""Custom indicators keyed by custom locations."""
return cls._custom_indicators
@classmethod
def list_custom_locations(cls) -> tp.List[str]:
"""List custom locations.
Appear in the order they were registered."""
return list(cls.custom_indicators.keys())
@classmethod
def list_builtin_locations(cls) -> tp.List[str]:
"""List built-in locations.
Appear in the order as defined by the author."""
return [
"vbt",
"talib_func",
"talib",
"pandas_ta",
"ta",
"technical",
"techcon",
"smc",
"wqa101",
]
@classmethod
def list_locations(cls) -> tp.List[str]:
"""List all supported locations.
First come custom locations, then built-in locations."""
return [*cls.list_custom_locations(), *cls.list_builtin_locations()]
@classmethod
def match_location(cls, location: str) -> tp.Optional[str]:
"""Match location."""
for k in cls.list_locations():
if k.lower() == location.lower():
return k
return None
@classmethod
def split_indicator_name(cls, name: str) -> tp.Tuple[tp.Optional[str], tp.Optional[str]]:
"""Split an indicator name into location and actual name."""
locations = cls.list_locations()
matched_location = cls.match_location(name)
if matched_location is not None:
return matched_location, None
if ":" in name:
location = name.split(":")[0].strip()
name = name.split(":")[1].strip()
else:
location = None
found_location = False
if "_" in name:
for location in locations:
if name.lower().startswith(location.lower() + "_"):
found_location = True
break
if found_location:
name = name[len(location) + 1 :]
else:
location = None
return location, name
@classmethod
def register_custom_indicator(
cls,
indicator: tp.Union[str, tp.Type[IndicatorBase]],
name: tp.Optional[str] = None,
location: tp.Optional[str] = None,
if_exists: str = "raise",
) -> None:
"""Register a custom indicator under a custom location.
Argument `if_exists` can be "raise", "skip", or "override"."""
if isinstance(indicator, str):
indicator = cls.get_indicator(indicator)
if name is None:
name = indicator.__name__
elif location is None:
location, name = cls.split_indicator_name(name)
if location is None:
location = "custom"
else:
matched_location = cls.match_location(location)
if matched_location is not None:
location = matched_location
if not name.isidentifier():
raise ValueError(f"Custom name '{name}' must be a valid variable name")
if not location.isidentifier():
raise ValueError(f"Custom location '{location}' must be a valid variable name")
if location in cls.list_builtin_locations():
raise ValueError(f"Custom location '{location}' shadows a built-in location with the same name")
if location not in cls.custom_indicators:
cls.custom_indicators[location] = dict()
for k in cls.custom_indicators[location]:
if name.upper() == k.upper():
if if_exists.lower() == "raise":
raise ValueError(f"Indicator with name '{name}' already exists under location '{location}'")
if if_exists.lower() == "skip":
return None
if if_exists.lower() == "override":
break
raise ValueError(f"Invalid if_exists: '{if_exists}'")
cls.custom_indicators[location][name] = indicator
@classmethod
def deregister_custom_indicator(
cls,
name: tp.Optional[str] = None,
location: tp.Optional[str] = None,
remove_location: bool = True,
) -> None:
"""Deregister a custom indicator by its name and location.
If `location` is None, deregisters all indicators with the same name across all custom locations."""
if location is not None:
matched_location = cls.match_location(location)
if matched_location is not None:
location = matched_location
if name is None:
if location is None:
for k in list(cls.custom_indicators.keys()):
del cls.custom_indicators[k]
else:
del cls.custom_indicators[location]
else:
if location is None:
location, name = cls.split_indicator_name(name)
if location is None:
for k, v in list(cls.custom_indicators.items()):
for k2 in list(cls.custom_indicators[k].keys()):
if name.upper() == k2.upper():
del cls.custom_indicators[k][k2]
if remove_location and len(cls.custom_indicators[k]) == 0:
del cls.custom_indicators[k]
else:
for k in list(cls.custom_indicators[location].keys()):
if name.upper() == k.upper():
del cls.custom_indicators[location][k]
if remove_location and len(cls.custom_indicators[location]) == 0:
del cls.custom_indicators[location]
@classmethod
def get_custom_indicator(
cls,
name: str,
location: tp.Optional[str] = None,
return_first: bool = False,
) -> tp.Type[IndicatorBase]:
"""Get a custom indicator."""
if location is None:
location, name = cls.split_indicator_name(name)
else:
matched_location = cls.match_location(location)
if matched_location is not None:
location = matched_location
name = name.upper()
if location is None:
found_indicators = []
for k, v in cls.custom_indicators.items():
for k2, v2 in v.items():
k2 = k2.upper()
if k2 == name:
found_indicators.append(v2)
if len(found_indicators) == 1:
return found_indicators[0]
if len(found_indicators) > 1:
if return_first:
return found_indicators[0]
raise KeyError(f"Found multiple custom indicators with name '{name}'")
raise KeyError(f"Found no custom indicator with name '{name}'")
else:
for k, v in cls.custom_indicators[location].items():
k = k.upper()
if k == name:
return v
raise KeyError(f"Found no custom indicator with name '{name}' under location '{location}'")
@classmethod
def list_custom_indicators(
cls,
uppercase: bool = False,
location: tp.Optional[str] = None,
prepend_location: tp.Optional[bool] = None,
) -> tp.List[str]:
"""List custom indicators."""
if location is not None:
matched_location = cls.match_location(location)
if matched_location is not None:
location = matched_location
locations_names = []
non_custom_location = False
for k, v in cls.custom_indicators.items():
if location is not None:
if k != location:
continue
for k2, v2 in v.items():
if uppercase:
k2 = k2.upper()
if not non_custom_location and k != "custom":
non_custom_location = True
locations_names.append((k, k2))
locations_names = sorted(locations_names, key=lambda x: (x[0].upper(), x[1]))
if prepend_location is None:
prepend_location = location is None and non_custom_location
if prepend_location:
return list(map(lambda x: x[0] + ":" + x[1], locations_names))
return list(map(lambda x: x[1], locations_names))
@classmethod
def list_vbt_indicators(cls) -> tp.List[str]:
"""List all vectorbt indicators."""
import vectorbtpro as vbt
return sorted(
[
attr
for attr in dir(vbt)
if not attr.startswith("_")
and isinstance(getattr(vbt, attr), type)
and getattr(vbt, attr) is not IndicatorBase
and issubclass(getattr(vbt, attr), IndicatorBase)
]
)
@classmethod
def list_indicators(
cls,
pattern: tp.Optional[str] = None,
case_sensitive: bool = False,
use_regex: bool = False,
location: tp.Optional[str] = None,
prepend_location: tp.Optional[bool] = None,
) -> tp.List[str]:
"""List indicators, optionally matching a pattern.
Pattern can also be a location, in such a case all indicators from that location will be returned.
For supported locations, see `IndicatorFactory.list_locations`."""
if pattern is not None:
if not case_sensitive:
pattern = pattern.lower()
if location is None and cls.match_location(pattern) is not None:
location = pattern
pattern = None
if prepend_location is None:
if location is not None:
prepend_location = False
else:
prepend_location = True
with WarningsFiltered():
if location is not None:
matched_location = cls.match_location(location)
if matched_location is not None:
location = matched_location
if location in cls.list_custom_locations():
all_indicators = cls.list_custom_indicators(location=location, prepend_location=prepend_location)
else:
all_indicators = map(
lambda x: location + ":" + x if prepend_location else x,
getattr(cls, f"list_{location}_indicators")(),
)
else:
from vectorbtpro.utils.module_ import check_installed
all_indicators = [
*cls.list_custom_indicators(prepend_location=prepend_location),
*map(lambda x: "vbt:" + x if prepend_location else x, cls.list_vbt_indicators()),
*map(
lambda x: "talib:" + x if prepend_location else x,
cls.list_talib_indicators() if check_installed("talib") else [],
),
*map(
lambda x: "pandas_ta:" + x if prepend_location else x,
cls.list_pandas_ta_indicators() if check_installed("pandas_ta") else [],
),
*map(
lambda x: "ta:" + x if prepend_location else x,
cls.list_ta_indicators() if check_installed("ta") else [],
),
*map(
lambda x: "technical:" + x if prepend_location else x,
cls.list_technical_indicators() if check_installed("technical") else [],
),
*map(
lambda x: "techcon:" + x if prepend_location else x,
cls.list_techcon_indicators() if check_installed("technical") else [],
),
*map(
lambda x: "smc:" + x if prepend_location else x,
cls.list_smc_indicators() if check_installed("smartmoneyconcepts") else [],
),
*map(lambda x: "wqa101:" + str(x) if prepend_location else str(x), range(1, 102)),
]
found_indicators = []
for indicator in all_indicators:
if prepend_location and location is not None:
indicator = location + ":" + indicator
if case_sensitive:
indicator_name = indicator
else:
indicator_name = indicator.lower()
if pattern is not None:
if use_regex:
if location is not None:
if not re.match(pattern, indicator_name):
continue
else:
if not re.match(pattern, indicator_name.split(":")[1]):
continue
else:
if location is not None:
if not re.match(fnmatch.translate(pattern), indicator_name):
continue
else:
if not re.match(fnmatch.translate(pattern), indicator_name.split(":")[1]):
continue
found_indicators.append(indicator)
return found_indicators
@classmethod
def get_indicator(cls, name: str, location: tp.Optional[str] = None) -> tp.Type[IndicatorBase]:
"""Get the indicator class by its name.
The name can contain a location suffix followed by a colon. For example, "talib:sma"
or "talib_sma" will return the TA-Lib's SMA. Without a location, the indicator will be
searched throughout all indicators, including the vectorbt's ones."""
if location is None:
location, name = cls.split_indicator_name(name)
else:
matched_location = cls.match_location(location)
if matched_location is not None:
location = matched_location
if name is not None:
name = name.upper()
if location is not None:
if location in cls.list_custom_locations():
return cls.get_custom_indicator(name, location=location)
if location == "vbt":
import vectorbtpro as vbt
return getattr(vbt, name.upper())
if location == "talib":
return cls.from_talib(name)
if location == "pandas_ta":
return cls.from_pandas_ta(name)
if location == "ta":
return cls.from_ta(name)
if location == "technical":
return cls.from_technical(name)
if location == "techcon":
return cls.from_techcon(name)
if location == "smc":
return cls.from_smc(name)
if location == "wqa101":
return cls.from_wqa101(int(name))
raise ValueError(f"Location '{location}' not found")
else:
import vectorbtpro as vbt
from vectorbtpro.utils.module_ import check_installed
if name in cls.list_custom_indicators(uppercase=True, prepend_location=False):
return cls.get_custom_indicator(name, return_first=True)
if hasattr(vbt, name):
return getattr(vbt, name)
if str(name).isnumeric():
return cls.from_wqa101(int(name))
if check_installed("smc") and name in cls.list_smc_indicators():
return cls.from_smc(name)
if check_installed("technical") and name in cls.list_techcon_indicators():
return cls.from_techcon(name)
if check_installed("talib") and name in cls.list_talib_indicators():
return cls.from_talib(name)
if check_installed("ta") and name in cls.list_ta_indicators(uppercase=True):
return cls.from_ta(name)
if check_installed("pandas_ta") and name in cls.list_pandas_ta_indicators():
return cls.from_pandas_ta(name)
if check_installed("technical") and name in cls.list_technical_indicators():
return cls.from_technical(name)
raise ValueError(f"Indicator '{name}' not found")
# ############# Third party ############# #
@classmethod
def list_talib_indicators(cls) -> tp.List[str]:
"""List all parseable indicators in `talib`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("talib")
import talib
return sorted(talib.get_functions())
@classmethod
def from_talib(cls, func_name: str, factory_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Type[IndicatorBase]:
"""Build an indicator class around a `talib` function.
Requires [TA-Lib](https://github.com/mrjbq7/ta-lib) installed.
For input, parameter and output names, see [docs](https://github.com/mrjbq7/ta-lib/blob/master/docs/index.md).
Args:
func_name (str): Function name.
factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`.
**kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`.
Returns:
Indicator
Usage:
```pycon
>>> SMA = vbt.IF.from_talib('SMA')
>>> sma = SMA.run(price, timeperiod=[2, 3])
>>> sma.real
sma_timeperiod 2 3
a b a b
2020-01-01 NaN NaN NaN NaN
2020-01-02 1.5 4.5 NaN NaN
2020-01-03 2.5 3.5 2.0 4.0
2020-01-04 3.5 2.5 3.0 3.0
2020-01-05 4.5 1.5 4.0 2.0
```
* To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`:
```pycon
>>> vbt.phelp(SMA.run)
SMA.run(
close,
timeperiod=Default(value=30),
timeframe=Default(value=None),
short_name='sma',
hide_params=None,
hide_default=True,
**kwargs
):
Run `SMA` indicator.
* Inputs: `close`
* Parameters: `timeperiod`, `timeframe`
* Outputs: `real`
Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all.
Set `hide_default` to False to show the column levels of the parameters with a default value.
Other keyword arguments are passed to `SMA.run_pipeline`.
```
* To plot an indicator:
```pycon
>>> sma.plot(column=(2, 'a')).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("talib")
import talib
from talib import abstract
from vectorbtpro.indicators.talib_ import talib_func, talib_plot_func
func_name = func_name.upper()
info = abstract.Function(func_name).info
input_names = []
for in_names in info["input_names"].values():
if isinstance(in_names, (list, tuple)):
input_names.extend(list(in_names))
else:
input_names.append(in_names)
class_name = info["name"]
class_docstring = "{}, {}".format(info["display_name"], info["group"])
param_names = list(info["parameters"].keys()) + ["timeframe"]
output_names = info["output_names"]
output_flags = info["output_flags"]
_talib_func = talib_func(func_name)
_talib_plot_func = talib_plot_func(func_name)
def apply_func(
input_tuple: tp.Tuple[tp.Array2d, ...],
in_output_tuple: tp.Tuple[tp.Array2d, ...],
param_tuple: tp.Tuple[tp.ParamValue, ...],
timeframe: tp.Optional[tp.FrequencyLike] = None,
**_kwargs,
) -> tp.MaybeTuple[tp.Array2d]:
if len(param_tuple) == len(param_names):
if timeframe is not None:
raise ValueError("Time frame is set both as a parameter and as a keyword argument")
timeframe = param_tuple[-1]
param_tuple = param_tuple[:-1]
elif len(param_tuple) > len(param_names):
raise ValueError("Provided more parameters than registered")
return _talib_func(
*input_tuple,
*param_tuple,
timeframe=timeframe,
**_kwargs,
)
apply_func.__doc__ = f"""Apply function based on `vbt.talib_func("{func_name}")`."""
kwargs = merge_dicts({k: Default(v) for k, v in info["parameters"].items()}, dict(timeframe=None), kwargs)
Indicator = cls(
**merge_dicts(
dict(
class_name=class_name,
class_docstring=class_docstring,
module_name=__name__ + ".talib",
input_names=input_names,
param_names=param_names,
output_names=output_names,
output_flags=output_flags,
),
factory_kwargs,
)
).with_apply_func(
apply_func,
pass_packed=True,
pass_wrapper=True,
forward_skipna=True,
**kwargs,
)
def plot(
self,
column: tp.Optional[tp.Label] = None,
add_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**kwargs,
) -> tp.BaseFigure:
self_col = self.select_col(column=column)
return _talib_plot_func(
*[getattr(self_col, output_name) for output_name in output_names],
add_shape_kwargs=add_shape_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
**kwargs,
)
plot.__doc__ = f"""Plot function based on `vbt.talib_plot_func("{func_name}")`."""
setattr(Indicator, "plot", plot)
return Indicator
@classmethod
def parse_pandas_ta_config(
cls,
func: tp.Callable,
test_input_names: tp.Optional[tp.Sequence[str]] = None,
test_index_len: int = 100,
silence_warnings: bool = False,
**kwargs,
) -> tp.Kwargs:
"""Parse the config of a `pandas_ta` indicator."""
if test_input_names is None:
test_input_names = {"open_", "open", "high", "low", "close", "adj_close", "volume", "dividends", "split"}
input_names = []
param_names = []
output_names = []
defaults = {}
# Parse the function signature of the indicator to get input names
sig = inspect.signature(func)
for k, v in sig.parameters.items():
if v.kind not in (v.VAR_POSITIONAL, v.VAR_KEYWORD):
if v.annotation != inspect.Parameter.empty and v.annotation == pd.Series:
input_names.append(k)
elif k in test_input_names:
input_names.append(k)
elif v.default == inspect.Parameter.empty:
# Any positional argument is considered input
input_names.append(k)
else:
param_names.append(k)
defaults[k] = v.default
# To get output names, we need to run the indicator
test_df = pd.DataFrame(
{c: np.random.uniform(1, 10, size=(test_index_len,)) for c in input_names},
index=pd.date_range("2020", periods=test_index_len),
)
new_args = merge_dicts({c: test_df[c] for c in input_names}, kwargs)
result = suppress_stdout(func)(**new_args)
# Concatenate Series/DataFrames if the result is a tuple
if isinstance(result, tuple):
results = []
for i, r in enumerate(result):
if len(r.index) != len(test_df.index):
if not silence_warnings:
warn(f"Couldn't parse the output at index {i}: mismatching index")
else:
results.append(r)
if len(results) > 1:
result = pd.concat(results, axis=1)
elif len(results) == 1:
result = results[0]
else:
raise ValueError("Couldn't parse the output")
# Test if the produced array has the same index length
if len(result.index) != len(test_df.index):
raise ValueError("Couldn't parse the output: mismatching index")
# Standardize output names: remove numbers, remove hyphens, and bring to lower case
output_cols = result.columns.tolist() if isinstance(result, pd.DataFrame) else [result.name]
new_output_cols = []
for i in range(len(output_cols)):
name_parts = []
for name_part in output_cols[i].split("_"):
try:
float(name_part)
continue
except:
name_parts.append(name_part.replace("-", "_").lower())
output_col = "_".join(name_parts)
new_output_cols.append(output_col)
# Add numbers to duplicates
for k, v in Counter(new_output_cols).items():
if v == 1:
output_names.append(k)
else:
for i in range(v):
output_names.append(k + str(i))
return dict(
class_name=func.__name__.upper(),
class_docstring=func.__doc__,
input_names=input_names,
param_names=param_names,
output_names=output_names,
defaults=defaults,
)
@classmethod
def list_pandas_ta_indicators(cls, silence_warnings: bool = True, **kwargs) -> tp.List[str]:
"""List all parseable indicators in `pandas_ta`.
!!! note
Returns only the indicators that have been successfully parsed."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pandas_ta")
import pandas_ta
indicators = set()
for func_name in [_k for k, v in pandas_ta.Category.items() for _k in v]:
try:
cls.parse_pandas_ta_config(getattr(pandas_ta, func_name), silence_warnings=silence_warnings, **kwargs)
indicators.add(func_name.upper())
except Exception as e:
if not silence_warnings:
warn(f"Function {func_name}: " + str(e))
return sorted(indicators)
@classmethod
def from_pandas_ta(
cls,
func_name: str,
parse_kwargs: tp.KwargsLike = None,
factory_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Type[IndicatorBase]:
"""Build an indicator class around a `pandas_ta` function.
Requires [pandas-ta](https://github.com/twopirllc/pandas-ta) installed.
Args:
func_name (str): Function name.
parse_kwargs (dict): Keyword arguments passed to `IndicatorFactory.parse_pandas_ta_config`.
factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`.
**kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`.
Returns:
Indicator
Usage:
```pycon
>>> SMA = vbt.IF.from_pandas_ta('SMA')
>>> sma = SMA.run(price, length=[2, 3])
>>> sma.sma
sma_length 2 3
a b a b
2020-01-01 NaN NaN NaN NaN
2020-01-02 1.5 4.5 NaN NaN
2020-01-03 2.5 3.5 2.0 4.0
2020-01-04 3.5 2.5 3.0 3.0
2020-01-05 4.5 1.5 4.0 2.0
```
* To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`:
```pycon
>>> vbt.phelp(SMA.run)
SMA.run(
close,
length=Default(value=None),
talib=Default(value=None),
offset=Default(value=None),
short_name='sma',
hide_params=None,
hide_default=True,
**kwargs
):
Run `SMA` indicator.
* Inputs: `close`
* Parameters: `length`, `talib`, `offset`
* Outputs: `sma`
Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all.
Set `hide_default` to False to show the column levels of the parameters with a default value.
Other keyword arguments are passed to `SMA.run_pipeline`.
```
* To get the indicator docstring, use the `help` command or print the `__doc__` attribute:
```pycon
>>> print(SMA.__doc__)
Simple Moving Average (SMA)
The Simple Moving Average is the classic moving average that is the equally
weighted average over n periods.
Sources:
https://www.tradingtechnologies.com/help/x-study/technical-indicator-definitions/simple-moving-average-sma/
Calculation:
Default Inputs:
length=10
SMA = SUM(close, length) / length
Args:
close (pd.Series): Series of 'close's
length (int): It's period. Default: 10
offset (int): How many periods to offset the result. Default: 0
Kwargs:
adjust (bool): Default: True
presma (bool, optional): If True, uses SMA for initial value.
fillna (value, optional): pd.DataFrame.fillna(value)
fill_method (value, optional): Type of fill method
Returns:
pd.Series: New feature generated.
```
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pandas_ta")
import pandas_ta
func_name = func_name.lower()
func = getattr(pandas_ta, func_name)
if parse_kwargs is None:
parse_kwargs = {}
config = cls.parse_pandas_ta_config(func, **parse_kwargs)
def apply_func(
input_tuple: tp.Tuple[tp.AnyArray, ...],
in_output_tuple: tp.Tuple[tp.SeriesFrame, ...],
param_tuple: tp.Tuple[tp.ParamValue, ...],
**_kwargs,
) -> tp.MaybeTuple[tp.Array2d]:
is_series = isinstance(input_tuple[0], pd.Series)
n_input_cols = 1 if is_series else len(input_tuple[0].columns)
outputs = []
for col in range(n_input_cols):
output = suppress_stdout(func)(
**{
name: input_tuple[i] if is_series else input_tuple[i].iloc[:, col]
for i, name in enumerate(config["input_names"])
},
**{name: param_tuple[i] for i, name in enumerate(config["param_names"])},
**_kwargs,
)
if isinstance(output, tuple):
_outputs = []
for o in output:
if len(input_tuple[0].index) == len(o.index):
_outputs.append(o)
if len(_outputs) > 1:
output = pd.concat(_outputs, axis=1)
elif len(_outputs) == 1:
output = _outputs[0]
else:
raise ValueError("No valid outputs were returned")
if isinstance(output, pd.DataFrame):
output = tuple([output.iloc[:, i] for i in range(len(output.columns))])
outputs.append(output)
if isinstance(outputs[0], tuple): # multiple outputs
outputs = list(zip(*outputs))
return tuple(map(column_stack_arrays, outputs))
return column_stack_arrays(outputs)
kwargs = merge_dicts({k: Default(v) for k, v in config.pop("defaults").items()}, kwargs)
Indicator = cls(
**merge_dicts(dict(module_name=__name__ + ".pandas_ta"), config, factory_kwargs),
).with_apply_func(apply_func, pass_packed=True, keep_pd=True, to_2d=False, **kwargs)
return Indicator
@classmethod
def list_ta_indicators(cls, uppercase: bool = False) -> tp.List[str]:
"""List all parseable indicators in `ta`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("ta")
import ta
ta_module_names = [k for k in dir(ta) if isinstance(getattr(ta, k), ModuleType)]
indicators = set()
for module_name in ta_module_names:
module = getattr(ta, module_name)
for name in dir(module):
obj = getattr(module, name)
if (
isinstance(obj, type)
and obj != ta.utils.IndicatorMixin
and issubclass(obj, ta.utils.IndicatorMixin)
):
if uppercase:
indicators.add(obj.__name__.upper())
else:
indicators.add(obj.__name__)
return sorted(indicators)
@classmethod
def find_ta_indicator(cls, cls_name: str) -> IndicatorMixinT:
"""Get `ta` indicator class by its name."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("ta")
import ta
ta_module_names = [k for k in dir(ta) if isinstance(getattr(ta, k), ModuleType)]
for module_name in ta_module_names:
module = getattr(ta, module_name)
for attr in dir(module):
if cls_name.upper() == attr.upper():
return getattr(module, attr)
raise AttributeError(f"Indicator '{cls_name}' not found")
@classmethod
def parse_ta_config(cls, ind_cls: IndicatorMixinT) -> tp.Kwargs:
"""Parse the config of a `ta` indicator."""
input_names = []
param_names = []
defaults = {}
output_names = []
# Parse the __init__ signature of the indicator class to get input names
sig = inspect.signature(ind_cls)
for k, v in sig.parameters.items():
if v.kind not in (v.VAR_POSITIONAL, v.VAR_KEYWORD):
if v.annotation == inspect.Parameter.empty:
raise ValueError(f'Argument "{k}" has no annotation')
if v.annotation == pd.Series:
input_names.append(k)
else:
param_names.append(k)
if v.default != inspect.Parameter.empty:
defaults[k] = v.default
# Get output names by looking into instance methods
for attr in dir(ind_cls):
if not attr.startswith("_"):
if inspect.signature(getattr(ind_cls, attr)).return_annotation == pd.Series:
output_names.append(attr)
elif "Returns:\n pandas.Series" in getattr(ind_cls, attr).__doc__:
output_names.append(attr)
return dict(
class_name=ind_cls.__name__,
class_docstring=ind_cls.__doc__,
input_names=input_names,
param_names=param_names,
output_names=output_names,
defaults=defaults,
)
@classmethod
def from_ta(cls, cls_name: str, factory_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Type[IndicatorBase]:
"""Build an indicator class around a `ta` class.
Requires [ta](https://github.com/bukosabino/ta) installed.
Args:
cls_name (str): Class name.
factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`.
**kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`.
Returns:
Indicator
Usage:
```pycon
>>> SMAIndicator = vbt.IF.from_ta('SMAIndicator')
>>> sma = SMAIndicator.run(price, window=[2, 3])
>>> sma.sma_indicator
smaindicator_window 2 3
a b a b
2020-01-01 NaN NaN NaN NaN
2020-01-02 1.5 4.5 NaN NaN
2020-01-03 2.5 3.5 2.0 4.0
2020-01-04 3.5 2.5 3.0 3.0
2020-01-05 4.5 1.5 4.0 2.0
```
* To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`:
```pycon
>>> vbt.phelp(SMAIndicator.run)
SMAIndicator.run(
close,
window,
fillna=Default(value=False),
short_name='smaindicator',
hide_params=None,
hide_default=True,
**kwargs
):
Run `SMAIndicator` indicator.
* Inputs: `close`
* Parameters: `window`, `fillna`
* Outputs: `sma_indicator`
Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all.
Set `hide_default` to False to show the column levels of the parameters with a default value.
Other keyword arguments are passed to `SMAIndicator.run_pipeline`.
```
* To get the indicator docstring, use the `help` command or print the `__doc__` attribute:
```pycon
>>> print(SMAIndicator.__doc__)
SMA - Simple Moving Average
Args:
close(pandas.Series): dataset 'Close' column.
window(int): n period.
fillna(bool): if True, fill nan values.
```
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("ta")
ind_cls = cls.find_ta_indicator(cls_name)
config = cls.parse_ta_config(ind_cls)
def apply_func(
input_tuple: tp.Tuple[tp.AnyArray, ...],
in_output_tuple: tp.Tuple[tp.SeriesFrame, ...],
param_tuple: tp.Tuple[tp.ParamValue, ...],
**_kwargs,
) -> tp.MaybeTuple[tp.Array2d]:
is_series = isinstance(input_tuple[0], pd.Series)
n_input_cols = 1 if is_series else len(input_tuple[0].columns)
outputs = []
for col in range(n_input_cols):
ind = ind_cls(
**{
name: input_tuple[i] if is_series else input_tuple[i].iloc[:, col]
for i, name in enumerate(config["input_names"])
},
**{name: param_tuple[i] for i, name in enumerate(config["param_names"])},
**_kwargs,
)
output = []
for output_name in config["output_names"]:
output.append(getattr(ind, output_name)())
if len(output) == 1:
output = output[0]
else:
output = tuple(output)
outputs.append(output)
if isinstance(outputs[0], tuple): # multiple outputs
outputs = list(zip(*outputs))
return tuple(map(column_stack_arrays, outputs))
return column_stack_arrays(outputs)
kwargs = merge_dicts({k: Default(v) for k, v in config.pop("defaults").items()}, kwargs)
Indicator = cls(**merge_dicts(dict(module_name=__name__ + ".ta"), config, factory_kwargs)).with_apply_func(
apply_func,
pass_packed=True,
keep_pd=True,
to_2d=False,
**kwargs,
)
return Indicator
@classmethod
def parse_technical_config(cls, func: tp.Callable, test_index_len: int = 100) -> tp.Kwargs:
"""Parse the config of a `technical` indicator."""
df = pd.DataFrame(
np.random.randint(1, 10, size=(test_index_len, 5)),
index=pd.date_range("2020", periods=test_index_len),
columns=["open", "high", "low", "close", "volume"],
)
func_arg_names = get_func_arg_names(func)
func_kwargs = get_func_kwargs(func)
args = ()
input_names = []
param_names = []
output_names = []
defaults = {}
for arg_name in func_arg_names:
if arg_name == "field":
continue
if arg_name in ("dataframe", "df", "bars"):
args += (df,)
if "field" in func_kwargs:
input_names.append(func_kwargs["field"])
else:
input_names.extend(["open", "high", "low", "close", "volume"])
elif arg_name in ("series", "sr"):
args += (df["close"],)
input_names.append("close")
elif arg_name in ("open", "high", "low", "close", "volume"):
args += (df["close"],)
input_names.append(arg_name)
else:
if arg_name not in func_kwargs:
args += (5,)
else:
defaults[arg_name] = func_kwargs[arg_name]
param_names.append(arg_name)
if len(input_names) == 0:
raise ValueError("Couldn't parse the output: unknown input arguments")
def _validate_series(sr, name: tp.Optional[str] = None):
if not isinstance(sr, pd.Series):
raise TypeError("Couldn't parse the output: wrong output type")
if len(sr.index) != len(df.index):
raise ValueError("Couldn't parse the output: mismatching index")
if np.issubdtype(sr.dtype, object):
raise ValueError("Couldn't parse the output: wrong output data type")
if name is None and sr.name is None:
raise ValueError("Couldn't parse the output: missing output name")
out = suppress_stdout(func)(*args)
if isinstance(out, list):
out = np.asarray(out)
if isinstance(out, np.ndarray):
out = pd.Series(out)
if isinstance(out, dict):
out = pd.DataFrame(out)
if isinstance(out, tuple):
out = pd.concat(out, axis=1)
if isinstance(out, (pd.Series, pd.DataFrame)):
if isinstance(out, pd.DataFrame):
for c in out.columns:
_validate_series(out[c], name=c)
output_names.append(c)
else:
if out.name is not None:
out_name = out.name
else:
out_name = func.__name__.lower()
_validate_series(out, name=out_name)
output_names.append(out_name)
else:
raise TypeError("Couldn't parse the output: wrong output type")
new_output_names = []
for name in output_names:
name = name.replace(" ", "").lower()
if len(output_names) == 1 and name == "close":
new_output_names.append(func.__name__.lower())
continue
if name in ("open", "high", "low", "close", "volume", "data"):
continue
new_output_names.append(name)
return dict(
class_name=func.__name__.upper(),
class_docstring=func.__doc__,
input_names=input_names,
param_names=param_names,
output_names=new_output_names,
defaults=defaults,
)
@classmethod
def list_technical_indicators(cls, silence_warnings: bool = True, **kwargs) -> tp.List[str]:
"""List all parseable indicators in `technical`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("technical")
import technical
match_func = lambda k, v: isinstance(v, FunctionType)
funcs = search_package(technical, match_func, blacklist=["technical.util"])
indicators = set()
for func_name, func in funcs.items():
try:
cls.parse_technical_config(func, **kwargs)
indicators.add(func_name.upper())
except Exception as e:
if not silence_warnings:
warn(f"Function {func_name}: " + str(e))
return sorted(indicators)
@classmethod
def find_technical_indicator(cls, func_name: str) -> IndicatorMixinT:
"""Get `technical` indicator function by its name."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("technical")
import technical
match_func = lambda k, v: isinstance(v, FunctionType)
funcs = search_package(technical, match_func, blacklist=["technical.util"])
for k, v in funcs.items():
if func_name.upper() == k.upper():
return v
raise AttributeError(f"Indicator '{func_name}' not found")
@classmethod
def from_technical(
cls,
func_name: str,
parse_kwargs: tp.KwargsLike = None,
factory_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Type[IndicatorBase]:
"""Build an indicator class around a `technical` function.
Requires [technical](https://github.com/freqtrade/technical) installed.
Args:
func_name (str): Function name.
parse_kwargs (dict): Keyword arguments passed to `IndicatorFactory.parse_technical_config`.
factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`.
**kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`.
Returns:
Indicator
Usage:
```pycon
>>> ROLLING_MEAN = vbt.IF.from_technical("ROLLING_MEAN")
>>> rolling_mean = ROLLING_MEAN.run(price, window=[3, 4])
>>> rolling_mean.rolling_mean
rolling_mean_window 3 4
a b a b
2020-01-01 NaN NaN NaN NaN
2020-01-02 NaN NaN NaN NaN
2020-01-03 2.0 4.0 NaN NaN
2020-01-04 3.0 3.0 2.5 3.5
2020-01-05 4.0 2.0 3.5 2.5
```
* To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`:
```pycon
>>> vbt.phelp(ROLLING_MEAN.run)
ROLLING_MEAN.run(
close,
window=Default(value=200),
min_periods=Default(value=None),
short_name='rolling_mean',
hide_params=None,
hide_default=True,
**kwargs
):
Run `ROLLING_MEAN` indicator.
* Inputs: `close`
* Parameters: `window`, `min_periods`
* Outputs: `rolling_mean`
Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all.
Set `hide_default` to False to show the column levels of the parameters with a default value.
Other keyword arguments are passed to `ROLLING_MEAN.run_pipeline`.
```
"""
func = cls.find_technical_indicator(func_name)
func_arg_names = get_func_arg_names(func)
if parse_kwargs is None:
parse_kwargs = {}
config = cls.parse_technical_config(func, **parse_kwargs)
def apply_func(
input_tuple: tp.Tuple[tp.Series, ...],
in_output_tuple: tp.Tuple[tp.Series, ...],
param_tuple: tp.Tuple[tp.ParamValue, ...],
*_args,
**_kwargs,
) -> tp.MaybeTuple[tp.Array1d]:
input_series = {name: input_tuple[i] for i, name in enumerate(config["input_names"])}
_kwargs = {**{name: param_tuple[i] for i, name in enumerate(config["param_names"])}, **_kwargs}
__args = ()
for arg_name in func_arg_names:
if arg_name in ("dataframe", "df", "bars"):
__args += (pd.DataFrame(input_series),)
elif arg_name in ("series", "sr"):
__args += (input_series["close"],)
elif arg_name in ("open", "high", "low", "close", "volume"):
__args += (input_series["close"],)
else:
break
out = suppress_stdout(func)(*__args, *_args, **_kwargs)
if isinstance(out, list):
out = np.asarray(out)
if isinstance(out, np.ndarray):
out = pd.Series(out)
if isinstance(out, dict):
out = pd.DataFrame(out)
if isinstance(out, tuple):
out = pd.concat(out, axis=1)
if isinstance(out, pd.DataFrame):
outputs = []
for c in out.columns:
if len(out.columns) == len(config["output_names"]):
outputs.append(out[c].values)
elif c.replace(" ", "").lower() not in ("open", "high", "low", "close", "volume", "data"):
outputs.append(out[c].values)
return tuple(outputs)
return out.values
kwargs = merge_dicts({k: Default(v) for k, v in config.pop("defaults").items()}, kwargs)
Indicator = cls(
**merge_dicts(dict(module_name=__name__ + ".technical"), config, factory_kwargs),
).with_apply_func(apply_func, pass_packed=True, keep_pd=True, takes_1d=True, **kwargs)
return Indicator
@classmethod
def from_custom_techcon(
cls,
consensus_cls: tp.Type[ConsensusT],
factory_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Type[IndicatorBase]:
"""Create an indicator based on a technical consensus class subclassing
`technical.consensus.consensus.Consensus`.
Requires Technical library: https://github.com/freqtrade/technical"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("technical")
from technical.consensus.consensus import Consensus
checks.assert_subclass_of(consensus_cls, Consensus)
def apply_func(
open: tp.Series,
high: tp.Series,
low: tp.Series,
close: tp.Series,
volume: tp.Series,
smooth: tp.Optional[int] = None,
_consensus_cls: tp.Type[ConsensusT] = consensus_cls,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]:
"""Apply function for `technical.consensus.movingaverage.MovingAverageConsensus`."""
dataframe = pd.DataFrame(
{
"open": open,
"high": high,
"low": low,
"close": close,
"volume": volume,
}
)
consensus = _consensus_cls(dataframe)
score = consensus.score(smooth=smooth)
return (
score["buy"].values,
score["sell"].values,
score["buy_agreement"].values,
score["sell_agreement"].values,
score["buy_disagreement"].values,
score["sell_disagreement"].values,
)
if factory_kwargs is None:
factory_kwargs = {}
factory_kwargs = merge_dicts(
dict(
class_name="CON",
module_name=__name__ + ".custom_techcon",
short_name=None,
input_names=["open", "high", "low", "close", "volume"],
param_names=["smooth"],
output_names=[
"buy",
"sell",
"buy_agreement",
"sell_agreement",
"buy_disagreement",
"sell_disagreement",
],
),
factory_kwargs,
)
Indicator = cls(**factory_kwargs).with_apply_func(
apply_func,
takes_1d=True,
keep_pd=True,
smooth=None,
**kwargs,
)
def plot(
self,
column: tp.Optional[tp.Label] = None,
buy_trace_kwargs: tp.KwargsLike = None,
sell_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `MA.ma` against `MA.close`.
Args:
column (str): Name of the column to plot.
buy_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `buy`.
sell_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `sell`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if buy_trace_kwargs is None:
buy_trace_kwargs = {}
if sell_trace_kwargs is None:
sell_trace_kwargs = {}
buy_trace_kwargs = merge_dicts(
dict(name="Buy", line=dict(color=plotting_cfg["color_schema"]["green"])),
buy_trace_kwargs,
)
sell_trace_kwargs = merge_dicts(
dict(name="Sell", line=dict(color=plotting_cfg["color_schema"]["red"])),
sell_trace_kwargs,
)
fig = self_col.buy.vbt.lineplot(
trace_kwargs=buy_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.sell.vbt.lineplot(
trace_kwargs=sell_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
Indicator.plot = plot
return Indicator
@classmethod
def from_techcon(cls, cls_name: str, **kwargs) -> tp.Type[IndicatorBase]:
"""Create an indicator from a preset technical consensus.
Supported are case-insensitive values `MACON` (or `MovingAverageConsensus`),
`OSCCON` (or `OscillatorConsensus`), and `SUMCON` (or `SummaryConsensus`)."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("technical")
if cls_name.lower() in ("MACON".lower(), "MovingAverageConsensus".lower()):
from technical.consensus.movingaverage import MovingAverageConsensus
return cls.from_custom_techcon(
MovingAverageConsensus,
factory_kwargs=dict(module_name=__name__ + ".techcon", class_name="MACON"),
**kwargs,
)
if cls_name.lower() in ("OSCCON".lower(), "OscillatorConsensus".lower()):
from technical.consensus.oscillator import OscillatorConsensus
return cls.from_custom_techcon(
OscillatorConsensus,
factory_kwargs=dict(module_name=__name__ + ".techcon", class_name="OSCCON"),
**kwargs,
)
if cls_name.lower() in ("SUMCON".lower(), "SummaryConsensus".lower()):
from technical.consensus.summary import SummaryConsensus
return cls.from_custom_techcon(
SummaryConsensus,
factory_kwargs=dict(module_name=__name__ + ".techcon", class_name="SUMCON"),
**kwargs,
)
raise ValueError(f"Unknown technical consensus class '{cls_name}'")
@classmethod
def list_techcon_indicators(cls) -> tp.List[str]:
"""List all consensus indicators in `technical`."""
return sorted({"MACON", "OSCCON", "SUMCON"})
@classmethod
def find_smc_indicator(cls, func_name: str, raise_error: bool = True) -> tp.Optional[tp.Callable]:
"""Get `smartmoneyconcepts` indicator class by its name."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("smartmoneyconcepts")
from smartmoneyconcepts import smc
for k in dir(smc):
if not k.startswith("_"):
if camel_to_snake_case(func_name) == camel_to_snake_case(k):
return getattr(smc, k)
if raise_error:
raise AttributeError(f"Indicator '{func_name}' not found")
return None
@classmethod
def parse_smc_config(cls, func: tp.Callable, collapse: bool = True, snake_case: bool = True) -> tp.Kwargs:
"""Parse the config of a `smartmoneyconcepts` indicator."""
func_arg_names = get_func_arg_names(func)
input_names = []
param_names = []
defaults = {}
dep_input_names = {}
sig = inspect.signature(func)
for k in func_arg_names:
if k == "ohlc":
input_names.extend(["open", "high", "low", "close", "volume"])
else:
found_smc_indicator = cls.find_smc_indicator(k, raise_error=False)
if found_smc_indicator is not None:
dep_input_names[k] = []
k_func_config = cls.parse_smc_config(found_smc_indicator)
if collapse:
for input_name in k_func_config["input_names"]:
if input_name not in input_names:
input_names.append(input_name)
for param_name in k_func_config["param_names"]:
if param_name not in param_names:
param_names.append(param_name)
for k2, v2 in k_func_config["defaults"].items():
defaults[k2] = v2
else:
for output_name in k_func_config["output_names"]:
if output_name not in input_names:
input_names.append(output_name)
dep_input_names[k].append(output_name)
else:
v = sig.parameters[k]
if v.kind not in (v.VAR_POSITIONAL, v.VAR_KEYWORD):
if v.default == inspect.Parameter.empty and v.annotation == pd.DataFrame:
if k not in input_names:
input_names.append(k)
else:
if k not in param_names:
param_names.append(k)
if v.default != inspect.Parameter.empty:
defaults[k] = v.default
func_doc = inspect.getsource(func)
output_names = re.findall(r'name="([^"]+)"', func_doc)
output_names = [k.replace("%", "") for k in output_names]
if snake_case:
input_names = list(map(camel_to_snake_case, input_names))
param_names = list(map(camel_to_snake_case, param_names))
output_names = list(map(camel_to_snake_case, output_names))
return dict(
class_name=func.__name__.upper(),
class_docstring=func.__doc__,
input_names=input_names,
param_names=param_names,
output_names=output_names,
defaults=defaults,
dep_input_names=dep_input_names,
)
@classmethod
def list_smc_indicators(cls, silence_warnings: bool = True, **kwargs) -> tp.List[str]:
"""List all parseable indicators in `smartmoneyconcepts`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("smartmoneyconcepts")
from smartmoneyconcepts import smc
indicators = set()
for func_name in dir(smc):
if not func_name.startswith("_"):
try:
cls.parse_smc_config(getattr(smc, func_name), **kwargs)
indicators.add(func_name.upper())
except Exception as e:
if not silence_warnings:
warn(f"Function {func_name}: " + str(e))
return sorted(indicators)
@classmethod
def from_smc(
cls,
func_name: str,
collapse: bool = True,
parse_kwargs: tp.KwargsLike = None,
factory_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Type[IndicatorBase]:
"""Build an indicator class around a `smartmoneyconcepts` function.
Requires [smart-money-concepts](https://github.com/joshyattridge/smart-money-concepts) installed.
Args:
func_name (str): Function name.
collapse (bool): Whether to collapse all nested indicators into a single one.
parse_kwargs (dict): Keyword arguments passed to `IndicatorFactory.parse_smc_config`.
factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`.
**kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`.
"""
func = cls.find_smc_indicator(func_name)
func_arg_names = get_func_arg_names(func)
if parse_kwargs is None:
parse_kwargs = {}
collapsed_config = cls.parse_smc_config(func, collapse=True, snake_case=True, **parse_kwargs)
_ = collapsed_config.pop("dep_input_names")
expanded_config = cls.parse_smc_config(func, collapse=False, snake_case=True, **parse_kwargs)
dep_input_names = expanded_config.pop("dep_input_names")
if collapse:
config = collapsed_config
else:
config = expanded_config
def apply_func(
input_tuple: tp.Tuple[tp.Series, ...],
in_output_tuple: tp.Tuple[tp.Series, ...],
param_tuple: tp.Tuple[tp.ParamValue, ...],
**_kwargs,
) -> tp.MaybeTuple[tp.Array1d]:
named_args = dict(_kwargs)
for i, input_name in enumerate(config["input_names"]):
named_args[input_name] = input_tuple[i]
for i, param_name in enumerate(config["param_names"]):
named_args[param_name] = param_tuple[i]
named_args["ohlc"] = pd.concat(
[
named_args["open"].rename("open"),
named_args["high"].rename("high"),
named_args["low"].rename("low"),
named_args["close"].rename("close"),
named_args["volume"].rename("volume"),
],
axis=1,
)
if collapse and len(dep_input_names) > 0:
for dep_func_name in dep_input_names:
dep_func = cls.find_smc_indicator(dep_func_name)
dep_func_arg_names = get_func_arg_names(dep_func)
dep_output = dep_func(*[named_args[camel_to_snake_case(k)] for k in dep_func_arg_names])
dep_output.index = input_tuple[0].index
named_args[dep_func_name] = dep_output
elif not collapse:
for dep_func_name in dep_input_names:
dep_func = cls.find_smc_indicator(dep_func_name)
dep_config = cls.parse_smc_config(dep_func, collapse=False, snake_case=False, **parse_kwargs)
named_args[dep_func_name] = pd.concat(
[named_args[input_name] for input_name in dep_input_names[dep_func_name]],
axis=1,
keys=dep_config["output_names"],
)
output = func(*[named_args[camel_to_snake_case(k)] for k in func_arg_names])
return tuple([output[c] for c in output.columns])
kwargs = merge_dicts({k: Default(v) for k, v in config.pop("defaults").items()}, kwargs)
Indicator = cls(
**merge_dicts(dict(module_name=__name__ + ".smc"), config, factory_kwargs),
).with_apply_func(apply_func, pass_packed=True, keep_pd=True, takes_1d=True, **kwargs)
return Indicator
# ############# Expressions ############# #
@hybrid_method
def from_expr(
cls_or_self,
expr: str,
parse_annotations: bool = True,
factory_kwargs: tp.KwargsLike = None,
magnet_inputs: tp.Iterable[str] = None,
magnet_in_outputs: tp.Iterable[str] = None,
magnet_params: tp.Iterable[str] = None,
func_mapping: tp.KwargsLike = None,
res_func_mapping: tp.KwargsLike = None,
use_pd_eval: tp.Optional[bool] = None,
pd_eval_kwargs: tp.KwargsLike = None,
return_clean_expr: bool = False,
**kwargs,
) -> tp.Union[str, tp.Type[IndicatorBase]]:
"""Build an indicator class from an indicator expression.
Args:
expr (str): Expression.
Expression must be a string with a valid Python code.
Supported are both single-line and multi-line expressions.
parse_annotations (bool): Whether to parse annotations starting with `@`.
factory_kwargs (dict): Keyword arguments passed to `IndicatorFactory`.
Only applied when calling the class method.
magnet_inputs (iterable of str): Names recognized as input names.
Defaults to `open`, `high`, `low`, `close`, and `volume`.
magnet_in_outputs (iterable of str): Names recognized as in-output names.
Defaults to an empty list.
magnet_params (iterable of str): Names recognized as params names.
Defaults to an empty list.
func_mapping (mapping): Mapping merged over `vectorbtpro.indicators.expr.expr_func_config`.
Each key must be a function name and each value must be a dict with
`func` and optionally `magnet_inputs`, `magnet_in_outputs`, and `magnet_params`.
res_func_mapping (mapping): Mapping merged over `vectorbtpro.indicators.expr.expr_res_func_config`.
Each key must be a function name and each value must be a dict with
`func` and optionally `magnet_inputs`, `magnet_in_outputs`, and `magnet_params`.
use_pd_eval (bool): Whether to use `pd.eval`.
Defaults to False.
Otherwise, uses `vectorbtpro.utils.eval_.evaluate`.
!!! hint
By default, operates on NumPy objects using NumExpr.
If you want to operate on Pandas objects, set `keep_pd` to True.
pd_eval_kwargs (dict): Keyword arguments passed to `pd.eval`.
return_clean_expr (bool): Whether to return a cleaned expression.
**kwargs: Keyword arguments passed to `IndicatorFactory.with_apply_func`.
Returns:
Indicator
Searches each variable name parsed from `expr` in
* `vectorbtpro.indicators.expr.expr_res_func_config` (calls right away)
* `vectorbtpro.indicators.expr.expr_func_config`
* inputs, in-outputs, and params
* keyword arguments
* attributes of `np`
* attributes of `vectorbtpro.generic.nb` (with and without `_nb` suffix)
* attributes of `vbt`
`vectorbtpro.indicators.expr.expr_func_config` and `vectorbtpro.indicators.expr.expr_res_func_config`
can be overridden with `func_mapping` and `res_func_mapping` respectively.
!!! note
Each variable name is case-sensitive.
When using the class method, all names are parsed from the expression itself.
If any of `open`, `high`, `low`, `close`, and `volume` appear in the expression or
in `magnet_inputs` in either `vectorbtpro.indicators.expr.expr_func_config` or
`vectorbtpro.indicators.expr.expr_res_func_config`, they are automatically added to `input_names`.
Set `magnet_inputs` to an empty list to disable this logic.
If the expression begins with a valid variable name and a colon (`:`), the variable name
will be used as the name of the generated class. Provide another variable in the square brackets
after this one and before the colon to specify the indicator's short name.
If `parse_annotations` is True, variables that start with `@` have a special meaning:
* `@in_*`: input variable
* `@inout_*`: in-output variable
* `@p_*`: parameter variable
* `@out_*`: output variable
* `@out_*:`: indicates that the next part until a comma is an output
* `@talib_*`: name of a TA-Lib function. Uses the indicator's `apply_func`.
* `@res_*`: name of the indicator to resolve automatically. Input names can overlap with
those of other indicators, while all other information gets a prefix with the indicator's short name.
* `@settings(*)`: settings to be merged with the current `IndicatorFactory.from_expr` settings.
Everything within the parentheses gets evaluated using the Pythons `eval` command
and must be a dictionary. Overrides defaults but gets overridden by any argument
passed to this method. Arguments `expr` and `parse_annotations` cannot be overridden.
!!! note
The parsed names come in the same order they appear in the expression, not in the execution order,
apart from the magnet input names, which are added in the same order they appear in the list.
The number of outputs is derived based on the number of commas outside of any bracket pair.
If there is only one output, the output name is `out`. If more - `out1`, `out2`, etc.
Any information can be overridden using `factory_kwargs`.
Usage:
```pycon
>>> WMA = vbt.IF(
... class_name='WMA',
... input_names=['close'],
... param_names=['window'],
... output_names=['wma']
... ).from_expr("wm_mean_nb(close, window)")
>>> wma = WMA.run(price, window=[2, 3])
>>> wma.wma
wma_window 2 3
a b a b
2020-01-01 NaN NaN NaN NaN
2020-01-02 1.666667 4.333333 NaN NaN
2020-01-03 2.666667 3.333333 2.333333 3.666667
2020-01-04 3.666667 2.333333 3.333333 2.666667
2020-01-05 4.666667 1.333333 4.333333 1.666667
```
* The same can be achieved by calling the class method and providing prefixes
to the variable names to indicate their type:
```pycon
>>> expr = "WMA: @out_wma:wm_mean_nb((@in_high + @in_low) / 2, @p_window)"
>>> WMA = vbt.IF.from_expr(expr)
>>> wma = WMA.run(price + 1, price, window=[2, 3])
>>> wma.wma
wma_window 2 3
a b a b
2020-01-01 NaN NaN NaN NaN
2020-01-02 2.166667 4.833333 NaN NaN
2020-01-03 3.166667 3.833333 2.833333 4.166667
2020-01-04 4.166667 2.833333 3.833333 3.166667
2020-01-05 5.166667 1.833333 4.833333 2.166667
```
* Magnet names are recognized automatically:
```pycon
>>> expr = "WMA: @out_wma:wm_mean_nb((high + low) / 2, @p_window)"
```
* Most settings of this method can be overriden from within the expression:
```pycon
>>> expr = \"\"\"
... @settings({factory_kwargs={'class_name': 'WMA', 'param_names': ['window']}})
... @out_wma:wm_mean_nb((high + low) / 2, window)
... \"\"\"
```
"""
def _clean_expr(expr: str) -> str:
# Clean the expression from redundant brackets and commas
expr = inspect.cleandoc(expr).strip()
if expr.endswith(","):
expr = expr[:-1]
if expr.startswith("(") and expr.endswith(")"):
n_open_brackets = 0
remove_brackets = True
for i, s in enumerate(expr):
if s == "(":
n_open_brackets += 1
elif s == ")":
n_open_brackets -= 1
if n_open_brackets == 0 and i < len(expr) - 1:
remove_brackets = False
break
if remove_brackets:
expr = expr[1:-1]
if expr.endswith(","):
expr = expr[:-1] # again
return expr
if isinstance(cls_or_self, type):
settings = dict(
factory_kwargs=dict(
class_name=None,
input_names=[],
in_output_names=[],
param_names=[],
output_names=[],
)
)
# Parse the class name
match = re.match(r"^\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:\[([a-zA-Z_][a-zA-Z0-9_]*)\])?\s*:\s*", expr)
if match:
settings["factory_kwargs"]["class_name"] = match.group(1)
if match.group(2):
settings["factory_kwargs"]["short_name"] = match.group(2)
expr = expr[len(match.group(0)) :]
# Parse the settings dictionary
if "@settings" in expr:
remove_chars = set()
for m in re.finditer("@settings", expr):
n_open_brackets = 0
from_i = None
to_i = None
for i in range(m.start(), m.end()):
remove_chars.add(i)
for i in range(m.end(), len(expr)):
remove_chars.add(i)
s = expr[i]
if s in "(":
if n_open_brackets == 0:
from_i = i + 1
n_open_brackets += 1
elif s in ")":
n_open_brackets -= 1
if n_open_brackets == 0:
to_i = i
break
if n_open_brackets != 0:
raise ValueError("Couldn't parse the settings: mismatching brackets")
settings = merge_dicts(settings, eval(_clean_expr(expr[from_i:to_i])))
expr = "".join([expr[i] for i in range(len(expr)) if i not in remove_chars])
expr = _clean_expr(expr)
# Merge info
parsed_factory_kwargs = settings.pop("factory_kwargs")
magnet_inputs = settings.pop("magnet_inputs", magnet_inputs)
magnet_in_outputs = settings.pop("magnet_in_outputs", magnet_in_outputs)
magnet_params = settings.pop("magnet_params", magnet_params)
func_mapping = merge_dicts(expr_func_config, settings.pop("func_mapping", None), func_mapping)
res_func_mapping = merge_dicts(
expr_res_func_config,
settings.pop("res_func_mapping", None),
res_func_mapping,
)
use_pd_eval = settings.pop("use_pd_eval", use_pd_eval)
pd_eval_kwargs = merge_dicts(settings.pop("pd_eval_kwargs", None), pd_eval_kwargs)
# Resolve defaults
if use_pd_eval is None:
use_pd_eval = False
if magnet_inputs is None:
magnet_inputs = ["open", "high", "low", "close", "volume"]
if magnet_in_outputs is None:
magnet_in_outputs = []
if magnet_params is None:
magnet_params = []
found_magnet_inputs = []
found_magnet_in_outputs = []
found_magnet_params = []
found_defaults = {}
remove_defaults = set()
# Parse annotated variables
if parse_annotations:
# Parse input, in-output, parameter, and TA-Lib function names
for var_name in re.findall(r"@[a-z]+_[a-zA-Z_][a-zA-Z0-9_]*", expr):
var_name = var_name.replace("@", "")
if var_name.startswith("in_"):
var_name = var_name[3:]
if var_name in magnet_inputs:
if var_name not in found_magnet_inputs:
found_magnet_inputs.append(var_name)
else:
if var_name not in parsed_factory_kwargs["input_names"]:
parsed_factory_kwargs["input_names"].append(var_name)
elif var_name.startswith("inout_"):
var_name = var_name[6:]
if var_name in magnet_in_outputs:
if var_name not in found_magnet_in_outputs:
found_magnet_in_outputs.append(var_name)
else:
if var_name not in parsed_factory_kwargs["in_output_names"]:
parsed_factory_kwargs["in_output_names"].append(var_name)
elif var_name.startswith("p_"):
var_name = var_name[2:]
if var_name in magnet_params:
if var_name not in found_magnet_params:
found_magnet_params.append(var_name)
else:
if var_name not in parsed_factory_kwargs["param_names"]:
parsed_factory_kwargs["param_names"].append(var_name)
elif var_name.startswith("res_"):
ind_name = var_name[4:]
if ind_name.startswith("talib_"):
ind_name = ind_name[6:]
I = cls_or_self.from_talib(ind_name)
else:
I = kwargs[ind_name]
if not issubclass(I, IndicatorBase):
raise TypeError(f"Indicator class '{ind_name}' must subclass IndicatorBase")
def _ind_func(context: tp.Kwargs, _I: IndicatorBase = I) -> tp.Any:
_args = ()
_kwargs = {}
signature = inspect.signature(_I.run)
for p in signature.parameters.values():
if p.name in _I.input_names:
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD):
_args += (context[p.name],)
else:
_kwargs[p.name] = context[p.name]
else:
ind_p_name = _I.short_name + "_" + p.name
if ind_p_name in context:
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD):
_args += (context[ind_p_name],)
elif p.kind == p.VAR_POSITIONAL:
_args += context[ind_p_name]
elif p.kind == p.VAR_KEYWORD:
for k, v in context[ind_p_name].items():
_kwargs[k] = v
else:
_kwargs[p.name] = context[ind_p_name]
return_raw = _kwargs.pop("return_raw", True)
ind = _I.run(*_args, return_raw=return_raw, **_kwargs)
if return_raw:
raw_outputs = ind[0]
if len(raw_outputs) == 1:
return raw_outputs[0]
return raw_outputs
return ind
res_func_mapping["__" + var_name] = dict(
func=_ind_func,
magnet_inputs=I.input_names,
magnet_in_outputs=[I.short_name + "_" + name for name in I.in_output_names],
magnet_params=[I.short_name + "_" + name for name in I.param_names],
)
run_kwargs = get_func_kwargs(I.run)
def _add_defaults(names, prefix=None):
for k in names:
if prefix is None:
k_prefixed = k
else:
k_prefixed = prefix + "_" + k
if k in run_kwargs:
if k_prefixed in found_defaults:
if not checks.is_deep_equal(found_defaults[k_prefixed], run_kwargs[k]):
remove_defaults.add(k_prefixed)
else:
found_defaults[k_prefixed] = run_kwargs[k]
_add_defaults(I.input_names)
_add_defaults(I.in_output_names, I.short_name)
_add_defaults(I.param_names, I.short_name)
expr = expr.replace("@in_", "__in_")
expr = expr.replace("@inout_", "__inout_")
expr = expr.replace("@p_", "__p_")
expr = expr.replace("@talib_", "__talib_")
expr = expr.replace("@res_", "__res_")
# Parse output names
to_replace = []
for var_name in re.findall(r"@out_[a-zA-Z_][a-zA-Z0-9_]*\s*:\s*", expr):
to_replace.append(var_name)
var_name = var_name.split(":")[0].strip()[5:]
if var_name not in parsed_factory_kwargs["output_names"]:
parsed_factory_kwargs["output_names"].append(var_name)
for s in to_replace:
expr = expr.replace(s, "")
for var_name in re.findall(r"@out_[a-zA-Z_][a-zA-Z0-9_]*", expr):
var_name = var_name.replace("@", "")
if var_name.startswith("out_"):
var_name = var_name[4:]
if var_name not in parsed_factory_kwargs["output_names"]:
parsed_factory_kwargs["output_names"].append(var_name)
expr = expr.replace("@out_", "__out_")
if len(parsed_factory_kwargs["output_names"]) == 0:
lines = expr.split("\n")
if len(lines) > 1:
last_line = _clean_expr(lines[-1])
valid_output_names = []
found_not_valid = False
for i, out in enumerate(last_line.split(",")):
out = out.strip()
if not out.startswith("__") and out.isidentifier():
valid_output_names.append(out)
else:
found_not_valid = True
break
if not found_not_valid:
parsed_factory_kwargs["output_names"] = valid_output_names
# Parse magnet names
var_names = get_expr_var_names(expr)
def _find_magnets(magnet_type, magnet_names, magnet_lst, found_magnet_lst):
for var_name in var_names:
if var_name in magnet_lst:
if var_name not in found_magnet_lst:
found_magnet_lst.append(var_name)
if var_name in func_mapping:
for magnet_name in func_mapping[var_name].get(magnet_type, []):
if magnet_name not in found_magnet_lst:
found_magnet_lst.append(magnet_name)
if var_name in res_func_mapping:
for magnet_name in res_func_mapping[var_name].get(magnet_type, []):
if magnet_name not in found_magnet_lst:
found_magnet_lst.append(magnet_name)
for magnet_name in magnet_lst:
if magnet_name in found_magnet_lst and magnet_name not in magnet_names:
magnet_names.append(magnet_name)
for magnet_name in found_magnet_lst:
if magnet_name not in magnet_names and magnet_name not in magnet_names:
magnet_names.append(magnet_name)
_find_magnets("magnet_inputs", parsed_factory_kwargs["input_names"], magnet_inputs, found_magnet_inputs)
_find_magnets(
"magnet_in_outputs",
parsed_factory_kwargs["in_output_names"],
magnet_in_outputs,
found_magnet_in_outputs,
)
_find_magnets("magnet_params", parsed_factory_kwargs["param_names"], magnet_params, found_magnet_params)
# Prepare defaults
for k in remove_defaults:
found_defaults.pop(k, None)
def _sort_names(names_name):
new_names = []
for k in parsed_factory_kwargs[names_name]:
if k not in found_defaults:
new_names.append(k)
for k in parsed_factory_kwargs[names_name]:
if k in found_defaults:
new_names.append(k)
parsed_factory_kwargs[names_name] = new_names
_sort_names("input_names")
_sort_names("in_output_names")
_sort_names("param_names")
# Parse the number of outputs
if len(parsed_factory_kwargs["output_names"]) == 0:
lines = expr.split("\n")
last_line = _clean_expr(lines[-1])
n_open_brackets = 0
n_outputs = 1
for i, s in enumerate(last_line):
if s == "," and n_open_brackets == 0:
n_outputs += 1
elif s in "([{":
n_open_brackets += 1
elif s in ")]}":
n_open_brackets -= 1
if n_open_brackets != 0:
raise ValueError("Couldn't parse the number of outputs: mismatching brackets")
elif len(parsed_factory_kwargs["output_names"]) == 0:
if n_outputs == 1:
parsed_factory_kwargs["output_names"] = ["out"]
else:
parsed_factory_kwargs["output_names"] = ["out%d" % (i + 1) for i in range(n_outputs)]
factory = cls_or_self(**merge_dicts(parsed_factory_kwargs, factory_kwargs))
kwargs = merge_dicts(settings, found_defaults, kwargs)
else:
func_mapping = merge_dicts(expr_func_config, func_mapping)
res_func_mapping = merge_dicts(expr_res_func_config, res_func_mapping)
var_names = get_expr_var_names(expr)
factory = cls_or_self
if return_clean_expr:
# For debugging purposes
return expr
input_names = factory.input_names
in_output_names = factory.in_output_names
param_names = factory.param_names
def apply_func(
input_tuple: tp.Tuple[tp.AnyArray, ...],
in_output_tuple: tp.Tuple[tp.SeriesFrame, ...],
param_tuple: tp.Tuple[tp.ParamValue, ...],
**_kwargs,
) -> tp.MaybeTuple[tp.Array2d]:
import vectorbtpro as vbt
input_context = dict(np=np, pd=pd, vbt=vbt)
for i, input in enumerate(input_tuple):
input_context[input_names[i]] = input
for i, in_output in enumerate(in_output_tuple):
input_context[in_output_names[i]] = in_output
for i, param in enumerate(param_tuple):
input_context[param_names[i]] = param
merged_context = merge_dicts(input_context, _kwargs)
context = {}
# Resolve each variable in the expression
for var_name in var_names:
if var_name in context:
continue
if var_name.startswith("__in_"):
var = merged_context[var_name[5:]]
elif var_name.startswith("__inout_"):
var = merged_context[var_name[8:]]
elif var_name.startswith("__p_"):
var = merged_context[var_name[4:]]
elif var_name.startswith("__talib_"):
from vectorbtpro.indicators.talib_ import talib_func
talib_func_name = var_name[8:].upper()
_talib_func = talib_func(talib_func_name)
var = functools.partial(_talib_func, wrapper=_kwargs["wrapper"])
elif var_name in res_func_mapping:
var = res_func_mapping[var_name]["func"]
elif var_name in func_mapping:
var = func_mapping[var_name]["func"]
elif var_name in merged_context:
var = merged_context[var_name]
elif hasattr(np, var_name):
var = getattr(np, var_name)
elif hasattr(generic_nb, var_name):
var = getattr(generic_nb, var_name)
elif hasattr(generic_nb, var_name + "_nb"):
var = getattr(generic_nb, var_name + "_nb")
elif hasattr(vbt, var_name):
var = getattr(vbt, var_name)
else:
continue
try:
if callable(var) and "context" in get_func_arg_names(var):
var = functools.partial(var, context=merged_context)
except:
pass
if var_name in res_func_mapping:
var = var()
context[var_name] = var
# Evaluate the expression using resolved variables as a context
if use_pd_eval:
return pd.eval(expr, local_dict=context, **resolve_dict(pd_eval_kwargs))
return evaluate(expr, context=context)
return factory.with_apply_func(apply_func, pass_packed=True, pass_wrapper=True, **kwargs)
@classmethod
def from_wqa101(cls, alpha_idx: tp.Union[str, int], **kwargs) -> tp.Type[IndicatorBase]:
"""Build an indicator class from one of the WorldQuant's 101 alpha expressions.
See `vectorbtpro.indicators.expr.wqa101_expr_config`.
!!! note
Some expressions that utilize cross-sectional operations require columns to be
a multi-index with a level `sector`, `subindustry`, or `industry`.
Usage:
```pycon
>>> data = vbt.YFData.pull(['BTC-USD', 'ETH-USD'])
>>> WQA1 = vbt.IF.from_wqa101(1)
>>> wqa1 = WQA1.run(data.get('Close'))
>>> wqa1.out
symbol BTC-USD ETH-USD
Date
2014-09-17 00:00:00+00:00 0.25 0.25
2014-09-18 00:00:00+00:00 0.25 0.25
2014-09-19 00:00:00+00:00 0.25 0.25
2014-09-20 00:00:00+00:00 0.25 0.25
2014-09-21 00:00:00+00:00 0.25 0.25
... ... ...
2022-01-21 00:00:00+00:00 0.00 0.50
2022-01-22 00:00:00+00:00 0.00 0.50
2022-01-23 00:00:00+00:00 0.25 0.25
2022-01-24 00:00:00+00:00 0.50 0.00
2022-01-25 00:00:00+00:00 0.50 0.00
[2688 rows x 2 columns]
```
* To get help on running the indicator, use `vectorbtpro.utils.formatting.phelp`:
```pycon
>>> vbt.phelp(WQA1.run)
WQA1.run(
close,
short_name='wqa1',
hide_params=None,
hide_default=True,
**kwargs
):
Run `WQA1` indicator.
* Inputs: `close`
* Outputs: `out`
Pass a list of parameter names as `hide_params` to hide their column levels, or True to hide all.
Set `hide_default` to False to show the column levels of the parameters with a default value.
Other keyword arguments are passed to `WQA1.run_pipeline`.
```
"""
if isinstance(alpha_idx, str):
alpha_idx = int(alpha_idx.upper().replace("WQA", ""))
return cls.from_expr(
wqa101_expr_config[alpha_idx],
factory_kwargs=dict(class_name="WQA%d" % alpha_idx, module_name=__name__ + ".wqa101"),
**kwargs,
)
@classmethod
def list_wqa101_indicators(cls) -> tp.List[str]:
"""List all WorldQuant's 101 alpha indicators."""
return [str(i) for i in range(1, 102)]
IF = IndicatorFactory
"""Shortcut for `IndicatorFactory`."""
__pdoc__["IF"] = False
def indicator(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.get_indicator`."""
return IndicatorFactory.get_indicator(*args, **kwargs)
def talib(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_talib`."""
return IndicatorFactory.from_talib(*args, **kwargs)
def pandas_ta(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_pandas_ta`."""
return IndicatorFactory.from_pandas_ta(*args, **kwargs)
def ta(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_ta`."""
return IndicatorFactory.from_ta(*args, **kwargs)
def wqa101(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_wqa101`."""
return IndicatorFactory.from_wqa101(*args, **kwargs)
def technical(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_technical`."""
return IndicatorFactory.from_technical(*args, **kwargs)
def techcon(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_techcon`."""
return IndicatorFactory.from_techcon(*args, **kwargs)
def smc(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbtpro.indicators.factory.IndicatorFactory.from_smc`."""
return IndicatorFactory.from_smc(*args, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for custom indicators.
Provides an arsenal of Numba-compiled functions that are used by indicator
classes. These only accept NumPy arrays and other Numba-compatible types."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.flex_indexing import flex_select_1d_nb, flex_select_col_nb
from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb
from vectorbtpro.generic import nb as generic_nb, enums as generic_enums
from vectorbtpro.indicators.enums import Pivot, SuperTrendAIS, SuperTrendAOS, HurstMethod
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
__all__ = []
# ############# MA ############# #
@register_jitted(cache=True)
def ma_1d_nb(
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array1d:
"""Moving average.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
return generic_nb.ma_1d_nb(close, window, wtype=wtype, minp=minp, adjust=adjust)
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ma_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array2d:
"""2-dim version of `ma_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
ma = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
ma[:, col] = ma_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
minp=minp,
adjust=adjust,
)
return ma
# ############# MSD ############# #
@register_jitted(cache=True)
def msd_1d_nb(
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
ddof: int = 0,
) -> tp.Array1d:
"""Moving standard deviation.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
return generic_nb.msd_1d_nb(close, window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof)
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
ddof=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def msd_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
ddof: int = 0,
) -> tp.Array2d:
"""2-dim version of `msd_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
msd = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
msd[:, col] = msd_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
minp=minp,
adjust=adjust,
ddof=ddof,
)
return msd
# ############# BBANDS ############# #
@register_jitted(cache=True)
def bbands_1d_nb(
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Simple,
alpha: float = 2.0,
minp: tp.Optional[int] = None,
adjust: bool = False,
ddof: int = 0,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]:
"""Bollinger Bands.
Returns the upper band, the middle band, and the lower band.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
ma = ma_1d_nb(close, window=window, wtype=wtype, minp=minp, adjust=adjust)
msd = msd_1d_nb(close, window=window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof)
upper = ma + alpha * msd
middle = ma
lower = ma - alpha * msd
return upper, middle, lower
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
alpha=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
ddof=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def bbands_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
alpha: tp.FlexArray1dLike = 2.0,
minp: tp.Optional[int] = None,
adjust: bool = False,
ddof: int = 0,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d]:
"""2-dim version of `bbands_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
alpha_ = to_1d_array_nb(np.asarray(alpha))
upper = np.empty(close.shape, dtype=float_)
middle = np.empty(close.shape, dtype=float_)
lower = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
upper[:, col], middle[:, col], lower[:, col] = bbands_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
alpha=flex_select_1d_nb(alpha_, col),
minp=minp,
adjust=adjust,
ddof=ddof,
)
return upper, middle, lower
@register_jitted(cache=True)
def bbands_percent_b_1d_nb(close: tp.Array1d, upper: tp.Array1d, lower: tp.Array1d) -> tp.Array1d:
"""Bollinger Bands %B."""
return (close - lower) / (upper - lower)
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
upper=ch.ArraySlicer(axis=1),
lower=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def bbands_percent_b_nb(close: tp.Array2d, upper: tp.Array2d, lower: tp.Array2d) -> tp.Array2d:
"""2-dim version of `bbands_percent_b_1d_nb`."""
percent_b = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
percent_b[:, col] = bbands_percent_b_1d_nb(close[:, col], upper[:, col], lower[:, col])
return percent_b
@register_jitted(cache=True)
def bbands_bandwidth_1d_nb(upper: tp.Array1d, middle: tp.Array1d, lower: tp.Array1d) -> tp.Array1d:
"""Bollinger Bands Bandwidth."""
return (upper - lower) / middle
@register_chunkable(
size=ch.ArraySizer(arg_query="upper", axis=1),
arg_take_spec=dict(
upper=ch.ArraySlicer(axis=1),
middle=ch.ArraySlicer(axis=1),
lower=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def bbands_bandwidth_nb(upper: tp.Array2d, middle: tp.Array2d, lower: tp.Array2d) -> tp.Array2d:
"""2-dim version of `bbands_bandwidth_1d_nb`."""
bandwidth = np.empty(upper.shape, dtype=float_)
for col in prange(upper.shape[1]):
bandwidth[:, col] = bbands_bandwidth_1d_nb(upper[:, col], middle[:, col], lower[:, col])
return bandwidth
# ############# RSI ############# #
@register_jitted(cache=True)
def avg_gain_1d_nb(
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Wilder,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array1d:
"""Average gain."""
up_change = np.empty(close.shape, dtype=float_)
for i in range(close.shape[0]):
if i == 0:
up_change[i] = np.nan
else:
change = close[i] - close[i - 1]
if change < 0:
up_change[i] = 0.0
else:
up_change[i] = change
avg_gain = ma_1d_nb(up_change, window=window, wtype=wtype, minp=minp, adjust=adjust)
return avg_gain
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def avg_gain_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array2d:
"""2-dim version of `avg_gain_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
avg_gain = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
avg_gain[:, col] = avg_gain_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
minp=minp,
adjust=adjust,
)
return avg_gain
@register_jitted(cache=True)
def avg_loss_1d_nb(
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Wilder,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array1d:
"""Average loss."""
down_change = np.empty(close.shape, dtype=float_)
for i in range(close.shape[0]):
if i == 0:
down_change[i] = np.nan
else:
change = close[i] - close[i - 1]
if change < 0:
down_change[i] = abs(change)
else:
down_change[i] = 0.0
avg_loss = ma_1d_nb(down_change, window=window, wtype=wtype, minp=minp, adjust=adjust)
return avg_loss
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def avg_loss_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array2d:
"""2-dim version of `avg_loss_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
avg_loss = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
avg_loss[:, col] = avg_loss_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
minp=minp,
adjust=adjust,
)
return avg_loss
@register_jitted(cache=True)
def rsi_1d_nb(
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Wilder,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array1d:
"""RSI.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
avg_gain = avg_gain_1d_nb(close, window=window, wtype=wtype, minp=minp, adjust=adjust)
avg_loss = avg_loss_1d_nb(close, window=window, wtype=wtype, minp=minp, adjust=adjust)
return 100 * avg_gain / (avg_gain + avg_loss)
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rsi_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array2d:
"""2-dim version of `rsi_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
rsi = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
rsi[:, col] = rsi_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
minp=minp,
adjust=adjust,
)
return rsi
# ############# STOCH ############# #
@register_jitted(cache=True)
def stoch_k_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
close: tp.Array1d,
window: int = 14,
minp: tp.Optional[int] = None,
) -> tp.Array1d:
"""Stochastic Oscillator %K."""
lowest_low = generic_nb.rolling_min_1d_nb(low, window, minp=minp)
highest_high = generic_nb.rolling_max_1d_nb(high, window, minp=minp)
stoch_k = 100 * (close - lowest_low) / (highest_high - lowest_low)
return stoch_k
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
minp=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def stoch_k_nb(
high: tp.Array2d,
low: tp.Array2d,
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
minp: tp.Optional[int] = None,
) -> tp.Array2d:
"""2-dim version of `stoch_k_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
stoch_k = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
stoch_k[:, col] = stoch_k_1d_nb(
high=high[:, col],
low=low[:, col],
close=close[:, col],
window=flex_select_1d_nb(window_, col),
minp=minp,
)
return stoch_k
@register_jitted(cache=True)
def stoch_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
close: tp.Array1d,
fast_k_window: int = 14,
slow_k_window: int = 3,
slow_d_window: int = 3,
wtype: int = generic_enums.WType.Simple,
slow_k_wtype: tp.Optional[int] = None,
slow_d_wtype: tp.Optional[int] = None,
minp: tp.Optional[int] = None,
fast_k_minp: tp.Optional[int] = None,
slow_k_minp: tp.Optional[int] = None,
slow_d_minp: tp.Optional[int] = None,
adjust: bool = False,
slow_k_adjust: tp.Optional[bool] = None,
slow_d_adjust: tp.Optional[bool] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]:
"""Stochastic Oscillator.
Returns the fast %K, the slow %K, and the slow %D.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
if slow_k_wtype is not None:
slow_k_wtype_ = slow_k_wtype
else:
slow_k_wtype_ = wtype
if slow_d_wtype is not None:
slow_d_wtype_ = slow_d_wtype
else:
slow_d_wtype_ = wtype
if fast_k_minp is not None:
fast_k_minp_ = fast_k_minp
else:
fast_k_minp_ = minp
if slow_k_minp is not None:
slow_k_minp_ = slow_k_minp
else:
slow_k_minp_ = minp
if slow_d_minp is not None:
slow_d_minp_ = slow_d_minp
else:
slow_d_minp_ = minp
if slow_k_adjust is not None:
slow_k_adjust_ = slow_k_adjust
else:
slow_k_adjust_ = adjust
if slow_d_adjust is not None:
slow_d_adjust_ = slow_d_adjust
else:
slow_d_adjust_ = adjust
fast_k = stoch_k_1d_nb(high, low, close, window=fast_k_window, minp=fast_k_minp_)
slow_k = ma_1d_nb(fast_k, window=slow_k_window, wtype=slow_k_wtype_, minp=slow_k_minp_, adjust=slow_k_adjust_)
slow_d = ma_1d_nb(slow_k, window=slow_d_window, wtype=slow_d_wtype_, minp=slow_d_minp_, adjust=slow_d_adjust_)
return fast_k, slow_k, slow_d
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
fast_k_window=base_ch.FlexArraySlicer(),
slow_k_window=base_ch.FlexArraySlicer(),
slow_d_window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
slow_k_wtype=base_ch.FlexArraySlicer(),
slow_d_wtype=base_ch.FlexArraySlicer(),
minp=None,
fast_k_minp=None,
slow_k_minp=None,
slow_d_minp=None,
adjust=None,
slow_k_adjust=None,
slow_d_adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def stoch_nb(
high: tp.Array2d,
low: tp.Array2d,
close: tp.Array2d,
fast_k_window: tp.FlexArray1dLike = 14,
slow_k_window: tp.FlexArray1dLike = 3,
slow_d_window: tp.FlexArray1dLike = 3,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
slow_k_wtype: tp.Optional[tp.FlexArray1dLike] = None,
slow_d_wtype: tp.Optional[tp.FlexArray1dLike] = None,
minp: tp.Optional[int] = None,
fast_k_minp: tp.Optional[int] = None,
slow_k_minp: tp.Optional[int] = None,
slow_d_minp: tp.Optional[int] = None,
adjust: bool = False,
slow_k_adjust: tp.Optional[bool] = None,
slow_d_adjust: tp.Optional[bool] = None,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d]:
"""2-dim version of `stoch_1d_nb`."""
fast_k_window_ = to_1d_array_nb(np.asarray(fast_k_window))
slow_k_window_ = to_1d_array_nb(np.asarray(slow_k_window))
slow_d_window_ = to_1d_array_nb(np.asarray(slow_d_window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
if slow_k_wtype is not None:
slow_k_wtype_ = to_1d_array_nb(np.asarray(slow_k_wtype))
else:
slow_k_wtype_ = wtype_
if slow_d_wtype is not None:
slow_d_wtype_ = to_1d_array_nb(np.asarray(slow_d_wtype))
else:
slow_d_wtype_ = wtype_
fast_k = np.empty(close.shape, dtype=float_)
slow_k = np.empty(close.shape, dtype=float_)
slow_d = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
fast_k[:, col], slow_k[:, col], slow_d[:, col] = stoch_1d_nb(
high=high[:, col],
low=low[:, col],
close=close[:, col],
fast_k_window=flex_select_1d_nb(fast_k_window_, col),
slow_k_window=flex_select_1d_nb(slow_k_window_, col),
slow_d_window=flex_select_1d_nb(slow_d_window_, col),
wtype=flex_select_1d_nb(wtype_, col),
slow_k_wtype=flex_select_1d_nb(slow_k_wtype_, col),
slow_d_wtype=flex_select_1d_nb(slow_d_wtype_, col),
minp=minp,
fast_k_minp=fast_k_minp,
slow_k_minp=slow_k_minp,
slow_d_minp=slow_d_minp,
adjust=adjust,
slow_k_adjust=slow_k_adjust,
slow_d_adjust=slow_d_adjust,
)
return fast_k, slow_k, slow_d
# ############# MACD ############# #
@register_jitted(cache=True)
def macd_1d_nb(
close: tp.Array1d,
fast_window: int = 12,
slow_window: int = 26,
signal_window: int = 9,
wtype: int = generic_enums.WType.Exp,
macd_wtype: tp.Optional[int] = None,
signal_wtype: tp.Optional[int] = None,
minp: tp.Optional[int] = None,
macd_minp: tp.Optional[int] = None,
signal_minp: tp.Optional[int] = None,
adjust: bool = False,
macd_adjust: tp.Optional[bool] = None,
signal_adjust: tp.Optional[bool] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""MACD.
Returns the MACD and the signal.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
if macd_wtype is not None:
macd_wtype_ = macd_wtype
else:
macd_wtype_ = wtype
if signal_wtype is not None:
signal_wtype_ = signal_wtype
else:
signal_wtype_ = wtype
if macd_minp is not None:
macd_minp_ = macd_minp
else:
macd_minp_ = minp
if signal_minp is not None:
signal_minp_ = signal_minp
else:
signal_minp_ = minp
if macd_adjust is not None:
macd_adjust_ = macd_adjust
else:
macd_adjust_ = adjust
if signal_adjust is not None:
signal_adjust_ = signal_adjust
else:
signal_adjust_ = adjust
fast_ma = ma_1d_nb(close, window=fast_window, wtype=macd_wtype_, minp=macd_minp_, adjust=macd_adjust_)
slow_ma = ma_1d_nb(close, window=slow_window, wtype=macd_wtype_, minp=macd_minp_, adjust=macd_adjust_)
macd = fast_ma - slow_ma
signal = ma_1d_nb(macd, window=signal_window, wtype=signal_wtype_, minp=signal_minp_, adjust=signal_adjust_)
return macd, signal
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
fast_window=base_ch.FlexArraySlicer(),
slow_window=base_ch.FlexArraySlicer(),
signal_window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
macd_wtype=base_ch.FlexArraySlicer(),
signal_wtype=base_ch.FlexArraySlicer(),
minp=None,
macd_minp=None,
signal_minp=None,
adjust=None,
macd_adjust=None,
signal_adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def macd_nb(
close: tp.Array2d,
fast_window: tp.FlexArray1dLike = 12,
slow_window: tp.FlexArray1dLike = 26,
signal_window: tp.FlexArray1dLike = 9,
wtype: tp.FlexArray1dLike = generic_enums.WType.Exp,
macd_wtype: tp.Optional[tp.FlexArray1dLike] = None,
signal_wtype: tp.Optional[tp.FlexArray1dLike] = None,
minp: tp.Optional[int] = None,
macd_minp: tp.Optional[int] = None,
signal_minp: tp.Optional[int] = None,
adjust: bool = False,
macd_adjust: tp.Optional[bool] = None,
signal_adjust: tp.Optional[bool] = None,
) -> tp.Tuple[tp.Array2d, tp.Array2d]:
"""2-dim version of `macd_1d_nb`."""
fast_window_ = to_1d_array_nb(np.asarray(fast_window))
slow_window_ = to_1d_array_nb(np.asarray(slow_window))
signal_window_ = to_1d_array_nb(np.asarray(signal_window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
if macd_wtype is not None:
macd_wtype_ = to_1d_array_nb(np.asarray(macd_wtype))
else:
macd_wtype_ = wtype_
if signal_wtype is not None:
signal_wtype_ = to_1d_array_nb(np.asarray(signal_wtype))
else:
signal_wtype_ = wtype_
macd = np.empty(close.shape, dtype=float_)
signal = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
macd[:, col], signal[:, col] = macd_1d_nb(
close=close[:, col],
fast_window=flex_select_1d_nb(fast_window_, col),
slow_window=flex_select_1d_nb(slow_window_, col),
signal_window=flex_select_1d_nb(signal_window_, col),
wtype=flex_select_1d_nb(wtype_, col),
macd_wtype=flex_select_1d_nb(macd_wtype_, col),
signal_wtype=flex_select_1d_nb(signal_wtype_, col),
minp=minp,
macd_minp=macd_minp,
signal_minp=signal_minp,
adjust=adjust,
macd_adjust=macd_adjust,
signal_adjust=signal_adjust,
)
return macd, signal
@register_jitted(cache=True)
def macd_hist_1d_nb(macd: tp.Array1d, signal: tp.Array1d) -> tp.Array1d:
"""MACD histogram."""
return macd - signal
@register_chunkable(
size=ch.ArraySizer(arg_query="macd", axis=1),
arg_take_spec=dict(
macd=ch.ArraySlicer(axis=1),
signal=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def macd_hist_nb(macd: tp.Array2d, signal: tp.Array2d) -> tp.Array2d:
"""2-dim version of `macd_hist_1d_nb`."""
macd_hist = np.empty(macd.shape, dtype=float_)
for col in prange(macd.shape[1]):
macd_hist[:, col] = macd_hist_1d_nb(macd[:, col], signal[:, col])
return macd_hist
# ############# ATR ############# #
@register_jitted(cache=True)
def iter_tr_nb(high: float, low: float, prev_close: float) -> float:
"""True Range (TR) at one iteration."""
tr0 = abs(high - low)
tr1 = abs(high - prev_close)
tr2 = abs(low - prev_close)
if np.isnan(tr0) or np.isnan(tr1) or np.isnan(tr2):
tr = np.nan
else:
tr = max(tr0, tr1, tr2)
return tr
@register_jitted(cache=True)
def tr_1d_nb(high: tp.Array1d, low: tp.Array1d, close: tp.Array1d) -> tp.Array1d:
"""True Range (TR)."""
tr = np.empty(close.shape, dtype=float_)
for i in range(close.shape[0]):
tr[i] = iter_tr_nb(high[i], low[i], close[i - 1] if i > 0 else np.nan)
return tr
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def tr_nb(high: tp.Array2d, low: tp.Array2d, close: tp.Array2d) -> tp.Array2d:
"""2-dim version of `tr_1d_nb`."""
tr = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
tr[:, col] = tr_1d_nb(high[:, col], low[:, col], close[:, col])
return tr
@register_jitted(cache=True)
def atr_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Wilder,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Average True Range (ATR).
Returns TR and ATR.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
tr = tr_1d_nb(high, low, close)
atr = ma_1d_nb(tr, window, wtype=wtype, minp=minp, adjust=adjust)
return tr, atr
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def atr_nb(
high: tp.Array2d,
low: tp.Array2d,
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Wilder,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Tuple[tp.Array2d, tp.Array2d]:
"""2-dim version of `atr_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
tr = np.empty(close.shape, dtype=float_)
atr = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
tr[:, col], atr[:, col] = atr_1d_nb(
high[:, col],
low[:, col],
close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
minp=minp,
adjust=adjust,
)
return tr, atr
# ############# ADX ############# #
@register_jitted(cache=True)
def adx_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Wilder,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]:
"""Average Directional Movement Index (ADX).
Returns +DI, -DI, DX, and ADX.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
_, atr = atr_1d_nb(
high,
low,
close,
window=window,
wtype=wtype,
minp=minp,
adjust=adjust,
)
dm_plus = np.empty(close.shape, dtype=float_)
dm_minus = np.empty(close.shape, dtype=float_)
for i in range(close.shape[0]):
up_change = np.nan if i == 0 else high[i] - high[i - 1]
down_change = np.nan if i == 0 else low[i - 1] - low[i]
if up_change > down_change and up_change > 0:
dm_plus[i] = up_change
else:
dm_plus[i] = 0.0
if down_change > up_change and down_change > 0:
dm_minus[i] = down_change
else:
dm_minus[i] = 0.0
dm_plus_smoothed = ma_1d_nb(dm_plus, window, wtype=wtype, minp=minp, adjust=adjust)
dm_minus_smoothed = ma_1d_nb(dm_minus, window, wtype=wtype, minp=minp, adjust=adjust)
plus_di = 100 * dm_plus_smoothed / atr
minus_di = 100 * dm_minus_smoothed / atr
dx = 100 * np.abs(plus_di - minus_di) / (plus_di + minus_di)
adx = ma_1d_nb(dx, window, wtype=wtype, minp=minp, adjust=adjust)
return plus_di, minus_di, dx, adx
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def adx_nb(
high: tp.Array2d,
low: tp.Array2d,
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Wilder,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]:
"""2-dim version of `adx_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
plus_di = np.empty(close.shape, dtype=float_)
minus_di = np.empty(close.shape, dtype=float_)
dx = np.empty(close.shape, dtype=float_)
adx = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
plus_di[:, col], minus_di[:, col], dx[:, col], adx[:, col] = adx_1d_nb(
high[:, col],
low[:, col],
close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
minp=minp,
adjust=adjust,
)
return plus_di, minus_di, dx, adx
# ############# OBV ############# #
@register_jitted(cache=True)
def obv_1d_nb(close: tp.Array1d, volume: tp.Array1d) -> tp.Array1d:
"""On-Balance Volume (OBV)."""
obv = np.empty(close.shape, dtype=float_)
cumsum = 0.0
for i in range(close.shape[0]):
prev_close = close[i - 1] if i > 0 else np.nan
if close[i] < prev_close:
value = -volume[i]
else:
value = volume[i]
if not np.isnan(value):
cumsum += value
obv[i] = cumsum
return obv
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
volume=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def obv_nb(close: tp.Array2d, volume: tp.Array2d) -> tp.Array2d:
"""2-dim version of `obv_1d_nb`."""
obv = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
obv[:, col] = obv_1d_nb(close[:, col], volume[:, col])
return obv
# ############# OLS ############# #
@register_jitted(cache=True)
def ols_1d_nb(
x: tp.Array1d,
y: tp.Array1d,
window: int = 14,
norm_window: tp.Optional[int] = None,
minp: tp.Optional[int] = None,
ddof: int = 0,
with_zscore: bool = True,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]:
"""Rolling Ordinary Least Squares (OLS)."""
if norm_window is not None:
norm_window_ = norm_window
else:
norm_window_ = window
slope, intercept = generic_nb.rolling_ols_1d_nb(x, y, window, minp=minp)
if with_zscore:
pred = intercept + slope * x
error = y - pred
error_mean = generic_nb.rolling_mean_1d_nb(error, norm_window_, minp=minp)
error_std = generic_nb.rolling_std_1d_nb(error, norm_window_, minp=minp, ddof=ddof)
zscore = (error - error_mean) / error_std
else:
zscore = np.full(x.shape, np.nan, dtype=float_)
return slope, intercept, zscore
@register_chunkable(
size=ch.ArraySizer(arg_query="x", axis=1),
arg_take_spec=dict(
x=ch.ArraySlicer(axis=1),
y=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
norm_window=base_ch.FlexArraySlicer(),
minp=None,
ddof=None,
with_zscore=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ols_nb(
x: tp.Array2d,
y: tp.Array2d,
window: tp.FlexArray1dLike = 14,
norm_window: tp.Optional[tp.FlexArray1dLike] = None,
minp: tp.Optional[int] = None,
ddof: int = 0,
with_zscore: bool = True,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d]:
"""2-dim version of `ols_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
if norm_window is not None:
norm_window_ = to_1d_array_nb(np.asarray(norm_window))
else:
norm_window_ = window_
slope = np.empty(x.shape, dtype=float_)
intercept = np.empty(x.shape, dtype=float_)
zscore = np.empty(x.shape, dtype=float_)
for col in prange(x.shape[1]):
slope[:, col], intercept[:, col], zscore[:, col] = ols_1d_nb(
x[:, col],
y[:, col],
window=flex_select_1d_nb(window_, col),
norm_window=flex_select_1d_nb(norm_window_, col),
minp=minp,
ddof=ddof,
with_zscore=with_zscore,
)
return slope, intercept, zscore
@register_jitted(cache=True)
def ols_pred_1d_nb(x: tp.Array1d, slope: tp.Array1d, intercept: tp.Array1d) -> tp.Array1d:
"""OLS prediction."""
return intercept + slope * x
@register_chunkable(
size=ch.ArraySizer(arg_query="x", axis=1),
arg_take_spec=dict(
x=ch.ArraySlicer(axis=1),
slope=ch.ArraySlicer(axis=1),
intercept=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ols_pred_nb(x: tp.Array2d, slope: tp.Array2d, intercept: tp.Array2d) -> tp.Array2d:
"""2-dim version of `ols_pred_1d_nb`."""
pred = np.empty(x.shape, dtype=float_)
for col in prange(x.shape[1]):
pred[:, col] = ols_pred_1d_nb(x[:, col], slope[:, col], intercept[:, col])
return pred
@register_jitted(cache=True)
def ols_error_1d_nb(y: tp.Array1d, pred: tp.Array1d) -> tp.Array1d:
"""OLS error."""
return y - pred
@register_chunkable(
size=ch.ArraySizer(arg_query="y", axis=1),
arg_take_spec=dict(
y=ch.ArraySlicer(axis=1),
pred=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ols_error_nb(y: tp.Array2d, pred: tp.Array2d) -> tp.Array2d:
"""2-dim version of `ols_error_1d_nb`."""
error = np.empty(y.shape, dtype=float_)
for col in prange(y.shape[1]):
error[:, col] = ols_error_1d_nb(y[:, col], pred[:, col])
return error
@register_jitted(cache=True)
def ols_angle_1d_nb(slope: tp.Array1d) -> tp.Array1d:
"""OLS angle."""
return np.arctan(slope) * 180 / np.pi
@register_chunkable(
size=ch.ArraySizer(arg_query="slope", axis=1),
arg_take_spec=dict(
slope=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def ols_angle_nb(slope: tp.Array2d) -> tp.Array2d:
"""2-dim version of `ols_angle_1d_nb`."""
angle = np.empty(slope.shape, dtype=float_)
for col in prange(slope.shape[1]):
angle[:, col] = ols_angle_1d_nb(slope[:, col])
return angle
# ############# VWAP ############# #
@register_jitted(cache=True)
def typical_price_1d_nb(high: tp.Array1d, low: tp.Array1d, close: tp.Array1d) -> tp.Array1d:
"""Typical price."""
return (high + low + close) / 3
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def typical_price_nb(high: tp.Array2d, low: tp.Array2d, close: tp.Array2d) -> tp.Array2d:
"""2-dim version of `typical_price_1d_nb`."""
typical_price = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
typical_price[:, col] = typical_price_1d_nb(high[:, col], low[:, col], close[:, col])
return typical_price
@register_jitted(cache=True)
def vwap_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
close: tp.Array1d,
volume: tp.Array1d,
group_lens: tp.GroupLens,
) -> tp.Array1d:
"""Volume-Weighted Average Price (VWAP)."""
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
out = np.full(volume.shape, np.nan, dtype=float_)
typical_price = typical_price_1d_nb(high, low, close)
for group in range(len(group_lens)):
from_i = group_start_idxs[group]
to_i = group_end_idxs[group]
nom_cumsum = 0
denum_cumsum = 0
for i in range(from_i, to_i):
nom_cumsum += volume[i] * typical_price[i]
denum_cumsum += volume[i]
if denum_cumsum == 0:
out[i] = np.nan
else:
out[i] = nom_cumsum / denum_cumsum
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
volume=ch.ArraySlicer(axis=1),
group_lens=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def vwap_nb(
high: tp.Array2d,
low: tp.Array2d,
close: tp.Array2d,
volume: tp.Array2d,
group_lens: tp.GroupLens,
) -> tp.Array2d:
"""2-dim version of `vwap_1d_nb`."""
vwap = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
vwap[:, col] = vwap_1d_nb(
high[:, col],
low[:, col],
close[:, col],
volume[:, col],
group_lens,
)
return vwap
# ############# PIVOTINFO ############# #
@register_jitted(cache=True)
def pivot_info_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
up_th: tp.FlexArray1dLike,
down_th: tp.FlexArray1dLike,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]:
"""Pivot information."""
up_th_ = to_1d_array_nb(np.asarray(up_th))
down_th_ = to_1d_array_nb(np.asarray(down_th))
conf_pivot = np.empty(high.shape, dtype=int_)
conf_idx = np.empty(high.shape, dtype=int_)
last_pivot = np.empty(high.shape, dtype=int_)
last_idx = np.empty(high.shape, dtype=int_)
_conf_pivot = 0
_conf_idx = -1
_conf_value = np.nan
_last_pivot = 0
_last_idx = -1
_last_value = np.nan
first_valid_idx = -1
for i in range(high.shape[0]):
if not np.isnan(high[i]) and not np.isnan(low[i]):
if first_valid_idx == -1:
_up_th = 1 + abs(flex_select_1d_nb(up_th_, i))
_down_th = 1 - abs(flex_select_1d_nb(down_th_, i))
if np.isnan(_up_th) or np.isnan(_down_th):
conf_pivot[i] = _conf_pivot
conf_idx[i] = _conf_idx
last_pivot[i] = _last_pivot
last_idx[i] = _last_idx
continue
first_valid_idx = i
if _last_idx == -1:
_up_th = 1 + abs(flex_select_1d_nb(up_th_, first_valid_idx))
_down_th = 1 - abs(flex_select_1d_nb(down_th_, first_valid_idx))
if not np.isnan(_up_th) and high[i] >= low[first_valid_idx] * _up_th:
if not np.isnan(_down_th) and low[i] <= high[first_valid_idx] * _down_th:
pass # wait
else:
_conf_pivot = Pivot.Valley
_conf_idx = first_valid_idx
_conf_value = low[first_valid_idx]
_last_pivot = Pivot.Peak
_last_idx = i
_last_value = high[i]
if not np.isnan(_down_th) and low[i] <= high[first_valid_idx] * _down_th:
if not np.isnan(_up_th) and high[i] >= low[first_valid_idx] * _up_th:
pass # wait
else:
_conf_pivot = Pivot.Peak
_conf_idx = first_valid_idx
_conf_value = high[first_valid_idx]
_last_pivot = Pivot.Valley
_last_idx = i
_last_value = low[i]
else:
_up_th = 1 + abs(flex_select_1d_nb(up_th_, _last_idx))
_down_th = 1 - abs(flex_select_1d_nb(down_th_, _last_idx))
if _last_pivot == Pivot.Valley:
if not np.isnan(_last_value) and not np.isnan(_up_th) and high[i] >= _last_value * _up_th:
_conf_pivot = _last_pivot
_conf_idx = _last_idx
_conf_value = _last_value
_last_pivot = Pivot.Peak
_last_idx = i
_last_value = high[i]
elif np.isnan(_last_value) or low[i] < _last_value:
_last_idx = i
_last_value = low[i]
elif _last_pivot == Pivot.Peak:
if not np.isnan(_last_value) and not np.isnan(_down_th) and low[i] <= _last_value * _down_th:
_conf_pivot = _last_pivot
_conf_idx = _last_idx
_conf_value = _last_value
_last_pivot = Pivot.Valley
_last_idx = i
_last_value = low[i]
elif np.isnan(_last_value) or high[i] > _last_value:
_last_idx = i
_last_value = high[i]
conf_pivot[i] = _conf_pivot
conf_idx[i] = _conf_idx
last_pivot[i] = _last_pivot
last_idx[i] = _last_idx
return conf_pivot, conf_idx, last_pivot, last_idx
@register_chunkable(
size=ch.ArraySizer(arg_query="high", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
up_th=base_ch.FlexArraySlicer(axis=1),
down_th=base_ch.FlexArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def pivot_info_nb(
high: tp.Array2d,
low: tp.Array2d,
up_th: tp.FlexArray2dLike,
down_th: tp.FlexArray2dLike,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]:
"""2-dim version of `pivot_info_1d_nb`."""
up_th_ = to_2d_array_nb(np.asarray(up_th))
down_th_ = to_2d_array_nb(np.asarray(down_th))
conf_pivot = np.empty(high.shape, dtype=int_)
conf_idx = np.empty(high.shape, dtype=int_)
last_pivot = np.empty(high.shape, dtype=int_)
last_idx = np.empty(high.shape, dtype=int_)
for col in prange(high.shape[1]):
conf_pivot[:, col], conf_idx[:, col], last_pivot[:, col], last_idx[:, col] = pivot_info_1d_nb(
high[:, col],
low[:, col],
flex_select_col_nb(up_th_, col),
flex_select_col_nb(down_th_, col),
)
return conf_pivot, conf_idx, last_pivot, last_idx
@register_jitted(cache=True)
def pivot_value_1d_nb(high: tp.Array1d, low: tp.Array1d, last_pivot: tp.Array1d, last_idx: tp.Array1d) -> tp.Array1d:
"""Pivot value."""
pivot_value = np.empty(high.shape, dtype=float_)
for i in range(high.shape[0]):
if last_pivot[i] == Pivot.Peak:
pivot_value[i] = high[last_idx[i]]
elif last_pivot[i] == Pivot.Valley:
pivot_value[i] = low[last_idx[i]]
else:
pivot_value[i] = np.nan
return pivot_value
@register_chunkable(
size=ch.ArraySizer(arg_query="high", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
last_pivot=ch.ArraySlicer(axis=1),
last_idx=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def pivot_value_nb(high: tp.Array2d, low: tp.Array2d, last_pivot: tp.Array2d, last_idx: tp.Array2d) -> tp.Array2d:
"""2-dim version of `pivot_value_1d_nb`."""
pivot_value = np.empty(high.shape, dtype=float_)
for col in prange(high.shape[1]):
pivot_value[:, col] = pivot_value_1d_nb(high[:, col], low[:, col], last_pivot[:, col], last_idx[:, col])
return pivot_value
@register_jitted(cache=True)
def pivots_1d_nb(conf_pivot: tp.Array1d, conf_idx: tp.Array1d, last_pivot: tp.Array1d) -> tp.Array1d:
"""Pivots.
!!! warning
To be used in plotting. Do not use it as an indicator!"""
pivots = np.zeros(conf_pivot.shape, dtype=int_)
for i in range(conf_pivot.shape[0] - 1):
pivots[conf_idx[i]] = conf_pivot[i]
pivots[-1] = last_pivot[-1]
return pivots
@register_chunkable(
size=ch.ArraySizer(arg_query="conf_pivot", axis=1),
arg_take_spec=dict(
conf_pivot=ch.ArraySlicer(axis=1),
conf_idx=ch.ArraySlicer(axis=1),
last_pivot=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def pivots_nb(conf_pivot: tp.Array2d, conf_idx: tp.Array2d, last_pivot: tp.Array2d) -> tp.Array2d:
"""2-dim version of `pivots_1d_nb`."""
pivots = np.empty(conf_pivot.shape, dtype=int_)
for col in prange(conf_pivot.shape[1]):
pivots[:, col] = pivots_1d_nb(conf_pivot[:, col], conf_idx[:, col], last_pivot[:, col])
return pivots
@register_jitted(cache=True)
def modes_1d_nb(pivots: tp.Array1d) -> tp.Array1d:
"""Modes.
!!! warning
To be used in plotting. Do not use it as an indicator!"""
modes = np.empty(pivots.shape, dtype=int_)
mode = 0
for i in range(pivots.shape[0]):
if pivots[i] != 0:
mode = -pivots[i]
modes[i] = mode
return modes
@register_chunkable(
size=ch.ArraySizer(arg_query="pivots", axis=1),
arg_take_spec=dict(
pivots=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def modes_nb(pivots: tp.Array2d) -> tp.Array2d:
"""2-dim version of `modes_1d_nb`."""
modes = np.empty(pivots.shape, dtype=int_)
for col in prange(pivots.shape[1]):
modes[:, col] = modes_1d_nb(pivots[:, col])
return modes
# ############# SUPERTREND ############# #
@register_jitted(cache=True)
def iter_med_price_nb(high: float, low: float) -> float:
"""Median price at one iteration."""
return (high + low) / 2
@register_jitted(cache=True)
def iter_basic_bands_nb(high: float, low: float, atr: float, multiplier: float) -> tp.Tuple[float, float]:
"""Upper and lower bands at one iteration."""
med_price = iter_med_price_nb(high, low)
matr = multiplier * atr
upper = med_price + matr
lower = med_price - matr
return upper, lower
@register_jitted(cache=True)
def final_basic_bands_nb(
close: float,
upper: float,
lower: float,
prev_upper: float,
prev_lower: float,
prev_direction: int,
) -> tp.Tuple[float, float, float, int, float, float]:
"""Final bands at one iteration."""
if close > prev_upper:
direction = 1
elif close < prev_lower:
direction = -1
else:
direction = prev_direction
if direction > 0 and lower < prev_lower:
lower = prev_lower
if direction < 0 and upper > prev_upper:
upper = prev_upper
if direction > 0:
trend = long = lower
short = np.nan
else:
trend = short = upper
long = np.nan
return upper, lower, trend, direction, long, short
@register_jitted(cache=True)
def supertrend_acc_nb(in_state: SuperTrendAIS) -> SuperTrendAOS:
"""Accumulator of `supertrend_nb`.
Takes a state of type `vectorbtpro.indicators.enums.SuperTrendAIS` and returns
a state of type `vectorbtpro.indicators.enums.SuperTrendAOS`."""
i = in_state.i
high = in_state.high
low = in_state.low
close = in_state.close
prev_close = in_state.prev_close
prev_upper = in_state.prev_upper
prev_lower = in_state.prev_lower
prev_direction = in_state.prev_direction
nobs = in_state.nobs
weighted_avg = in_state.weighted_avg
old_wt = in_state.old_wt
period = in_state.period
multiplier = in_state.multiplier
tr = iter_tr_nb(high, low, prev_close)
alpha = generic_nb.alpha_from_wilder_nb(period)
ewm_mean_in_state = generic_enums.EWMMeanAIS(
i=i,
value=tr,
old_wt=old_wt,
weighted_avg=weighted_avg,
nobs=nobs,
alpha=alpha,
minp=period,
adjust=False,
)
ewm_mean_out_state = generic_nb.ewm_mean_acc_nb(ewm_mean_in_state)
atr = ewm_mean_out_state.value
upper, lower = iter_basic_bands_nb(high, low, atr, multiplier)
if i == 0:
trend, direction, long, short = np.nan, 1, np.nan, np.nan
else:
upper, lower, trend, direction, long, short = final_basic_bands_nb(
close,
upper,
lower,
prev_upper,
prev_lower,
prev_direction,
)
return SuperTrendAOS(
nobs=ewm_mean_out_state.nobs,
weighted_avg=ewm_mean_out_state.weighted_avg,
old_wt=ewm_mean_out_state.old_wt,
upper=upper,
lower=lower,
trend=trend,
direction=direction,
long=long,
short=short,
)
@register_jitted(cache=True)
def supertrend_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
close: tp.Array1d,
period: int = 7,
multiplier: float = 3.0,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]:
"""Supertrend."""
trend = np.empty(close.shape, dtype=float_)
direction = np.empty(close.shape, dtype=int_)
long = np.empty(close.shape, dtype=float_)
short = np.empty(close.shape, dtype=float_)
if close.shape[0] == 0:
return trend, direction, long, short
nobs = 0
old_wt = 1.0
weighted_avg = np.nan
prev_upper = np.nan
prev_lower = np.nan
for i in range(close.shape[0]):
in_state = SuperTrendAIS(
i=i,
high=high[i],
low=low[i],
close=close[i],
prev_close=close[i - 1] if i > 0 else np.nan,
prev_upper=prev_upper,
prev_lower=prev_lower,
prev_direction=direction[i - 1] if i > 0 else 1,
nobs=nobs,
weighted_avg=weighted_avg,
old_wt=old_wt,
period=period,
multiplier=multiplier,
)
out_state = supertrend_acc_nb(in_state)
nobs = out_state.nobs
weighted_avg = out_state.weighted_avg
old_wt = out_state.old_wt
prev_upper = out_state.upper
prev_lower = out_state.lower
trend[i] = out_state.trend
direction[i] = out_state.direction
long[i] = out_state.long
short[i] = out_state.short
return trend, direction, long, short
@register_chunkable(
size=ch.ArraySizer(arg_query="high", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
close=ch.ArraySlicer(axis=1),
period=base_ch.FlexArraySlicer(),
multiplier=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def supertrend_nb(
high: tp.Array2d,
low: tp.Array2d,
close: tp.Array2d,
period: tp.FlexArray1dLike = 7,
multiplier: tp.FlexArray1dLike = 3.0,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]:
"""2-dim version of `supertrend_1d_nb`."""
period_ = to_1d_array_nb(np.asarray(period))
multiplier_ = to_1d_array_nb(np.asarray(multiplier))
trend = np.empty(close.shape, dtype=float_)
direction = np.empty(close.shape, dtype=int_)
long = np.empty(close.shape, dtype=float_)
short = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
trend[:, col], direction[:, col], long[:, col], short[:, col] = supertrend_1d_nb(
high[:, col],
low[:, col],
close[:, col],
period=flex_select_1d_nb(period_, col),
multiplier=flex_select_1d_nb(multiplier_, col),
)
return trend, direction, long, short
# ############# SIGDET ############# #
@register_jitted(cache=True)
def signal_detection_1d_nb(
close: tp.Array1d,
lag: int = 14,
factor: tp.FlexArray1dLike = 1.0,
influence: tp.FlexArray1dLike = 1.0,
up_factor: tp.Optional[tp.FlexArray1dLike] = None,
down_factor: tp.Optional[tp.FlexArray1dLike] = None,
mean_influence: tp.Optional[tp.FlexArray1dLike] = None,
std_influence: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d]:
"""Signal detection."""
factor_ = to_1d_array_nb(np.asarray(factor))
influence_ = to_1d_array_nb(np.asarray(influence))
if up_factor is not None:
up_factor_ = to_1d_array_nb(np.asarray(up_factor))
else:
up_factor_ = factor_
if down_factor is not None:
down_factor_ = to_1d_array_nb(np.asarray(down_factor))
else:
down_factor_ = factor_
if mean_influence is not None:
mean_influence_ = to_1d_array_nb(np.asarray(mean_influence))
else:
mean_influence_ = influence_
if std_influence is not None:
std_influence_ = to_1d_array_nb(np.asarray(std_influence))
else:
std_influence_ = influence_
signal = np.full(close.shape, 0, dtype=int_)
close_mean_filter = close.astype(float_)
close_std_filter = close.astype(float_)
mean_filter = np.full(close.shape, np.nan, dtype=float_)
std_filter = np.full(close.shape, np.nan, dtype=float_)
upper_band = np.full(close.shape, np.nan, dtype=float_)
lower_band = np.full(close.shape, np.nan, dtype=float_)
if lag == 0:
raise ValueError("Lag cannot be zero")
if lag - 1 >= close.shape[0]:
raise ValueError("Lag must be smaller than close")
mean_filter[lag - 1] = np.nanmean(close[:lag])
std_filter[lag - 1] = np.nanstd(close[:lag])
for i in range(lag, close.shape[0]):
_up_factor = abs(flex_select_1d_nb(up_factor_, i))
_down_factor = abs(flex_select_1d_nb(down_factor_, i))
_mean_influence = abs(flex_select_1d_nb(mean_influence_, i))
_std_influence = abs(flex_select_1d_nb(std_influence_, i))
up_crossed = close[i] - mean_filter[i - 1] >= _up_factor * std_filter[i - 1]
down_crossed = close[i] - mean_filter[i - 1] <= -_down_factor * std_filter[i - 1]
if up_crossed or down_crossed:
if up_crossed:
signal[i] = 1
else:
signal[i] = -1
close_mean_filter[i] = _mean_influence * close[i] + (1 - _mean_influence) * close_mean_filter[i - 1]
close_std_filter[i] = _std_influence * close[i] + (1 - _std_influence) * close_std_filter[i - 1]
else:
signal[i] = 0
close_mean_filter[i] = close[i]
close_std_filter[i] = close[i]
mean_filter[i] = np.nanmean(close_mean_filter[(i - lag + 1) : i + 1])
std_filter[i] = np.nanstd(close_std_filter[(i - lag + 1) : i + 1])
upper_band[i] = mean_filter[i] + _up_factor * std_filter[i - 1]
lower_band[i] = mean_filter[i] - _down_factor * std_filter[i - 1]
return signal, upper_band, lower_band
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
lag=base_ch.FlexArraySlicer(),
factor=base_ch.FlexArraySlicer(axis=1),
influence=base_ch.FlexArraySlicer(axis=1),
up_factor=base_ch.FlexArraySlicer(axis=1),
down_factor=base_ch.FlexArraySlicer(axis=1),
mean_influence=base_ch.FlexArraySlicer(axis=1),
std_influence=base_ch.FlexArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def signal_detection_nb(
close: tp.Array2d,
lag: tp.FlexArray1dLike = 14,
factor: tp.FlexArray2dLike = 1.0,
influence: tp.FlexArray2dLike = 1.0,
up_factor: tp.Optional[tp.FlexArray2dLike] = None,
down_factor: tp.Optional[tp.FlexArray2dLike] = None,
mean_influence: tp.Optional[tp.FlexArray2dLike] = None,
std_influence: tp.Optional[tp.FlexArray2dLike] = None,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d]:
"""2-dim version of `signal_detection_1d_nb`."""
lag_ = to_1d_array_nb(np.asarray(lag))
factor_ = to_2d_array_nb(np.asarray(factor))
influence_ = to_2d_array_nb(np.asarray(influence))
if up_factor is not None:
up_factor_ = to_2d_array_nb(np.asarray(up_factor))
else:
up_factor_ = factor_
if down_factor is not None:
down_factor_ = to_2d_array_nb(np.asarray(down_factor))
else:
down_factor_ = factor_
if mean_influence is not None:
mean_influence_ = to_2d_array_nb(np.asarray(mean_influence))
else:
mean_influence_ = influence_
if std_influence is not None:
std_influence_ = to_2d_array_nb(np.asarray(std_influence))
else:
std_influence_ = influence_
signal = np.empty(close.shape, dtype=int_)
upper_band = np.empty(close.shape, dtype=float_)
lower_band = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
signal[:, col], upper_band[:, col], lower_band[:, col] = signal_detection_1d_nb(
close[:, col],
lag=flex_select_1d_nb(lag_, col),
factor=flex_select_col_nb(factor_, col),
influence=flex_select_col_nb(influence_, col),
up_factor=flex_select_col_nb(up_factor_, col),
down_factor=flex_select_col_nb(down_factor_, col),
mean_influence=flex_select_col_nb(mean_influence_, col),
std_influence=flex_select_col_nb(std_influence_, col),
)
return signal, upper_band, lower_band
# ############# HURST ############# #
@register_jitted(cache=True)
def get_standard_hurst_nb(
close: tp.Array1d,
max_lag: int = 20,
stabilize: bool = False,
) -> float:
"""Estimate the Hurst exponent using standard method."""
if max_lag is None:
lags = np.arange(2, len(close) - 1)
else:
lags = np.arange(2, min(max_lag, len(close) - 1))
tau = np.empty(len(lags), dtype=float_)
for i, lag in enumerate(lags):
tau[i] = np.var(np.subtract(close[lag:], close[:-lag]))
coef = generic_nb.polyfit_1d_nb(np.log(lags), np.log(tau), 1, stabilize=stabilize)
return coef[0] / 2
@register_jitted(cache=True)
def get_rs_nb(close: tp.Array1d) -> float:
"""Get rescaled range (R/S) for Hurst exponent estimation."""
incs = close[1:] / close[:-1] - 1.0
mean_inc = np.sum(incs) / len(incs)
deviations = incs - mean_inc
Z = np.cumsum(deviations)
R = np.max(Z) - np.min(Z)
S = generic_nb.nanstd_1d_nb(incs, ddof=1)
if R == 0 or S == 0:
return 0
return R / S
@register_jitted(cache=True)
def get_log_rs_hurst_nb(
close: tp.Array1d,
min_log: int = 1,
max_log: int = 2,
log_step: int = 0.25,
) -> float:
"""Estimate the Hurst exponent using R/S method.
Windows are log-distributed."""
max_log = min(max_log, np.log10(len(close) - 1))
log_range = np.arange(min_log, max_log, log_step)
windows = np.empty(len(log_range) + 1, dtype=int_)
windows[: len(log_range)] = 10**log_range
windows[-1] = len(close)
RS = np.empty(len(windows), dtype=float_)
W = np.empty(len(windows), dtype=int_)
k = 0
for i, w in enumerate(windows):
rs_sum = 0.0
rs_count = 0
for start in range(0, len(close), w):
if (start + w) > len(close):
break
rs = get_rs_nb(close[start : start + w])
if rs != 0:
rs_sum += rs
rs_count += 1
if rs_count != 0:
RS[k] = rs_sum / rs_count
W[k] = w
k += 1
if k == 0:
return np.nan
A = np.vstack((np.log10(W[:k]), np.ones(len(RS[:k])))).T
H, c = np.linalg.lstsq(A, np.log10(RS[:k]), rcond=-1)[0]
return H
@register_jitted(cache=True)
def get_rs_hurst_nb(
close: tp.Array1d,
min_chunk: int = 8,
max_chunk: int = 100,
num_chunks: int = 5,
) -> float:
"""Estimate the Hurst exponent using R/S method.
Windows are linearly distributed."""
diff = close[1:] - close[:-1]
N = len(diff)
max_chunk += 1
max_chunk = min(max_chunk, len(diff) - 1)
rs_tmp = np.empty(N, dtype=float_)
chunk_size_range = np.linspace(min_chunk, max_chunk, num_chunks).astype(int_)
chunk_size_list = np.empty(len(chunk_size_range), dtype=int_)
rs_values_list = np.empty(len(chunk_size_range), dtype=float_)
k = 0
for chunk_size in chunk_size_range:
number_of_chunks = int(len(diff) / chunk_size)
for idx in range(number_of_chunks):
ini = idx * chunk_size
end = ini + chunk_size
chunk = diff[ini:end]
z = np.cumsum(chunk - np.mean(chunk))
rs_tmp[idx] = np.divide(np.max(z) - np.min(z), np.nanstd(chunk))
rs = np.nanmean(rs_tmp[: idx + 1])
if not np.isnan(rs) and rs != 0:
chunk_size_list[k] = chunk_size
rs_values_list[k] = rs
k += 1
H, c = np.linalg.lstsq(
a=np.vstack((np.log(chunk_size_list[:k]), np.ones(len(chunk_size_list[:k])))).T,
b=np.log(rs_values_list[:k]),
rcond=-1,
)[0]
return H
@register_jitted(cache=True)
def get_dma_hurst_nb(
close: tp.Array1d,
min_chunk: int = 8,
max_chunk: int = 100,
num_chunks: int = 5,
) -> float:
"""Estimate the Hurst exponent using DMA method.
Windows are linearly distributed."""
max_chunk += 1
max_chunk = min(max_chunk, len(close) - 1)
N = len(close)
n_range = np.linspace(min_chunk, max_chunk, num_chunks).astype(int_)
n_list = np.empty(len(n_range), dtype=int_)
dma_list = np.empty(len(n_range), dtype=float_)
k = 0
factor = 1 / (N - max_chunk)
for i, n in enumerate(n_range):
x1 = np.full(n, -1, int_)
x1[0] = n - 1
b = np.divide(x1, n) # do the same as: y - y_ma_n
noise = np.power(generic_nb.fir_filter_1d_nb(b, close)[max_chunk:], 2)
dma = np.sqrt(factor * np.sum(noise))
if not np.isnan(dma) and dma != 0:
n_list[k] = n
dma_list[k] = dma
k += 1
if k == 0:
return np.nan
H, const = np.linalg.lstsq(
a=np.vstack((np.log10(n_list[:k]), np.ones(len(n_list[:k])))).T,
b=np.log10(dma_list[:k]),
rcond=-1,
)[0]
return H
@register_jitted(cache=True)
def get_dsod_hurst_nb(close: tp.Array1d) -> float:
"""Estimate the Hurst exponent using discrete second order derivative."""
diff = close[1:] - close[:-1]
y = np.cumsum(diff)
b1 = [1, -2, 1]
y1 = generic_nb.fir_filter_1d_nb(b1, y)
y1 = y1[len(b1) - 1 :]
b2 = [1, 0, -2, 0, 1]
y2 = generic_nb.fir_filter_1d_nb(b2, y)
y2 = y2[len(b2) - 1 :]
s1 = np.mean(y1**2)
s2 = np.mean(y2**2)
return 0.5 * np.log2(s2 / s1)
@register_jitted(cache=True)
def get_hurst_nb(
close: tp.Array1d,
method: int = HurstMethod.Standard,
max_lag: int = 20,
min_log: int = 1,
max_log: int = 2,
log_step: int = 0.25,
min_chunk: int = 8,
max_chunk: int = 100,
num_chunks: int = 5,
stabilize: bool = False,
) -> float:
"""Estimate the Hurst exponent using various methods.
Uses the following methods:
* `HurstMethod.Standard`: `vectorbtpro.indicators.nb.get_standard_hurst_nb`
* `HurstMethod.LogRS`: `vectorbtpro.indicators.nb.get_log_rs_hurst_nb`
* `HurstMethod.RS`: `vectorbtpro.indicators.nb.get_rs_hurst_nb`
* `HurstMethod.DMA`: `vectorbtpro.indicators.nb.get_dma_hurst_nb`
* `HurstMethod.DSOD`: `vectorbtpro.indicators.nb.get_dsod_hurst_nb`
"""
if method == HurstMethod.Standard:
return get_standard_hurst_nb(close, max_lag=max_lag, stabilize=stabilize)
if method == HurstMethod.LogRS:
return get_log_rs_hurst_nb(close, min_log=min_log, max_log=max_log, log_step=log_step)
if method == HurstMethod.RS:
return get_rs_hurst_nb(close, min_chunk=min_chunk, max_chunk=max_chunk, num_chunks=num_chunks)
if method == HurstMethod.DMA:
return get_dma_hurst_nb(close, min_chunk=min_chunk, max_chunk=max_chunk, num_chunks=num_chunks)
if method == HurstMethod.DSOD:
return get_dsod_hurst_nb(close)
raise ValueError("Invalid HurstMethod option")
@register_jitted(cache=True)
def rolling_hurst_1d_nb(
close: tp.Array1d,
window: int,
method: int = HurstMethod.Standard,
max_lag: int = 20,
min_log: int = 1,
max_log: int = 2,
log_step: int = 0.25,
min_chunk: int = 8,
max_chunk: int = 100,
num_chunks: int = 5,
minp: tp.Optional[int] = None,
stabilize: bool = False,
) -> tp.Array1d:
"""Rolling version of `get_hurst_nb`.
For `method`, see `vectorbtpro.indicators.enums.HurstMethod`."""
if minp is None:
minp = window
if minp > window:
raise ValueError("minp must be <= window")
out = np.empty_like(close, dtype=float_)
nancnt = 0
for i in range(close.shape[0]):
if np.isnan(close[i]):
nancnt = nancnt + 1
if i < window:
valid_cnt = i + 1 - nancnt
else:
if np.isnan(close[i - window]):
nancnt = nancnt - 1
valid_cnt = window - nancnt
if valid_cnt < minp:
out[i] = np.nan
else:
from_i = max(0, i + 1 - window)
to_i = i + 1
close_window = close[from_i:to_i]
out[i] = get_hurst_nb(
close_window,
method=method,
max_lag=max_lag,
min_log=min_log,
max_log=max_log,
log_step=log_step,
min_chunk=min_chunk,
max_chunk=max_chunk,
num_chunks=num_chunks,
stabilize=stabilize,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=None,
method=None,
max_lag=None,
min_log=None,
max_log=None,
log_step=None,
min_chunk=None,
max_chunk=None,
num_chunks=None,
minp=None,
stabilize=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_hurst_nb(
close: tp.Array2d,
window: int,
method: int = HurstMethod.Standard,
max_lag: int = 20,
min_log: int = 1,
max_log: int = 2,
log_step: int = 0.25,
min_chunk: int = 8,
max_chunk: int = 100,
num_chunks: int = 5,
minp: tp.Optional[int] = None,
stabilize: bool = False,
) -> tp.Array2d:
"""2-dim version of `rolling_hurst_1d_nb`."""
out = np.empty_like(close, dtype=float_)
for col in prange(close.shape[1]):
out[:, col] = rolling_hurst_1d_nb(
close[:, col],
window,
method=method,
max_lag=max_lag,
min_log=min_log,
max_log=max_log,
log_step=log_step,
min_chunk=min_chunk,
max_chunk=max_chunk,
num_chunks=num_chunks,
minp=minp,
stabilize=stabilize,
)
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Helper functions for TA-Lib."""
import inspect
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.merging import column_stack_arrays
from vectorbtpro.base.reshaping import to_pd_array, broadcast_arrays, broadcast
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.accessors import GenericAccessor
from vectorbtpro.utils.array_ import build_nan_mask, squeeze_nan, unsqueeze_nan
from vectorbtpro.utils.colors import adjust_opacity
from vectorbtpro.utils.config import merge_dicts, resolve_dict
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"talib_func",
"talib_plot_func",
]
def talib_func(func_name: str) -> tp.Callable:
"""Get the TA-Lib indicator function."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("talib")
import talib
from talib import abstract
func_name = func_name.upper()
talib_func = getattr(talib, func_name)
info = abstract.Function(func_name).info
input_names = []
for in_names in info["input_names"].values():
if isinstance(in_names, (list, tuple)):
input_names.extend(list(in_names))
else:
input_names.append(in_names)
output_names = info["output_names"]
one_output = len(output_names) == 1
param_names = list(info["parameters"].keys())
def run_talib_func(
*args,
timeframe: tp.Optional[tp.FrequencyLike] = None,
resample_map: tp.KwargsLike = None,
resample_kwargs: tp.KwargsLikeSequence = None,
realign_kwargs: tp.KwargsLikeSequence = None,
wrapper: tp.Optional[ArrayWrapper] = None,
skipna: bool = False,
silence_warnings: bool = False,
broadcast_kwargs: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
wrap: tp.Optional[bool] = None,
unpack_to: tp.Optional[str] = None,
**kwargs,
) -> tp.Union[tp.MaybeTuple[tp.AnyArray], tp.Dict[str, tp.AnyArray]]:
if broadcast_kwargs is None:
broadcast_kwargs = {}
if wrap_kwargs is None:
wrap_kwargs = {}
inputs = []
other_args = []
for k in range(len(args)):
if k < len(input_names) and len(inputs) < len(input_names):
inputs.append(args[k])
else:
other_args.append(args[k])
if len(inputs) < len(input_names):
for k in input_names:
if k in kwargs:
inputs.append(kwargs.pop(k))
is_pandas = False
common_type = None
common_shape = None
broadcasting_needed = False
new_inputs = []
for input in inputs:
if isinstance(input, (pd.Series, pd.DataFrame)):
is_pandas = True
elif not isinstance(input, np.ndarray):
input = np.asarray(input)
if common_type is None:
common_type = type(input)
elif type(input) != common_type:
broadcasting_needed = True
if common_shape is None:
common_shape = input.shape
elif input.shape != common_shape:
broadcasting_needed = True
new_inputs.append(input)
inputs = new_inputs
if broadcasting_needed:
if is_pandas:
if wrapper is None:
inputs, wrapper = broadcast(
dict(zip(input_names, inputs)),
return_wrapper=True,
**broadcast_kwargs,
)
else:
inputs = broadcast(dict(zip(input_names, inputs)), **broadcast_kwargs)
inputs = [inputs[k].values for k in input_names]
else:
inputs = broadcast_arrays(*inputs)
else:
if is_pandas:
if wrapper is None:
wrapper = ArrayWrapper.from_obj(inputs[0])
inputs = [input.values for input in inputs]
input_shape = inputs[0].shape
def _run_talib_func(inputs, *_args, **_kwargs):
target_index = None
if timeframe is not None:
if wrapper is None:
raise ValueError("Resampling requires a wrapper")
if wrapper.freq is None:
if not silence_warnings:
warn(
"Couldn't parse the frequency of index. "
"Set freq in wrapper_kwargs via broadcast_kwargs, or globally."
)
new_inputs = ()
_resample_map = merge_dicts(
resample_map,
{
"open": "first",
"high": "max",
"low": "min",
"close": "last",
"volume": "sum",
},
)
source_wrapper = ArrayWrapper(index=wrapper.index, freq=wrapper.freq)
for i, input in enumerate(inputs):
_resample_kwargs = resolve_dict(resample_kwargs, i=i)
new_input = GenericAccessor(source_wrapper, input).resample_apply(
timeframe,
_resample_map[input_names[i]],
**_resample_kwargs,
)
target_index = new_input.index
new_inputs += (new_input.values,)
inputs = new_inputs
def _build_nan_outputs():
nan_outputs = []
for i in range(len(output_names)):
nan_outputs.append(np.full(input_shape, np.nan, dtype=np.double))
if len(nan_outputs) == 1:
return nan_outputs[0]
return nan_outputs
all_nan = False
if skipna:
nan_mask = build_nan_mask(*inputs)
if nan_mask.all():
all_nan = True
else:
inputs = squeeze_nan(*inputs, nan_mask=nan_mask)
else:
nan_mask = None
if all_nan:
outputs = _build_nan_outputs()
else:
inputs = tuple([arr.astype(np.double) for arr in inputs])
try:
outputs = talib_func(*inputs, *_args, **_kwargs)
except Exception as e:
if "inputs are all NaN" in str(e):
outputs = _build_nan_outputs()
all_nan = True
else:
raise e
if not all_nan:
if one_output:
outputs = unsqueeze_nan(outputs, nan_mask=nan_mask)
else:
outputs = unsqueeze_nan(*outputs, nan_mask=nan_mask)
if timeframe is not None:
new_outputs = ()
target_wrapper = ArrayWrapper(index=target_index)
for i, output in enumerate(outputs):
_realign_kwargs = merge_dicts(
dict(
source_rbound=True,
target_rbound=True,
nan_value=np.nan,
ffill=True,
silence_warnings=True,
),
resolve_dict(realign_kwargs, i=i),
)
new_output = GenericAccessor(target_wrapper, output).realign(
wrapper.index,
freq=wrapper.freq,
**_realign_kwargs,
)
new_outputs += (new_output.values,)
outputs = new_outputs
return outputs
if inputs[0].ndim == 1:
outputs = _run_talib_func(inputs, *other_args, **kwargs)
else:
outputs = []
for col in range(inputs[0].shape[1]):
col_inputs = [input[:, col] for input in inputs]
col_outputs = _run_talib_func(col_inputs, *other_args, **kwargs)
outputs.append(col_outputs)
outputs = list(zip(*outputs))
outputs = tuple(map(column_stack_arrays, outputs))
if wrap is None:
wrap = is_pandas
if wrap:
outputs = [wrapper.wrap(output, **wrap_kwargs) for output in outputs]
if unpack_to is not None:
if unpack_to.lower() in ("dict", "frame"):
dct = {name: outputs[i] for i, name in enumerate(output_names)}
if unpack_to.lower() == "dict":
return dct
return pd.concat(list(dct.values()), axis=1, keys=pd.Index(list(dct.keys()), name="output"))
raise ValueError(f"Invalid unpack_to: '{unpack_to}'")
if one_output:
return outputs[0]
return outputs
signature = inspect.signature(run_talib_func)
new_parameters = list(signature.parameters.values())[1:]
k = 0
for input_name in input_names:
new_parameters.insert(
k,
inspect.Parameter(
input_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=tp.ArrayLike,
),
)
k += 1
for param_name in param_names:
new_parameters.insert(
k,
inspect.Parameter(
param_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=info["parameters"][param_name],
annotation=tp.Scalar,
),
)
k += 1
run_talib_func.__signature__ = signature.replace(parameters=new_parameters)
run_talib_func.__name__ = "run_" + func_name.lower()
run_talib_func.__qualname__ = run_talib_func.__name__
run_talib_func.__doc__ = f"""Run `talib.{func_name}` on NumPy arrays, Series, and DataFrames.
Requires [TA-Lib](https://github.com/mrjbq7/ta-lib) installed.
Set `timeframe` to a frequency to resample the input arrays to this frequency, run the function,
and then resample the output arrays back to the original frequency. Optionally, provide `resample_map`
as a dictionary that maps input names to resample-apply function names. Keyword arguments
`resample_kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.resample_apply`
while `realign_kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.realign`.
Both can be also provided as sequences of dictionaries - one dictionary per input and output respectively.
Set `skipna` to True to run the TA-Lib function on non-NA values only.
Broadcasts the input arrays if they have different types or shapes.
If one of the input arrays is a Series/DataFrame, wraps the output arrays into a Pandas format.
To enable or disable wrapping, set `wrap` to True and False respectively."""
return run_talib_func
def talib_plot_func(func_name: str) -> tp.Callable:
"""Get the TA-Lib indicator plotting function."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("talib")
from talib import abstract
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
func_name = func_name.upper()
info = abstract.Function(func_name).info
output_names = info["output_names"]
output_flags = info["output_flags"]
def run_talib_plot_func(
*outputs,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
column: tp.Optional[tp.Label] = None,
limits: tp.Optional[tp.Tuple[float, float]] = None,
add_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**kwargs,
) -> tp.BaseFigure:
if wrap_kwargs is None:
wrap_kwargs = {}
new_outputs = []
for output in outputs:
if not isinstance(output, (pd.Series, pd.DataFrame)):
if wrapper is not None:
output = wrapper.wrap(output, **wrap_kwargs)
else:
output = to_pd_array(output)
if wrapper is not None:
output = Wrapping.select_col_from_obj(
output,
column=column,
wrapper=wrapper,
)
else:
output = Wrapping.select_col_from_obj(
output,
column=column,
wrapper=ArrayWrapper.from_obj(output),
)
new_outputs.append(output)
outputs = dict(zip(output_names, new_outputs))
output_trace_kwargs = {}
for output_name in output_names:
output_trace_kwargs[output_name] = kwargs.pop(output_name + "_trace_kwargs", {})
priority_outputs = []
other_outputs = []
for output_name in output_names:
flags = set(output_flags.get(output_name))
found_priority = False
if abstract.TA_OUTPUT_FLAGS[2048] in flags:
priority_outputs = priority_outputs + [output_name]
found_priority = True
if abstract.TA_OUTPUT_FLAGS[4096] in flags:
priority_outputs = [output_name] + priority_outputs
found_priority = True
if not found_priority:
other_outputs.append(output_name)
for output_name in priority_outputs + other_outputs:
output = outputs[output_name].rename(output_name)
flags = set(output_flags.get(output_name))
trace_kwargs = {}
plot_func_name = "lineplot"
if abstract.TA_OUTPUT_FLAGS[2] in flags:
# Dotted Line
if "line" not in trace_kwargs:
trace_kwargs["line"] = dict()
trace_kwargs["line"]["dash"] = "dashdot"
if abstract.TA_OUTPUT_FLAGS[4] in flags:
# Dashed Line
if "line" not in trace_kwargs:
trace_kwargs["line"] = dict()
trace_kwargs["line"]["dash"] = "dash"
if abstract.TA_OUTPUT_FLAGS[8] in flags:
# Dot
if "line" not in trace_kwargs:
trace_kwargs["line"] = dict()
trace_kwargs["line"]["dash"] = "dot"
if abstract.TA_OUTPUT_FLAGS[16] in flags:
# Histogram
hist = np.asarray(output)
hist_diff = generic_nb.diff_1d_nb(hist)
marker_colors = np.full(hist.shape, adjust_opacity("silver", 0.75), dtype=object)
marker_colors[(hist > 0) & (hist_diff > 0)] = adjust_opacity("green", 0.75)
marker_colors[(hist > 0) & (hist_diff <= 0)] = adjust_opacity("lightgreen", 0.75)
marker_colors[(hist < 0) & (hist_diff < 0)] = adjust_opacity("red", 0.75)
marker_colors[(hist < 0) & (hist_diff >= 0)] = adjust_opacity("lightcoral", 0.75)
if "marker" not in trace_kwargs:
trace_kwargs["marker"] = {}
trace_kwargs["marker"]["color"] = marker_colors
if "line" not in trace_kwargs["marker"]:
trace_kwargs["marker"]["line"] = {}
trace_kwargs["marker"]["line"]["width"] = 0
kwargs["bargap"] = 0
plot_func_name = "barplot"
if abstract.TA_OUTPUT_FLAGS[2048] in flags:
# Values represent an upper limit
if "line" not in trace_kwargs:
trace_kwargs["line"] = {}
trace_kwargs["line"]["color"] = adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.75)
trace_kwargs["fill"] = "tonexty"
trace_kwargs["fillcolor"] = "rgba(128, 128, 128, 0.2)"
if abstract.TA_OUTPUT_FLAGS[4096] in flags:
# Values represent a lower limit
if "line" not in trace_kwargs:
trace_kwargs["line"] = {}
trace_kwargs["line"]["color"] = adjust_opacity(plotting_cfg["color_schema"]["gray"], 0.75)
trace_kwargs = merge_dicts(trace_kwargs, output_trace_kwargs[output_name])
plot_func = getattr(output.vbt, plot_func_name)
fig = plot_func(trace_kwargs=trace_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig, **kwargs)
if limits is not None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
add_shape_kwargs = merge_dicts(
dict(
type="rect",
xref=xref,
yref=yref,
x0=outputs[output_names[0]].index[0],
y0=limits[0],
x1=outputs[output_names[0]].index[-1],
y1=limits[1],
fillcolor="mediumslateblue",
opacity=0.2,
layer="below",
line_width=0,
),
add_shape_kwargs,
)
fig.add_shape(**add_shape_kwargs)
return fig
signature = inspect.signature(run_talib_plot_func)
new_parameters = list(signature.parameters.values())[1:-1]
k = 0
for output_name in output_names:
new_parameters.insert(
k,
inspect.Parameter(
output_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=tp.ArrayLike,
),
)
k += 1
for output_name in output_names:
new_parameters.insert(
-3,
inspect.Parameter(
output_name + "_trace_kwargs",
inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=tp.KwargsLike,
),
)
new_parameters.append(inspect.Parameter("layout_kwargs", inspect.Parameter.VAR_KEYWORD))
run_talib_plot_func.__signature__ = signature.replace(parameters=new_parameters)
output_trace_kwargs_docstring = "\n ".join(
[
f"{output_name}_trace_kwargs (dict): Keyword arguments passed to the trace of `{output_name}`."
for output_name in output_names
]
)
run_talib_plot_func.__name__ = "plot_" + func_name.lower()
run_talib_plot_func.__qualname__ = run_talib_plot_func.__name__
run_talib_plot_func.__doc__ = f"""Plot output arrays of `talib.{func_name}`.
Args:
column (str): Name of the column to plot.
limits (tuple of float): Tuple of the lower and upper limit.
{output_trace_kwargs_docstring}
add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` when adding the range between both limits.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add the traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`."""
return run_talib_plot_func
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Basic look-ahead indicators and label generators.
You can access all the indicators by `vbt.*`.
Run for the examples:
```pycon
>>> ohlcv = vbt.YFData.pull(
... "BTC-USD",
... start="2019-03-01",
... end="2019-09-01"
... ).get()
```
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.labels.generators.bolb import *
from vectorbtpro.labels.generators.fixlb import *
from vectorbtpro.labels.generators.fmax import *
from vectorbtpro.labels.generators.fmean import *
from vectorbtpro.labels.generators.fmin import *
from vectorbtpro.labels.generators.fstd import *
from vectorbtpro.labels.generators.pivotlb import *
from vectorbtpro.labels.generators.meanlb import *
from vectorbtpro.labels.generators.trendlb import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `BOLB`."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
__all__ = [
"BOLB",
]
__pdoc__ = {}
BOLB = IndicatorFactory(
class_name="BOLB",
module_name=__name__,
input_names=["high", "low"],
param_names=["window", "up_th", "down_th", "wait"],
output_names=["labels"],
).with_apply_func(
nb.breakout_labels_nb,
param_settings=dict(
up_th=flex_elem_param_config,
down_th=flex_elem_param_config,
),
window=14,
up_th=np.inf,
down_th=np.inf,
wait=1,
)
class _BOLB(BOLB):
"""Label generator based on `vectorbtpro.labels.nb.breakout_labels_nb`."""
def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure:
"""Plot the median of `BOLB.high` and `BOLB.low`, and overlay it with the heatmap of `BOLB.labels`.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`.
Usage:
```pycon
>>> vbt.BOLB.run(ohlcv['High'], ohlcv['Low'], up_th=0.2, down_th=0.2).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
self_col = self.select_col(column=column, group_by=False)
median = (self_col.high + self_col.low) / 2
return median.rename("Median").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs)
setattr(BOLB, "__doc__", _BOLB.__doc__)
setattr(BOLB, "plot", _BOLB.plot)
BOLB.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FIXLB`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
__all__ = [
"FIXLB",
]
__pdoc__ = {}
FIXLB = IndicatorFactory(
class_name="FIXLB",
module_name=__name__,
input_names=["close"],
param_names=["n"],
output_names=["labels"],
).with_apply_func(
nb.fixed_labels_nb,
n=1,
)
class _FIXLB(FIXLB):
"""Label generator based on `vectorbtpro.labels.nb.fixed_labels_nb`."""
def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure:
"""Plot `FIXLB.close` and overlay it with the heatmap of `FIXLB.labels`.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`.
Usage:
```pycon
>>> vbt.FIXLB.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
self_col = self.select_col(column=column, group_by=False)
return self_col.close.rename("Close").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs)
setattr(FIXLB, "__doc__", _FIXLB.__doc__)
setattr(FIXLB, "plot", _FIXLB.plot)
FIXLB.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FMAX`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"FMAX",
]
__pdoc__ = {}
FMAX = IndicatorFactory(
class_name="FMAX",
module_name=__name__,
input_names=["close"],
param_names=["window", "wait"],
output_names=["fmax"],
).with_apply_func(
nb.future_max_nb,
window=14,
wait=1,
)
class _FMAX(FMAX):
"""Look-ahead indicator based on `vectorbtpro.labels.nb.future_max_nb`."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_close: bool = True,
close_trace_kwargs: tp.KwargsLike = None,
fmax_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `FMAX.fmax` against `FMAX.close`.
Args:
column (str): Name of the column to plot.
plot_close (bool): Whether to plot `FMAX.close`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMAX.close`.
fmax_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMAX.fmax`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.FMAX.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if fmax_trace_kwargs is None:
fmax_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
fmax_trace_kwargs = merge_dicts(
dict(name="Future max", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
fmax_trace_kwargs,
)
if plot_close:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.fmax.vbt.lineplot(
trace_kwargs=fmax_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(FMAX, "__doc__", _FMAX.__doc__)
setattr(FMAX, "plot", _FMAX.plot)
FMAX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FMEAN`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"FMEAN",
]
__pdoc__ = {}
FMEAN = IndicatorFactory(
class_name="FMEAN",
module_name=__name__,
input_names=["close"],
param_names=["window", "wtype", "wait"],
output_names=["fmean"],
).with_apply_func(
nb.future_mean_nb,
kwargs_as_args=["minp", "adjust"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="simple",
wait=1,
minp=None,
adjust=False,
)
class _FMEAN(FMEAN):
"""Look-ahead indicator based on `vectorbtpro.labels.nb.future_mean_nb`."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_close: bool = True,
close_trace_kwargs: tp.KwargsLike = None,
fmean_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `FMEAN.fmean` against `FMEAN.close`.
Args:
column (str): Name of the column to plot.
plot_close (bool): Whether to plot `FMEAN.close`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMEAN.close`.
fmean_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMEAN.fmean`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.FMEAN.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if fmean_trace_kwargs is None:
fmean_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
fmean_trace_kwargs = merge_dicts(
dict(name="Future mean", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
fmean_trace_kwargs,
)
if plot_close:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.fmean.vbt.lineplot(
trace_kwargs=fmean_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(FMEAN, "__doc__", _FMEAN.__doc__)
setattr(FMEAN, "plot", _FMEAN.plot)
FMEAN.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FMIN`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"FMIN",
]
__pdoc__ = {}
FMIN = IndicatorFactory(
class_name="FMIN",
module_name=__name__,
input_names=["close"],
param_names=["window", "wait"],
output_names=["fmin"],
).with_apply_func(
nb.future_min_nb,
window=14,
wait=1,
)
class _FMIN(FMIN):
"""Look-ahead indicator based on `vectorbtpro.labels.nb.future_min_nb`."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_close: bool = True,
close_trace_kwargs: tp.KwargsLike = None,
fmin_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `FMIN.fmin` against `FMIN.close`.
Args:
column (str): Name of the column to plot.
plot_close (bool): Whether to plot `FMIN.close`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMIN.close`.
fmin_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FMIN.fmin`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.FMIN.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if close_trace_kwargs is None:
close_trace_kwargs = {}
if fmin_trace_kwargs is None:
fmin_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(name="Close", line=dict(color=plotting_cfg["color_schema"]["blue"])),
close_trace_kwargs,
)
fmin_trace_kwargs = merge_dicts(
dict(name="Future min", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
fmin_trace_kwargs,
)
if plot_close:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
fig = self_col.fmin.vbt.lineplot(
trace_kwargs=fmin_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(FMIN, "__doc__", _FMIN.__doc__)
setattr(FMIN, "plot", _FMIN.plot)
FMIN.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `FSTD`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
from vectorbtpro.utils.config import merge_dicts
__all__ = [
"FSTD",
]
__pdoc__ = {}
FSTD = IndicatorFactory(
class_name="FSTD",
module_name=__name__,
input_names=["close"],
param_names=["window", "wtype", "wait"],
output_names=["fstd"],
).with_apply_func(
nb.future_std_nb,
kwargs_as_args=["minp", "adjust", "ddof"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="simple",
wait=1,
minp=None,
adjust=False,
ddof=0,
)
class _FSTD(FSTD):
"""Look-ahead indicator based on `vectorbtpro.labels.nb.future_std_nb`."""
def plot(
self,
column: tp.Optional[tp.Label] = None,
fstd_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot `FSTD.fstd`.
Args:
column (str): Name of the column to plot.
fstd_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `FSTD.fstd`.
add_trace_kwargs (dict): Keyword arguments passed to `fig.add_trace` when adding each trace.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments passed to `fig.update_layout`.
Usage:
```pycon
>>> vbt.FSTD.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if fstd_trace_kwargs is None:
fstd_trace_kwargs = {}
fstd_trace_kwargs = merge_dicts(
dict(name="Future STD", line=dict(color=plotting_cfg["color_schema"]["lightblue"])),
fstd_trace_kwargs,
)
fig = self_col.fstd.vbt.lineplot(
trace_kwargs=fstd_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
return fig
setattr(FSTD, "__doc__", _FSTD.__doc__)
setattr(FSTD, "plot", _FSTD.plot)
FSTD.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `MEANLB`."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic import enums as generic_enums
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
__all__ = [
"MEANLB",
]
__pdoc__ = {}
MEANLB = IndicatorFactory(
class_name="MEANLB",
module_name=__name__,
input_names=["close"],
param_names=["window", "wtype", "wait"],
output_names=["labels"],
).with_apply_func(
nb.mean_labels_nb,
kwargs_as_args=["minp", "adjust"],
param_settings=dict(
wtype=dict(
dtype=generic_enums.WType,
post_index_func=lambda index: index.str.lower(),
)
),
window=14,
wtype="simple",
wait=1,
minp=None,
adjust=False,
)
class _MEANLB(MEANLB):
"""Label generator based on `vectorbtpro.labels.nb.mean_labels_nb`."""
def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure:
"""Plot `close` and overlay it with the heatmap of `MEANLB.labels`.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`.
Usage:
```pycon
>>> vbt.MEANLB.run(ohlcv['Close']).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
self_col = self.select_col(column=column, group_by=False)
return self_col.close.rename("Close").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs)
setattr(MEANLB, "__doc__", _MEANLB.__doc__)
setattr(MEANLB, "plot", _MEANLB.plot)
MEANLB.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `PIVOTLB`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
__all__ = [
"PIVOTLB",
]
__pdoc__ = {}
PIVOTLB = IndicatorFactory(
class_name="PIVOTLB",
module_name=__name__,
input_names=["high", "low"],
param_names=["up_th", "down_th"],
output_names=["labels"],
).with_apply_func(
nb.pivots_nb,
param_settings=dict(
up_th=flex_elem_param_config,
down_th=flex_elem_param_config,
),
)
class _PIVOTLB(PIVOTLB):
"""Label generator based on `vectorbtpro.labels.nb.pivots_nb`."""
def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure:
"""Plot the median of `PIVOTLB.high` and `PIVOTLB.low`, and overlay it with the heatmap of `PIVOTLB.labels`.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`.
Usage:
```pycon
>>> vbt.PIVOTLB.run(ohlcv['High'], ohlcv['Low'], up_th=0.2, down_th=0.2).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
self_col = self.select_col(column=column, group_by=False)
median = (self_col.high + self_col.low) / 2
return median.rename("Median").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs)
setattr(PIVOTLB, "__doc__", _PIVOTLB.__doc__)
setattr(PIVOTLB, "plot", _PIVOTLB.plot)
PIVOTLB.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `TRENDLB`."""
from vectorbtpro import _typing as tp
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.indicators.factory import IndicatorFactory
from vectorbtpro.labels import nb
from vectorbtpro.labels.enums import TrendLabelMode
__all__ = [
"TRENDLB",
]
__pdoc__ = {}
TRENDLB = IndicatorFactory(
class_name="TRENDLB",
module_name=__name__,
input_names=["high", "low"],
param_names=["up_th", "down_th", "mode"],
output_names=["labels"],
).with_apply_func(
nb.trend_labels_nb,
param_settings=dict(
up_th=flex_elem_param_config,
down_th=flex_elem_param_config,
mode=dict(
dtype=TrendLabelMode,
post_index_func=lambda index: index.str.lower(),
),
),
mode=TrendLabelMode.Binary,
)
class _TRENDLB(TRENDLB):
"""Label generator based on `vectorbtpro.labels.nb.trend_labels_nb`."""
def plot(self, column: tp.Optional[tp.Label] = None, **kwargs) -> tp.BaseFigure:
"""Plot the median of `TRENDLB.high` and `TRENDLB.low`, and overlay it with the heatmap of `TRENDLB.labels`.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.overlay_with_heatmap`.
Usage:
```pycon
>>> vbt.TRENDLB.run(ohlcv['High'], ohlcv['Low'], up_th=0.2, down_th=0.2).plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
self_col = self.select_col(column=column, group_by=False)
median = (self_col.high + self_col.low) / 2
return median.rename("Median").vbt.overlay_with_heatmap(self_col.labels.rename("Labels"), **kwargs)
setattr(TRENDLB, "__doc__", _TRENDLB.__doc__)
setattr(TRENDLB, "plot", _TRENDLB.plot)
TRENDLB.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for building and running look-ahead indicators and label generators."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.labels.generators import *
from vectorbtpro.labels.nb import *
__exclude_from__all__ = [
"enums",
]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Named tuples and enumerated types for label generation.
Defines enums and other schemas for `vectorbtpro.labels`."""
from vectorbtpro import _typing as tp
from vectorbtpro.utils.formatting import prettify
__pdoc__all__ = __all__ = ["TrendLabelMode"]
__pdoc__ = {}
class TrendLabelModeT(tp.NamedTuple):
Binary: int = 0
BinaryCont: int = 1
BinaryContSat: int = 2
PctChange: int = 3
PctChangeNorm: int = 4
TrendLabelMode = TrendLabelModeT()
"""_"""
__pdoc__[
"TrendLabelMode"
] = f"""Trend label mode.
```python
{prettify(TrendLabelMode)}
```
Attributes:
Binary: See `vectorbtpro.labels.nb.bin_trend_labels_nb`.
BinaryCont: See `vectorbtpro.labels.nb.binc_trend_labels_nb`.
BinaryContSat: See `vectorbtpro.labels.nb.bincs_trend_labels_nb`.
PctChange: See `vectorbtpro.labels.nb.pct_trend_labels_nb` with `normalize` set to False.
PctChangeNorm: See `vectorbtpro.labels.nb.pct_trend_labels_nb` with `normalize` set to True.
"""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for label generation.
!!! note
Set `wait` to 1 to exclude the current value from calculation of future values.
!!! warning
Do not attempt to use these functions for building predictor variables as they may introduce
the look-ahead bias to your model - only use for building target variables."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.flex_indexing import flex_select_1d_nb, flex_select_col_nb
from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb
from vectorbtpro.generic import nb as generic_nb, enums as generic_enums
from vectorbtpro.indicators.enums import Pivot
from vectorbtpro.labels.enums import TrendLabelMode
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
__all__ = []
# ############# FMEAN ############# #
@register_jitted(cache=True)
def future_mean_1d_nb(
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Simple,
wait: int = 1,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array1d:
"""Rolling average over future values.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
future_mean = generic_nb.ma_1d_nb(close[::-1], window, wtype=wtype, minp=minp, adjust=adjust)[::-1]
if wait > 0:
return generic_nb.bshift_1d_nb(future_mean, wait)
return future_mean
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
wait=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def future_mean_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
wait: tp.FlexArray1dLike = 1,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array2d:
"""2-dim version of `future_mean_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
wait_ = to_1d_array_nb(np.asarray(wait))
future_mean = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
future_mean[:, col] = future_mean_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
wait=flex_select_1d_nb(wait_, col),
minp=minp,
adjust=adjust,
)
return future_mean
# ############# FSTD ############# #
@register_jitted(cache=True)
def future_std_1d_nb(
close: tp.Array1d,
window: int = 14,
wtype: int = generic_enums.WType.Simple,
wait: int = 1,
minp: tp.Optional[int] = None,
adjust: bool = False,
ddof: int = 0,
) -> tp.Array1d:
"""Rolling standard deviation over future values.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
future_std = generic_nb.msd_1d_nb(close[::-1], window, wtype=wtype, minp=minp, adjust=adjust, ddof=ddof)[::-1]
if wait > 0:
return generic_nb.bshift_1d_nb(future_std, wait)
return future_std
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
wait=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
ddof=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def future_std_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
wait: tp.FlexArray1dLike = 1,
minp: tp.Optional[int] = None,
adjust: bool = False,
ddof: int = 0,
) -> tp.Array2d:
"""2-dim version of `future_std_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
wait_ = to_1d_array_nb(np.asarray(wait))
future_std = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
future_std[:, col] = future_std_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
wait=flex_select_1d_nb(wait_, col),
minp=minp,
adjust=adjust,
ddof=ddof,
)
return future_std
# ############# FMIN ############# #
@register_jitted(cache=True)
def future_min_1d_nb(
close: tp.Array1d,
window: int = 14,
wait: int = 1,
minp: tp.Optional[int] = None,
) -> tp.Array1d:
"""Rolling minimum over future values."""
future_min = generic_nb.rolling_min_1d_nb(close[::-1], window, minp=minp)[::-1]
if wait > 0:
return generic_nb.bshift_1d_nb(future_min, wait)
return future_min
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wait=base_ch.FlexArraySlicer(),
minp=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def future_min_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wait: tp.FlexArray1dLike = 1,
minp: tp.Optional[int] = None,
) -> tp.Array2d:
"""2-dim version of `future_min_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wait_ = to_1d_array_nb(np.asarray(wait))
future_min = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
future_min[:, col] = future_min_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wait=flex_select_1d_nb(wait_, col),
minp=minp,
)
return future_min
# ############# FMAX ############# #
@register_jitted(cache=True)
def future_max_1d_nb(
close: tp.Array1d,
window: int = 14,
wait: int = 1,
minp: tp.Optional[int] = None,
) -> tp.Array1d:
"""Rolling maximum over future values."""
future_max = generic_nb.rolling_max_1d_nb(close[::-1], window, minp=minp)[::-1]
if wait > 0:
return generic_nb.bshift_1d_nb(future_max, wait)
return future_max
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wait=base_ch.FlexArraySlicer(),
minp=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def future_max_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wait: tp.FlexArray1dLike = 1,
minp: tp.Optional[int] = None,
) -> tp.Array2d:
"""2-dim version of `future_max_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wait_ = to_1d_array_nb(np.asarray(wait))
future_max = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
future_max[:, col] = future_max_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wait=flex_select_1d_nb(wait_, col),
minp=minp,
)
return future_max
# ############# FIXLB ############# #
@register_jitted(cache=True)
def fixed_labels_1d_nb(
close: tp.Array1d,
n: int = 1,
) -> tp.Array1d:
"""Percentage change of the current value relative to a future value."""
return (generic_nb.bshift_1d_nb(close, n) - close) / close
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
n=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def fixed_labels_nb(
close: tp.Array2d,
n: tp.FlexArray1dLike = 1,
) -> tp.Array2d:
"""2-dim version of `fixed_labels_1d_nb`."""
n_ = to_1d_array_nb(np.asarray(n))
fixed_labels = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
fixed_labels[:, col] = fixed_labels_1d_nb(
close=close[:, col],
n=flex_select_1d_nb(n_, col),
)
return fixed_labels
# ############# MEANLB ############# #
@register_jitted(cache=True)
def mean_labels_1d_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
wait: tp.FlexArray1dLike = 1,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array1d:
"""Percentage change of the current value relative to the average of a future period.
For `wtype`, see `vectorbtpro.generic.enums.WType`."""
future_mean = future_mean_1d_nb(close, window=window, wtype=wtype, wait=wait, minp=minp, adjust=adjust)
return (future_mean - close) / close
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
wtype=base_ch.FlexArraySlicer(),
wait=base_ch.FlexArraySlicer(),
minp=None,
adjust=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def mean_labels_nb(
close: tp.Array2d,
window: tp.FlexArray1dLike = 14,
wtype: tp.FlexArray1dLike = generic_enums.WType.Simple,
wait: tp.FlexArray1dLike = 1,
minp: tp.Optional[int] = None,
adjust: bool = False,
) -> tp.Array2d:
"""2-dim version of `mean_labels_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
wtype_ = to_1d_array_nb(np.asarray(wtype))
wait_ = to_1d_array_nb(np.asarray(wait))
mean_labels = np.empty(close.shape, dtype=float_)
for col in prange(close.shape[1]):
mean_labels[:, col] = mean_labels_1d_nb(
close=close[:, col],
window=flex_select_1d_nb(window_, col),
wtype=flex_select_1d_nb(wtype_, col),
wait=flex_select_1d_nb(wait_, col),
minp=minp,
adjust=adjust,
)
return mean_labels
# ############# PIVOTLB ############# #
@register_jitted(cache=True)
def iter_symmetric_up_th_nb(down_th: float) -> float:
"""Positive upper threshold that is symmetric to a negative one at one iteration.
For example, 50% down requires 100% to go up to the initial level."""
return down_th / (1 - down_th)
@register_jitted(cache=True)
def iter_symmetric_down_th_nb(up_th: float) -> float:
"""Negative upper threshold that is symmetric to a positive one at one iteration."""
return up_th / (1 + up_th)
@register_jitted(cache=True)
def pivots_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
up_th: tp.FlexArray1dLike,
down_th: tp.FlexArray1dLike,
) -> tp.Array1d:
"""Pivots denoted by 1 (peak), 0 (no pivot) or -1 (valley).
Two adjacent peak and valley points should exceed the given threshold parameters.
If any threshold is given element-wise, it will be applied per new/updated pivot."""
up_th_ = to_1d_array_nb(np.asarray(up_th))
down_th_ = to_1d_array_nb(np.asarray(down_th))
pivots = np.full(high.shape, 0, dtype=int_)
last_pivot = 0
last_i = -1
last_value = np.nan
first_valid_i = -1
for i in range(high.shape[0]):
if not np.isnan(high[i]) and not np.isnan(low[i]):
if first_valid_i == -1:
first_valid_i = 0
if last_i == -1:
_up_th = 1 + abs(flex_select_1d_nb(up_th_, first_valid_i))
_down_th = 1 - abs(flex_select_1d_nb(down_th_, first_valid_i))
if not np.isnan(_up_th) and high[i] >= low[first_valid_i] * _up_th:
if not np.isnan(_down_th) and low[i] <= high[first_valid_i] * _down_th:
pass # wait
else:
pivots[first_valid_i] = Pivot.Valley
last_i = i
last_value = high[i]
last_pivot = Pivot.Peak
if not np.isnan(_down_th) and low[i] <= high[first_valid_i] * _down_th:
if not np.isnan(_up_th) and high[i] >= low[first_valid_i] * _up_th:
pass # wait
else:
pivots[first_valid_i] = Pivot.Peak
last_i = i
last_value = low[i]
last_pivot = Pivot.Valley
else:
_up_th = 1 + abs(flex_select_1d_nb(up_th_, last_i))
_down_th = 1 - abs(flex_select_1d_nb(down_th_, last_i))
if last_pivot == Pivot.Valley:
if not np.isnan(last_value) and not np.isnan(_up_th) and high[i] >= last_value * _up_th:
pivots[last_i] = last_pivot
last_i = i
last_value = high[i]
last_pivot = Pivot.Peak
elif np.isnan(last_value) or low[i] < last_value:
last_i = i
last_value = low[i]
elif last_pivot == Pivot.Peak:
if not np.isnan(last_value) and not np.isnan(_down_th) and low[i] <= last_value * _down_th:
pivots[last_i] = last_pivot
last_i = i
last_value = low[i]
last_pivot = Pivot.Valley
elif np.isnan(last_value) or high[i] > last_value:
last_i = i
last_value = high[i]
if last_i != -1 and i == high.shape[0] - 1:
pivots[last_i] = last_pivot
return pivots
@register_chunkable(
size=ch.ArraySizer(arg_query="high", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
up_th=base_ch.FlexArraySlicer(axis=1),
down_th=base_ch.FlexArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def pivots_nb(
high: tp.Array2d,
low: tp.Array2d,
up_th: tp.FlexArray2dLike,
down_th: tp.FlexArray2dLike,
) -> tp.Array2d:
"""2-dim version of `pivots_1d_nb`."""
up_th_ = to_2d_array_nb(np.asarray(up_th))
down_th_ = to_2d_array_nb(np.asarray(down_th))
pivots = np.empty(high.shape, dtype=int_)
for col in prange(high.shape[1]):
pivots[:, col] = pivots_1d_nb(
high[:, col],
low[:, col],
flex_select_col_nb(up_th_, col),
flex_select_col_nb(down_th_, col),
)
return pivots
# ############# TRENDLB ############# #
@register_jitted(cache=True)
def bin_trend_labels_1d_nb(pivots: tp.Array1d) -> tp.Array1d:
"""Values classified into 0 (downtrend) and 1 (uptrend)."""
bin_trend_labels = np.full(pivots.shape, np.nan, dtype=float_)
idxs = np.flatnonzero(pivots)
if idxs.shape[0] == 0:
return bin_trend_labels
for k in range(1, idxs.shape[0]):
prev_i = idxs[k - 1]
next_i = idxs[k]
for i in range(prev_i, next_i):
if pivots[next_i] == Pivot.Peak:
bin_trend_labels[i] = 1
else:
bin_trend_labels[i] = 0
return bin_trend_labels
@register_chunkable(
size=ch.ArraySizer(arg_query="pivots", axis=1),
arg_take_spec=dict(
pivots=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def bin_trend_labels_nb(pivots: tp.Array2d) -> tp.Array2d:
"""2-dim version of `bin_trend_labels_1d_nb`."""
bin_trend_labels = np.empty(pivots.shape, dtype=float_)
for col in prange(pivots.shape[1]):
bin_trend_labels[:, col] = bin_trend_labels_1d_nb(pivots[:, col])
return bin_trend_labels
@register_jitted(cache=True)
def binc_trend_labels_1d_nb(high: tp.Array1d, low: tp.Array1d, pivots: tp.Array1d) -> tp.Array1d:
"""Median values normalized between 0 (downtrend) and 1 (uptrend)."""
binc_trend_labels = np.full(pivots.shape, np.nan, dtype=float_)
idxs = np.flatnonzero(pivots[:])
if idxs.shape[0] == 0:
return binc_trend_labels
for k in range(1, idxs.shape[0]):
prev_i = idxs[k - 1]
next_i = idxs[k]
_min = np.nanmin(low[prev_i : next_i + 1])
_max = np.nanmax(high[prev_i : next_i + 1])
for i in range(prev_i, next_i):
_med = (high[i] + low[i]) / 2
binc_trend_labels[i] = 1 - (_med - _min) / (_max - _min)
return binc_trend_labels
@register_chunkable(
size=ch.ArraySizer(arg_query="pivots", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
pivots=ch.ArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def binc_trend_labels_nb(high: tp.Array2d, low: tp.Array2d, pivots: tp.Array2d) -> tp.Array2d:
"""2-dim version of `binc_trend_labels_1d_nb`."""
binc_trend_labels = np.empty(pivots.shape, dtype=float_)
for col in prange(pivots.shape[1]):
binc_trend_labels[:, col] = binc_trend_labels_1d_nb(high[:, col], low[:, col], pivots[:, col])
return binc_trend_labels
@register_jitted(cache=True)
def bincs_trend_labels_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
pivots: tp.Array1d,
up_th: tp.FlexArray1dLike,
down_th: tp.FlexArray1dLike,
) -> tp.Array1d:
"""Median values normalized between 0 (downtrend) and 1 (uptrend) but capped once
the threshold defined at the beginning of the trend is exceeded."""
up_th_ = to_1d_array_nb(np.asarray(up_th))
down_th_ = to_1d_array_nb(np.asarray(down_th))
bincs_trend_labels = np.full(pivots.shape, np.nan, dtype=float_)
idxs = np.flatnonzero(pivots)
if idxs.shape[0] == 0:
return bincs_trend_labels
for k in range(1, idxs.shape[0]):
prev_i = idxs[k - 1]
next_i = idxs[k]
_up_th = 1 + abs(flex_select_1d_nb(up_th_, prev_i))
_down_th = 1 - abs(flex_select_1d_nb(down_th_, prev_i))
_min = np.min(low[prev_i : next_i + 1])
_max = np.max(high[prev_i : next_i + 1])
for i in range(prev_i, next_i):
if not np.isnan(high[i]) and not np.isnan(low[i]):
_med = (high[i] + low[i]) / 2
if pivots[next_i] == Pivot.Peak:
if not np.isnan(_up_th):
_start = _max / _up_th
_end = _min * _up_th
if _max >= _end and _med <= _start:
bincs_trend_labels[i] = 1
else:
bincs_trend_labels[i] = 1 - (_med - _start) / (_max - _start)
else:
if not np.isnan(_down_th):
_start = _min / _down_th
_end = _max * _down_th
if _min <= _end and _med >= _start:
bincs_trend_labels[i] = 0
else:
bincs_trend_labels[i] = 1 - (_med - _min) / (_start - _min)
return bincs_trend_labels
@register_chunkable(
size=ch.ArraySizer(arg_query="pivots", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
pivots=ch.ArraySlicer(axis=1),
up_th=base_ch.FlexArraySlicer(axis=1),
down_th=base_ch.FlexArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def bincs_trend_labels_nb(
high: tp.Array2d,
low: tp.Array2d,
pivots: tp.Array2d,
up_th: tp.FlexArray2dLike,
down_th: tp.FlexArray2dLike,
) -> tp.Array2d:
"""2-dim version of `bincs_trend_labels_1d_nb`."""
up_th_ = to_2d_array_nb(np.asarray(up_th))
down_th_ = to_2d_array_nb(np.asarray(down_th))
bincs_trend_labels = np.empty(pivots.shape, dtype=float_)
for col in prange(pivots.shape[1]):
bincs_trend_labels[:, col] = bincs_trend_labels_1d_nb(
high[:, col],
low[:, col],
pivots[:, col],
flex_select_col_nb(up_th_, col),
flex_select_col_nb(down_th_, col),
)
return bincs_trend_labels
@register_jitted(cache=True)
def pct_trend_labels_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
pivots: tp.Array1d,
normalize: bool = False,
) -> tp.Array1d:
"""Percentage change of median values relative to the next pivot."""
pct_trend_labels = np.full(pivots.shape, np.nan, dtype=float_)
idxs = np.flatnonzero(pivots)
if idxs.shape[0] == 0:
return pct_trend_labels
for k in range(1, idxs.shape[0]):
prev_i = idxs[k - 1]
next_i = idxs[k]
for i in range(prev_i, next_i):
_med = (high[i] + low[i]) / 2
if pivots[next_i] == Pivot.Peak:
if normalize:
pct_trend_labels[i] = (high[next_i] - _med) / high[next_i]
else:
pct_trend_labels[i] = (high[next_i] - _med) / _med
else:
if normalize:
pct_trend_labels[i] = (low[next_i] - _med) / _med
else:
pct_trend_labels[i] = (low[next_i] - _med) / low[next_i]
return pct_trend_labels
@register_chunkable(
size=ch.ArraySizer(arg_query="pivots", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
pivots=ch.ArraySlicer(axis=1),
normalize=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def pct_trend_labels_nb(
high: tp.Array2d,
low: tp.Array2d,
pivots: tp.Array2d,
normalize: bool = False,
) -> tp.Array2d:
"""2-dim version of `pct_trend_labels_1d_nb`."""
pct_trend_labels = np.empty(pivots.shape, dtype=float_)
for col in prange(pivots.shape[1]):
pct_trend_labels[:, col] = pct_trend_labels_1d_nb(
high[:, col],
low[:, col],
pivots[:, col],
normalize=normalize,
)
return pct_trend_labels
@register_jitted(cache=True)
def trend_labels_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
up_th: tp.FlexArray1dLike,
down_th: tp.FlexArray1dLike,
mode: int = TrendLabelMode.Binary,
) -> tp.Array2d:
"""Trend labels based on `vectorbtpro.labels.enums.TrendLabelMode`."""
pivots = pivots_1d_nb(high, low, up_th, down_th)
if mode == TrendLabelMode.Binary:
return bin_trend_labels_1d_nb(pivots)
if mode == TrendLabelMode.BinaryCont:
return binc_trend_labels_1d_nb(high, low, pivots)
if mode == TrendLabelMode.BinaryContSat:
return bincs_trend_labels_1d_nb(high, low, pivots, up_th, down_th)
if mode == TrendLabelMode.PctChange:
return pct_trend_labels_1d_nb(high, low, pivots, normalize=False)
if mode == TrendLabelMode.PctChangeNorm:
return pct_trend_labels_1d_nb(high, low, pivots, normalize=True)
raise ValueError("Invalid trend mode")
@register_chunkable(
size=ch.ArraySizer(arg_query="high", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
up_th=base_ch.FlexArraySlicer(axis=1),
down_th=base_ch.FlexArraySlicer(axis=1),
mode=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def trend_labels_nb(
high: tp.Array2d,
low: tp.Array2d,
up_th: tp.FlexArray2dLike,
down_th: tp.FlexArray2dLike,
mode: tp.FlexArray1dLike = TrendLabelMode.Binary,
) -> tp.Array2d:
"""2-dim version of `trend_labels_1d_nb`."""
up_th_ = to_2d_array_nb(np.asarray(up_th))
down_th_ = to_2d_array_nb(np.asarray(down_th))
mode_ = to_1d_array_nb(np.asarray(mode))
trend_labels = np.empty(high.shape, dtype=float_)
for col in prange(high.shape[1]):
trend_labels[:, col] = trend_labels_1d_nb(
high[:, col],
low[:, col],
flex_select_col_nb(up_th_, col),
flex_select_col_nb(down_th_, col),
mode=flex_select_1d_nb(mode_, col),
)
return trend_labels
# ############# BOLB ############# #
@register_jitted(cache=True)
def breakout_labels_1d_nb(
high: tp.Array1d,
low: tp.Array1d,
window: int = 14,
up_th: tp.FlexArray1dLike = np.inf,
down_th: tp.FlexArray1dLike = np.inf,
wait: int = 1,
) -> tp.Array1d:
"""For each value, return 1 if any value in the next period is greater than the
positive threshold (in %), -1 if less than the negative threshold, and 0 otherwise.
First hit wins. Continue search if both thresholds were hit at the same time."""
up_th_ = to_1d_array_nb(np.asarray(up_th))
down_th_ = to_1d_array_nb(np.asarray(down_th))
breakout_labels = np.full(high.shape, 0, dtype=float_)
for i in range(high.shape[0]):
if not np.isnan(high[i]) and not np.isnan(low[i]):
_up_th = 1 + abs(flex_select_1d_nb(up_th_, i))
_down_th = 1 - abs(flex_select_1d_nb(down_th_, i))
for j in range(i + wait, min(i + window + wait, high.shape[0])):
if not np.isnan(high[j]) and not np.isnan(low[j]):
if not np.isnan(_up_th) and high[j] >= low[i] * _up_th:
breakout_labels[i] = 1
break
if not np.isnan(_down_th) and low[j] <= high[i] * _down_th:
if breakout_labels[i] == 1:
breakout_labels[i] = 0
continue
breakout_labels[i] = -1
break
return breakout_labels
@register_chunkable(
size=ch.ArraySizer(arg_query="high", axis=1),
arg_take_spec=dict(
high=ch.ArraySlicer(axis=1),
low=ch.ArraySlicer(axis=1),
window=base_ch.FlexArraySlicer(),
up_th=base_ch.FlexArraySlicer(axis=1),
down_th=base_ch.FlexArraySlicer(axis=1),
wait=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def breakout_labels_nb(
high: tp.Array2d,
low: tp.Array2d,
window: tp.FlexArray1dLike = 14,
up_th: tp.FlexArray2dLike = np.inf,
down_th: tp.FlexArray2dLike = np.inf,
wait: tp.FlexArray1dLike = 1,
) -> tp.Array2d:
"""2-dim version of `breakout_labels_1d_nb`."""
window_ = to_1d_array_nb(np.asarray(window))
up_th_ = to_2d_array_nb(np.asarray(up_th))
down_th_ = to_2d_array_nb(np.asarray(down_th))
wait_ = to_1d_array_nb(np.asarray(wait))
breakout_labels = np.empty(high.shape, dtype=float_)
for col in prange(high.shape[1]):
breakout_labels[:, col] = breakout_labels_1d_nb(
high[:, col],
low[:, col],
window=flex_select_1d_nb(window_, col),
up_th=flex_select_col_nb(up_th_, col),
down_th=flex_select_col_nb(down_th_, col),
wait=flex_select_1d_nb(wait_, col),
)
return breakout_labels
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for working with OHLC(V) data."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.ohlcv.accessors import *
from vectorbtpro.ohlcv.nb import *
__exclude_from__all__ = [
"enums",
]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom Pandas accessors for OHLC(V) data.
Methods can be accessed as follows:
* `OHLCVDFAccessor` -> `pd.DataFrame.vbt.ohlcv.*`
The accessors inherit `vectorbtpro.generic.accessors`.
!!! note
Accessors do not utilize caching.
## Column names
By default, vectorbt searches for columns with names 'open', 'high', 'low', 'close', and 'volume'
(case doesn't matter). You can change the naming either using `feature_map` in
`vectorbtpro._settings.ohlcv`, or by providing `feature_map` directly to the accessor.
```pycon
>>> from vectorbtpro import *
>>> df = pd.DataFrame({
... 'my_open1': [2, 3, 4, 3.5, 2.5],
... 'my_high2': [3, 4, 4.5, 4, 3],
... 'my_low3': [1.5, 2.5, 3.5, 2.5, 1.5],
... 'my_close4': [2.5, 3.5, 4, 3, 2],
... 'my_volume5': [10, 11, 10, 9, 10]
... })
>>> df.vbt.ohlcv.get_feature('open')
None
>>> my_feature_map = {
... "my_open1": "Open",
... "my_high2": "High",
... "my_low3": "Low",
... "my_close4": "Close",
... "my_volume5": "Volume",
... }
>>> ohlcv_acc = df.vbt.ohlcv(freq='d', feature_map=my_feature_map)
>>> ohlcv_acc.get_feature('open')
0 2.0
1 3.0
2 4.0
3 3.5
4 2.5
Name: my_open1, dtype: float64
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `OHLCVDFAccessor.metrics`.
```pycon
>>> ohlcv_acc.stats()
Start 0
End 4
Period 5 days 00:00:00
First Price 2.0
Lowest Price 1.5
Highest Price 4.5
Last Price 2.0
First Volume 10
Lowest Volume 9
Highest Volume 11
Last Volume 10
Name: agg_stats, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `OHLCVDFAccessor.subplots`.
`OHLCVDFAccessor` class has a single subplot based on `OHLCVDFAccessor.plot` (without volume):
```pycon
>>> ohlcv_acc.plots(settings=dict(ohlc_type='candlestick')).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.accessors import register_df_vbt_accessor
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.base.reshaping import to_1d_array, to_2d_array
from vectorbtpro.data.base import OHLCDataMixin
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.accessors import GenericAccessor, GenericDFAccessor
from vectorbtpro.ohlcv import nb, enums
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig
from vectorbtpro.utils.decorators import hybrid_property
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
if tp.TYPE_CHECKING:
from vectorbtpro.data.base import Data as DataT
else:
DataT = "Data"
__all__ = [
"OHLCVDFAccessor",
]
__pdoc__ = {}
OHLCVDFAccessorT = tp.TypeVar("OHLCVDFAccessorT", bound="OHLCVDFAccessor")
@register_df_vbt_accessor("ohlcv")
class OHLCVDFAccessor(OHLCDataMixin, GenericDFAccessor):
"""Accessor on top of OHLCV data. For DataFrames only.
Accessible via `pd.DataFrame.vbt.ohlcv`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
feature_map: tp.KwargsLike = None,
**kwargs,
) -> None:
GenericDFAccessor.__init__(self, wrapper, obj=obj, feature_map=feature_map, **kwargs)
self._feature_map = feature_map
@hybrid_property
def df_accessor_cls(cls_or_self) -> tp.Type["OHLCVDFAccessor"]:
"""Accessor class for `pd.DataFrame`."""
return OHLCVDFAccessor
@property
def feature_map(self) -> tp.Kwargs:
"""Column names."""
from vectorbtpro._settings import settings
ohlcv_cfg = settings["ohlcv"]
return merge_dicts(ohlcv_cfg["feature_map"], self._feature_map)
@property
def feature_wrapper(self) -> ArrayWrapper:
new_columns = self.wrapper.columns.map(lambda x: self.feature_map[x] if x in self.feature_map else x)
return self.wrapper.replace(columns=new_columns)
@property
def symbol_wrapper(self) -> ArrayWrapper:
return ArrayWrapper(self.wrapper.index, [None], 1)
def select_feature_idxs(self: OHLCVDFAccessorT, idxs: tp.MaybeSequence[int], **kwargs) -> OHLCVDFAccessorT:
return self.iloc[:, idxs]
def select_symbol_idxs(self: OHLCVDFAccessorT, idxs: tp.MaybeSequence[int], **kwargs) -> OHLCVDFAccessorT:
raise NotImplementedError
def get(
self,
features: tp.Union[None, tp.MaybeFeatures] = None,
symbols: tp.Union[None, tp.MaybeSymbols] = None,
feature: tp.Optional[tp.Feature] = None,
symbol: tp.Optional[tp.Symbol] = None,
**kwargs,
) -> tp.SeriesFrame:
if features is not None and feature is not None:
raise ValueError("Must provide either features or feature, not both")
if symbols is not None or symbol is not None:
raise ValueError("Cannot provide symbols")
if feature is not None:
features = feature
single_feature = True
else:
if features is None:
return self.obj
single_feature = not self.has_multiple_keys(features)
if not single_feature:
feature_idxs = [self.get_feature_idx(k, raise_error=True) for k in features]
else:
feature_idxs = self.get_feature_idx(features, raise_error=True)
return self.obj.iloc[:, feature_idxs]
# ############# Conversion ############# #
def to_data(self, data_cls: tp.Optional[tp.Type[DataT]] = None, **kwargs) -> DataT:
"""Convert to a `vectorbtpro.data.base.Data` instance."""
if data_cls is None:
from vectorbtpro.data.base import Data
data_cls = Data
return data_cls.from_data(self.obj, columns_are_symbols=False, **kwargs)
# ############# Transforming ############# #
def mirror_ohlc(
self: OHLCVDFAccessorT,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
start_value: tp.ArrayLike = np.nan,
ref_feature: tp.ArrayLike = -1,
) -> tp.Frame:
"""Mirror OHLC features."""
if isinstance(ref_feature, str):
ref_feature = map_enum_fields(ref_feature, enums.PriceFeature)
open_idx = self.get_feature_idx("Open")
high_idx = self.get_feature_idx("High")
low_idx = self.get_feature_idx("Low")
close_idx = self.get_feature_idx("Close")
func = jit_reg.resolve_option(nb.mirror_ohlc_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
new_open, new_high, new_low, new_close = func(
self.symbol_wrapper.shape_2d,
open=to_2d_array(self.obj.iloc[:, open_idx]) if open_idx is not None else None,
high=to_2d_array(self.obj.iloc[:, high_idx]) if high_idx is not None else None,
low=to_2d_array(self.obj.iloc[:, low_idx]) if low_idx is not None else None,
close=to_2d_array(self.obj.iloc[:, close_idx]) if close_idx is not None else None,
start_value=to_1d_array(start_value),
ref_feature=to_1d_array(ref_feature),
)
df = self.obj.copy()
if open_idx is not None:
df.iloc[:, open_idx] = new_open[:, 0]
if high_idx is not None:
df.iloc[:, high_idx] = new_high[:, 0]
if low_idx is not None:
df.iloc[:, low_idx] = new_low[:, 0]
if close_idx is not None:
df.iloc[:, close_idx] = new_close[:, 0]
return df
def resample(self: OHLCVDFAccessorT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> OHLCVDFAccessorT:
"""Perform resampling on `OHLCVDFAccessor`."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.resample_meta(*args, **kwargs)
sr_dct = {}
for feature in self.feature_wrapper.columns:
if isinstance(feature, str) and feature.lower() == "open":
sr_dct[feature] = self.obj[feature].vbt.resample_apply(
wrapper_meta["resampler"],
generic_nb.first_reduce_nb,
)
elif isinstance(feature, str) and feature.lower() == "high":
sr_dct[feature] = self.obj[feature].vbt.resample_apply(
wrapper_meta["resampler"],
generic_nb.max_reduce_nb,
)
elif isinstance(feature, str) and feature.lower() == "low":
sr_dct[feature] = self.obj[feature].vbt.resample_apply(
wrapper_meta["resampler"],
generic_nb.min_reduce_nb,
)
elif isinstance(feature, str) and feature.lower() == "close":
sr_dct[feature] = self.obj[feature].vbt.resample_apply(
wrapper_meta["resampler"],
generic_nb.last_reduce_nb,
)
elif isinstance(feature, str) and feature.lower() == "volume":
sr_dct[feature] = self.obj[feature].vbt.resample_apply(
wrapper_meta["resampler"],
generic_nb.sum_reduce_nb,
)
else:
raise ValueError(f"Cannot match feature '{feature}'")
new_obj = pd.DataFrame(sr_dct)
return self.replace(
wrapper=wrapper_meta["new_wrapper"],
obj=new_obj,
)
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `OHLCVDFAccessor.stats`.
Merges `vectorbtpro.generic.accessors.GenericAccessor.stats_defaults` and
`stats` from `vectorbtpro._settings.ohlcv`."""
from vectorbtpro._settings import settings
ohlcv_stats_cfg = settings["ohlcv"]["stats"]
return merge_dicts(GenericAccessor.stats_defaults.__get__(self), ohlcv_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
first_price=dict(
title="First Price",
calc_func=lambda ohlc: generic_nb.bfill_1d_nb(ohlc.values.flatten())[0],
resolve_ohlc=True,
tags=["ohlcv", "ohlc"],
),
lowest_price=dict(
title="Lowest Price",
calc_func=lambda ohlc: ohlc.values.min(),
resolve_ohlc=True,
tags=["ohlcv", "ohlc"],
),
highest_price=dict(
title="Highest Price",
calc_func=lambda ohlc: ohlc.values.max(),
resolve_ohlc=True,
tags=["ohlcv", "ohlc"],
),
last_price=dict(
title="Last Price",
calc_func=lambda ohlc: generic_nb.ffill_1d_nb(ohlc.values.flatten())[-1],
resolve_ohlc=True,
tags=["ohlcv", "ohlc"],
),
first_volume=dict(
title="First Volume",
calc_func=lambda volume: generic_nb.bfill_1d_nb(volume.values)[0],
resolve_volume=True,
tags=["ohlcv", "volume"],
),
lowest_volume=dict(
title="Lowest Volume",
calc_func=lambda volume: volume.values.min(),
resolve_volume=True,
tags=["ohlcv", "volume"],
),
highest_volume=dict(
title="Highest Volume",
calc_func=lambda volume: volume.values.max(),
resolve_volume=True,
tags=["ohlcv", "volume"],
),
last_volume=dict(
title="Last Volume",
calc_func=lambda volume: generic_nb.ffill_1d_nb(volume.values)[-1],
resolve_volume=True,
tags=["ohlcv", "volume"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot_ohlc(
self,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot OHLC data.
Args:
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
ohlcv_cfg = settings["ohlcv"]
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
# Set up figure
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if ohlc_type is None:
ohlc_type = ohlcv_cfg["ohlc_type"]
if isinstance(ohlc_type, str):
if ohlc_type.lower() == "ohlc":
plot_obj = go.Ohlc
elif ohlc_type.lower() == "candlestick":
plot_obj = go.Candlestick
else:
raise ValueError("Plot type can be either 'OHLC' or 'Candlestick'")
else:
plot_obj = ohlc_type
def_trace_kwargs = dict(
x=self.wrapper.index,
open=self.open,
high=self.high,
low=self.low,
close=self.close,
name="OHLC",
increasing=dict(
fillcolor=plotting_cfg["color_schema"]["increasing"],
line=dict(color=plotting_cfg["color_schema"]["increasing"]),
),
decreasing=dict(
fillcolor=plotting_cfg["color_schema"]["decreasing"],
line=dict(color=plotting_cfg["color_schema"]["decreasing"]),
),
opacity=0.75,
)
if plot_obj is go.Ohlc:
del def_trace_kwargs["increasing"]["fillcolor"]
del def_trace_kwargs["decreasing"]["fillcolor"]
_trace_kwargs = merge_dicts(def_trace_kwargs, trace_kwargs)
ohlc = plot_obj(**_trace_kwargs)
fig.add_trace(ohlc, **add_trace_kwargs)
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
if "rangeslider_visible" not in layout_kwargs.get(xaxis, {}):
fig.update_layout({xaxis: dict(rangeslider_visible=False)})
return fig
def plot_volume(
self,
trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot volume data.
Args:
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Bar`.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if trace_kwargs is None:
trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
# Set up figure
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
marker_colors = np.empty(self.volume.shape, dtype=object)
mask_greater = (self.close.values - self.open.values) > 0
mask_less = (self.close.values - self.open.values) < 0
marker_colors[mask_greater] = plotting_cfg["color_schema"]["increasing"]
marker_colors[mask_less] = plotting_cfg["color_schema"]["decreasing"]
marker_colors[~(mask_greater | mask_less)] = plotting_cfg["color_schema"]["gray"]
_trace_kwargs = merge_dicts(
dict(
x=self.wrapper.index,
y=self.volume,
marker=dict(color=marker_colors, line_width=0),
opacity=0.5,
name="Volume",
),
trace_kwargs,
)
volume_bar = go.Bar(**_trace_kwargs)
fig.add_trace(volume_bar, **add_trace_kwargs)
return fig
def plot(
self,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
plot_volume: tp.Optional[bool] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
volume_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
volume_add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot OHLC(V) data.
Args:
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
plot_volume (bool): Whether to plot volume beneath.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
volume_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Bar`.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace` for OHLC.
volume_add_trace_kwargs (dict): Keyword arguments passed to `add_trace` for volume.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> vbt.YFData.pull("BTC-USD").get().vbt.ohlcv.plot().show()
```
[=100% "100%"]{: .candystripe .candystripe-animate }
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from vectorbtpro.utils.figure import make_figure, make_subplots
if plot_volume is None:
plot_volume = self.volume is not None
if plot_volume:
add_trace_kwargs = merge_dicts(dict(row=1, col=1), add_trace_kwargs)
volume_add_trace_kwargs = merge_dicts(dict(row=2, col=1), volume_add_trace_kwargs)
# Set up figure
if fig is None:
if plot_volume:
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0,
row_heights=[0.7, 0.3],
)
else:
fig = make_figure()
fig.update_layout(
showlegend=True,
xaxis=dict(showgrid=True),
yaxis=dict(showgrid=True),
)
if plot_volume:
fig.update_layout(
xaxis2=dict(showgrid=True),
yaxis2=dict(showgrid=True),
)
fig.update_layout(**layout_kwargs)
fig = self.plot_ohlc(
ohlc_type=ohlc_type,
trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if plot_volume:
fig = self.plot_volume(
trace_kwargs=volume_trace_kwargs,
add_trace_kwargs=volume_add_trace_kwargs,
fig=fig,
)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `OHLCVDFAccessor.plots`.
Merges `vectorbtpro.generic.accessors.GenericAccessor.plots_defaults` and
`plots` from `vectorbtpro._settings.ohlcv`."""
from vectorbtpro._settings import settings
ohlcv_plots_cfg = settings["ohlcv"]["plots"]
return merge_dicts(GenericAccessor.plots_defaults.__get__(self), ohlcv_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot=dict(
title="OHLC",
xaxis_kwargs=dict(showgrid=True, rangeslider_visible=False),
yaxis_kwargs=dict(showgrid=True),
check_is_not_grouped=True,
plot_func="plot",
plot_volume=False,
tags="ohlcv",
)
)
)
@property
def subplots(self) -> Config:
return self._subplots
OHLCVDFAccessor.override_metrics_doc(__pdoc__)
OHLCVDFAccessor.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Named tuples and enumerated types for OHLC(V) data.
Defines enums and other schemas for `vectorbtpro.ohlcv`."""
from vectorbtpro import _typing as tp
from vectorbtpro.utils.formatting import prettify
__pdoc__all__ = __all__ = [
"PriceFeature",
]
__pdoc__ = {}
# ############# Enums ############# #
class PriceFeatureT(tp.NamedTuple):
Open: int = 0
High: int = 1
Low: int = 2
Close: int = 3
PriceFeature = PriceFeatureT()
"""_"""
__pdoc__[
"PriceFeature"
] = f"""Price feature.
```python
{prettify(PriceFeature)}
```
"""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for OHLCV.
!!! note
vectorbt treats matrices as first-class citizens and expects input arrays to be
2-dim, unless function has suffix `_1d` or is meant to be input to another function.
Data is processed along index (axis 0)."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.flex_indexing import flex_select_1d_pr_nb, flex_select_1d_nb
from vectorbtpro.base.reshaping import to_1d_array_nb
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
from vectorbtpro.ohlcv.enums import PriceFeature
__all__ = []
@register_jitted(cache=True)
def ohlc_every_1d_nb(price: tp.Array1d, n: tp.FlexArray1dLike) -> tp.Array2d:
"""Aggregate every `n` price points into an OHLC point."""
n_ = to_1d_array_nb(np.asarray(n))
out = np.empty((price.shape[0], 4), dtype=float_)
vmin = np.inf
vmax = -np.inf
k = 0
start_i = 0
for i in range(price.shape[0]):
_n = flex_select_1d_pr_nb(n_, k)
if _n <= 0:
out[k, 0] = np.nan
out[k, 1] = np.nan
out[k, 2] = np.nan
out[k, 3] = np.nan
vmin = np.inf
vmax = -np.inf
if i < price.shape[0] - 1:
k = k + 1
continue
if price[i] < vmin:
vmin = price[i]
if price[i] > vmax:
vmax = price[i]
if i == start_i:
out[k, 0] = price[i]
if i == start_i + _n - 1 or i == price.shape[0] - 1:
out[k, 1] = vmax
out[k, 2] = vmin
out[k, 3] = price[i]
vmin = np.inf
vmax = -np.inf
if i < price.shape[0] - 1:
k = k + 1
start_i = start_i + _n
return out[: k + 1]
@register_jitted(cache=True)
def mirror_ohlc_1d_nb(
n_rows: int,
open: tp.Optional[tp.Array1d] = None,
high: tp.Optional[tp.Array1d] = None,
low: tp.Optional[tp.Array1d] = None,
close: tp.Optional[tp.Array1d] = None,
start_value: float = np.nan,
ref_feature: int = -1,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]:
"""Mirror OHLC."""
new_open = np.empty(n_rows, dtype=float_)
new_high = np.empty(n_rows, dtype=float_)
new_low = np.empty(n_rows, dtype=float_)
new_close = np.empty(n_rows, dtype=float_)
cumsum = 0.0
first_idx = -1
factor = 1.0
for i in range(n_rows):
_open = open[i] if open is not None else np.nan
_high = high[i] if high is not None else np.nan
_low = low[i] if low is not None else np.nan
_close = close[i] if close is not None else np.nan
if ref_feature == PriceFeature.Open or (ref_feature == -1 and not np.isnan(_open)):
if first_idx == -1:
first_idx = i
if not np.isnan(start_value):
new_open[i] = start_value
else:
new_open[i] = _open
factor = new_open[i] / _open
new_high[i] = _high * factor if not np.isnan(_high) else np.nan
new_low[i] = _low * factor if not np.isnan(_low) else np.nan
new_close[i] = _close * factor if not np.isnan(_close) else np.nan
else:
prev_open = open[i - 1] if open is not None else np.nan
cumsum += -np.log(_open / prev_open)
new_open[i] = open[first_idx] * np.exp(cumsum) * factor
new_high[i] = (_open / _low) * new_open[i] if not np.isnan(_low) else np.nan
new_low[i] = (_open / _high) * new_open[i] if not np.isnan(_high) else np.nan
new_close[i] = (_open / _close) * new_open[i] if not np.isnan(_close) else np.nan
elif ref_feature == PriceFeature.Close or (ref_feature == -1 and not np.isnan(_close)):
if first_idx == -1:
first_idx = i
if not np.isnan(start_value):
new_close[i] = start_value
else:
new_close[i] = _close
factor = new_close[i] / _close
new_open[i] = _open * factor if not np.isnan(_open) else np.nan
new_high[i] = _high * factor if not np.isnan(_high) else np.nan
new_low[i] = _low * factor if not np.isnan(_low) else np.nan
else:
prev_close = close[i - 1] if close is not None else np.nan
cumsum += -np.log(_close / prev_close)
new_close[i] = close[first_idx] * np.exp(cumsum) * factor
new_open[i] = (_close / _open) * new_close[i] if not np.isnan(_open) else np.nan
new_high[i] = (_close / _low) * new_close[i] if not np.isnan(_low) else np.nan
new_low[i] = (_close / _high) * new_close[i] if not np.isnan(_high) else np.nan
elif ref_feature == PriceFeature.High or (ref_feature == -1 and not np.isnan(_high)):
if first_idx == -1:
first_idx = i
if not np.isnan(start_value):
new_high[i] = start_value
else:
new_high[i] = _high
factor = new_high[i] / _high
new_open[i] = _open * factor if not np.isnan(_open) else np.nan
new_low[i] = _low * factor * new_high[i] if not np.isnan(_low) else np.nan
new_close[i] = _close * factor * new_high[i] if not np.isnan(_close) else np.nan
else:
prev_high = high[i - 1] if high is not None else np.nan
cumsum += -np.log(_high / prev_high)
new_high[i] = high[first_idx] * np.exp(cumsum) * factor
new_open[i] = (_high / _open) * new_high[i] if not np.isnan(_open) else np.nan
new_high[i] = (_high / _low) * new_high[i] if not np.isnan(_low) else np.nan
new_close[i] = (_high / _close) * new_high[i] if not np.isnan(_close) else np.nan
elif ref_feature == PriceFeature.Low or (ref_feature == -1 and not np.isnan(_low)):
if first_idx == -1:
first_idx = i
if not np.isnan(start_value):
new_low[i] = start_value
else:
new_low[i] = _low
factor = new_low[i] / _low
new_open[i] = _open * factor if not np.isnan(_open) else np.nan
new_high[i] = _high * factor if not np.isnan(_high) else np.nan
new_close[i] = _close * factor if not np.isnan(_close) else np.nan
else:
prev_low = low[i - 1] if low is not None else np.nan
cumsum += -np.log(_low / prev_low)
new_low[i] = low[first_idx] * np.exp(cumsum) * factor
new_open[i] = (_low / _open) * new_low[i] if not np.isnan(_open) else np.nan
new_high[i] = (_low / _high) * new_low[i] if not np.isnan(_high) else np.nan
new_close[i] = (_low / _close) * new_low[i] if not np.isnan(_close) else np.nan
else:
new_open[i] = np.nan
new_high[i] = np.nan
new_low[i] = np.nan
new_close[i] = np.nan
return new_open, new_high, new_low, new_close
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
open=base_ch.ArraySlicer(axis=1),
high=base_ch.ArraySlicer(axis=1),
low=base_ch.ArraySlicer(axis=1),
close=base_ch.ArraySlicer(axis=1),
start_value=base_ch.FlexArraySlicer(),
ref_feature=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def mirror_ohlc_nb(
target_shape: tp.Shape,
open: tp.Optional[tp.Array2d] = None,
high: tp.Optional[tp.Array2d] = None,
low: tp.Optional[tp.Array2d] = None,
close: tp.Optional[tp.Array2d] = None,
start_value: tp.FlexArray1dLike = np.nan,
ref_feature: tp.FlexArray1dLike = -1,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]:
"""2-dim version of `mirror_ohlc_1d_nb`."""
start_value_ = to_1d_array_nb(np.asarray(start_value))
ref_feature_ = to_1d_array_nb(np.asarray(ref_feature))
new_open = np.empty(target_shape, dtype=float_)
new_high = np.empty(target_shape, dtype=float_)
new_low = np.empty(target_shape, dtype=float_)
new_close = np.empty(target_shape, dtype=float_)
for col in prange(target_shape[1]):
new_open[:, col], new_high[:, col], new_low[:, col], new_close[:, col] = mirror_ohlc_1d_nb(
target_shape[0],
open[:, col] if open is not None else None,
high[:, col] if high is not None else None,
low[:, col] if low is not None else None,
close[:, col] if close is not None else None,
start_value=flex_select_1d_nb(start_value_, col),
ref_feature=flex_select_1d_nb(ref_feature_, col),
)
return new_open, new_high, new_low, new_close
"""Numba-compiled functions for working with portfolio.
Provides an arsenal of Numba-compiled functions that are used for portfolio
simulation, such as generating and filling orders. These only accept NumPy arrays and
other Numba-compatible types.
!!! note
vectorbt treats matrices as first-class citizens and expects input arrays to be
2-dim, unless function has suffix `_1d` or is meant to be input to another function.
All functions passed as argument must be Numba-compiled.
Records must retain the order they were created in.
!!! warning
Accumulation of roundoff error possible.
See [here](https://en.wikipedia.org/wiki/Round-off_error#Accumulation_of_roundoff_error) for explanation.
Rounding errors can cause trades and positions to not close properly:
```pycon
>>> print('%.50f' % 0.1) # has positive error
0.10000000000000000555111512312578270211815834045410
>>> # many buy transactions with positive error -> cannot close position
>>> sum([0.1 for _ in range(1000000)]) - 100000
1.3328826753422618e-06
>>> print('%.50f' % 0.3) # has negative error
0.29999999999999998889776975374843459576368331909180
>>> # many sell transactions with negative error -> cannot close position
>>> 300000 - sum([0.3 for _ in range(1000000)])
5.657668225467205e-06
```
While vectorbt has implemented tolerance checks when comparing floats for equality,
adding/subtracting small amounts large number of times may still introduce a noticable
error that cannot be corrected post factum.
To mitigate this issue, avoid repeating lots of micro-transactions of the same sign.
For example, reduce by `np.inf` or `position_now` to close a long/short position.
See `vectorbtpro.utils.math_` for current tolerance values.
!!! warning
Make sure to use `parallel=True` only if your columns are independent.
"""
from vectorbtpro.portfolio.nb.analysis import *
from vectorbtpro.portfolio.nb.core import *
from vectorbtpro.portfolio.nb.ctx_helpers import *
from vectorbtpro.portfolio.nb.from_order_func import *
from vectorbtpro.portfolio.nb.from_orders import *
from vectorbtpro.portfolio.nb.from_signals import *
from vectorbtpro.portfolio.nb.iter_ import *
from vectorbtpro.portfolio.nb.records import *
__all__ = []
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for portfolio analysis."""
from numba import prange
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb
from vectorbtpro.portfolio import chunking as portfolio_ch
from vectorbtpro.portfolio.nb.core import *
from vectorbtpro.records import chunking as records_ch
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.returns import nb as returns_nb_
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.math_ import is_close_nb, add_nb
from vectorbtpro.utils.template import RepFunc
# ############# Assets ############# #
@register_jitted(cache=True)
def get_long_size_nb(position_before: float, position_now: float) -> float:
"""Get long size."""
if position_before <= 0 and position_now <= 0:
return 0.0
if position_before >= 0 and position_now < 0:
return -position_before
if position_before < 0 and position_now >= 0:
return position_now
return add_nb(position_now, -position_before)
@register_jitted(cache=True)
def get_short_size_nb(position_before: float, position_now: float) -> float:
"""Get short size."""
if position_before >= 0 and position_now >= 0:
return 0.0
if position_before >= 0 and position_now < 0:
return -position_now
if position_before < 0 and position_now >= 0:
return position_before
return add_nb(position_before, -position_now)
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
direction=None,
init_position=base_ch.FlexArraySlicer(),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def asset_flow_nb(
target_shape: tp.Shape,
order_records: tp.RecordArray,
col_map: tp.GroupMap,
direction: int = Direction.Both,
init_position: tp.FlexArray1dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get asset flow series per column.
Returns the total transacted amount of assets at each time step."""
init_position_ = to_1d_array_nb(np.asarray(init_position))
out = np.full(target_shape, np.nan, dtype=float_)
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(col_lens.shape[0]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
out[_sim_start:_sim_end, col] = 0.0
if _sim_start >= _sim_end:
continue
col_len = col_lens[col]
if col_len == 0:
continue
last_id = -1
position_now = flex_select_1d_pc_nb(init_position_, col)
for c in range(col_len):
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end:
continue
if order_record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = order_record["id"]
i = order_record["idx"]
side = order_record["side"]
size = order_record["size"]
if side == OrderSide.Sell:
size *= -1
new_position_now = add_nb(position_now, size)
if direction == Direction.LongOnly:
asset_flow = get_long_size_nb(position_now, new_position_now)
elif direction == Direction.ShortOnly:
asset_flow = get_short_size_nb(position_now, new_position_now)
else:
asset_flow = size
out[i, col] = add_nb(out[i, col], asset_flow)
position_now = new_position_now
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="asset_flow", axis=1),
arg_take_spec=dict(
asset_flow=ch.ArraySlicer(axis=1),
direction=None,
init_position=base_ch.FlexArraySlicer(),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def assets_nb(
asset_flow: tp.Array2d,
direction: int = Direction.Both,
init_position: tp.FlexArray1dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get asset series per column.
Returns the current position at each time step."""
init_position_ = to_1d_array_nb(np.asarray(init_position))
out = np.full(asset_flow.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=asset_flow.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(asset_flow.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
position_now = flex_select_1d_pc_nb(init_position_, col)
for i in range(_sim_start, _sim_end):
flow_value = asset_flow[i, col]
position_now = add_nb(position_now, flow_value)
if direction == Direction.Both:
out[i, col] = position_now
elif direction == Direction.LongOnly and position_now > 0:
out[i, col] = position_now
elif direction == Direction.ShortOnly and position_now < 0:
out[i, col] = -position_now
else:
out[i, col] = 0.0
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="assets", axis=1),
arg_take_spec=dict(
assets=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def position_mask_nb(
assets: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get position mask per column."""
out = np.full(assets.shape, False, dtype=np.bool_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=assets.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(assets.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
if assets[i, col] != 0:
out[i, col] = True
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
assets=base_ch.array_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def position_mask_grouped_nb(
assets: tp.Array2d,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get position mask per group."""
out = np.full((assets.shape[0], len(group_lens)), False, dtype=np.bool_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=assets.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
if not np.isnan(assets[i, col]) and assets[i, col] != 0:
out[i, group] = True
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="assets", axis=1),
arg_take_spec=dict(
assets=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def position_coverage_nb(
assets: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""Get position mask per column."""
out = np.full(assets.shape[1], 0.0, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=assets.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(assets.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
hit_elements = 0
for i in range(_sim_start, _sim_end):
if assets[i, col] != 0:
hit_elements += 1
out[col] = hit_elements / (_sim_end - _sim_start)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
assets=base_ch.array_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
granular_groups=None,
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def position_coverage_grouped_nb(
assets: tp.Array2d,
group_lens: tp.GroupLens,
granular_groups: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""Get position coverage per group."""
out = np.full(len(group_lens), 0.0, dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=assets.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
n_elements = 0
hit_elements = 0
if granular_groups:
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
n_elements += _sim_end - _sim_start
for i in range(_sim_start, _sim_end):
if not np.isnan(assets[i, col]) and assets[i, col] != 0:
hit_elements += 1
else:
min_sim_start = assets.shape[0]
max_sim_end = 0
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
if _sim_start < min_sim_start:
min_sim_start = _sim_start
if _sim_end > max_sim_end:
max_sim_end = _sim_end
if min_sim_start >= max_sim_end:
continue
n_elements = max_sim_end - min_sim_start
for i in range(min_sim_start, max_sim_end):
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
if not np.isnan(assets[i, col]) and assets[i, col] != 0:
hit_elements += 1
break
if n_elements == 0:
out[group] = np.nan
else:
out[group] = hit_elements / n_elements
return out
# ############# Cash ############# #
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
cash_sharing=None,
cash_deposits_raw=RepFunc(portfolio_ch.get_cash_deposits_slicer),
split_shared=None,
weights=base_ch.flex_1d_array_gl_slicer,
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cash_deposits_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
cash_deposits_raw: tp.FlexArray2dLike = 0.0,
split_shared: bool = False,
weights: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get cash deposit series per column."""
cash_deposits_raw_ = to_2d_array_nb(np.asarray(cash_deposits_raw))
if weights is None:
weights_ = np.full(target_shape[1], np.nan, dtype=float_)
else:
weights_ = to_1d_array_nb(np.asarray(weights).astype(float_))
out = np.full(target_shape, np.nan, dtype=float_)
if not cash_sharing:
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(target_shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_weights = flex_select_1d_pc_nb(weights_, col)
for i in range(_sim_start, _sim_end):
_cash_deposits = flex_select_nb(cash_deposits_raw_, i, col)
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[i, col] = _weights * _cash_deposits
else:
out[i, col] = _cash_deposits
else:
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_weights = flex_select_1d_pc_nb(weights_, col)
for i in range(_sim_start, _sim_end):
_cash_deposits = flex_select_nb(cash_deposits_raw_, i, group)
if split_shared:
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[i, col] = _weights * _cash_deposits / (to_col - from_col)
else:
out[i, col] = _cash_deposits / (to_col - from_col)
else:
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[i, col] = _weights * _cash_deposits
else:
out[i, col] = _cash_deposits
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
cash_sharing=None,
cash_deposits_raw=RepFunc(portfolio_ch.get_cash_deposits_slicer),
weights=base_ch.flex_1d_array_gl_slicer,
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cash_deposits_grouped_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
cash_deposits_raw: tp.FlexArray2dLike = 0.0,
weights: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get cash deposit series per group."""
cash_deposits_raw_ = to_2d_array_nb(np.asarray(cash_deposits_raw))
if weights is None:
weights_ = np.full(target_shape[1], np.nan, dtype=float_)
else:
weights_ = to_1d_array_nb(np.asarray(weights).astype(float_))
out = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
if cash_sharing:
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_grouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
_sim_start = sim_start_[group]
_sim_end = sim_end_[group]
if _sim_start >= _sim_end:
continue
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for i in range(_sim_start, _sim_end):
_cash_deposits = flex_select_nb(cash_deposits_raw_, i, group)
if np.isnan(_cash_deposits) or _cash_deposits == 0:
out[i, group] = _cash_deposits
continue
group_weight = 0.0
for col in range(from_col, to_col):
_weights = flex_select_1d_pc_nb(weights_, col)
if not np.isnan(group_weight) and not np.isnan(_weights):
group_weight += _weights
else:
group_weight = np.nan
break
if not np.isnan(group_weight):
group_weight /= group_lens[group]
if not np.isnan(group_weight) and not is_close_nb(group_weight, 1.0):
out[i, group] = group_weight * _cash_deposits
else:
out[i, group] = _cash_deposits
else:
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_weights = flex_select_1d_pc_nb(weights_, col)
for i in range(_sim_start, _sim_end):
_cash_deposits = flex_select_nb(cash_deposits_raw_, i, col)
if np.isnan(out[i, group]):
out[i, group] = 0.0
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[i, group] += _weights * _cash_deposits
else:
out[i, group] += _cash_deposits
return out
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
cash_earnings_raw=base_ch.FlexArraySlicer(axis=1),
weights=base_ch.FlexArraySlicer(),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cash_earnings_nb(
target_shape: tp.Shape,
cash_earnings_raw: tp.FlexArray2dLike = 0.0,
weights: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get cash earning series per column."""
cash_earnings_raw_ = to_2d_array_nb(np.asarray(cash_earnings_raw))
if weights is None:
weights_ = np.full(target_shape[1], np.nan, dtype=float_)
else:
weights_ = to_1d_array_nb(np.asarray(weights).astype(float_))
out = np.full(target_shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(target_shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_weights = flex_select_1d_pc_nb(weights_, col)
for i in range(_sim_start, _sim_end):
_cash_earnings = flex_select_nb(cash_earnings_raw_, i, col)
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[i, col] = _weights * _cash_earnings
else:
out[i, col] = _cash_earnings
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
cash_earnings_raw=base_ch.flex_array_gl_slicer,
weights=base_ch.flex_1d_array_gl_slicer,
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cash_earnings_grouped_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_earnings_raw: tp.FlexArray2dLike = 0.0,
weights: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get cash earning series per group."""
cash_earnings_raw_ = to_2d_array_nb(np.asarray(cash_earnings_raw))
if weights is None:
weights_ = np.full(target_shape[1], np.nan, dtype=float_)
else:
weights_ = to_1d_array_nb(np.asarray(weights).astype(float_))
out = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_weights = flex_select_1d_pc_nb(weights_, col)
for i in range(_sim_start, _sim_end):
_cash_earnings = flex_select_nb(cash_earnings_raw_, i, col)
if np.isnan(out[i, group]):
out[i, group] = 0.0
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[i, group] += _weights * _cash_earnings
else:
out[i, group] += _cash_earnings
return out
@register_jitted(cache=True)
def get_free_cash_diff_nb(
position_before: float,
position_now: float,
debt_now: float,
price: float,
fees: float,
) -> tp.Tuple[float, float]:
"""Get updated debt and free cash flow."""
size = add_nb(position_now, -position_before)
final_cash = -size * price - fees
if is_close_nb(size, 0):
new_debt = debt_now
free_cash_diff = 0.0
elif size > 0:
if position_before < 0:
if position_now < 0:
short_size = abs(size)
else:
short_size = abs(position_before)
avg_entry_price = debt_now / abs(position_before)
debt_diff = short_size * avg_entry_price
new_debt = add_nb(debt_now, -debt_diff)
free_cash_diff = add_nb(2 * debt_diff, final_cash)
else:
new_debt = debt_now
free_cash_diff = final_cash
else:
if position_now < 0:
if position_before < 0:
short_size = abs(size)
else:
short_size = abs(position_now)
short_value = short_size * price
new_debt = debt_now + short_value
free_cash_diff = add_nb(final_cash, -2 * short_value)
else:
new_debt = debt_now
free_cash_diff = final_cash
return new_debt, free_cash_diff
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
free=None,
cash_earnings=base_ch.FlexArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cash_flow_nb(
target_shape: tp.Shape,
order_records: tp.RecordArray,
col_map: tp.GroupMap,
free: bool = False,
cash_earnings: tp.FlexArray2dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get (free) cash flow series per column."""
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
out = np.full(target_shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in range(target_shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
out[i, col] = flex_select_nb(cash_earnings_, i, col)
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
for col in prange(col_lens.shape[0]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
col_len = col_lens[col]
if col_len == 0:
continue
last_id = -1
position_now = 0.0
debt_now = 0.0
for c in range(col_len):
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end:
continue
if order_record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = order_record["id"]
i = order_record["idx"]
side = order_record["side"]
size = order_record["size"]
price = order_record["price"]
fees = order_record["fees"]
if side == OrderSide.Sell:
size *= -1
position_before = position_now
position_now = add_nb(position_now, size)
if free:
debt_now, cash_flow = get_free_cash_diff_nb(
position_before=position_before,
position_now=position_now,
debt_now=debt_now,
price=price,
fees=fees,
)
else:
cash_flow = -size * price - fees
out[i, col] = add_nb(out[i, col], cash_flow)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
cash_flow=base_ch.array_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cash_flow_grouped_nb(
cash_flow: tp.Array2d,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get (free) cash flow series per group."""
out = np.full((cash_flow.shape[0], len(group_lens)), np.nan, dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=cash_flow.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
if np.isnan(out[i, group]):
out[i, group] = 0.0
out[i, group] += cash_flow[i, col]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="free_cash_flow", axis=1),
arg_take_spec=dict(
init_cash_raw=None,
free_cash_flow=ch.ArraySlicer(axis=1),
cash_deposits=base_ch.FlexArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def align_init_cash_nb(
init_cash_raw: int,
free_cash_flow: tp.Array2d,
cash_deposits: tp.FlexArray2dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""Align initial cash to the maximum negative free cash flow per column or group."""
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
out = np.full(free_cash_flow.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=free_cash_flow.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(free_cash_flow.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
free_cash = 0.0
min_req_cash = np.inf
for i in range(_sim_start, _sim_end):
free_cash = add_nb(free_cash, free_cash_flow[i, col])
free_cash = add_nb(free_cash, flex_select_nb(cash_deposits_, i, col))
if free_cash < min_req_cash:
min_req_cash = free_cash
if min_req_cash < 0:
out[col] = np.abs(min_req_cash)
else:
out[col] = 1.0
if init_cash_raw == InitCashMode.AutoAlign:
out = np.full(out.shape, np.max(out))
return out
@register_jitted(cache=True)
def init_cash_nb(
init_cash_raw: tp.FlexArray1d,
group_lens: tp.GroupLens,
cash_sharing: bool,
split_shared: bool = False,
weights: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""Get initial cash per column."""
out = np.empty(np.sum(group_lens), dtype=float_)
if weights is None:
weights_ = np.full(group_lens.sum(), np.nan, dtype=float_)
else:
weights_ = to_1d_array_nb(np.asarray(weights).astype(float_))
if not cash_sharing:
for col in range(out.shape[0]):
_init_cash = flex_select_1d_pc_nb(init_cash_raw, col)
_weights = flex_select_1d_pc_nb(weights_, col)
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[col] = _weights * _init_cash
else:
out[col] = _init_cash
else:
from_col = 0
for group in range(len(group_lens)):
to_col = from_col + group_lens[group]
group_len = to_col - from_col
_init_cash = flex_select_1d_pc_nb(init_cash_raw, group)
for col in range(from_col, to_col):
_weights = flex_select_1d_pc_nb(weights_, col)
if split_shared:
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[col] = _weights * _init_cash / group_len
else:
out[col] = _init_cash / group_len
else:
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
out[col] = _weights * _init_cash
else:
out[col] = _init_cash
from_col = to_col
return out
@register_jitted(cache=True)
def init_cash_grouped_nb(
init_cash_raw: tp.FlexArray1d,
group_lens: tp.GroupLens,
cash_sharing: bool,
weights: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""Get initial cash per group."""
out = np.empty(group_lens.shape, dtype=float_)
if weights is None:
weights_ = np.full(group_lens.sum(), np.nan, dtype=float_)
else:
weights_ = to_1d_array_nb(np.asarray(weights).astype(float_))
if cash_sharing:
from_col = 0
for group in range(len(group_lens)):
to_col = from_col + group_lens[group]
_init_cash = flex_select_1d_pc_nb(init_cash_raw, group)
group_weight = 0.0
for col in range(from_col, to_col):
_weights = flex_select_1d_pc_nb(weights_, col)
if not np.isnan(group_weight) and not np.isnan(_weights):
group_weight += _weights
else:
group_weight = np.nan
break
if not np.isnan(group_weight):
group_weight /= group_lens[group]
if not np.isnan(group_weight) and not is_close_nb(group_weight, 1.0):
out[group] = group_weight * _init_cash
else:
out[group] = _init_cash
from_col = to_col
else:
from_col = 0
for group in range(len(group_lens)):
to_col = from_col + group_lens[group]
cash_sum = 0.0
for col in range(from_col, to_col):
_init_cash = flex_select_1d_pc_nb(init_cash_raw, col)
_weights = flex_select_1d_pc_nb(weights_, col)
if not np.isnan(_weights) and not is_close_nb(_weights, 1.0):
cash_sum += _weights * _init_cash
else:
cash_sum += _init_cash
out[group] = cash_sum
from_col = to_col
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="cash_flow", axis=1),
arg_take_spec=dict(
cash_flow=ch.ArraySlicer(axis=1),
init_cash=base_ch.FlexArraySlicer(),
cash_deposits=base_ch.FlexArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cash_nb(
cash_flow: tp.Array2d,
init_cash: tp.FlexArray1d,
cash_deposits: tp.FlexArray2dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get cash series per column or group."""
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
out = np.full(cash_flow.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=cash_flow.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(cash_flow.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
cash_now = flex_select_1d_pc_nb(init_cash, col)
for i in range(_sim_start, _sim_end):
cash_now = add_nb(cash_now, flex_select_nb(cash_deposits_, i, col))
cash_now = add_nb(cash_now, cash_flow[i, col])
out[i, col] = cash_now
return out
# ############# Value ############# #
@register_jitted(cache=True)
def init_position_value_nb(
n_cols: int,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
) -> tp.Array1d:
"""Get initial position value per column."""
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
out = np.empty(n_cols, dtype=float_)
for col in range(n_cols):
_init_position = float(flex_select_1d_pc_nb(init_position_, col))
_init_price = float(flex_select_1d_pc_nb(init_price_, col))
if _init_position == 0:
out[col] = 0.0
else:
out[col] = _init_position * _init_price
return out
@register_jitted(cache=True)
def init_position_value_grouped_nb(
group_lens: tp.GroupLens,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
) -> tp.Array1d:
"""Get initial position value per group."""
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
out = np.full(len(group_lens), 0.0, dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_init_position = float(flex_select_1d_pc_nb(init_position_, col))
_init_price = float(flex_select_1d_pc_nb(init_price_, col))
if _init_position != 0:
out[group] += _init_position * _init_price
return out
@register_jitted(cache=True)
def init_value_nb(init_position_value: tp.Array1d, init_cash: tp.FlexArray1d) -> tp.Array1d:
"""Get initial value per column or group."""
out = np.empty(len(init_position_value), dtype=float_)
for col in range(len(init_position_value)):
_init_cash = flex_select_1d_pc_nb(init_cash, col)
out[col] = _init_cash + init_position_value[col]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
assets=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def asset_value_nb(
close: tp.Array2d,
assets: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get asset value series per column."""
out = np.full(close.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=close.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(close.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
if assets[i, col] == 0:
out[i, col] = 0.0
else:
out[i, col] = close[i, col] * assets[i, col]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
asset_value=base_ch.array_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def asset_value_grouped_nb(
asset_value: tp.Array2d,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get asset value series per group."""
out = np.full((asset_value.shape[0], len(group_lens)), np.nan, dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=asset_value.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
if np.isnan(out[i, group]):
out[i, group] = 0.0
out[i, group] += asset_value[i, col]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="cash", axis=1),
arg_take_spec=dict(
cash=ch.ArraySlicer(axis=1),
asset_value=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def value_nb(
cash: tp.Array2d,
asset_value: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get value series per column or group."""
out = np.full(cash.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=cash.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(cash.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
out[i, col] = cash[i, col] + asset_value[i, col]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="asset_value", axis=1),
arg_take_spec=dict(
asset_value=ch.ArraySlicer(axis=1),
value=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def gross_exposure_nb(
asset_value: tp.Array2d,
value: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get gross exposure series per column."""
out = np.full(asset_value.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=asset_value.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(asset_value.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
if value[i, col] == 0:
out[i, col] = np.nan
else:
out[i, col] = abs(asset_value[i, col] / value[i, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="long_exposure", axis=1),
arg_take_spec=dict(
long_exposure=ch.ArraySlicer(axis=1),
short_exposure=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def net_exposure_nb(
long_exposure: tp.Array2d,
short_exposure: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get net exposure series per column."""
out = np.full(long_exposure.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=long_exposure.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(long_exposure.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
out[i, col] = long_exposure[i, col] - short_exposure[i, col]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
asset_value=base_ch.array_gl_slicer,
value=ch.ArraySlicer(axis=1),
group_lens=ch.ArraySlicer(axis=0),
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def allocations_nb(
asset_value: tp.Array2d,
value: tp.Array2d,
group_lens: tp.GroupLens,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get allocations per column."""
out = np.full(asset_value.shape, np.nan, dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=asset_value.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
if value[i, group] == 0:
out[i, col] = np.nan
else:
out[i, col] = asset_value[i, col] / value[i, group]
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
close=ch.ArraySlicer(axis=1),
order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
init_position=base_ch.FlexArraySlicer(),
init_price=base_ch.FlexArraySlicer(),
cash_earnings=base_ch.FlexArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def total_profit_nb(
target_shape: tp.Shape,
close: tp.Array2d,
order_records: tp.RecordArray,
col_map: tp.GroupMap,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_earnings: tp.FlexArray2dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""Get total profit per column.
A much faster version than the one based on `value_nb`."""
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
assets = np.full(target_shape[1], 0.0, dtype=float_)
cash = np.full(target_shape[1], 0.0, dtype=float_)
total_profit = np.full(target_shape[1], np.nan, dtype=float_)
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(target_shape[1]):
_init_position = float(flex_select_1d_pc_nb(init_position_, col))
_init_price = float(flex_select_1d_pc_nb(init_price_, col))
if _init_position != 0:
assets[col] = _init_position
cash[col] = -_init_position * _init_price
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
for i in range(_sim_start, _sim_end):
cash[col] += flex_select_nb(cash_earnings_, i, col)
for col in prange(col_lens.shape[0]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
col_len = col_lens[col]
if col_len == 0:
if assets[col] == 0 and cash[col] == 0:
total_profit[col] = 0.0
continue
last_id = -1
for c in range(col_len):
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end:
continue
if order_record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = order_record["id"]
# Fill assets
if order_record["side"] == OrderSide.Buy:
order_size = order_record["size"]
assets[col] = add_nb(assets[col], order_size)
else:
order_size = order_record["size"]
assets[col] = add_nb(assets[col], -order_size)
# Fill cash balance
if order_record["side"] == OrderSide.Buy:
order_cash = order_record["size"] * order_record["price"] + order_record["fees"]
cash[col] = add_nb(cash[col], -order_cash)
else:
order_cash = order_record["size"] * order_record["price"] - order_record["fees"]
cash[col] = add_nb(cash[col], order_cash)
total_profit[col] = cash[col] + assets[col] * close[_sim_end - 1, col]
return total_profit
@register_jitted(cache=True)
def total_profit_grouped_nb(total_profit: tp.Array1d, group_lens: tp.GroupLens) -> tp.Array1d:
"""Get total profit per group."""
out = np.empty(len(group_lens), dtype=float_)
from_col = 0
for group in range(len(group_lens)):
to_col = from_col + group_lens[group]
out[group] = np.sum(total_profit[from_col:to_col])
from_col = to_col
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="value", axis=1),
arg_take_spec=dict(
value=ch.ArraySlicer(axis=1),
init_value=base_ch.FlexArraySlicer(),
cash_deposits=base_ch.FlexArraySlicer(axis=1),
cash_deposits_as_input=None,
log_returns=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def returns_nb(
value: tp.Array2d,
init_value: tp.FlexArray1d,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_deposits_as_input: bool = False,
log_returns: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get return series per column or group."""
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
out = np.full(value.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=value.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(value.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
input_value = flex_select_1d_pc_nb(init_value, col)
for i in range(_sim_start, _sim_end):
_cash_deposits = flex_select_nb(cash_deposits_, i, col)
output_value = value[i, col]
if cash_deposits_as_input:
adj_input_value = input_value + _cash_deposits
out[i, col] = returns_nb_.get_return_nb(adj_input_value, output_value, log_returns=log_returns)
else:
adj_output_value = output_value - _cash_deposits
out[i, col] = returns_nb_.get_return_nb(input_value, adj_output_value, log_returns=log_returns)
input_value = output_value
return out
@register_jitted(cache=True)
def get_asset_pnl_nb(
input_asset_value: float,
output_asset_value: float,
cash_flow: float,
) -> float:
"""Get asset PnL from the input and output asset value, and the cash flow."""
return output_asset_value + cash_flow - input_asset_value
@register_chunkable(
size=ch.ArraySizer(arg_query="asset_value", axis=1),
arg_take_spec=dict(
asset_value=ch.ArraySlicer(axis=1),
cash_flow=ch.ArraySlicer(axis=1),
init_position_value=base_ch.FlexArraySlicer(axis=0),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def asset_pnl_nb(
asset_value: tp.Array2d,
cash_flow: tp.Array2d,
init_position_value: tp.FlexArray1dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get asset (realized and unrealized) PnL series per column or group."""
init_position_value_ = to_1d_array_nb(np.asarray(init_position_value))
out = np.full(asset_value.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=asset_value.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(asset_value.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_init_position_value = flex_select_1d_pc_nb(init_position_value_, col)
for i in range(_sim_start, _sim_end):
if i == _sim_start:
input_asset_value = _init_position_value
else:
input_asset_value = asset_value[i - 1, col]
out[i, col] = get_asset_pnl_nb(
input_asset_value,
asset_value[i, col],
cash_flow[i, col],
)
return out
@register_jitted(cache=True)
def get_asset_return_nb(
input_asset_value: float,
output_asset_value: float,
cash_flow: float,
log_returns: bool = False,
) -> float:
"""Get asset return from the input and output asset value, and the cash flow."""
if is_close_nb(input_asset_value, 0):
input_value = -output_asset_value
output_value = cash_flow
else:
input_value = input_asset_value
output_value = output_asset_value + cash_flow
if input_value < 0 and output_value < 0:
return_value = -returns_nb_.get_return_nb(-input_value, -output_value, log_returns=False)
else:
return_value = returns_nb_.get_return_nb(input_value, output_value, log_returns=False)
if log_returns:
return np.log1p(return_value)
return return_value
@register_chunkable(
size=ch.ArraySizer(arg_query="asset_value", axis=1),
arg_take_spec=dict(
asset_value=ch.ArraySlicer(axis=1),
cash_flow=ch.ArraySlicer(axis=1),
init_position_value=base_ch.FlexArraySlicer(axis=0),
log_returns=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def asset_returns_nb(
asset_value: tp.Array2d,
cash_flow: tp.Array2d,
init_position_value: tp.FlexArray1dLike = 0.0,
log_returns: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get asset return series per column or group."""
init_position_value_ = to_1d_array_nb(np.asarray(init_position_value))
out = np.full(asset_value.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=asset_value.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(asset_value.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_init_position_value = flex_select_1d_pc_nb(init_position_value_, col)
for i in range(_sim_start, _sim_end):
if i == _sim_start:
input_asset_value = _init_position_value
else:
input_asset_value = asset_value[i - 1, col]
out[i, col] = get_asset_return_nb(
input_asset_value,
asset_value[i, col],
cash_flow[i, col],
log_returns=log_returns,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="close", axis=1),
arg_take_spec=dict(
close=ch.ArraySlicer(axis=1),
init_value=base_ch.FlexArraySlicer(),
cash_deposits=base_ch.FlexArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def market_value_nb(
close: tp.Array2d,
init_value: tp.FlexArray1d,
cash_deposits: tp.FlexArray2dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get market value per column."""
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
out = np.full(close.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=close.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(close.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
curr_value = flex_select_1d_pc_nb(init_value, col)
for i in range(_sim_start, _sim_end):
if i > _sim_start:
curr_value *= close[i, col] / close[i - 1, col]
curr_value += flex_select_nb(cash_deposits_, i, col)
out[i, col] = curr_value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
close=base_ch.array_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
init_value=base_ch.FlexArraySlicer(mapper=base_ch.group_lens_mapper),
cash_deposits=base_ch.FlexArraySlicer(axis=1, mapper=base_ch.group_lens_mapper),
sim_start=base_ch.flex_1d_array_gl_slicer,
sim_end=base_ch.flex_1d_array_gl_slicer,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def market_value_grouped_nb(
close: tp.Array2d,
group_lens: tp.GroupLens,
init_value: tp.FlexArray1d,
cash_deposits: tp.FlexArray2dLike = 0.0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get market value per group."""
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
out = np.full((close.shape[0], len(group_lens)), np.nan, dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=close.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
for col in range(from_col, to_col):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
curr_value = prev_value = flex_select_1d_pc_nb(init_value, col)
for i in range(_sim_start, _sim_end):
if i > _sim_start:
if not np.isnan(close[i - 1, col]):
prev_close = close[i - 1, col]
prev_value = prev_close
else:
prev_close = prev_value
if not np.isnan(close[i, col]):
curr_close = close[i, col]
prev_value = curr_close
else:
curr_close = prev_value
curr_value *= curr_close / prev_close
curr_value += flex_select_nb(cash_deposits_, i, col)
if np.isnan(out[i, group]):
out[i, group] = 0.0
out[i, group] += curr_value
return out
@register_jitted(cache=True)
def total_market_return_nb(
market_value: tp.Array2d,
input_value: tp.FlexArray1d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""Get total market return per column or group."""
out = np.full(market_value.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=market_value.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in range(market_value.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_input_value = flex_select_1d_pc_nb(input_value, col)
if _input_value != 0:
out[col] = (market_value[_sim_end - 1, col] - _input_value) / _input_value
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Core Numba-compiled functions for portfolio simulation."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb, flex_select_nb
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.portfolio.enums import *
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils.math_ import is_close_nb, is_close_or_less_nb, is_close_or_greater_nb, is_less_nb, add_nb
@register_jitted(cache=True)
def order_not_filled_nb(status: int, status_info: int) -> OrderResult:
"""Return `OrderResult` for order that hasn't been filled."""
return OrderResult(size=np.nan, price=np.nan, fees=np.nan, side=-1, status=status, status_info=status_info)
@register_jitted(cache=True)
def check_adj_price_nb(
adj_price: float,
price_area: PriceArea,
is_closing_price: bool,
price_area_vio_mode: int,
) -> float:
"""Check whether adjusted price is within price boundaries."""
if price_area_vio_mode == PriceAreaVioMode.Ignore:
return adj_price
if adj_price > price_area.high:
if price_area_vio_mode == PriceAreaVioMode.Error:
raise ValueError("Adjusted order price is above the highest price")
elif price_area_vio_mode == PriceAreaVioMode.Cap:
adj_price = price_area.high
if adj_price < price_area.low:
if price_area_vio_mode == PriceAreaVioMode.Error:
raise ValueError("Adjusted order price is below the lowest price")
elif price_area_vio_mode == PriceAreaVioMode.Cap:
adj_price = price_area.low
if is_closing_price and adj_price != price_area.close:
if price_area_vio_mode == PriceAreaVioMode.Error:
raise ValueError("Adjusted order price is beyond the closing price")
elif price_area_vio_mode == PriceAreaVioMode.Cap:
adj_price = price_area.close
return adj_price
@register_jitted(cache=True)
def approx_long_buy_value_nb(val_price: float, size: float) -> float:
"""Approximate value of a long-buy operation.
Positive value means spending (for sorting reasons)."""
if size == 0:
return 0.0
order_value = abs(size) * val_price
add_free_cash = -order_value
return -add_free_cash
@register_jitted(cache=True)
def should_apply_size_granularity_nb(size: float, size_granularity: float) -> bool:
"""Whether to apply a size granularity to a size."""
if np.isnan(size_granularity):
return False
if size_granularity % 1 == 0:
return True
adj_size = size // size_granularity * size_granularity
return not is_close_nb(size, adj_size) and not is_close_nb(size, adj_size + size_granularity)
@register_jitted(cache=True)
def apply_size_granularity_nb(size: float, size_granularity: float) -> float:
"""Apply a size granularity to a size."""
return size // size_granularity * size_granularity
@register_jitted(cache=True)
def cast_account_state_nb(account_state: AccountState) -> AccountState:
"""Cast account state to float."""
return AccountState(
cash=float(account_state.cash),
position=float(account_state.position),
debt=float(account_state.debt),
locked_cash=float(account_state.locked_cash),
free_cash=float(account_state.free_cash),
)
@register_jitted(cache=True)
def long_buy_nb(
account_state: AccountState,
size: float,
price: float,
fees: float = 0.0,
fixed_fees: float = 0.0,
slippage: float = 0.0,
min_size: float = np.nan,
max_size: float = np.nan,
size_granularity: float = np.nan,
leverage: float = 1.0,
leverage_mode: int = LeverageMode.Lazy,
price_area_vio_mode: int = PriceAreaVioMode.Ignore,
allow_partial: bool = True,
percent: float = np.nan,
price_area: PriceArea = NoPriceArea,
is_closing_price: bool = False,
) -> tp.Tuple[OrderResult, AccountState]:
"""Open or increase a long position."""
_account_state = cast_account_state_nb(account_state)
# Get cash limit
cash_limit = _account_state.free_cash
if not np.isnan(percent):
cash_limit = cash_limit * percent
if cash_limit <= 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoCash), _account_state
cash_limit = cash_limit * leverage
# Adjust for granularity
if should_apply_size_granularity_nb(size, size_granularity):
size = apply_size_granularity_nb(size, size_granularity)
# Adjust for max size
if not np.isnan(max_size) and size > max_size:
if not allow_partial:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.MaxSizeExceeded), _account_state
size = max_size
if np.isinf(size) and np.isinf(cash_limit):
raise ValueError("Attempt to go in long direction infinitely")
# Get price adjusted with slippage
adj_price = price * (1 + slippage)
adj_price = check_adj_price_nb(adj_price, price_area, is_closing_price, price_area_vio_mode)
# Get cash required to complete this order
if np.isinf(size):
req_cash = np.inf
req_fees = np.inf
else:
order_value = size * adj_price
req_fees = order_value * fees + fixed_fees
req_cash = order_value + req_fees
if is_close_or_less_nb(req_cash, cash_limit):
# Sufficient amount of cash
final_size = size
fees_paid = req_fees
else:
# Insufficient amount of cash, size will be less than requested
# For fees of 10% and 1$ per transaction, you can buy for 90$ (new_req_cash)
# to spend 100$ (cash_limit) in total
max_req_cash = add_nb(cash_limit, -fixed_fees) / (1 + fees)
if max_req_cash <= 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state
max_acq_size = max_req_cash / adj_price
# Adjust for granularity
if should_apply_size_granularity_nb(max_acq_size, size_granularity):
final_size = apply_size_granularity_nb(max_acq_size, size_granularity)
new_order_value = final_size * adj_price
fees_paid = new_order_value * fees + fixed_fees
req_cash = new_order_value + fees_paid
else:
final_size = max_acq_size
fees_paid = cash_limit - max_req_cash
req_cash = cash_limit
# Check against size of zero
if is_close_nb(final_size, 0):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeZero), _account_state
# Check against minimum size
if not np.isnan(min_size) and is_less_nb(final_size, min_size):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), _account_state
# Check against partial fill (np.inf doesn't count)
if np.isfinite(size) and is_less_nb(final_size, size) and not allow_partial:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.PartialFill), _account_state
# Create a filled order
order_result = OrderResult(
float(final_size),
float(adj_price),
float(fees_paid),
OrderSide.Buy,
OrderStatus.Filled,
-1,
)
# Update the current account state
new_cash = add_nb(_account_state.cash, -req_cash)
new_position = add_nb(_account_state.position, final_size)
if leverage_mode == LeverageMode.Lazy:
debt_diff = max(add_nb(req_cash, -_account_state.free_cash), 0.0)
if debt_diff > 0:
new_debt = _account_state.debt + debt_diff
new_locked_cash = _account_state.locked_cash + _account_state.free_cash
new_free_cash = 0.0
else:
new_debt = _account_state.debt
new_locked_cash = _account_state.locked_cash
new_free_cash = add_nb(_account_state.free_cash, -req_cash)
else:
if leverage > 1:
if np.isinf(leverage):
raise ValueError("Leverage must be finite for LeverageMode.Eager")
order_value = final_size * adj_price
new_debt = _account_state.debt + order_value * (leverage - 1) / leverage
new_locked_cash = _account_state.locked_cash + order_value / leverage
new_free_cash = add_nb(_account_state.free_cash, -order_value / leverage - fees_paid)
else:
new_debt = _account_state.debt
new_locked_cash = _account_state.locked_cash
new_free_cash = add_nb(_account_state.free_cash, -req_cash)
new_account_state = AccountState(
cash=float(new_cash),
position=float(new_position),
debt=float(new_debt),
locked_cash=float(new_locked_cash),
free_cash=float(new_free_cash),
)
return order_result, new_account_state
@register_jitted(cache=True)
def approx_long_sell_value_nb(position: float, debt: float, val_price: float, size: float) -> float:
"""Approximate value of a long-sell operation.
Positive value means spending (for sorting reasons)."""
if size == 0 or position == 0:
return 0.0
size_limit = min(abs(size), position)
order_value = size_limit * val_price
size_fraction = size_limit / position
released_debt = size_fraction * debt
add_free_cash = order_value - released_debt
return -add_free_cash
@register_jitted(cache=True)
def long_sell_nb(
account_state: AccountState,
size: float,
price: float,
fees: float = 0.0,
fixed_fees: float = 0.0,
slippage: float = 0.0,
min_size: float = np.nan,
max_size: float = np.nan,
size_granularity: float = np.nan,
price_area_vio_mode: int = PriceAreaVioMode.Ignore,
allow_partial: bool = True,
percent: float = np.nan,
price_area: PriceArea = NoPriceArea,
is_closing_price: bool = False,
) -> tp.Tuple[OrderResult, AccountState]:
"""Decrease or close a long position."""
_account_state = cast_account_state_nb(account_state)
# Check for open position
if _account_state.position == 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoOpenPosition), _account_state
# Get size limit
size_limit = min(size, _account_state.position)
if not np.isnan(percent):
size_limit = size_limit * percent
# Adjust for granularity
if should_apply_size_granularity_nb(size_limit, size_granularity):
size = apply_size_granularity_nb(size, size_granularity)
size_limit = apply_size_granularity_nb(size_limit, size_granularity)
# Adjust for max size
if not np.isnan(max_size) and size_limit > max_size:
if not allow_partial:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.MaxSizeExceeded), _account_state
size_limit = max_size
# Check against size of zero
if is_close_nb(size_limit, 0):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeZero), _account_state
# Check against minimum size
if not np.isnan(min_size) and is_less_nb(size_limit, min_size):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), _account_state
# Check against partial fill
if np.isfinite(size) and is_less_nb(size_limit, size) and not allow_partial: # np.inf doesn't count
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.PartialFill), _account_state
# Get price adjusted with slippage
adj_price = price * (1 - slippage)
adj_price = check_adj_price_nb(adj_price, price_area, is_closing_price, price_area_vio_mode)
# Get acquired cash
acq_cash = size_limit * adj_price
# Update fees
fees_paid = acq_cash * fees + fixed_fees
# Get final cash by subtracting costs
final_acq_cash = add_nb(acq_cash, -fees_paid)
if final_acq_cash < 0 and is_less_nb(_account_state.free_cash, -final_acq_cash):
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state
# Create a filled order
order_result = OrderResult(
float(size_limit),
float(adj_price),
float(fees_paid),
OrderSide.Sell,
OrderStatus.Filled,
-1,
)
# Update the current account state
new_cash = _account_state.cash + final_acq_cash
new_position = add_nb(_account_state.position, -size_limit)
new_pos_fraction = abs(new_position) / abs(_account_state.position)
new_debt = new_pos_fraction * _account_state.debt
new_locked_cash = new_pos_fraction * _account_state.locked_cash
size_fraction = size_limit / _account_state.position
released_debt = size_fraction * _account_state.debt
new_free_cash = add_nb(_account_state.free_cash, final_acq_cash - released_debt)
new_account_state = AccountState(
cash=float(new_cash),
position=float(new_position),
debt=float(new_debt),
locked_cash=float(new_locked_cash),
free_cash=float(new_free_cash),
)
return order_result, new_account_state
@register_jitted(cache=True)
def approx_short_sell_value_nb(val_price: float, size: float) -> float:
"""Approximate value of a short-sell operation.
Positive value means spending (for sorting reasons)."""
if size == 0:
return 0.0
order_value = abs(size) * val_price
add_free_cash = -order_value
return -add_free_cash
@register_jitted(cache=True)
def short_sell_nb(
account_state: AccountState,
size: float,
price: float,
fees: float = 0.0,
fixed_fees: float = 0.0,
slippage: float = 0.0,
min_size: float = np.nan,
max_size: float = np.nan,
size_granularity: float = np.nan,
leverage: float = 1.0,
price_area_vio_mode: int = PriceAreaVioMode.Ignore,
allow_partial: bool = True,
percent: float = np.nan,
price_area: PriceArea = NoPriceArea,
is_closing_price: bool = False,
) -> tp.Tuple[OrderResult, AccountState]:
"""Open or increase a short position."""
_account_state = cast_account_state_nb(account_state)
# Get cash limit
cash_limit = _account_state.free_cash
if not np.isnan(percent):
cash_limit = cash_limit * percent
if cash_limit <= 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoCash), _account_state
cash_limit = cash_limit * leverage
# Get price adjusted with slippage
adj_price = price * (1 - slippage)
adj_price = check_adj_price_nb(adj_price, price_area, is_closing_price, price_area_vio_mode)
# Get size limit
fees_adj_price = adj_price * (1 + fees)
if fees_adj_price == 0:
max_size_limit = np.inf
else:
max_size_limit = add_nb(cash_limit, -fixed_fees) / (adj_price * (1 + fees))
size_limit = min(size, max_size_limit)
if size_limit <= 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state
# Adjust for granularity
if should_apply_size_granularity_nb(size_limit, size_granularity):
size = apply_size_granularity_nb(size, size_granularity)
size_limit = apply_size_granularity_nb(size_limit, size_granularity)
# Adjust for max size
if not np.isnan(max_size) and size_limit > max_size:
if not allow_partial:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.MaxSizeExceeded), _account_state
size_limit = max_size
if np.isinf(size_limit):
raise ValueError("Attempt to go in short direction infinitely")
# Check against size of zero
if is_close_nb(size_limit, 0):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeZero), _account_state
# Check against minimum size
if not np.isnan(min_size) and is_less_nb(size_limit, min_size):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), _account_state
# Check against partial fill
if np.isfinite(size) and is_less_nb(size_limit, size) and not allow_partial: # np.inf doesn't count
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.PartialFill), _account_state
# Get acquired cash
order_value = size_limit * adj_price
# Update fees
fees_paid = order_value * fees + fixed_fees
# Get final cash by subtracting costs
final_acq_cash = add_nb(order_value, -fees_paid)
if final_acq_cash < 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state
# Create a filled order
order_result = OrderResult(
float(size_limit),
float(adj_price),
float(fees_paid),
OrderSide.Sell,
OrderStatus.Filled,
-1,
)
# Update the current account state
new_cash = _account_state.cash + final_acq_cash
new_position = _account_state.position - size_limit
new_debt = _account_state.debt + order_value
if np.isinf(leverage):
if np.isinf(_account_state.free_cash):
raise ValueError("Leverage must be finite when _account_state.free_cash is infinite")
if is_close_or_less_nb(_account_state.free_cash, fees_paid):
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state
leverage_ = order_value / (_account_state.free_cash - fees_paid)
else:
leverage_ = float(leverage)
new_locked_cash = _account_state.locked_cash + order_value / leverage_
new_free_cash = add_nb(_account_state.free_cash, -order_value / leverage_ - fees_paid)
new_account_state = AccountState(
cash=float(new_cash),
position=float(new_position),
debt=float(new_debt),
locked_cash=float(new_locked_cash),
free_cash=float(new_free_cash),
)
return order_result, new_account_state
@register_jitted(cache=True)
def approx_short_buy_value_nb(position: float, debt: float, locked_cash: float, val_price: float, size: float) -> float:
"""Approximate value of a short-buy operation.
Positive value means spending (for sorting reasons)."""
if size == 0 or position == 0:
return 0.0
size_limit = min(abs(size), abs(position))
order_value = size_limit * val_price
size_fraction = size_limit / abs(position)
released_debt = size_fraction * debt
released_cash = size_fraction * locked_cash
add_free_cash = released_cash + released_debt - order_value
return -add_free_cash
@register_jitted(cache=True)
def short_buy_nb(
account_state: AccountState,
size: float,
price: float,
fees: float = 0.0,
fixed_fees: float = 0.0,
slippage: float = 0.0,
min_size: float = np.nan,
max_size: float = np.nan,
size_granularity: float = np.nan,
price_area_vio_mode: int = PriceAreaVioMode.Ignore,
allow_partial: bool = True,
percent: float = np.nan,
price_area: PriceArea = NoPriceArea,
is_closing_price: bool = False,
) -> tp.Tuple[OrderResult, AccountState]:
"""Decrease or close a short position."""
_account_state = cast_account_state_nb(account_state)
# Check for open position
if _account_state.position == 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoOpenPosition), _account_state
# Get cash limit
cash_limit = _account_state.free_cash + _account_state.debt + _account_state.locked_cash
if cash_limit <= 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.NoCash), _account_state
# Get size limit
size_limit = min(size, abs(_account_state.position))
if not np.isnan(percent):
size_limit = size_limit * percent
# Adjust for granularity
if should_apply_size_granularity_nb(size_limit, size_granularity):
size_limit = apply_size_granularity_nb(size_limit, size_granularity)
# Adjust for max size
if not np.isnan(max_size) and size_limit > max_size:
if not allow_partial:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.MaxSizeExceeded), _account_state
size_limit = max_size
# Get price adjusted with slippage
adj_price = price * (1 + slippage)
adj_price = check_adj_price_nb(adj_price, price_area, is_closing_price, price_area_vio_mode)
# Get cash required to complete this order
if np.isinf(size_limit):
req_cash = np.inf
req_fees = np.inf
else:
order_value = size_limit * adj_price
req_fees = order_value * fees + fixed_fees
req_cash = order_value + req_fees
if is_close_or_less_nb(req_cash, cash_limit):
# Sufficient amount of cash
final_size = size_limit
fees_paid = req_fees
else:
# Insufficient amount of cash, size will be less than requested
# For fees of 10% and 1$ per transaction, you can buy for 90$ (new_req_cash)
# to spend 100$ (cash_limit) in total
max_req_cash = add_nb(cash_limit, -fixed_fees) / (1 + fees)
if max_req_cash <= 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.CantCoverFees), _account_state
max_acq_size = max_req_cash / adj_price
# Adjust for granularity
if should_apply_size_granularity_nb(max_acq_size, size_granularity):
final_size = apply_size_granularity_nb(max_acq_size, size_granularity)
new_order_value = final_size * adj_price
fees_paid = new_order_value * fees + fixed_fees
req_cash = new_order_value + fees_paid
else:
final_size = max_acq_size
fees_paid = cash_limit - max_req_cash
req_cash = cash_limit
# Check size of zero
if is_close_nb(final_size, 0):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeZero), _account_state
# Check against minimum size
if not np.isnan(min_size) and is_less_nb(final_size, min_size):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), _account_state
# Check against partial fill (np.inf doesn't count)
if np.isfinite(size_limit) and is_less_nb(final_size, size_limit) and not allow_partial:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.PartialFill), _account_state
# Create a filled order
order_result = OrderResult(
float(final_size),
float(adj_price),
float(fees_paid),
OrderSide.Buy,
OrderStatus.Filled,
-1,
)
# Update the current account state
new_cash = add_nb(_account_state.cash, -req_cash)
new_position = add_nb(_account_state.position, final_size)
new_pos_fraction = abs(new_position) / abs(_account_state.position)
new_debt = new_pos_fraction * _account_state.debt
new_locked_cash = new_pos_fraction * _account_state.locked_cash
size_fraction = final_size / abs(_account_state.position)
released_debt = size_fraction * _account_state.debt
released_cash = size_fraction * _account_state.locked_cash
new_free_cash = add_nb(_account_state.free_cash, released_cash + released_debt - req_cash)
new_account_state = AccountState(
cash=float(new_cash),
position=float(new_position),
debt=float(new_debt),
locked_cash=float(new_locked_cash),
free_cash=float(new_free_cash),
)
return order_result, new_account_state
@register_jitted(cache=True)
def approx_buy_value_nb(
position: float,
debt: float,
locked_cash: float,
val_price: float,
size: float,
direction: int,
) -> float:
"""Approximate value of a buy operation.
Positive value means spending (for sorting reasons)."""
if position <= 0 and direction == Direction.ShortOnly:
return approx_short_buy_value_nb(position, debt, locked_cash, val_price, size)
if position >= 0:
return approx_long_buy_value_nb(val_price, size)
value1 = approx_short_buy_value_nb(position, debt, locked_cash, val_price, size)
new_size = add_nb(size, -abs(position))
if new_size <= 0:
return value1
value2 = approx_long_buy_value_nb(val_price, new_size)
return value1 + value2
@register_jitted(cache=True)
def buy_nb(
account_state: AccountState,
size: float,
price: float,
direction: int = Direction.Both,
fees: float = 0.0,
fixed_fees: float = 0.0,
slippage: float = 0.0,
min_size: float = np.nan,
max_size: float = np.nan,
size_granularity: float = np.nan,
leverage: float = 1.0,
leverage_mode: int = LeverageMode.Lazy,
price_area_vio_mode: int = PriceAreaVioMode.Ignore,
allow_partial: bool = True,
percent: float = np.nan,
price_area: PriceArea = NoPriceArea,
is_closing_price: bool = False,
) -> tp.Tuple[OrderResult, AccountState]:
"""Buy."""
_account_state = cast_account_state_nb(account_state)
if _account_state.position <= 0 and direction == Direction.ShortOnly:
return short_buy_nb(
account_state=_account_state,
size=size,
price=price,
fees=fees,
fixed_fees=fixed_fees,
slippage=slippage,
min_size=min_size,
max_size=max_size,
size_granularity=size_granularity,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
percent=percent,
price_area=price_area,
is_closing_price=is_closing_price,
)
if _account_state.position >= 0:
return long_buy_nb(
account_state=_account_state,
size=size,
price=price,
fees=fees,
fixed_fees=fixed_fees,
slippage=slippage,
min_size=min_size,
max_size=max_size,
size_granularity=size_granularity,
leverage=leverage,
leverage_mode=leverage_mode,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
percent=percent,
price_area=price_area,
is_closing_price=is_closing_price,
)
short_size = min(size, abs(_account_state.position))
if not np.isnan(min_size):
min_size1 = min(min_size, abs(_account_state.position))
else:
min_size1 = np.nan
if not np.isnan(max_size):
max_size1 = min(max_size, abs(_account_state.position))
else:
max_size1 = np.nan
new_order_result1, new_account_state1 = short_buy_nb(
account_state=_account_state,
size=short_size,
price=price,
fees=fees,
fixed_fees=fixed_fees,
slippage=slippage,
min_size=min_size1,
max_size=max_size1,
size_granularity=size_granularity,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
percent=np.nan,
price_area=price_area,
is_closing_price=is_closing_price,
)
if new_order_result1.status != OrderStatus.Filled:
return new_order_result1, _account_state
if new_account_state1.position != 0:
return new_order_result1, new_account_state1
new_size = add_nb(size, -abs(_account_state.position))
if new_size <= 0:
return new_order_result1, new_account_state1
if not np.isnan(min_size):
min_size2 = max(min_size - abs(_account_state.position), 0.0)
else:
min_size2 = np.nan
if not np.isnan(max_size):
max_size2 = max(max_size - abs(_account_state.position), 0.0)
else:
max_size2 = np.nan
new_order_result2, new_account_state2 = long_buy_nb(
account_state=new_account_state1,
size=new_size,
price=price,
fees=fees,
fixed_fees=0.0,
slippage=slippage,
min_size=min_size2,
max_size=max_size2,
size_granularity=size_granularity,
leverage=leverage,
leverage_mode=leverage_mode,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
percent=percent,
price_area=price_area,
is_closing_price=is_closing_price,
)
if new_order_result2.status != OrderStatus.Filled:
if allow_partial or np.isinf(new_size):
if new_order_result2.status_info == OrderStatusInfo.SizeZero:
return new_order_result1, new_account_state1
if new_order_result2.status_info == OrderStatusInfo.NoCash:
return new_order_result1, new_account_state1
return new_order_result2, _account_state
new_order_result = OrderResult(
new_order_result1.size + new_order_result2.size,
new_order_result2.price,
new_order_result1.fees + new_order_result2.fees,
new_order_result2.side,
new_order_result2.status,
new_order_result2.status_info,
)
return new_order_result, new_account_state2
@register_jitted(cache=True)
def approx_sell_value_nb(
position: float,
debt: float,
val_price: float,
size: float,
direction: int,
) -> float:
"""Approximate value of a sell operation.
Positive value means spending (for sorting reasons)."""
if position >= 0 and direction == Direction.LongOnly:
return approx_long_sell_value_nb(position, debt, val_price, size)
if position <= 0:
return approx_short_sell_value_nb(val_price, size)
value1 = approx_long_sell_value_nb(position, debt, val_price, size)
new_size = add_nb(size, -abs(position))
if new_size <= 0:
return value1
value2 = approx_short_sell_value_nb(val_price, new_size)
return value1 + value2
@register_jitted(cache=True)
def sell_nb(
account_state: AccountState,
size: float,
price: float,
direction: int = Direction.Both,
fees: float = 0.0,
fixed_fees: float = 0.0,
slippage: float = 0.0,
min_size: float = np.nan,
max_size: float = np.nan,
size_granularity: float = np.nan,
leverage: float = 1.0,
price_area_vio_mode: int = PriceAreaVioMode.Ignore,
allow_partial: bool = True,
percent: float = np.nan,
price_area: PriceArea = NoPriceArea,
is_closing_price: bool = False,
) -> tp.Tuple[OrderResult, AccountState]:
"""Sell."""
_account_state = cast_account_state_nb(account_state)
if _account_state.position >= 0 and direction == Direction.LongOnly:
return long_sell_nb(
account_state=_account_state,
size=size,
price=price,
fees=fees,
fixed_fees=fixed_fees,
slippage=slippage,
min_size=min_size,
max_size=max_size,
size_granularity=size_granularity,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
percent=percent,
price_area=price_area,
is_closing_price=is_closing_price,
)
if _account_state.position <= 0:
return short_sell_nb(
account_state=_account_state,
size=size,
price=price,
fees=fees,
fixed_fees=fixed_fees,
slippage=slippage,
min_size=min_size,
max_size=max_size,
size_granularity=size_granularity,
leverage=leverage,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
percent=percent,
price_area=price_area,
is_closing_price=is_closing_price,
)
long_size = min(size, _account_state.position)
if not np.isnan(min_size):
min_size1 = min(min_size, _account_state.position)
else:
min_size1 = np.nan
if not np.isnan(max_size):
max_size1 = min(max_size, _account_state.position)
else:
max_size1 = np.nan
new_order_result1, new_account_state1 = long_sell_nb(
account_state=_account_state,
size=long_size,
price=price,
fees=fees,
fixed_fees=fixed_fees,
slippage=slippage,
min_size=min_size1,
max_size=max_size1,
size_granularity=size_granularity,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
percent=np.nan,
price_area=price_area,
is_closing_price=is_closing_price,
)
if new_order_result1.status != OrderStatus.Filled:
return new_order_result1, _account_state
if new_account_state1.position != 0:
return new_order_result1, new_account_state1
new_size = add_nb(size, -abs(_account_state.position))
if new_size <= 0:
return new_order_result1, new_account_state1
if not np.isnan(min_size):
min_size2 = max(min_size - _account_state.position, 0.0)
else:
min_size2 = np.nan
if not np.isnan(max_size):
max_size2 = max(max_size - _account_state.position, 0.0)
else:
max_size2 = np.nan
new_order_result2, new_account_state2 = short_sell_nb(
account_state=new_account_state1,
size=new_size,
price=price,
fees=fees,
fixed_fees=0.0,
slippage=slippage,
min_size=min_size2,
max_size=max_size2,
size_granularity=size_granularity,
leverage=leverage,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
percent=percent,
price_area=price_area,
is_closing_price=is_closing_price,
)
if new_order_result2.status != OrderStatus.Filled:
if allow_partial or np.isinf(new_size):
if new_order_result2.status_info == OrderStatusInfo.SizeZero:
return new_order_result1, new_account_state1
if new_order_result2.status_info == OrderStatusInfo.NoCash:
return new_order_result1, new_account_state1
return new_order_result2, _account_state
new_order_result = OrderResult(
new_order_result1.size + new_order_result2.size,
new_order_result2.price,
new_order_result1.fees + new_order_result2.fees,
new_order_result2.side,
new_order_result2.status,
new_order_result2.status_info,
)
return new_order_result, new_account_state2
@register_jitted(cache=True)
def update_value_nb(
cash_before: float,
cash_now: float,
position_before: float,
position_now: float,
val_price_before: float,
val_price_now: float,
value_before: float,
) -> float:
"""Update valuation price and value."""
cash_flow = cash_now - cash_before
if position_before != 0:
asset_value_before = position_before * val_price_before
else:
asset_value_before = 0.0
if position_now != 0:
asset_value_now = position_now * val_price_now
else:
asset_value_now = 0.0
asset_value_diff = asset_value_now - asset_value_before
value_now = value_before + cash_flow + asset_value_diff
return value_now
@register_jitted(cache=True)
def get_diraware_size_nb(size: float, direction: int) -> float:
"""Get direction-aware size."""
if direction == Direction.ShortOnly:
return size * -1
return size
@register_jitted(cache=True)
def resolve_size_nb(
size: float,
size_type: int,
position: float,
val_price: float,
value: float,
target_size_type: int = SizeType.Amount,
as_requirement: bool = False,
) -> tp.Tuple[float, float]:
"""Resolve size into an absolute amount of assets and percentage of resources.
Percentage is only set if the option `SizeType.Percent(100)` is used."""
percent = np.nan
if size_type == target_size_type:
return float(size), percent
if size_type == SizeType.ValuePercent100:
if size_type == target_size_type:
return float(size), percent
size /= 100
size_type = SizeType.ValuePercent
if size_type == SizeType.TargetPercent100:
if size_type == target_size_type:
return float(size), percent
size /= 100
size_type = SizeType.TargetPercent
if size_type == SizeType.ValuePercent or size_type == SizeType.TargetPercent:
if size_type == target_size_type:
return float(size), percent
size *= value
if size_type == SizeType.ValuePercent:
size_type = SizeType.Value
else:
size_type = SizeType.TargetValue
if size_type == SizeType.Value or size_type == SizeType.TargetValue:
if size_type == target_size_type:
return float(size), percent
size /= val_price
if size_type == SizeType.Value:
size_type = SizeType.Amount
else:
size_type = SizeType.TargetAmount
if size_type == SizeType.TargetAmount:
if size_type == target_size_type:
return float(size), percent
if not as_requirement:
size -= position
size_type = SizeType.Amount
if size_type == SizeType.Percent100:
if size_type == target_size_type:
return float(size), percent
size /= 100
size_type = SizeType.Percent
if size_type == SizeType.Percent:
if size_type == target_size_type:
return float(size), percent
percent = abs(size)
if as_requirement:
size = np.nan
else:
size = np.sign(size) * np.inf
size_type = SizeType.Amount
if size_type != target_size_type:
raise ValueError("Cannot convert size to target size type")
if as_requirement:
size = abs(size)
return float(size), percent
@register_jitted(cache=True)
def approx_order_value_nb(
exec_state: ExecState,
size: float,
size_type: int = SizeType.Amount,
direction: int = Direction.Both,
) -> float:
"""Approximate the value of an order.
Assumes that cash is infinite.
Positive value means spending (for sorting reasons)."""
size = get_diraware_size_nb(float(size), direction)
amount_size, _ = resolve_size_nb(
size=size,
size_type=size_type,
position=exec_state.position,
val_price=exec_state.val_price,
value=exec_state.value,
)
if amount_size >= 0:
order_value = approx_buy_value_nb(
position=exec_state.position,
debt=exec_state.debt,
locked_cash=exec_state.locked_cash,
val_price=exec_state.val_price,
size=abs(amount_size),
direction=direction,
)
else:
order_value = approx_sell_value_nb(
position=exec_state.position,
debt=exec_state.debt,
val_price=exec_state.val_price,
size=abs(amount_size),
direction=direction,
)
return order_value
@register_jitted(cache=True)
def execute_order_nb(
exec_state: ExecState,
order: Order,
price_area: PriceArea = NoPriceArea,
update_value: bool = False,
) -> tp.Tuple[OrderResult, ExecState]:
"""Execute an order given the current state.
Args:
exec_state (ExecState): See `vectorbtpro.portfolio.enums.ExecState`.
order (Order): See `vectorbtpro.portfolio.enums.Order`.
price_area (OrderPriceArea): See `vectorbtpro.portfolio.enums.PriceArea`.
update_value (bool): Whether to update the value.
Error is thrown if an input has value that is not expected.
Order is ignored if its execution has no effect on the current balance.
Order is rejected if an input goes over a limit or against a restriction.
"""
# numerical stability
cash = float(exec_state.cash)
if is_close_nb(cash, 0):
cash = 0.0
position = float(exec_state.position)
if is_close_nb(position, 0):
position = 0.0
debt = float(exec_state.debt)
if is_close_nb(debt, 0):
debt = 0.0
locked_cash = float(exec_state.locked_cash)
if is_close_nb(locked_cash, 0):
locked_cash = 0.0
free_cash = float(exec_state.free_cash)
if is_close_nb(free_cash, 0):
free_cash = 0.0
val_price = float(exec_state.val_price)
if is_close_nb(val_price, 0):
val_price = 0.0
value = float(exec_state.value)
if is_close_nb(value, 0):
value = 0.0
# Pre-fill account state
account_state = AccountState(
cash=cash,
position=position,
debt=debt,
locked_cash=locked_cash,
free_cash=free_cash,
)
# Check price area
if np.isinf(price_area.open) or price_area.open < 0:
raise ValueError("price_area.open must be either NaN, or finite and 0 or greater")
if np.isinf(price_area.high) or price_area.high < 0:
raise ValueError("price_area.high must be either NaN, or finite and 0 or greater")
if np.isinf(price_area.low) or price_area.low < 0:
raise ValueError("price_area.low must be either NaN, or finite and 0 or greater")
if np.isinf(price_area.close) or price_area.close < 0:
raise ValueError("price_area.close must be either NaN, or finite and 0 or greater")
# Resolve price
order_price = order.price
is_closing_price = False
if np.isinf(order_price):
if order_price > 0:
order_price = price_area.close
is_closing_price = True
else:
order_price = price_area.open
elif order_price == PriceType.NextOpen:
raise ValueError("Next open must be handled higher in the stack")
elif order_price == PriceType.NextClose:
raise ValueError("Next close must be handled higher in the stack")
# Ignore order if size or price is nan
if np.isnan(order.size):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.SizeNaN), exec_state
if np.isnan(order_price):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.PriceNaN), exec_state
# Check account state
if np.isnan(cash):
raise ValueError("exec_state.cash cannot be NaN")
if not np.isfinite(position):
raise ValueError("exec_state.position must be finite")
if not np.isfinite(debt) or debt < 0:
raise ValueError("exec_state.debt must be finite and 0 or greater")
if not np.isfinite(locked_cash) or locked_cash < 0:
raise ValueError("exec_state.locked_cash must be finite and 0 or greater")
if np.isnan(free_cash):
raise ValueError("exec_state.free_cash cannot be NaN")
# Check order
if not np.isfinite(order_price) or order_price < 0:
raise ValueError("order.price must be finite and 0 or greater")
if order.size_type < 0 or order.size_type >= len(SizeType):
raise ValueError("order.size_type is invalid")
if order.direction < 0 or order.direction >= len(Direction):
raise ValueError("order.direction is invalid")
if not np.isfinite(order.fees):
raise ValueError("order.fees must be finite")
if not np.isfinite(order.fixed_fees):
raise ValueError("order.fixed_fees must be finite")
if not np.isfinite(order.slippage) or order.slippage < 0:
raise ValueError("order.slippage must be finite and 0 or greater")
if np.isinf(order.min_size) or order.min_size < 0:
raise ValueError("order.min_size must be either NaN, 0, or greater")
if order.max_size <= 0:
raise ValueError("order.max_size must be either NaN or greater than 0")
if np.isinf(order.size_granularity) or order.size_granularity <= 0:
raise ValueError("order.size_granularity must be either NaN, or finite and greater than 0")
if np.isnan(order.leverage) or order.leverage <= 0:
raise ValueError("order.leverage must be greater than 0")
if order.leverage_mode < 0 or order.leverage_mode >= len(LeverageMode):
raise ValueError("order.leverage_mode is invalid")
if not np.isfinite(order.reject_prob) or order.reject_prob < 0 or order.reject_prob > 1:
raise ValueError("order.reject_prob must be between 0 and 1")
# Positive/negative size in short direction should be treated as negative/positive
order_size = get_diraware_size_nb(order.size, order.direction)
min_order_size = order.min_size
max_order_size = order.max_size
order_size_type = order.size_type
if (
order_size_type == SizeType.ValuePercent100
or order_size_type == SizeType.ValuePercent
or order_size_type == SizeType.TargetPercent100
or order_size_type == SizeType.TargetPercent
or order_size_type == SizeType.Value
or order_size_type == SizeType.TargetValue
):
if np.isinf(val_price) or val_price <= 0:
raise ValueError("val_price_now must be finite and greater than 0")
if np.isnan(val_price):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.ValPriceNaN), exec_state
if (
order_size_type == SizeType.ValuePercent100
or order_size_type == SizeType.ValuePercent
or order_size_type == SizeType.TargetPercent100
or order_size_type == SizeType.TargetPercent
):
if np.isnan(value):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.ValueNaN), exec_state
if value <= 0:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.ValueZeroNeg), exec_state
order_size, percent = resolve_size_nb(
size=order_size,
size_type=order_size_type,
position=position,
val_price=val_price,
value=value,
)
if not np.isnan(min_order_size):
min_order_size, min_percent = resolve_size_nb(
size=min_order_size,
size_type=order_size_type,
position=position,
val_price=val_price,
value=value,
as_requirement=True,
)
if not np.isnan(percent) and not np.isnan(min_percent) and is_less_nb(percent, min_percent):
return order_not_filled_nb(OrderStatus.Ignored, OrderStatusInfo.MinSizeNotReached), exec_state
if not np.isnan(max_order_size):
max_order_size, max_percent = resolve_size_nb(
size=max_order_size,
size_type=order_size_type,
position=position,
val_price=val_price,
value=value,
as_requirement=True,
)
if not np.isnan(percent) and not np.isnan(max_percent) and is_less_nb(max_percent, percent):
percent = max_percent
if order_size >= 0:
order_result, new_account_state = buy_nb(
account_state=account_state,
size=order_size,
price=order_price,
direction=order.direction,
fees=order.fees,
fixed_fees=order.fixed_fees,
slippage=order.slippage,
min_size=min_order_size,
max_size=max_order_size,
size_granularity=order.size_granularity,
leverage=order.leverage,
leverage_mode=order.leverage_mode,
price_area_vio_mode=order.price_area_vio_mode,
allow_partial=order.allow_partial,
percent=percent,
price_area=price_area,
is_closing_price=is_closing_price,
)
else:
order_result, new_account_state = sell_nb(
account_state=account_state,
size=-order_size,
price=order_price,
direction=order.direction,
fees=order.fees,
fixed_fees=order.fixed_fees,
slippage=order.slippage,
min_size=min_order_size,
max_size=max_order_size,
size_granularity=order.size_granularity,
leverage=order.leverage,
price_area_vio_mode=order.price_area_vio_mode,
allow_partial=order.allow_partial,
percent=percent,
price_area=price_area,
is_closing_price=is_closing_price,
)
if order.reject_prob > 0:
if np.random.uniform(0, 1) < order.reject_prob:
return order_not_filled_nb(OrderStatus.Rejected, OrderStatusInfo.RandomEvent), exec_state
if order_result.status == OrderStatus.Rejected and order.raise_reject:
raise_rejected_order_nb(order_result)
is_filled = order_result.status == OrderStatus.Filled
if is_filled and update_value:
new_val_price = order_result.price
new_value = update_value_nb(
cash,
new_account_state.cash,
position,
new_account_state.position,
val_price,
order_result.price,
value,
)
else:
new_val_price = val_price
new_value = value
new_exec_state = ExecState(
cash=new_account_state.cash,
position=new_account_state.position,
debt=new_account_state.debt,
locked_cash=new_account_state.locked_cash,
free_cash=new_account_state.free_cash,
val_price=new_val_price,
value=new_value,
)
return order_result, new_exec_state
@register_jitted(cache=True)
def fill_log_record_nb(
records: tp.RecordArray2d,
r: int,
group: int,
col: int,
i: int,
price_area: PriceArea,
exec_state: ExecState,
order: Order,
order_result: OrderResult,
new_exec_state: ExecState,
order_id: int,
) -> None:
"""Fill a log record."""
records["id"][r, col] = r
records["group"][r, col] = group
records["col"][r, col] = col
records["idx"][r, col] = i
records["price_area_open"][r, col] = price_area.open
records["price_area_high"][r, col] = price_area.high
records["price_area_low"][r, col] = price_area.low
records["price_area_close"][r, col] = price_area.close
records["st0_cash"][r, col] = exec_state.cash
records["st0_position"][r, col] = exec_state.position
records["st0_debt"][r, col] = exec_state.debt
records["st0_locked_cash"][r, col] = exec_state.locked_cash
records["st0_free_cash"][r, col] = exec_state.free_cash
records["st0_val_price"][r, col] = exec_state.val_price
records["st0_value"][r, col] = exec_state.value
records["req_size"][r, col] = order.size
records["req_price"][r, col] = order.price
records["req_size_type"][r, col] = order.size_type
records["req_direction"][r, col] = order.direction
records["req_fees"][r, col] = order.fees
records["req_fixed_fees"][r, col] = order.fixed_fees
records["req_slippage"][r, col] = order.slippage
records["req_min_size"][r, col] = order.min_size
records["req_max_size"][r, col] = order.max_size
records["req_size_granularity"][r, col] = order.size_granularity
records["req_leverage"][r, col] = order.leverage
records["req_leverage_mode"][r, col] = order.leverage_mode
records["req_reject_prob"][r, col] = order.reject_prob
records["req_price_area_vio_mode"][r, col] = order.price_area_vio_mode
records["req_allow_partial"][r, col] = order.allow_partial
records["req_raise_reject"][r, col] = order.raise_reject
records["req_log"][r, col] = order.log
records["res_size"][r, col] = order_result.size
records["res_price"][r, col] = order_result.price
records["res_fees"][r, col] = order_result.fees
records["res_side"][r, col] = order_result.side
records["res_status"][r, col] = order_result.status
records["res_status_info"][r, col] = order_result.status_info
records["st1_cash"][r, col] = new_exec_state.cash
records["st1_position"][r, col] = new_exec_state.position
records["st1_debt"][r, col] = new_exec_state.debt
records["st1_locked_cash"][r, col] = new_exec_state.locked_cash
records["st1_free_cash"][r, col] = new_exec_state.free_cash
records["st1_val_price"][r, col] = new_exec_state.val_price
records["st1_value"][r, col] = new_exec_state.value
records["order_id"][r, col] = order_id
@register_jitted(cache=True)
def fill_order_record_nb(records: tp.RecordArray2d, r: int, col: int, i: int, order_result: OrderResult) -> None:
"""Fill an order record."""
records["id"][r, col] = r
records["col"][r, col] = col
records["idx"][r, col] = i
records["size"][r, col] = order_result.size
records["price"][r, col] = order_result.price
records["fees"][r, col] = order_result.fees
records["side"][r, col] = order_result.side
@register_jitted(cache=True)
def raise_rejected_order_nb(order_result: OrderResult) -> None:
"""Raise an `vectorbtpro.portfolio.enums.RejectedOrderError`."""
if order_result.status_info == OrderStatusInfo.SizeNaN:
raise RejectedOrderError("Size is NaN")
if order_result.status_info == OrderStatusInfo.PriceNaN:
raise RejectedOrderError("Price is NaN")
if order_result.status_info == OrderStatusInfo.ValPriceNaN:
raise RejectedOrderError("Asset valuation price is NaN")
if order_result.status_info == OrderStatusInfo.ValueNaN:
raise RejectedOrderError("Asset/group value is NaN")
if order_result.status_info == OrderStatusInfo.ValueZeroNeg:
raise RejectedOrderError("Asset/group value is zero or negative")
if order_result.status_info == OrderStatusInfo.SizeZero:
raise RejectedOrderError("Size is zero")
if order_result.status_info == OrderStatusInfo.NoCash:
raise RejectedOrderError("Not enough cash")
if order_result.status_info == OrderStatusInfo.NoOpenPosition:
raise RejectedOrderError("No open position to reduce/close")
if order_result.status_info == OrderStatusInfo.MaxSizeExceeded:
raise RejectedOrderError("Size is greater than maximum allowed")
if order_result.status_info == OrderStatusInfo.RandomEvent:
raise RejectedOrderError("Random event happened")
if order_result.status_info == OrderStatusInfo.CantCoverFees:
raise RejectedOrderError("Not enough cash to cover fees")
if order_result.status_info == OrderStatusInfo.MinSizeNotReached:
raise RejectedOrderError("Final size is less than minimum allowed")
if order_result.status_info == OrderStatusInfo.PartialFill:
raise RejectedOrderError("Final size is less than requested")
raise RejectedOrderError
@register_jitted(cache=True)
def process_order_nb(
group: int,
col: int,
i: int,
exec_state: ExecState,
order: Order,
price_area: PriceArea = NoPriceArea,
update_value: bool = False,
order_records: tp.Optional[tp.RecordArray2d] = None,
order_counts: tp.Optional[tp.Array1d] = None,
log_records: tp.Optional[tp.RecordArray2d] = None,
log_counts: tp.Optional[tp.Array1d] = None,
) -> tp.Tuple[OrderResult, ExecState]:
"""Process an order by executing it, saving relevant information to the logs, and returning a new state."""
# Execute the order
order_result, new_exec_state = execute_order_nb(
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
)
is_filled = order_result.status == OrderStatus.Filled
if order_records is not None and order_counts is not None:
if is_filled and order_records.shape[0] > 0:
# Fill order record
if order_counts[col] >= order_records.shape[0]:
raise IndexError("order_records index out of range. Set a higher max_order_records.")
fill_order_record_nb(order_records, order_counts[col], col, i, order_result)
order_counts[col] += 1
if log_records is not None and log_counts is not None:
if order.log and log_records.shape[0] > 0:
# Fill log record
if log_counts[col] >= log_records.shape[0]:
raise IndexError("log_records index out of range. Set a higher max_log_records.")
fill_log_record_nb(
log_records,
log_counts[col],
group,
col,
i,
price_area,
exec_state,
order,
order_result,
new_exec_state,
order_counts[col] - 1 if order_counts is not None and is_filled else -1,
)
log_counts[col] += 1
return order_result, new_exec_state
@register_jitted(cache=True)
def order_nb(
size: float = np.inf,
price: float = np.inf,
size_type: int = SizeType.Amount,
direction: int = Direction.Both,
fees: float = 0.0,
fixed_fees: float = 0.0,
slippage: float = 0.0,
min_size: float = np.nan,
max_size: float = np.nan,
size_granularity: float = np.nan,
leverage: float = 1.0,
leverage_mode: int = LeverageMode.Lazy,
reject_prob: float = 0.0,
price_area_vio_mode: int = PriceAreaVioMode.Ignore,
allow_partial: bool = True,
raise_reject: bool = False,
log: bool = False,
) -> Order:
"""Create an order.
See `vectorbtpro.portfolio.enums.Order` for details on arguments."""
return Order(
size=float(size),
price=float(price),
size_type=int(size_type),
direction=int(direction),
fees=float(fees),
fixed_fees=float(fixed_fees),
slippage=float(slippage),
min_size=float(min_size),
max_size=float(max_size),
size_granularity=float(size_granularity),
leverage=float(leverage),
leverage_mode=int(leverage_mode),
reject_prob=float(reject_prob),
price_area_vio_mode=int(price_area_vio_mode),
allow_partial=bool(allow_partial),
raise_reject=bool(raise_reject),
log=bool(log),
)
@register_jitted(cache=True)
def close_position_nb(
price: float = np.inf,
fees: float = 0.0,
fixed_fees: float = 0.0,
slippage: float = 0.0,
min_size: float = np.nan,
max_size: float = np.nan,
size_granularity: float = np.nan,
leverage: float = 1.0,
leverage_mode: int = LeverageMode.Lazy,
reject_prob: float = 0.0,
price_area_vio_mode: int = PriceAreaVioMode.Ignore,
allow_partial: bool = True,
raise_reject: bool = False,
log: bool = False,
) -> Order:
"""Close the current position."""
return order_nb(
size=0.0,
price=price,
size_type=SizeType.TargetAmount,
direction=Direction.Both,
fees=fees,
fixed_fees=fixed_fees,
slippage=slippage,
min_size=min_size,
max_size=max_size,
size_granularity=size_granularity,
leverage=leverage,
leverage_mode=leverage_mode,
reject_prob=reject_prob,
price_area_vio_mode=price_area_vio_mode,
allow_partial=allow_partial,
raise_reject=raise_reject,
log=log,
)
@register_jitted(cache=True)
def order_nothing_nb() -> Order:
"""Convenience function to order nothing."""
return NoOrder
@register_jitted(cache=True)
def check_group_lens_nb(group_lens: tp.GroupLens, n_cols: int) -> None:
"""Check `group_lens`."""
if np.sum(group_lens) != n_cols:
raise ValueError("group_lens has incorrect total number of columns")
@register_jitted(cache=True)
def is_grouped_nb(group_lens: tp.GroupLens) -> bool:
"""Check if columm,ns are grouped, that is, more than one column per group."""
return np.any(group_lens > 1)
@register_jitted(cache=True)
def prepare_records_nb(
target_shape: tp.Shape,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
) -> tp.Tuple[tp.RecordArray2d, tp.RecordArray2d]:
"""Prepare records."""
if max_order_records is None:
order_records = np.empty((target_shape[0], target_shape[1]), dtype=order_dt)
else:
order_records = np.empty((max_order_records, target_shape[1]), dtype=order_dt)
if max_log_records is None:
log_records = np.empty((target_shape[0], target_shape[1]), dtype=log_dt)
else:
log_records = np.empty((max_log_records, target_shape[1]), dtype=log_dt)
return order_records, log_records
@register_jitted(cache=True)
def prepare_last_cash_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
init_cash: tp.FlexArray1d,
) -> tp.Array1d:
"""Prepare `last_cash`."""
if cash_sharing:
last_cash = np.empty(len(group_lens), dtype=float_)
for group in range(len(group_lens)):
last_cash[group] = float(flex_select_1d_pc_nb(init_cash, group))
else:
last_cash = np.empty(target_shape[1], dtype=float_)
for col in range(target_shape[1]):
last_cash[col] = float(flex_select_1d_pc_nb(init_cash, col))
return last_cash
@register_jitted(cache=True)
def prepare_last_position_nb(target_shape: tp.Shape, init_position: tp.FlexArray1d) -> tp.Array1d:
"""Prepare `last_position`."""
last_position = np.empty(target_shape[1], dtype=float_)
for col in range(target_shape[1]):
last_position[col] = float(flex_select_1d_pc_nb(init_position, col))
return last_position
@register_jitted(cache=True)
def prepare_last_value_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
init_cash: tp.FlexArray1d,
init_position: tp.FlexArray1d,
init_price: tp.FlexArray1d,
) -> tp.Array1d:
"""Prepare `last_value`."""
if cash_sharing:
last_value = np.empty(len(group_lens), dtype=float_)
from_col = 0
for group in range(len(group_lens)):
to_col = from_col + group_lens[group]
_init_cash = float(flex_select_1d_pc_nb(init_cash, group))
last_value[group] = _init_cash
for col in range(from_col, to_col):
_init_position = float(flex_select_1d_pc_nb(init_position, col))
_init_price = float(flex_select_1d_pc_nb(init_price, col))
if _init_position != 0:
last_value[group] += _init_position * _init_price
from_col = to_col
else:
last_value = np.empty(target_shape[1], dtype=float_)
for col in range(target_shape[1]):
_init_cash = float(flex_select_1d_pc_nb(init_cash, col))
_init_position = float(flex_select_1d_pc_nb(init_position, col))
_init_price = float(flex_select_1d_pc_nb(init_price, col))
if _init_position == 0:
last_value[col] = _init_cash
else:
last_value[col] = _init_cash + _init_position * _init_price
return last_value
@register_jitted(cache=True)
def prepare_last_pos_info_nb(
target_shape: tp.Shape,
init_position: tp.FlexArray1d,
init_price: tp.FlexArray1d,
fill_pos_info: bool = True,
) -> tp.RecordArray:
"""Prepare `last_pos_info`."""
if fill_pos_info:
last_pos_info = np.empty(target_shape[1], dtype=trade_dt)
last_pos_info["id"][:] = -1
last_pos_info["col"][:] = -1
last_pos_info["size"][:] = np.nan
last_pos_info["entry_order_id"][:] = -1
last_pos_info["entry_idx"][:] = -1
last_pos_info["entry_price"][:] = np.nan
last_pos_info["entry_fees"][:] = np.nan
last_pos_info["exit_order_id"][:] = -1
last_pos_info["exit_idx"][:] = -1
last_pos_info["exit_price"][:] = np.nan
last_pos_info["exit_fees"][:] = np.nan
last_pos_info["pnl"][:] = np.nan
last_pos_info["return"][:] = np.nan
last_pos_info["direction"][:] = -1
last_pos_info["status"][:] = -1
last_pos_info["parent_id"][:] = -1
for col in range(target_shape[1]):
_init_position = float(flex_select_1d_pc_nb(init_position, col))
_init_price = float(flex_select_1d_pc_nb(init_price, col))
if _init_position != 0:
fill_init_pos_info_nb(last_pos_info[col], col, _init_position, _init_price)
else:
last_pos_info = np.empty(0, dtype=trade_dt)
return last_pos_info
@register_jitted
def prepare_sim_out_nb(
order_records: tp.RecordArray2d,
order_counts: tp.Array1d,
log_records: tp.RecordArray2d,
log_counts: tp.Array1d,
cash_deposits: tp.Array2d,
cash_earnings: tp.Array2d,
call_seq: tp.Optional[tp.Array2d] = None,
in_outputs: tp.Optional[tp.NamedTuple] = None,
sim_start: tp.Optional[tp.Array1d] = None,
sim_end: tp.Optional[tp.Array1d] = None,
) -> SimulationOutput:
"""Prepare simulation output."""
order_records_flat = generic_nb.repartition_nb(order_records, order_counts)
log_records_flat = generic_nb.repartition_nb(log_records, log_counts)
return SimulationOutput(
order_records=order_records_flat,
log_records=log_records_flat,
cash_deposits=cash_deposits,
cash_earnings=cash_earnings,
call_seq=call_seq,
in_outputs=in_outputs,
sim_start=sim_start,
sim_end=sim_end,
)
@register_jitted(cache=True)
def get_trade_stats_nb(
size: float,
entry_price: float,
entry_fees: float,
exit_price: float,
exit_fees: float,
direction: int,
) -> tp.Tuple[float, float]:
"""Get trade statistics."""
entry_val = size * entry_price
exit_val = size * exit_price
val_diff = add_nb(exit_val, -entry_val)
if val_diff != 0 and direction == TradeDirection.Short:
val_diff *= -1
pnl = val_diff - entry_fees - exit_fees
if is_close_nb(entry_val, 0):
ret = np.nan
else:
ret = pnl / entry_val
return pnl, ret
@register_jitted(cache=True)
def update_open_pos_info_stats_nb(record: tp.Record, position_now: float, price: float) -> None:
"""Update statistics of an open position record using custom price."""
if record["id"] >= 0 and record["status"] == TradeStatus.Open:
if np.isnan(record["exit_price"]):
exit_price = price
else:
exit_size_sum = record["size"] - abs(position_now)
exit_gross_sum = exit_size_sum * record["exit_price"]
exit_gross_sum += abs(position_now) * price
exit_price = exit_gross_sum / record["size"]
pnl, ret = get_trade_stats_nb(
record["size"],
record["entry_price"],
record["entry_fees"],
exit_price,
record["exit_fees"],
record["direction"],
)
record["pnl"] = pnl
record["return"] = ret
@register_jitted(cache=True)
def fill_init_pos_info_nb(record: tp.Record, col: int, position_now: float, price: float) -> None:
"""Fill position record for an initial position."""
record["id"] = 0
record["col"] = col
record["size"] = abs(position_now)
record["entry_order_id"] = -1
record["entry_idx"] = -1
record["entry_price"] = price
record["entry_fees"] = 0.0
record["exit_order_id"] = -1
record["exit_idx"] = -1
record["exit_price"] = np.nan
record["exit_fees"] = 0.0
if position_now >= 0:
record["direction"] = TradeDirection.Long
else:
record["direction"] = TradeDirection.Short
record["status"] = TradeStatus.Open
record["parent_id"] = record["id"]
# Update open position stats
update_open_pos_info_stats_nb(record, position_now, np.nan)
@register_jitted(cache=True)
def update_pos_info_nb(
record: tp.Record,
i: int,
col: int,
position_before: float,
position_now: float,
order_result: OrderResult,
order_id: int,
) -> None:
"""Update position record after filling an order."""
if order_result.status == OrderStatus.Filled:
if position_before == 0 and position_now != 0:
# New position opened
record["id"] += 1
record["col"] = col
record["size"] = order_result.size
record["entry_order_id"] = order_id
record["entry_idx"] = i
record["entry_price"] = order_result.price
record["entry_fees"] = order_result.fees
record["exit_order_id"] = -1
record["exit_idx"] = -1
record["exit_price"] = np.nan
record["exit_fees"] = 0.0
if order_result.side == OrderSide.Buy:
record["direction"] = TradeDirection.Long
else:
record["direction"] = TradeDirection.Short
record["status"] = TradeStatus.Open
record["parent_id"] = record["id"]
elif position_before != 0 and position_now == 0:
# Position closed
record["exit_order_id"] = order_id
record["exit_idx"] = i
if np.isnan(record["exit_price"]):
exit_price = order_result.price
else:
exit_size_sum = record["size"] - abs(position_before)
exit_gross_sum = exit_size_sum * record["exit_price"]
exit_gross_sum += abs(position_before) * order_result.price
exit_price = exit_gross_sum / record["size"]
record["exit_price"] = exit_price
record["exit_fees"] += order_result.fees
pnl, ret = get_trade_stats_nb(
record["size"],
record["entry_price"],
record["entry_fees"],
record["exit_price"],
record["exit_fees"],
record["direction"],
)
record["pnl"] = pnl
record["return"] = ret
record["status"] = TradeStatus.Closed
elif np.sign(position_before) != np.sign(position_now):
# Position reversed
record["id"] += 1
record["size"] = abs(position_now)
record["entry_order_id"] = order_id
record["entry_idx"] = i
record["entry_price"] = order_result.price
new_pos_fraction = abs(position_now) / abs(position_now - position_before)
record["entry_fees"] = new_pos_fraction * order_result.fees
record["exit_order_id"] = -1
record["exit_idx"] = -1
record["exit_price"] = np.nan
record["exit_fees"] = 0.0
if order_result.side == OrderSide.Buy:
record["direction"] = TradeDirection.Long
else:
record["direction"] = TradeDirection.Short
record["status"] = TradeStatus.Open
record["parent_id"] = record["id"]
else:
# Position changed
if abs(position_before) <= abs(position_now):
# Position increased
entry_gross_sum = record["size"] * record["entry_price"]
entry_gross_sum += order_result.size * order_result.price
entry_price = entry_gross_sum / (record["size"] + order_result.size)
record["entry_price"] = entry_price
record["entry_fees"] += order_result.fees
record["size"] += order_result.size
else:
# Position decreased
record["exit_order_id"] = order_id
if np.isnan(record["exit_price"]):
exit_price = order_result.price
else:
exit_size_sum = record["size"] - abs(position_before)
exit_gross_sum = exit_size_sum * record["exit_price"]
exit_gross_sum += order_result.size * order_result.price
exit_price = exit_gross_sum / (exit_size_sum + order_result.size)
record["exit_price"] = exit_price
record["exit_fees"] += order_result.fees
# Update open position stats
update_open_pos_info_stats_nb(record, position_now, order_result.price)
@register_jitted(cache=True)
def resolve_hl_nb(open, high, low, close):
"""Resolve the current high and low."""
if np.isnan(high):
if np.isnan(open):
high = close
elif np.isnan(close):
high = open
else:
high = max(open, close)
if np.isnan(low):
if np.isnan(open):
low = close
elif np.isnan(close):
low = open
else:
low = min(open, close)
return high, low
@register_jitted(cache=True)
def check_price_hit_nb(
open: float,
high: float,
low: float,
close: float,
price: float,
hit_below: bool = True,
can_use_ohlc: bool = True,
check_open: bool = True,
hard_price: bool = False,
) -> tp.Tuple[float, bool, bool]:
"""Check whether a target price was hit.
If `hard_price` is False, and `can_use_ohlc` and `check_open` are True and the target price
is hit by open, returns open. Otherwise, returns the actual target price.
Returns the stop price, whether it was hit by open, and whether it was hit during this bar."""
high, low = resolve_hl_nb(
open=open,
high=high,
low=low,
close=close,
)
if hit_below:
if can_use_ohlc and check_open and is_close_or_less_nb(open, price):
if hard_price:
return price, True, True
return open, True, True
if is_close_or_less_nb(close, price) or (can_use_ohlc and is_close_or_less_nb(low, price)):
return price, False, True
return price, False, False
if can_use_ohlc and check_open and is_close_or_greater_nb(open, price):
if hard_price:
return price, True, True
return open, True, True
if is_close_or_greater_nb(close, price) or (can_use_ohlc and is_close_or_greater_nb(high, price)):
return price, False, True
return price, False, False
@register_jitted(cache=True)
def resolve_stop_exit_price_nb(
stop_price: float,
close: float,
stop_exit_price: float,
) -> float:
"""Resolve the exit price of a stop order."""
if stop_exit_price == StopExitPrice.Stop or stop_exit_price == StopExitPrice.HardStop:
return float(stop_price)
elif stop_exit_price == StopExitPrice.Close:
return float(close)
elif stop_exit_price < 0:
raise ValueError("Invalid StopExitPrice option")
return float(stop_exit_price)
@register_jitted(cache=True)
def is_limit_active_nb(init_idx: int, init_price: float) -> bool:
"""Check whether a limit order is active."""
return init_idx != -1 and not np.isnan(init_price)
@register_jitted(cache=True)
def is_stop_active_nb(init_idx: int, stop: float) -> bool:
"""Check whether a stop order is active."""
return init_idx != -1 and not np.isnan(stop)
@register_jitted(cache=True)
def is_time_stop_active_nb(init_idx: int, stop: int) -> bool:
"""Check whether a time stop order is active."""
return init_idx != -1 and stop != -1
@register_jitted(cache=True)
def should_update_stop_nb(new_stop: float, upon_stop_update: int) -> bool:
"""Whether to update stop."""
if upon_stop_update == StopUpdateMode.Keep:
return False
if upon_stop_update == StopUpdateMode.Override or upon_stop_update == StopUpdateMode.OverrideNaN:
if not np.isnan(new_stop) or upon_stop_update == StopUpdateMode.OverrideNaN:
return True
return False
raise ValueError("Invalid StopUpdateMode option")
@register_jitted(cache=True)
def should_update_time_stop_nb(new_stop: int, upon_stop_update: int) -> bool:
"""Whether to update time stop."""
if upon_stop_update == StopUpdateMode.Keep:
return False
if upon_stop_update == StopUpdateMode.Override or upon_stop_update == StopUpdateMode.OverrideNaN:
if new_stop != -1 or upon_stop_update == StopUpdateMode.OverrideNaN:
return True
return False
raise ValueError("Invalid StopUpdateMode option")
@register_jitted(cache=True)
def check_limit_expired_nb(
creation_idx: int,
i: int,
tif: int = -1,
expiry: int = -1,
time_delta_format: int = TimeDeltaFormat.Index,
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
) -> tp.Tuple[bool, bool]:
"""Check whether limit is expired by comparing the current index with the creation index.
Returns whether the limit expires already on open, and whether the limit expires during this bar."""
if tif == -1 and expiry == -1:
return False, False
if time_delta_format == TimeDeltaFormat.Rows:
is_expired_on_open = False
is_expired = False
if tif != -1:
if creation_idx + tif <= i:
is_expired_on_open = True
is_expired = True
elif i < creation_idx + tif < i + 1:
is_expired = True
if expiry != -1:
if expiry <= i:
is_expired_on_open = True
is_expired = True
elif i < expiry < i + 1:
is_expired = True
return is_expired_on_open, is_expired
elif time_delta_format == TimeDeltaFormat.Index:
if index is None:
raise ValueError("Must provide index for TimeDeltaFormat.Index")
if freq is None:
raise ValueError("Must provide frequency for TimeDeltaFormat.Index")
is_expired_on_open = False
is_expired = False
if tif != -1:
if index[creation_idx] + tif <= index[i]:
is_expired_on_open = True
is_expired = True
elif index[i] < index[creation_idx] + tif < index[i] + freq:
is_expired = True
if expiry != -1:
if expiry <= index[i]:
is_expired_on_open = True
is_expired = True
elif index[i] < expiry < index[i] + freq:
is_expired = True
return is_expired_on_open, is_expired
else:
raise ValueError("Invalid TimeDeltaFormat option")
@register_jitted(cache=True)
def resolve_limit_price_nb(
init_price: float,
limit_delta: float = np.nan,
delta_format: int = DeltaFormat.Percent,
hit_below: bool = True,
) -> float:
"""Resolve the limit price."""
if delta_format == DeltaFormat.Percent100:
limit_delta /= 100
delta_format = DeltaFormat.Percent
if not np.isnan(limit_delta):
if hit_below:
if np.isinf(limit_delta) and delta_format != DeltaFormat.Target:
if limit_delta > 0:
limit_price = -np.inf
else:
limit_price = np.inf
else:
if delta_format == DeltaFormat.Absolute:
limit_price = init_price - limit_delta
elif delta_format == DeltaFormat.Percent:
limit_price = init_price * (1 - limit_delta)
elif delta_format == DeltaFormat.Target:
limit_price = limit_delta
else:
raise ValueError("Invalid DeltaFormat option")
else:
if np.isinf(limit_delta) and delta_format != DeltaFormat.Target:
if limit_delta < 0:
limit_price = -np.inf
else:
limit_price = np.inf
else:
if delta_format == DeltaFormat.Absolute:
limit_price = init_price + limit_delta
elif delta_format == DeltaFormat.Percent:
limit_price = init_price * (1 + limit_delta)
elif delta_format == DeltaFormat.Target:
limit_price = limit_delta
else:
raise ValueError("Invalid DeltaFormat option")
else:
limit_price = init_price
return limit_price
@register_jitted(cache=True)
def check_limit_hit_nb(
open: float,
high: float,
low: float,
close: float,
price: float,
size: float,
direction: int = Direction.Both,
limit_delta: float = np.nan,
delta_format: int = DeltaFormat.Percent,
limit_reverse: bool = False,
can_use_ohlc: bool = True,
check_open: bool = True,
hard_limit: bool = False,
) -> tp.Tuple[float, bool, bool]:
"""Resolve the limit price using `resolve_limit_price_nb` and check whether it was hit.
Returns the limit price, whether it was hit before open, and whether it was hit during this bar.
If `can_use_ohlc` and `check_open` is True and the stop is hit before open, returns open."""
if size == 0:
raise ValueError("Limit order size cannot be zero")
_size = get_diraware_size_nb(size, direction)
hit_below = (_size > 0 and not limit_reverse) or (_size < 0 and limit_reverse)
limit_price = resolve_limit_price_nb(
init_price=price,
limit_delta=limit_delta,
delta_format=delta_format,
hit_below=hit_below,
)
hit_on_open = False
if can_use_ohlc:
high, low = resolve_hl_nb(
open=open,
high=high,
low=low,
close=close,
)
if hit_below:
if check_open and is_close_or_less_nb(open, limit_price):
hit_on_open = True
hit = True
if not hard_limit:
limit_price = open
else:
hit = is_close_or_less_nb(low, limit_price)
if hit and np.isinf(limit_price):
limit_price = low
else:
if check_open and is_close_or_greater_nb(open, limit_price):
hit_on_open = True
hit = True
if not hard_limit:
limit_price = open
else:
hit = is_close_or_greater_nb(high, limit_price)
if hit and np.isinf(limit_price):
limit_price = high
else:
if hit_below:
hit = is_close_or_less_nb(close, limit_price)
else:
hit = is_close_or_greater_nb(close, limit_price)
if hit and np.isinf(limit_price):
limit_price = close
return limit_price, hit_on_open, hit
@register_jitted(cache=True)
def resolve_limit_order_price_nb(
limit_price: float,
close: float,
limit_order_price: float,
) -> float:
"""Resolve the limit order price of a limit order."""
if limit_order_price == LimitOrderPrice.Limit or limit_order_price == LimitOrderPrice.HardLimit:
return float(limit_price)
elif limit_order_price == LimitOrderPrice.Close:
return float(close)
elif limit_order_price < 0:
raise ValueError("Invalid LimitOrderPrice option")
return float(limit_order_price)
@register_jitted(cache=True)
def resolve_stop_price_nb(
init_price: float,
stop: float,
delta_format: int = DeltaFormat.Percent,
hit_below: bool = True,
) -> float:
"""Resolve the stop price."""
if delta_format == DeltaFormat.Percent100:
stop /= 100
delta_format = DeltaFormat.Percent
if hit_below:
if delta_format == DeltaFormat.Absolute:
stop_price = init_price - abs(stop)
elif delta_format == DeltaFormat.Percent:
stop_price = init_price * (1 - abs(stop))
elif delta_format == DeltaFormat.Target:
stop_price = stop
else:
raise ValueError("Invalid DeltaFormat option")
else:
if delta_format == DeltaFormat.Absolute:
stop_price = init_price + abs(stop)
elif delta_format == DeltaFormat.Percent:
stop_price = init_price * (1 + abs(stop))
elif delta_format == DeltaFormat.Target:
stop_price = stop
else:
raise ValueError("Invalid DeltaFormat option")
return stop_price
@register_jitted(cache=True)
def check_stop_hit_nb(
open: float,
high: float,
low: float,
close: float,
is_position_long: bool,
init_price: float,
stop: float,
delta_format: int = DeltaFormat.Percent,
hit_below: bool = True,
can_use_ohlc: bool = True,
check_open: bool = True,
hard_stop: bool = False,
) -> tp.Tuple[float, bool, bool]:
"""Resolve the stop price using `resolve_stop_price_nb` and check whether it was hit.
See `check_price_hit_nb`."""
hit_below = (is_position_long and hit_below) or (not is_position_long and not hit_below)
stop_price = resolve_stop_price_nb(
init_price=init_price,
stop=stop,
delta_format=delta_format,
hit_below=hit_below,
)
return check_price_hit_nb(
open=open,
high=high,
low=low,
close=close,
price=stop_price,
hit_below=hit_below,
can_use_ohlc=can_use_ohlc,
check_open=check_open,
hard_price=hard_stop,
)
@register_jitted(cache=True)
def check_td_stop_hit_nb(
init_idx: int,
i: int,
stop: int = -1,
time_delta_format: int = TimeDeltaFormat.Index,
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
) -> tp.Tuple[bool, bool]:
"""Check whether TD stop was hit by comparing the current index with the initial index.
Returns whether the stop was hit already on open, and whether the stop was hit during this bar."""
if stop == -1:
return False, False
if time_delta_format == TimeDeltaFormat.Rows:
is_hit_on_open = False
is_hit = False
if stop != -1:
if init_idx + stop <= i:
is_hit_on_open = True
is_hit = True
elif i < init_idx + stop < i + 1:
is_hit = True
return is_hit_on_open, is_hit
elif time_delta_format == TimeDeltaFormat.Index:
if index is None:
raise ValueError("Must provide index for TimeDeltaFormat.Index")
if freq is None:
raise ValueError("Must provide frequency for TimeDeltaFormat.Index")
is_hit_on_open = False
is_hit = False
if stop != -1:
if index[init_idx] + stop <= index[i]:
is_hit_on_open = True
is_hit = True
elif index[i] < index[init_idx] + stop < index[i] + freq:
is_hit = True
return is_hit_on_open, is_hit
else:
raise ValueError("Invalid TimeDeltaFormat option")
@register_jitted(cache=True)
def check_dt_stop_hit_nb(
i: int,
stop: int = -1,
time_delta_format: int = TimeDeltaFormat.Index,
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
) -> tp.Tuple[bool, bool]:
"""Check whether DT stop was hit by comparing the current index with the initial index.
Returns whether the stop was hit already on open, and whether the stop was hit during this bar."""
if stop == -1:
return False, False
if time_delta_format == TimeDeltaFormat.Rows:
is_hit_on_open = False
is_hit = False
if stop != -1:
if stop <= i:
is_hit_on_open = True
is_hit = True
elif i < stop < i + 1:
is_hit = True
return is_hit_on_open, is_hit
elif time_delta_format == TimeDeltaFormat.Index:
if index is None:
raise ValueError("Must provide index for TimeDeltaFormat.Index")
if freq is None:
raise ValueError("Must provide frequency for TimeDeltaFormat.Index")
is_hit_on_open = False
is_hit = False
if stop != -1:
if stop <= index[i]:
is_hit_on_open = True
is_hit = True
elif index[i] < stop < index[i] + freq:
is_hit = True
return is_hit_on_open, is_hit
else:
raise ValueError("Invalid TimeDeltaFormat option")
@register_jitted(cache=True)
def check_tsl_th_hit_nb(
is_position_long: bool,
init_price: float,
peak_price: float,
threshold: float,
delta_format: int = DeltaFormat.Percent,
) -> bool:
"""Resolve the TSL threshold price using `resolve_stop_price_nb` and check whether it was hit."""
hit_below = not is_position_long
tsl_th_price = resolve_stop_price_nb(
init_price=init_price,
stop=threshold,
delta_format=delta_format,
hit_below=hit_below,
)
if hit_below:
return is_close_or_less_nb(peak_price, tsl_th_price)
return is_close_or_greater_nb(peak_price, tsl_th_price)
@register_jitted(cache=True)
def resolve_dyn_limit_price_nb(val_price: float, price: float, limit_price: float) -> float:
"""Resolve price dynamically.
Uses the valuation price as the left bound and order price as the right bound."""
if np.isinf(limit_price):
if limit_price < 0:
return float(val_price)
return float(price)
return float(limit_price)
@register_jitted(cache=True)
def resolve_dyn_stop_entry_price_nb(val_price: float, price: float, stop_entry_price: float) -> float:
"""Resolve stop entry price dynamically.
Uses the valuation/open price as the left bound and order price as the right bound."""
if np.isinf(stop_entry_price):
if stop_entry_price < 0:
return float(val_price)
return float(price)
if stop_entry_price < 0:
if stop_entry_price == StopEntryPrice.ValPrice:
return float(val_price)
if stop_entry_price == StopEntryPrice.Price:
return float(price)
raise ValueError("Only valuation and order price are supported when setting stop entry price dynamically")
return float(stop_entry_price)
@register_jitted(cache=True)
def get_stop_ladder_exit_size_nb(
stop_: tp.FlexArray2d,
step: int,
col: int,
init_price: float,
init_position: float,
position_now: float,
ladder: int = StopLadderMode.Disabled,
delta_format: int = DeltaFormat.Percent,
hit_below: bool = True,
) -> float:
"""Get the exit size corresponding to the current step in the ladder."""
if ladder == StopLadderMode.Disabled:
raise ValueError("Stop ladder must be enabled to select exit size")
if ladder == StopLadderMode.Dynamic:
raise ValueError("Stop ladder must be static to select exit size")
stop = flex_select_nb(stop_, step, col)
if np.isnan(stop):
return np.nan
last_step = -1
for i in range(step, stop_.shape[0]):
if not np.isnan(flex_select_nb(stop_, i, col)):
last_step = i
else:
break
if last_step == -1:
return np.nan
if step == last_step:
return abs(position_now)
if ladder == StopLadderMode.Uniform:
exit_fraction = 1 / (last_step + 1)
return exit_fraction * abs(init_position)
if ladder == StopLadderMode.AdaptUniform:
exit_fraction = 1 / (last_step + 1 - step)
return exit_fraction * abs(position_now)
hit_below = (init_position >= 0 and hit_below) or (init_position < 0 and not hit_below)
price = resolve_stop_price_nb(
init_price=init_price,
stop=stop,
delta_format=delta_format,
hit_below=hit_below,
)
last_stop = flex_select_nb(stop_, last_step, col)
last_price = resolve_stop_price_nb(
init_price=init_price,
stop=last_stop,
delta_format=delta_format,
hit_below=hit_below,
)
if step == 0:
prev_price = init_price
else:
prev_stop = flex_select_nb(stop_, step - 1, col)
prev_price = resolve_stop_price_nb(
init_price=init_price,
stop=prev_stop,
delta_format=delta_format,
hit_below=hit_below,
)
if ladder == StopLadderMode.Weighted:
exit_fraction = (price - prev_price) / (last_price - init_price)
return exit_fraction * abs(init_position)
if ladder == StopLadderMode.AdaptWeighted:
exit_fraction = (price - prev_price) / (last_price - prev_price)
return exit_fraction * abs(position_now)
raise ValueError("Invalid StopLadderMode option")
@register_jitted(cache=True)
def get_time_stop_ladder_exit_size_nb(
stop_: tp.FlexArray2d,
step: int,
col: int,
init_idx: int,
init_position: float,
position_now: float,
ladder: int = StopLadderMode.Disabled,
time_delta_format: int = TimeDeltaFormat.Index,
index: tp.Optional[tp.Array1d] = None,
) -> float:
"""Get the exit size corresponding to the current step in the ladder."""
if ladder == StopLadderMode.Disabled:
raise ValueError("Stop ladder must be enabled to select exit size")
if ladder == StopLadderMode.Dynamic:
raise ValueError("Stop ladder must be static to select exit size")
if init_idx == -1:
raise ValueError("Initial index of the ladder must be known")
if time_delta_format == TimeDeltaFormat.Index:
if index is None:
raise ValueError("Must provide index for TimeDeltaFormat.Index")
init_idx = index[init_idx]
idx = flex_select_nb(stop_, step, col)
if idx == -1:
return np.nan
last_step = -1
for i in range(step, stop_.shape[0]):
if flex_select_nb(stop_, i, col) != -1:
last_step = i
else:
break
if last_step == -1:
return np.nan
if step == last_step:
return abs(position_now)
if ladder == StopLadderMode.Uniform:
exit_fraction = 1 / (last_step + 1)
return exit_fraction * abs(init_position)
if ladder == StopLadderMode.AdaptUniform:
exit_fraction = 1 / (last_step + 1 - step)
return exit_fraction * abs(position_now)
last_idx = flex_select_nb(stop_, last_step, col)
if step == 0:
prev_idx = init_idx
else:
prev_idx = flex_select_nb(stop_, step - 1, col)
if ladder == StopLadderMode.Weighted:
exit_fraction = (idx - prev_idx) / (last_idx - init_idx)
return exit_fraction * abs(init_position)
if ladder == StopLadderMode.AdaptWeighted:
exit_fraction = (idx - prev_idx) / (last_idx - prev_idx)
return exit_fraction * abs(position_now)
raise ValueError("Invalid StopLadderMode option")
@register_jitted(cache=True)
def is_limit_info_active_nb(limit_info: tp.Record) -> bool:
"""Check whether information record for a limit order is active."""
return is_limit_active_nb(limit_info["init_idx"], limit_info["init_price"])
@register_jitted(cache=True)
def is_stop_info_active_nb(stop_info: tp.Record) -> bool:
"""Check whether information record for a stop order is active."""
return is_stop_active_nb(stop_info["init_idx"], stop_info["stop"])
@register_jitted(cache=True)
def is_time_stop_info_active_nb(time_stop_info: tp.Record) -> bool:
"""Check whether information record for a time stop order is active."""
return is_time_stop_active_nb(time_stop_info["init_idx"], time_stop_info["stop"])
@register_jitted(cache=True)
def is_stop_info_ladder_active_nb(info: tp.Record) -> bool:
"""Check whether information record for a stop ladder is active."""
return info["step"] != -1
@register_jitted(cache=True)
def set_limit_info_nb(
limit_info: tp.Record,
signal_idx: int,
creation_idx: tp.Optional[int] = None,
init_idx: tp.Optional[int] = None,
init_price: float = -np.inf,
init_size: float = np.inf,
init_size_type: int = SizeType.Amount,
init_direction: int = Direction.Both,
init_stop_type: int = -1,
delta: float = np.nan,
delta_format: int = DeltaFormat.Percent,
tif: int = -1,
expiry: int = -1,
time_delta_format: int = TimeDeltaFormat.Index,
reverse: bool = False,
order_price: int = LimitOrderPrice.Limit,
) -> None:
"""Set limit order information.
See `vectorbtpro.portfolio.enums.limit_info_dt`."""
limit_info["signal_idx"] = signal_idx
limit_info["creation_idx"] = creation_idx if creation_idx is not None else signal_idx
limit_info["init_idx"] = init_idx if init_idx is not None else signal_idx
limit_info["init_price"] = init_price
limit_info["init_size"] = init_size
limit_info["init_size_type"] = init_size_type
limit_info["init_direction"] = init_direction
limit_info["init_stop_type"] = init_stop_type
limit_info["delta"] = delta
limit_info["delta_format"] = delta_format
limit_info["tif"] = tif
limit_info["expiry"] = expiry
limit_info["time_delta_format"] = time_delta_format
limit_info["reverse"] = reverse
limit_info["order_price"] = order_price
@register_jitted(cache=True)
def clear_limit_info_nb(limit_info: tp.Record) -> None:
"""Clear limit order information."""
limit_info["signal_idx"] = -1
limit_info["creation_idx"] = -1
limit_info["init_idx"] = -1
limit_info["init_price"] = np.nan
limit_info["init_size"] = np.nan
limit_info["init_size_type"] = -1
limit_info["init_direction"] = -1
limit_info["init_stop_type"] = -1
limit_info["delta"] = np.nan
limit_info["delta_format"] = -1
limit_info["tif"] = -1
limit_info["expiry"] = -1
limit_info["time_delta_format"] = -1
limit_info["reverse"] = False
limit_info["order_price"] = np.nan
@register_jitted(cache=True)
def set_sl_info_nb(
sl_info: tp.Record,
init_idx: int,
init_price: float = -np.inf,
init_position: float = np.nan,
stop: float = np.nan,
exit_price: float = StopExitPrice.Stop,
exit_size: float = np.nan,
exit_size_type: int = -1,
exit_type: int = StopExitType.Close,
order_type: int = OrderType.Market,
limit_delta: float = np.nan,
delta_format: int = DeltaFormat.Percent,
ladder: int = StopLadderMode.Disabled,
step: int = -1,
step_idx: int = -1,
) -> None:
"""Set SL order information.
See `vectorbtpro.portfolio.enums.sl_info_dt`."""
sl_info["init_idx"] = init_idx
sl_info["init_price"] = init_price
sl_info["init_position"] = init_position
sl_info["stop"] = stop
sl_info["exit_price"] = exit_price
sl_info["exit_size"] = exit_size
sl_info["exit_size_type"] = exit_size_type
sl_info["exit_type"] = exit_type
sl_info["order_type"] = order_type
sl_info["limit_delta"] = limit_delta
sl_info["delta_format"] = delta_format
sl_info["ladder"] = ladder
sl_info["step"] = step
sl_info["step_idx"] = step_idx
@register_jitted(cache=True)
def clear_sl_info_nb(sl_info: tp.Record) -> None:
"""Clear SL order information."""
sl_info["init_idx"] = -1
sl_info["init_price"] = np.nan
sl_info["init_position"] = np.nan
sl_info["stop"] = np.nan
sl_info["exit_price"] = -1
sl_info["exit_size"] = np.nan
sl_info["exit_size_type"] = -1
sl_info["exit_type"] = -1
sl_info["order_type"] = -1
sl_info["limit_delta"] = np.nan
sl_info["delta_format"] = -1
sl_info["ladder"] = -1
sl_info["step"] = -1
sl_info["step_idx"] = -1
@register_jitted(cache=True)
def set_tsl_info_nb(
tsl_info: tp.Record,
init_idx: int,
init_price: float = -np.inf,
init_position: float = np.nan,
peak_idx: tp.Optional[int] = None,
peak_price: tp.Optional[float] = None,
stop: float = np.nan,
th: float = np.nan,
exit_price: float = StopExitPrice.Stop,
exit_size: float = np.nan,
exit_size_type: int = -1,
exit_type: int = StopExitType.Close,
order_type: int = OrderType.Market,
limit_delta: float = np.nan,
delta_format: int = DeltaFormat.Percent,
ladder: int = StopLadderMode.Disabled,
step: int = -1,
step_idx: int = -1,
) -> None:
"""Set TSL/TTP order information.
See `vectorbtpro.portfolio.enums.tsl_info_dt`."""
tsl_info["init_idx"] = init_idx
tsl_info["init_price"] = init_price
tsl_info["init_position"] = init_position
tsl_info["peak_idx"] = peak_idx if peak_idx is not None else init_idx
tsl_info["peak_price"] = peak_price if peak_price is not None else init_price
tsl_info["stop"] = stop
tsl_info["th"] = th
tsl_info["exit_price"] = exit_price
tsl_info["exit_size"] = exit_size
tsl_info["exit_size_type"] = exit_size_type
tsl_info["exit_type"] = exit_type
tsl_info["order_type"] = order_type
tsl_info["limit_delta"] = limit_delta
tsl_info["delta_format"] = delta_format
tsl_info["ladder"] = ladder
tsl_info["step"] = step
tsl_info["step_idx"] = step_idx
@register_jitted(cache=True)
def clear_tsl_info_nb(tsl_info: tp.Record) -> None:
"""Clear TSL/TTP order information."""
tsl_info["init_idx"] = -1
tsl_info["init_price"] = np.nan
tsl_info["init_position"] = np.nan
tsl_info["peak_idx"] = -1
tsl_info["peak_price"] = np.nan
tsl_info["stop"] = np.nan
tsl_info["th"] = np.nan
tsl_info["exit_price"] = -1
tsl_info["exit_size"] = np.nan
tsl_info["exit_size_type"] = -1
tsl_info["exit_type"] = -1
tsl_info["order_type"] = -1
tsl_info["limit_delta"] = np.nan
tsl_info["delta_format"] = -1
tsl_info["ladder"] = -1
tsl_info["step"] = -1
tsl_info["step_idx"] = -1
@register_jitted(cache=True)
def set_tp_info_nb(
tp_info: tp.Record,
init_idx: int,
init_price: float = -np.inf,
init_position: float = np.nan,
stop: float = np.nan,
exit_price: float = StopExitPrice.Stop,
exit_size: float = np.nan,
exit_size_type: int = -1,
exit_type: int = StopExitType.Close,
order_type: int = OrderType.Market,
limit_delta: float = np.nan,
delta_format: int = DeltaFormat.Percent,
ladder: int = StopLadderMode.Disabled,
step: int = -1,
step_idx: int = -1,
) -> None:
"""Set TP order information.
See `vectorbtpro.portfolio.enums.tp_info_dt`."""
tp_info["init_idx"] = init_idx
tp_info["init_price"] = init_price
tp_info["init_position"] = init_position
tp_info["stop"] = stop
tp_info["exit_price"] = exit_price
tp_info["exit_size"] = exit_size
tp_info["exit_size_type"] = exit_size_type
tp_info["exit_type"] = exit_type
tp_info["order_type"] = order_type
tp_info["limit_delta"] = limit_delta
tp_info["delta_format"] = delta_format
tp_info["ladder"] = ladder
tp_info["step"] = step
tp_info["step_idx"] = step_idx
@register_jitted(cache=True)
def clear_tp_info_nb(tp_info: tp.Record) -> None:
"""Clear TP order information."""
tp_info["init_idx"] = -1
tp_info["init_price"] = np.nan
tp_info["init_position"] = np.nan
tp_info["stop"] = np.nan
tp_info["exit_price"] = -1
tp_info["exit_size"] = np.nan
tp_info["exit_size_type"] = -1
tp_info["exit_type"] = -1
tp_info["order_type"] = -1
tp_info["limit_delta"] = np.nan
tp_info["delta_format"] = -1
tp_info["ladder"] = -1
tp_info["step"] = -1
tp_info["step_idx"] = -1
@register_jitted(cache=True)
def set_time_info_nb(
time_info: tp.Record,
init_idx: int,
init_position: float = np.nan,
stop: int = -1,
exit_price: float = StopExitPrice.Stop,
exit_size: float = np.nan,
exit_size_type: int = -1,
exit_type: int = StopExitType.Close,
order_type: int = OrderType.Market,
limit_delta: float = np.nan,
delta_format: int = DeltaFormat.Percent,
time_delta_format: int = TimeDeltaFormat.Index,
ladder: int = StopLadderMode.Disabled,
step: int = -1,
step_idx: int = -1,
) -> None:
"""Set time order information.
See `vectorbtpro.portfolio.enums.time_info_dt`."""
time_info["init_idx"] = init_idx
time_info["init_position"] = init_position
time_info["stop"] = stop
time_info["exit_price"] = exit_price
time_info["exit_size"] = exit_size
time_info["exit_size_type"] = exit_size_type
time_info["exit_type"] = exit_type
time_info["order_type"] = order_type
time_info["limit_delta"] = limit_delta
time_info["delta_format"] = delta_format
time_info["time_delta_format"] = time_delta_format
time_info["ladder"] = ladder
time_info["step"] = step
time_info["step_idx"] = step_idx
@register_jitted(cache=True)
def clear_time_info_nb(time_info: tp.Record) -> None:
"""Clear time order information."""
time_info["init_idx"] = -1
time_info["init_position"] = np.nan
time_info["stop"] = -1
time_info["exit_price"] = -1
time_info["exit_size"] = np.nan
time_info["exit_size_type"] = -1
time_info["exit_type"] = -1
time_info["order_type"] = -1
time_info["limit_delta"] = np.nan
time_info["delta_format"] = -1
time_info["time_delta_format"] = -1
time_info["ladder"] = -1
time_info["step"] = -1
time_info["step_idx"] = -1
@register_jitted(cache=True)
def get_limit_info_target_price_nb(limit_info: tp.Record) -> float:
"""Get target price from limit order information."""
if not is_limit_info_active_nb(limit_info):
return np.nan
if limit_info["init_size"] == 0:
raise ValueError("Limit order size cannot be zero")
size = get_diraware_size_nb(limit_info["init_size"], limit_info["init_direction"])
hit_below = (size > 0 and not limit_info["reverse"]) or (size < 0 and limit_info["reverse"])
return resolve_limit_price_nb(
init_price=limit_info["init_price"],
limit_delta=limit_info["delta"],
delta_format=limit_info["delta_format"],
hit_below=hit_below,
)
@register_jitted
def get_sl_info_target_price_nb(sl_info: tp.Record, position_now: float) -> float:
"""Get target price from SL order information."""
if not is_stop_info_active_nb(sl_info):
return np.nan
hit_below = position_now > 0
return resolve_stop_price_nb(
init_price=sl_info["init_price"],
stop=sl_info["stop"],
delta_format=sl_info["delta_format"],
hit_below=hit_below,
)
@register_jitted
def get_tsl_info_target_price_nb(tsl_info: tp.Record, position_now: float) -> float:
"""Get target price from TSL/TTP order information."""
if not is_stop_info_active_nb(tsl_info):
return np.nan
hit_below = position_now > 0
return resolve_stop_price_nb(
init_price=tsl_info["peak_price"],
stop=tsl_info["stop"],
delta_format=tsl_info["delta_format"],
hit_below=hit_below,
)
@register_jitted
def get_tp_info_target_price_nb(tp_info: tp.Record, position_now: float) -> float:
"""Get target price from TP order information."""
if not is_stop_info_active_nb(tp_info):
return np.nan
hit_below = position_now < 0
return resolve_stop_price_nb(
init_price=tp_info["init_price"],
stop=tp_info["stop"],
delta_format=tp_info["delta_format"],
hit_below=hit_below,
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled context helper functions for portfolio simulation."""
from vectorbtpro.base.flex_indexing import flex_select_col_nb
from vectorbtpro.portfolio.nb import records as pf_records_nb
from vectorbtpro.portfolio.nb.core import *
from vectorbtpro.portfolio.nb.iter_ import select_nb
from vectorbtpro.records import nb as records_nb
# ############# Position ############# #
@register_jitted
def get_col_position_nb(c: tp.NamedTuple, col: int) -> float:
"""Get position of a column."""
return c.last_position[col]
@register_jitted
def get_position_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get position of the current column."""
return get_col_position_nb(c, c.col)
@register_jitted
def col_in_position_nb(c: tp.NamedTuple, col: int) -> bool:
"""Check whether a column is in a position."""
position = get_col_position_nb(c, col)
return position != 0
@register_jitted
def in_position_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> bool:
"""Check whether the current column is in a position."""
return col_in_position_nb(c, c.col)
@register_jitted
def col_in_long_position_nb(c: tp.NamedTuple, col: int) -> bool:
"""Check whether a column is in a long position."""
position = get_col_position_nb(c, col)
return position > 0
@register_jitted
def in_long_position_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> bool:
"""Check whether the current column is in a long position."""
return col_in_long_position_nb(c, c.col)
@register_jitted
def col_in_short_position_nb(c: tp.NamedTuple, col: int) -> bool:
"""Check whether a column is in a short position."""
position = get_col_position_nb(c, col)
return position < 0
@register_jitted
def in_short_position_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> bool:
"""Check whether the current column is in a short position."""
return col_in_short_position_nb(c, c.col)
@register_jitted
def get_n_active_positions_nb(
c: tp.Union[
GroupContext,
SegmentContext,
OrderContext,
PostOrderContext,
FlexOrderContext,
SignalSegmentContext,
SignalContext,
PostSignalContext,
],
all_groups: bool = False,
) -> int:
"""Get the number of active positions in the current group (regardless of cash sharing).
To calculate across all groups, set `all_groups` to True."""
n_active_positions = 0
if all_groups:
for col in range(c.target_shape[1]):
if c.last_position[col] != 0:
n_active_positions += 1
else:
for col in range(c.from_col, c.to_col):
if c.last_position[col] != 0:
n_active_positions += 1
return n_active_positions
# ############# Cash ############# #
@register_jitted
def get_col_cash_nb(c: tp.NamedTuple, col: int) -> float:
"""Get cash of a column."""
if c.cash_sharing:
raise ValueError("Cannot get cash of a single column from a group with cash sharing. " "Use get_group_cash_nb.")
return c.last_cash[col]
@register_jitted
def get_group_cash_nb(c: tp.NamedTuple, group: int) -> float:
"""Get cash of a group."""
if c.cash_sharing:
return c.last_cash[group]
cash = 0.0
from_col = 0
for g in range(len(c.group_lens)):
to_col = from_col + c.group_lens[g]
if g == group:
for col in range(from_col, to_col):
cash += c.last_cash[col]
break
from_col = to_col
return cash
@register_jitted
def get_cash_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get cash of the current column or group with cash sharing."""
if c.cash_sharing:
return get_group_cash_nb(c, c.group)
return get_col_cash_nb(c, c.col)
# ############# Debt ############# #
@register_jitted
def get_col_debt_nb(c: tp.NamedTuple, col: int) -> float:
"""Get debt of a column."""
return c.last_debt[col]
@register_jitted
def get_debt_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get debt of the current column."""
return get_col_debt_nb(c, c.col)
# ############# Locked cash ############# #
@register_jitted
def get_col_locked_cash_nb(c: tp.NamedTuple, col: int) -> float:
"""Get locked cash of a column."""
return c.last_locked_cash[col]
@register_jitted
def get_locked_cash_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get locked cash of the current column."""
return get_col_locked_cash_nb(c, c.col)
# ############# Free cash ############# #
@register_jitted
def get_col_free_cash_nb(c: tp.NamedTuple, col: int) -> float:
"""Get free cash of a column."""
if c.cash_sharing:
raise ValueError(
"Cannot get free cash of a single column from a group with cash sharing. " "Use get_group_free_cash_nb."
)
return c.last_free_cash[col]
@register_jitted
def get_group_free_cash_nb(c: tp.NamedTuple, group: int) -> float:
"""Get free cash of a group."""
if c.cash_sharing:
return c.last_free_cash[group]
free_cash = 0.0
from_col = 0
for g in range(len(c.group_lens)):
to_col = from_col + c.group_lens[g]
if g == group:
for col in range(from_col, to_col):
free_cash += c.last_free_cash[col]
break
from_col = to_col
return free_cash
@register_jitted
def get_free_cash_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get free cash of the current column or group with cash sharing."""
if c.cash_sharing:
return get_group_free_cash_nb(c, c.group)
return get_col_free_cash_nb(c, c.col)
@register_jitted
def col_has_free_cash_nb(c: tp.NamedTuple, col: int) -> float:
"""Check whether a column has free cash."""
return get_col_free_cash_nb(c, col) > 0
@register_jitted
def group_has_free_cash_nb(c: tp.NamedTuple, group: int) -> float:
"""Check whether a group has free cash."""
return get_group_free_cash_nb(c, group) > 0
@register_jitted
def has_free_cash_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> bool:
"""Check whether the current column or group with cash sharing has free cash."""
if c.cash_sharing:
return group_has_free_cash_nb(c, c.group)
return col_has_free_cash_nb(c, c.col)
# ############# Valuation price ############# #
@register_jitted
def get_col_val_price_nb(c: tp.NamedTuple, col: int) -> float:
"""Get valuation price of a column."""
return c.last_val_price[col]
@register_jitted
def get_val_price_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get valuation price of the current column."""
return get_col_val_price_nb(c, c.col)
# ############# Value ############# #
@register_jitted
def get_col_value_nb(c: tp.NamedTuple, col: int) -> float:
"""Get value of a column."""
if c.cash_sharing:
raise ValueError(
"Cannot get value of a single column from a group with cash sharing. " "Use get_group_value_nb."
)
return c.last_value[col]
@register_jitted
def get_group_value_nb(c: tp.NamedTuple, group: int) -> float:
"""Get value of a group."""
if c.cash_sharing:
return c.last_value[group]
value = 0.0
from_col = 0
for g in range(len(c.group_lens)):
to_col = from_col + c.group_lens[g]
if g == group:
for col in range(from_col, to_col):
value += c.last_value[col]
break
from_col = to_col
return value
@register_jitted
def get_value_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get value of the current column or group with cash sharing."""
if c.cash_sharing:
return get_group_value_nb(c, c.group)
return get_col_value_nb(c, c.col)
# ############# Leverage ############# #
@register_jitted
def get_col_leverage_nb(c: tp.NamedTuple, col: int) -> float:
"""Get leverage of a column."""
position = get_col_position_nb(c, col)
debt = get_col_debt_nb(c, col)
locked_cash = get_col_locked_cash_nb(c, col)
if locked_cash == 0:
return np.nan
leverage = debt / locked_cash
if position > 0:
leverage += 1
return leverage
@register_jitted
def get_leverage_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get leverage of the current column."""
return get_col_leverage_nb(c, c.col)
# ############# Allocation ############# #
@register_jitted
def get_col_position_value_nb(c: tp.NamedTuple, col: int) -> float:
"""Get position value of a column."""
position = get_col_position_nb(c, col)
val_price = get_col_val_price_nb(c, col)
if position == 0:
return 0.0
return position * val_price
@register_jitted
def get_group_position_value_nb(c: tp.NamedTuple, group: int) -> float:
"""Get position value of a group."""
value = 0.0
from_col = 0
for g in range(len(c.group_lens)):
to_col = from_col + c.group_lens[g]
if g == group:
for col in range(from_col, to_col):
value += get_col_position_value_nb(c, col)
break
from_col = to_col
return value
@register_jitted
def get_position_value_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get position value of the current column."""
return get_col_position_value_nb(c, c.col)
@register_jitted
def get_col_allocation_nb(c: tp.NamedTuple, col: int, group: tp.Optional[int] = None) -> float:
"""Get allocation of a column in its group."""
position_value = get_col_position_value_nb(c, col)
if group is None:
from_col = 0
found = False
for _group in range(len(c.group_lens)):
to_col = from_col + c.group_lens[_group]
if from_col <= col < to_col:
found = True
break
from_col = to_col
if not found:
raise ValueError("Column out of bounds")
else:
_group = group
value = get_group_value_nb(c, _group)
if position_value == 0:
return 0.0
if value <= 0:
return np.nan
return position_value / value
@register_jitted
def get_allocation_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> float:
"""Get allocation of the current column in the current group."""
return get_col_allocation_nb(c, c.col, group=c.group)
# ############# Orders ############# #
@register_jitted
def get_col_order_count_nb(c: tp.NamedTuple, col: int) -> int:
"""Get number of order records for a column."""
return c.order_counts[col]
@register_jitted
def get_order_count_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> int:
"""Get number of order records for the current column."""
return get_col_order_count_nb(c, c.col)
@register_jitted
def get_col_order_records_nb(c: tp.NamedTuple, col: int) -> tp.RecordArray:
"""Get order records for a column."""
order_count = get_col_order_count_nb(c, col)
return c.order_records[:order_count, col]
@register_jitted
def get_order_records_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> tp.RecordArray:
"""Get order records for the current column."""
return get_col_order_records_nb(c, c.col)
@register_jitted
def col_has_orders_nb(c: tp.NamedTuple, col: int) -> bool:
"""Check whether there is any order in a column."""
return get_col_order_count_nb(c, col) > 0
@register_jitted
def has_orders_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> bool:
"""Check whether there is any order in the current column."""
return col_has_orders_nb(c, c.col)
@register_jitted
def get_col_last_order_nb(c: tp.NamedTuple, col: int) -> tp.Record:
"""Get the last order in a column."""
if not col_has_orders_nb(c, col):
raise ValueError("There are no orders. Check for any orders first.")
return get_col_order_records_nb(c, col)[-1]
@register_jitted
def get_last_order_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
) -> tp.Record:
"""Get the last order in the current column."""
return get_col_last_order_nb(c, c.col)
# ############# Order result ############# #
@register_jitted
def order_filled_nb(
c: tp.Union[
PostOrderContext,
PostSignalContext,
]
) -> bool:
"""Check whether the order was filled."""
return c.order_result.status == OrderStatus.Filled
@register_jitted
def order_opened_position_nb(
c: tp.Union[
PostOrderContext,
PostSignalContext,
]
) -> bool:
"""Check whether the order has opened a new position."""
position_now = get_position_nb(c)
return order_reversed_position_nb(c) or (c.position_before == 0 and position_now != 0)
@register_jitted
def order_increased_position_nb(
c: tp.Union[
PostOrderContext,
PostSignalContext,
]
) -> bool:
"""Check whether the order has opened or increased an existing position."""
position_now = get_position_nb(c)
return order_opened_position_nb(c) or (
np.sign(position_now) == np.sign(c.position_before) and abs(position_now) > abs(c.position_before)
)
@register_jitted
def order_decreased_position_nb(
c: tp.Union[
PostOrderContext,
PostSignalContext,
]
) -> bool:
"""Check whether the order has decreased or closed an existing position."""
position_now = get_position_nb(c)
return (
order_closed_position_nb(c)
or order_reversed_position_nb(c)
or (np.sign(position_now) == np.sign(c.position_before) and abs(position_now) < abs(c.position_before))
)
@register_jitted
def order_closed_position_nb(
c: tp.Union[
PostOrderContext,
PostSignalContext,
]
) -> bool:
"""Check whether the order has closed out an existing position."""
position_now = get_position_nb(c)
return c.position_before != 0 and position_now == 0
@register_jitted
def order_reversed_position_nb(
c: tp.Union[
PostOrderContext,
PostSignalContext,
]
) -> bool:
"""Check whether the order has reversed an existing position."""
position_now = get_position_nb(c)
return c.position_before != 0 and position_now != 0 and np.sign(c.position_before) != np.sign(position_now)
# ############# Limit orders ############# #
@register_jitted
def get_col_limit_info_nb(c: tp.NamedTuple, col: int) -> tp.Record:
"""Get limit order information of a column."""
return c.last_limit_info[col]
@register_jitted
def get_limit_info_nb(
c: tp.Union[
SignalContext,
PostSignalContext,
],
) -> tp.Record:
"""Get limit order information of the current column."""
return get_col_limit_info_nb(c, c.col)
@register_jitted
def get_col_limit_target_price_nb(c: tp.NamedTuple, col: int) -> float:
"""Get target price of limit order in a column."""
if not col_in_position_nb(c, col):
return np.nan
limit_info = get_col_limit_info_nb(c, col)
return get_limit_info_target_price_nb(limit_info)
@register_jitted
def get_limit_target_price_nb(
c: tp.Union[
SignalContext,
PostSignalContext,
],
) -> float:
"""Get target price of limit order in the current column."""
return get_col_limit_target_price_nb(c, c.col)
# ############# Stop orders ############# #
@register_jitted
def get_col_sl_info_nb(c: tp.NamedTuple, col: int) -> tp.Record:
"""Get SL order information of a column."""
return c.last_sl_info[col]
@register_jitted
def get_sl_info_nb(
c: tp.Union[
SignalContext,
PostSignalContext,
],
) -> tp.Record:
"""Get SL order information of the current column."""
return get_col_sl_info_nb(c, c.col)
@register_jitted
def get_col_sl_target_price_nb(c: tp.NamedTuple, col: int) -> float:
"""Get target price of SL order in a column."""
if not col_in_position_nb(c, col):
return np.nan
position = get_col_position_nb(c, col)
sl_info = get_col_sl_info_nb(c, col)
return get_sl_info_target_price_nb(sl_info, position)
@register_jitted
def get_sl_target_price_nb(
c: tp.Union[
SignalContext,
PostSignalContext,
],
) -> float:
"""Get target price of SL order in the current column."""
return get_col_sl_target_price_nb(c, c.col)
@register_jitted
def get_col_tsl_info_nb(c: tp.NamedTuple, col: int) -> tp.Record:
"""Get TSL/TTP order information of a column."""
return c.last_tsl_info[col]
@register_jitted
def get_tsl_info_nb(
c: tp.Union[
SignalContext,
PostSignalContext,
],
) -> tp.Record:
"""Get TSL/TTP order information of the current column."""
return get_col_tsl_info_nb(c, c.col)
@register_jitted
def get_col_tsl_target_price_nb(c: tp.NamedTuple, col: int) -> float:
"""Get target price of TSL/TTP order in a column."""
if not col_in_position_nb(c, col):
return np.nan
position = get_col_position_nb(c, col)
tsl_info = get_col_tsl_info_nb(c, col)
return get_tsl_info_target_price_nb(tsl_info, position)
@register_jitted
def get_tsl_target_price_nb(
c: tp.Union[
SignalContext,
PostSignalContext,
],
) -> float:
"""Get target price of TSL/TTP order in the current column."""
return get_col_tsl_target_price_nb(c, c.col)
@register_jitted
def get_col_tp_info_nb(c: tp.NamedTuple, col: int) -> tp.Record:
"""Get TP order information of a column."""
return c.last_tp_info[col]
@register_jitted
def get_tp_info_nb(
c: tp.Union[
SignalContext,
PostSignalContext,
],
) -> tp.Record:
"""Get TP order information of the current column."""
return get_col_tp_info_nb(c, c.col)
@register_jitted
def get_col_tp_target_price_nb(c: tp.NamedTuple, col: int) -> float:
"""Get target price of TP order in a column."""
if not col_in_position_nb(c, col):
return np.nan
position = get_col_position_nb(c, col)
tp_info = get_col_tp_info_nb(c, col)
return get_tp_info_target_price_nb(tp_info, position)
@register_jitted
def get_tp_target_price_nb(
c: tp.Union[
SignalContext,
PostSignalContext,
],
) -> float:
"""Get target price of TP order in the current column."""
return get_col_tp_target_price_nb(c, c.col)
# ############# Trades ############# #
@register_jitted
def get_col_entry_trade_records_nb(
c: tp.NamedTuple,
col: int,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
) -> tp.Array1d:
"""Get entry trade records of a column up to this point."""
order_records = get_col_order_records_nb(c, col)
col_map = records_nb.col_map_nb(order_records["col"], c.target_shape[1])
close = flex_select_col_nb(c.close, col)
entry_trades = pf_records_nb.get_entry_trades_nb(
order_records,
close[: c.i + 1],
col_map,
init_position=init_position,
init_price=init_price,
)
return entry_trades
@register_jitted
def get_entry_trade_records_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
) -> tp.Array1d:
"""Get entry trade records of the current column up to this point."""
return get_col_entry_trade_records_nb(c, c.col, init_position=init_position, init_price=init_price)
@register_jitted
def get_col_exit_trade_records_nb(
c: tp.NamedTuple,
col: int,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
) -> tp.Array1d:
"""Get exit trade records of a column up to this point."""
order_records = get_col_order_records_nb(c, col)
col_map = records_nb.col_map_nb(order_records["col"], c.target_shape[1])
close = flex_select_col_nb(c.close, col)
exit_trades = pf_records_nb.get_exit_trades_nb(
order_records,
close[: c.i + 1],
col_map,
init_position=init_position,
init_price=init_price,
)
return exit_trades
@register_jitted
def get_exit_trade_records_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
) -> tp.Array1d:
"""Get exit trade records of the current column up to this point."""
return get_col_exit_trade_records_nb(c, c.col, init_position=init_position, init_price=init_price)
@register_jitted
def get_col_position_records_nb(
c: tp.NamedTuple,
col: int,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
) -> tp.Array1d:
"""Get position records of a column up to this point."""
exit_trade_records = get_col_exit_trade_records_nb(c, col, init_position=init_position, init_price=init_price)
col_map = records_nb.col_map_nb(exit_trade_records["col"], c.target_shape[1])
position_records = pf_records_nb.get_positions_nb(exit_trade_records, col_map)
return position_records
@register_jitted
def get_position_records_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
) -> tp.Array1d:
"""Get position records of the current column up to this point."""
return get_col_position_records_nb(c, c.col, init_position=init_position, init_price=init_price)
# ############# Simulation ############# #
@register_jitted
def stop_group_sim_nb(c: tp.NamedTuple, group: int) -> None:
"""Stop the simulation of a group."""
c.sim_end[group] = c.i + 1
@register_jitted
def stop_sim_nb(
c: tp.Union[
SegmentContext,
OrderContext,
PostOrderContext,
FlexOrderContext,
SignalSegmentContext,
SignalContext,
PostSignalContext,
],
) -> None:
"""Stop the simulation of the current group."""
stop_group_sim_nb(c, c.group)
# ############# Ordering ############# #
@register_jitted
def get_exec_state_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
val_price: tp.Optional[int] = None,
) -> ExecState:
"""Get execution state."""
if val_price is not None:
_val_price = float(val_price)
value = float(
update_value_nb(
cash_before=get_cash_nb(c),
cash_now=get_cash_nb(c),
position_before=get_position_nb(c),
position_now=get_position_nb(c),
val_price_before=get_val_price_nb(c),
val_price_now=_val_price,
value_before=get_value_nb(c),
)
)
else:
_val_price = float(get_val_price_nb(c))
value = float(get_value_nb(c))
return ExecState(
cash=get_cash_nb(c),
position=get_position_nb(c),
debt=get_debt_nb(c),
locked_cash=get_locked_cash_nb(c),
free_cash=get_free_cash_nb(c),
val_price=_val_price,
value=value,
)
@register_jitted
def get_price_area_nb(c: tp.NamedTuple) -> PriceArea:
"""Get price area."""
return PriceArea(
open=select_nb(c, c.open, i=c.i),
high=select_nb(c, c.high, i=c.i),
low=select_nb(c, c.low, i=c.i),
close=select_nb(c, c.close, i=c.i),
)
@register_jitted
def get_order_size_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
size: float,
size_type: int = SizeType.Amount,
val_price: tp.Optional[int] = None,
) -> float:
"""Get order size."""
exec_state = get_exec_state_nb(c, val_price=val_price)
if size_type == SizeType.Percent100 or size_type == SizeType.Percent:
raise ValueError("Size type Percent(100) is not supported")
return resolve_size_nb(
size=size,
size_type=size_type,
position=get_position_nb(c),
val_price=exec_state.val_price,
value=exec_state.value,
)[0]
@register_jitted
def get_order_value_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
size: float,
size_type: int = SizeType.Amount,
direction: int = Direction.Both,
val_price: tp.Optional[int] = None,
) -> float:
"""Get (approximate) order value."""
exec_state = get_exec_state_nb(c, val_price=val_price)
return approx_order_value_nb(
exec_state,
size=size,
size_type=size_type,
direction=direction,
)
@register_jitted
def get_order_result_nb(
c: tp.Union[
OrderContext,
PostOrderContext,
SignalContext,
PostSignalContext,
],
order: Order,
val_price: tp.Optional[float] = None,
update_value: bool = False,
) -> tp.Tuple[OrderResult, ExecState]:
"""Get order result and new execution state.
Doesn't have any effect on the simulation state."""
exec_state = get_exec_state_nb(c, val_price=val_price)
price_area = get_price_area_nb(c)
return execute_order_nb(
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for portfolio simulation based on an order function."""
from numba import prange
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb
from vectorbtpro.portfolio import chunking as portfolio_ch
from vectorbtpro.portfolio.nb.core import *
from vectorbtpro.portfolio.nb.iter_ import *
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.returns import nb as returns_nb_
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.array_ import insert_argsort_nb
from vectorbtpro.utils.template import RepFunc
@register_jitted(cache=True)
def calc_group_value_nb(
from_col: int,
to_col: int,
cash_now: float,
last_position: tp.Array1d,
last_val_price: tp.Array1d,
) -> float:
"""Calculate group value."""
group_value = cash_now
group_len = to_col - from_col
for k in range(group_len):
col = from_col + k
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
return group_value
@register_jitted
def calc_ctx_group_value_nb(seg_ctx: SegmentContext) -> float:
"""Calculate group value from context.
Accepts `vectorbtpro.portfolio.enums.SegmentContext`.
Best called once from `pre_segment_func_nb`. To set the valuation price, change `last_val_price`
of the context in-place.
!!! note
Cash sharing must be enabled."""
if not seg_ctx.cash_sharing:
raise ValueError("Cash sharing must be enabled")
return calc_group_value_nb(
seg_ctx.from_col,
seg_ctx.to_col,
seg_ctx.last_cash[seg_ctx.group],
seg_ctx.last_position,
seg_ctx.last_val_price,
)
@register_jitted
def sort_call_seq_out_1d_nb(
ctx: SegmentContext,
size: tp.FlexArray1d,
size_type: tp.FlexArray1d,
direction: tp.FlexArray1d,
order_value_out: tp.Array1d,
call_seq_out: tp.Array1d,
) -> None:
"""Sort call sequence `call_seq_out` based on the value of each potential order.
Accepts `vectorbtpro.portfolio.enums.SegmentContext` and other arguments, sorts `call_seq_out` in place,
and returns nothing.
Arrays `size`, `size_type`, and `direction` utilize flexible indexing; they must be 1-dim arrays
that broadcast to `group_len`.
The lengths of `order_value_out` and `call_seq_out` must match the number of columns in the group.
Array `order_value_out` must be empty and will contain sorted order values after execution.
Array `call_seq_out` must be filled with integers ranging from 0 to the number of columns in the group
(in this exact order).
Best called once from `pre_segment_func_nb`.
!!! note
Cash sharing must be enabled and `call_seq_out` must follow `CallSeqType.Default`."""
if not ctx.cash_sharing:
raise ValueError("Cash sharing must be enabled")
group_value_now = calc_ctx_group_value_nb(ctx)
group_len = ctx.to_col - ctx.from_col
for c in range(group_len):
if call_seq_out[c] != c:
raise ValueError("call_seq_out must follow CallSeqType.Default")
col = ctx.from_col + c
_size = flex_select_1d_pc_nb(size, c)
_size_type = flex_select_1d_pc_nb(size_type, c)
_direction = flex_select_1d_pc_nb(direction, c)
if ctx.cash_sharing:
cash_now = ctx.last_cash[ctx.group]
free_cash_now = ctx.last_free_cash[ctx.group]
else:
cash_now = ctx.last_cash[col]
free_cash_now = ctx.last_free_cash[col]
exec_state = ExecState(
cash=cash_now,
position=ctx.last_position[col],
debt=ctx.last_debt[col],
locked_cash=ctx.last_locked_cash[col],
free_cash=free_cash_now,
val_price=ctx.last_val_price[col],
value=group_value_now,
)
order_value_out[c] = approx_order_value_nb(
exec_state,
_size,
_size_type,
_direction,
)
# Sort by order value
insert_argsort_nb(order_value_out, call_seq_out)
@register_jitted
def sort_call_seq_1d_nb(
ctx: SegmentContext,
size: tp.FlexArray1d,
size_type: tp.FlexArray1d,
direction: tp.FlexArray1d,
order_value_out: tp.Array1d,
) -> None:
"""Sort call sequence attached to `vectorbtpro.portfolio.enums.SegmentContext`.
See `sort_call_seq_out_1d_nb`.
!!! note
Can only be used in non-flexible simulation functions."""
if ctx.call_seq_now is None:
raise ValueError("Call sequence array is None. Use sort_call_seq_out_1d_nb to sort a custom array.")
sort_call_seq_out_1d_nb(ctx, size, size_type, direction, order_value_out, ctx.call_seq_now)
@register_jitted
def sort_call_seq_out_nb(
ctx: SegmentContext,
size: tp.FlexArray2d,
size_type: tp.FlexArray2d,
direction: tp.FlexArray2d,
order_value_out: tp.Array1d,
call_seq_out: tp.Array1d,
) -> None:
"""Same as `sort_call_seq_out_1d_nb` but with `size`, `size_type`, and `direction` being 2-dim arrays."""
if not ctx.cash_sharing:
raise ValueError("Cash sharing must be enabled")
group_value_now = calc_ctx_group_value_nb(ctx)
group_len = ctx.to_col - ctx.from_col
for c in range(group_len):
if call_seq_out[c] != c:
raise ValueError("call_seq_out must follow CallSeqType.Default")
col = ctx.from_col + c
_size = select_from_col_nb(ctx, col, size)
_size_type = select_from_col_nb(ctx, col, size_type)
_direction = select_from_col_nb(ctx, col, direction)
if ctx.cash_sharing:
cash_now = ctx.last_cash[ctx.group]
free_cash_now = ctx.last_free_cash[ctx.group]
else:
cash_now = ctx.last_cash[col]
free_cash_now = ctx.last_free_cash[col]
exec_state = ExecState(
cash=cash_now,
position=ctx.last_position[col],
debt=ctx.last_debt[col],
locked_cash=ctx.last_locked_cash[col],
free_cash=free_cash_now,
val_price=ctx.last_val_price[col],
value=group_value_now,
)
order_value_out[c] = approx_order_value_nb(
exec_state,
_size,
_size_type,
_direction,
)
# Sort by order value
insert_argsort_nb(order_value_out, call_seq_out)
@register_jitted
def sort_call_seq_nb(
ctx: SegmentContext,
size: tp.FlexArray2d,
size_type: tp.FlexArray2d,
direction: tp.FlexArray2d,
order_value_out: tp.Array1d,
) -> None:
"""Sort call sequence attached to `vectorbtpro.portfolio.enums.SegmentContext`.
See `sort_call_seq_out_nb`.
!!! note
Can only be used in non-flexible simulation functions."""
if ctx.call_seq_now is None:
raise ValueError("Call sequence array is None. Use sort_call_seq_out_1d_nb to sort a custom array.")
sort_call_seq_out_nb(ctx, size, size_type, direction, order_value_out, ctx.call_seq_now)
@register_jitted
def try_order_nb(ctx: OrderContext, order: Order) -> tp.Tuple[OrderResult, ExecState]:
"""Execute an order without persistence."""
exec_state = ExecState(
cash=ctx.cash_now,
position=ctx.position_now,
debt=ctx.debt_now,
locked_cash=ctx.locked_cash_now,
free_cash=ctx.free_cash_now,
val_price=ctx.val_price_now,
value=ctx.value_now,
)
price_area = PriceArea(
open=flex_select_nb(ctx.open, ctx.i, ctx.col),
high=flex_select_nb(ctx.high, ctx.i, ctx.col),
low=flex_select_nb(ctx.low, ctx.i, ctx.col),
close=flex_select_nb(ctx.close, ctx.i, ctx.col),
)
return execute_order_nb(exec_state=exec_state, order=order, price_area=price_area)
@register_jitted
def no_pre_func_nb(c: tp.NamedTuple, *args) -> tp.Args:
"""Placeholder preprocessing function that forwards received arguments down the stack."""
return args
@register_jitted
def no_order_func_nb(c: OrderContext, *args) -> Order:
"""Placeholder order function that returns no order."""
return NoOrder
@register_jitted
def no_post_func_nb(c: tp.NamedTuple, *args) -> None:
"""Placeholder postprocessing function that returns nothing."""
return None
PreSimFuncT = tp.Callable[[SimulationContext, tp.VarArg()], tp.Args]
PostSimFuncT = tp.Callable[[SimulationContext, tp.VarArg()], None]
PreGroupFuncT = tp.Callable[[GroupContext, tp.VarArg()], tp.Args]
PostGroupFuncT = tp.Callable[[GroupContext, tp.VarArg()], None]
PreSegmentFuncT = tp.Callable[[SegmentContext, tp.VarArg()], tp.Args]
PostSegmentFuncT = tp.Callable[[SegmentContext, tp.VarArg()], None]
OrderFuncT = tp.Callable[[OrderContext, tp.VarArg()], Order]
PostOrderFuncT = tp.Callable[[PostOrderContext, tp.VarArg()], None]
# %
# %
# %
# @register_jitted
# def pre_sim_func_nb(
# c: SimulationContext,
# *args,
# ) -> tp.Args:
# """Custom simulation pre-processing function."""
# return args
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def post_sim_func_nb(
# c: SimulationContext,
# *args,
# ) -> None:
# """Custom simulation post-processing function."""
# return None
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def pre_group_func_nb(
# c: GroupContext,
# *args,
# ) -> tp.Args:
# """Custom group pre-processing function."""
# return args
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def post_group_func_nb(
# c: GroupContext,
# *args,
# ) -> None:
# """Custom group post-processing function."""
# return None
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def pre_segment_func_nb(
# c: SegmentContext,
# *args,
# ) -> tp.Args:
# """Custom segment pre-processing function."""
# return args
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def post_segment_func_nb(
# c: SegmentContext,
# *args,
# ) -> None:
# """Custom segment post-processing function."""
# return None
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def order_func_nb(
# c: OrderContext,
# *args,
# ) -> Order:
# """Custom order function."""
# return NoOrder
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def post_order_func_nb(
# c: PostOrderContext,
# *args,
# ) -> None:
# """Custom order post-processing function."""
# return None
#
#
# %
# %
# %
# %
# %
# import vectorbtpro as vbt
# from vectorbtpro.portfolio.nb.from_order_func import *
# %? import_lines
#
#
# %
# %? blocks[pre_sim_func_nb_block]
# % blocks["pre_sim_func_nb"]
# %? blocks[post_sim_func_nb_block]
# % blocks["post_sim_func_nb"]
# %? blocks[pre_group_func_nb_block]
# % blocks["pre_group_func_nb"]
# %? blocks[post_group_func_nb_block]
# % blocks["post_group_func_nb"]
# %? blocks[pre_segment_func_nb_block]
# % blocks["pre_segment_func_nb"]
# %? blocks[post_segment_func_nb_block]
# % blocks["post_segment_func_nb"]
# %? blocks[order_func_nb_block]
# % blocks["order_func_nb"]
# %? blocks[post_order_func_nb_block]
# % blocks["post_order_func_nb"]
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
cash_sharing=None,
call_seq=base_ch.array_gl_slicer,
init_cash=RepFunc(portfolio_ch.get_init_cash_slicer),
init_position=base_ch.flex_1d_array_gl_slicer,
init_price=base_ch.flex_1d_array_gl_slicer,
cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer),
cash_earnings=base_ch.flex_array_gl_slicer,
segment_mask=base_ch.FlexArraySlicer(axis=1),
call_pre_segment=None,
call_post_segment=None,
pre_sim_func_nb=None, # % None
pre_sim_args=ch.ArgsTaker(),
post_sim_func_nb=None, # % None
post_sim_args=ch.ArgsTaker(),
pre_group_func_nb=None, # % None
pre_group_args=ch.ArgsTaker(),
post_group_func_nb=None, # % None
post_group_args=ch.ArgsTaker(),
pre_segment_func_nb=None, # % None
pre_segment_args=ch.ArgsTaker(),
post_segment_func_nb=None, # % None
post_segment_args=ch.ArgsTaker(),
order_func_nb=None, # % None
order_args=ch.ArgsTaker(),
post_order_func_nb=None, # % None
post_order_args=ch.ArgsTaker(),
index=None,
freq=None,
open=base_ch.flex_array_gl_slicer,
high=base_ch.flex_array_gl_slicer,
low=base_ch.flex_array_gl_slicer,
close=base_ch.flex_array_gl_slicer,
bm_close=base_ch.flex_array_gl_slicer,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
ffill_val_price=None,
update_value=None,
fill_pos_info=None,
track_value=None,
max_order_records=None,
max_log_records=None,
in_outputs=ch.ArgsTaker(),
),
**portfolio_ch.merge_sim_outs_config,
setup_id=None, # %? line.replace("None", task_id)
)
@register_jitted(
tags={"can_parallel"},
cache=False, # % line.replace("False", "True")
task_id_or_func=None, # %? line.replace("None", task_id)
)
def from_order_func_nb( # %? line.replace("from_order_func_nb", new_func_name)
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
call_seq: tp.Optional[tp.Array2d] = None,
init_cash: tp.FlexArray1dLike = 100.0,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_earnings: tp.FlexArray2dLike = 0.0,
segment_mask: tp.FlexArray2dLike = True,
call_pre_segment: bool = False,
call_post_segment: bool = False,
pre_sim_func_nb: PreSimFuncT = no_pre_func_nb, # % None
pre_sim_args: tp.Args = (),
post_sim_func_nb: PostSimFuncT = no_post_func_nb, # % None
post_sim_args: tp.Args = (),
pre_group_func_nb: PreGroupFuncT = no_pre_func_nb, # % None
pre_group_args: tp.Args = (),
post_group_func_nb: PostGroupFuncT = no_post_func_nb, # % None
post_group_args: tp.Args = (),
pre_segment_func_nb: PreSegmentFuncT = no_pre_func_nb, # % None
pre_segment_args: tp.Args = (),
post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None
post_segment_args: tp.Args = (),
order_func_nb: OrderFuncT = no_order_func_nb, # % None
order_args: tp.Args = (),
post_order_func_nb: PostOrderFuncT = no_post_func_nb, # % None
post_order_args: tp.Args = (),
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
open: tp.FlexArray2dLike = np.nan,
high: tp.FlexArray2dLike = np.nan,
low: tp.FlexArray2dLike = np.nan,
close: tp.FlexArray2dLike = np.nan,
bm_close: tp.FlexArray2dLike = np.nan,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
ffill_val_price: bool = True,
update_value: bool = False,
fill_pos_info: bool = True,
track_value: bool = True,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
in_outputs: tp.Optional[tp.NamedTuple] = None,
) -> SimulationOutput:
"""Fill order and log records by iterating over a shape and calling a range of user-defined functions.
Starting with initial cash `init_cash`, iterates over each group and column in `target_shape`,
and for each data point, generates an order using `order_func_nb`. Tries then to fulfill that
order. Upon success, updates the current state including the cash balance and the position.
Returns `vectorbtpro.portfolio.enums.SimulationOutput`.
As opposed to `from_order_func_rw_nb`, order processing happens in column-major order.
Column-major order means processing the entire column/group with all rows before moving to the next one.
See [Row- and column-major order](https://en.wikipedia.org/wiki/Row-_and_column-major_order).
Args:
target_shape (tuple): See `vectorbtpro.portfolio.enums.SimulationContext.target_shape`.
group_lens (array_like of int): See `vectorbtpro.portfolio.enums.SimulationContext.group_lens`.
cash_sharing (bool): See `vectorbtpro.portfolio.enums.SimulationContext.cash_sharing`.
call_seq (array_like of int): See `vectorbtpro.portfolio.enums.SimulationContext.call_seq`.
init_cash (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.init_cash`.
init_position (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.init_position`.
init_price (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.init_price`.
cash_deposits (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.cash_deposits`.
cash_earnings (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.cash_earnings`.
segment_mask (array_like of bool): See `vectorbtpro.portfolio.enums.SimulationContext.segment_mask`.
call_pre_segment (bool): See `vectorbtpro.portfolio.enums.SimulationContext.call_pre_segment`.
call_post_segment (bool): See `vectorbtpro.portfolio.enums.SimulationContext.call_post_segment`.
pre_sim_func_nb (callable): Function called before simulation.
Can be used for creation of global arrays and setting the seed.
Must accept `vectorbtpro.portfolio.enums.SimulationContext` and `*pre_sim_args`.
Must return a tuple of any content, which is then passed to `pre_group_func_nb` and
`post_group_func_nb`.
pre_sim_args (tuple): Packed arguments passed to `pre_sim_func_nb`.
post_sim_func_nb (callable): Function called after simulation.
Must accept `vectorbtpro.portfolio.enums.SimulationContext` and `*post_sim_args`.
Must return nothing.
post_sim_args (tuple): Packed arguments passed to `post_sim_func_nb`.
pre_group_func_nb (callable): Function called before each group.
Must accept `vectorbtpro.portfolio.enums.GroupContext`, unpacked tuple from `pre_sim_func_nb`,
and `*pre_group_args`. Must return a tuple of any content, which is then passed to
`pre_segment_func_nb` and `post_segment_func_nb`.
pre_group_args (tuple): Packed arguments passed to `pre_group_func_nb`.
post_group_func_nb (callable): Function called after each group.
Must accept `vectorbtpro.portfolio.enums.GroupContext`, unpacked tuple from `pre_sim_func_nb`,
and `*post_group_args`. Must return nothing.
post_group_args (tuple): Packed arguments passed to `post_group_func_nb`.
pre_segment_func_nb (callable): Function called before each segment.
Called if `segment_mask` or `call_pre_segment` is True.
Must accept `vectorbtpro.portfolio.enums.SegmentContext`, unpacked tuple from `pre_group_func_nb`,
and `*pre_segment_args`. Must return a tuple of any content, which is then passed to
`order_func_nb` and `post_order_func_nb`.
This is the right place to change call sequence and set the valuation price.
Group re-valuation and update of the open position stats happens right after this function,
regardless of whether it has been called.
!!! note
To change the call sequence of a segment, access
`vectorbtpro.portfolio.enums.SegmentContext.call_seq_now` and change it in-place.
Make sure to not generate any new arrays as it may negatively impact performance.
Assigning `vectorbtpro.portfolio.enums.SegmentContext.call_seq_now` as any other context
(named tuple) value is not supported. See `vectorbtpro.portfolio.enums.SegmentContext.call_seq_now`.
!!! note
You can override elements of `last_val_price` to manipulate group valuation.
See `vectorbtpro.portfolio.enums.SimulationContext.last_val_price`.
pre_segment_args (tuple): Packed arguments passed to `pre_segment_func_nb`.
post_segment_func_nb (callable): Function called after each segment.
Called if `segment_mask` or `call_post_segment` is True.
Addition of cash_earnings, the final group re-valuation, and the final update of the open
position stats happens right before this function, regardless of whether it has been called.
The passed context represents the final state of each segment, thus makes sure
to do any changes before this function is called.
Must accept `vectorbtpro.portfolio.enums.SegmentContext`, unpacked tuple from `pre_group_func_nb`,
and `*post_segment_args`. Must return nothing.
post_segment_args (tuple): Packed arguments passed to `post_segment_func_nb`.
order_func_nb (callable): Order generation function.
Used for either generating an order or skipping.
Must accept `vectorbtpro.portfolio.enums.OrderContext`, unpacked tuple from `pre_segment_func_nb`,
and `*order_args`. Must return `vectorbtpro.portfolio.enums.Order`.
!!! note
If the returned order has been rejected, there is no way of issuing a new order.
You should make sure that the order passes, for example, by using `try_order_nb`.
To have a greater freedom in order management, use `from_flex_order_func_nb`.
order_args (tuple): Arguments passed to `order_func_nb`.
post_order_func_nb (callable): Callback that is called after the order has been processed.
Used for checking the order status and doing some post-processing.
Must accept `vectorbtpro.portfolio.enums.PostOrderContext`, unpacked tuple from
`pre_segment_func_nb`, and `*post_order_args`. Must return nothing.
post_order_args (tuple): Arguments passed to `post_order_func_nb`.
index (array): See `vectorbtpro.portfolio.enums.SimulationContext.index`.
freq (int): See `vectorbtpro.portfolio.enums.SimulationContext.freq`.
open (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.open`.
high (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.high`.
low (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.low`.
close (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.close`.
bm_close (array_like of float): See `vectorbtpro.portfolio.enums.SimulationContext.bm_close`.
ffill_val_price (bool): See `vectorbtpro.portfolio.enums.SimulationContext.ffill_val_price`.
update_value (bool): See `vectorbtpro.portfolio.enums.SimulationContext.update_value`.
fill_pos_info (bool): See `vectorbtpro.portfolio.enums.SimulationContext.fill_pos_info`.
track_value (bool): See `vectorbtpro.portfolio.enums.SimulationContext.track_value`.
max_order_records (int): The max number of order records expected to be filled at each column.
max_log_records (int): The max number of log records expected to be filled at each column.
in_outputs (bool): See `vectorbtpro.portfolio.enums.SimulationContext.in_outputs`.
!!! note
Remember that indexing of 2-dim arrays in vectorbt follows that of pandas: `a[i, col]`.
!!! warning
You can only safely access data of columns that are to the left of the current group and
rows that are to the top of the current row within the same group. Other data points have
not been processed yet and thus empty. Accessing them will not trigger any errors or warnings,
but provide you with arbitrary data (see [np.empty](https://numpy.org/doc/stable/reference/generated/numpy.empty.html)).
Call hierarchy:
Like most things in the vectorbt universe, simulation is also done by iterating over a (imaginary) frame.
This frame consists of two dimensions: time (rows) and assets/features (columns).
Each element of this frame is a potential order, which gets generated by calling an order function.
The question is: how do we move across this frame to simulate trading? There are two movement patterns:
column-major (as done by `from_order_func_nb`) and row-major order (as done by `from_order_func_rw_nb`).
In each of these patterns, we are always moving from top to bottom (time axis) and from left to right
(asset/feature axis); the only difference between them is across which axis we are moving faster:
do we want to process each column first (thus assuming that columns are independent) or each row?
Choosing between them is mostly a matter of preference, but it also makes different data being
available when generating an order.
The frame is further divided into "blocks": columns, groups, rows, segments, and elements.
For example, columns can be grouped into groups that may or may not share the same capital.
Regardless of capital sharing, each collection of elements within a group and a time step is called
a segment, which simply defines a single context (such as shared capital) for one or multiple orders.
Each segment can also define a custom sequence (a so-called call sequence) in which orders are executed.
You can imagine each of these blocks as a rectangle drawn over different parts of the frame,
and having its own context and pre/post-processing function. The pre-processing function is a
simple callback that is called before entering the block, and can be provided by the user to, for example,
prepare arrays or do some custom calculations. It must return a tuple (can be empty) that is then unpacked and
passed as arguments to the pre- and postprocessing function coming next in the call hierarchy.
The postprocessing function can be used, for example, to write user-defined arrays such as returns.
```plaintext
1. pre_sim_out = pre_sim_func_nb(SimulationContext, *pre_sim_args)
2. pre_group_out = pre_group_func_nb(GroupContext, *pre_sim_out, *pre_group_args)
3. if call_pre_segment or segment_mask: pre_segment_out = pre_segment_func_nb(SegmentContext, *pre_group_out, *pre_segment_args)
4. if segment_mask: order = order_func_nb(OrderContext, *pre_segment_out, *order_args)
5. if order: post_order_func_nb(PostOrderContext, *pre_segment_out, *post_order_args)
...
6. if call_post_segment or segment_mask: post_segment_func_nb(SegmentContext, *pre_group_out, *post_segment_args)
...
7. post_group_func_nb(GroupContext, *pre_sim_out, *post_group_args)
...
8. post_sim_func_nb(SimulationContext, *post_sim_args)
```
Let's demonstrate a frame with one group of two columns and one group of one column, and the
following call sequence:
```plaintext
array([[0, 1, 0],
[1, 0, 0]])
```
{: loading=lazy style="width:800px;" }
And here is the context information available at each step:
{: loading=lazy style="width:700px;" }
Usage:
* Create a group of three assets together sharing 100$ and simulate an equal-weighted portfolio
that rebalances every second tick, all without leaving Numba:
```pycon
>>> from vectorbtpro import *
>>> @njit
... def pre_sim_func_nb(c):
... print('before simulation')
... # Create a temporary array and pass it down the stack
... order_value_out = np.empty(c.target_shape[1], dtype=float_)
... return (order_value_out,)
>>> @njit
... def pre_group_func_nb(c, order_value_out):
... print('\\tbefore group', c.group)
... # Forward down the stack (you can omit pre_group_func_nb entirely)
... return (order_value_out,)
>>> @njit
... def pre_segment_func_nb(c, order_value_out, size, price, size_type, direction):
... print('\\t\\tbefore segment', c.i)
... for col in range(c.from_col, c.to_col):
... # Here we use order price for group valuation
... c.last_val_price[col] = vbt.pf_nb.select_from_col_nb(c, col, price)
...
... # Reorder call sequence of this segment such that selling orders come first and buying last
... # Rearranges c.call_seq_now based on order value (size, size_type, direction, and val_price)
... # Utilizes flexible indexing using select_from_col_nb (as we did above)
... vbt.pf_nb.sort_call_seq_nb(
... c,
... size,
... size_type,
... direction,
... order_value_out[c.from_col:c.to_col]
... )
... # Forward nothing
... return ()
>>> @njit
... def order_func_nb(c, size, price, size_type, direction, fees, fixed_fees, slippage):
... print('\\t\\t\\tcreating order', c.call_idx, 'at column', c.col)
... # Create and return an order
... return vbt.pf_nb.order_nb(
... size=vbt.pf_nb.select_nb(c, size),
... price=vbt.pf_nb.select_nb(c, price),
... size_type=vbt.pf_nb.select_nb(c, size_type),
... direction=vbt.pf_nb.select_nb(c, direction),
... fees=vbt.pf_nb.select_nb(c, fees),
... fixed_fees=vbt.pf_nb.select_nb(c, fixed_fees),
... slippage=vbt.pf_nb.select_nb(c, slippage)
... )
>>> @njit
... def post_order_func_nb(c):
... print('\\t\\t\\t\\torder status:', c.order_result.status)
... return None
>>> @njit
... def post_segment_func_nb(c, order_value_out):
... print('\\t\\tafter segment', c.i)
... return None
>>> @njit
... def post_group_func_nb(c, order_value_out):
... print('\\tafter group', c.group)
... return None
>>> @njit
... def post_sim_func_nb(c):
... print('after simulation')
... return None
>>> target_shape = (5, 3)
>>> np.random.seed(42)
>>> group_lens = np.array([3]) # one group of three columns
>>> cash_sharing = True
>>> segment_mask = np.array([True, False, True, False, True])[:, None]
>>> price = close = np.random.uniform(1, 10, size=target_shape)
>>> size = np.array([[1 / target_shape[1]]]) # custom flexible arrays must be 2-dim
>>> size_type = np.array([[vbt.pf_enums.SizeType.TargetPercent]])
>>> direction = np.array([[vbt.pf_enums.Direction.LongOnly]])
>>> fees = np.array([[0.001]])
>>> fixed_fees = np.array([[1.]])
>>> slippage = np.array([[0.001]])
>>> sim_out = vbt.pf_nb.from_order_func_nb(
... target_shape,
... group_lens,
... cash_sharing,
... segment_mask=segment_mask,
... pre_sim_func_nb=pre_sim_func_nb,
... post_sim_func_nb=post_sim_func_nb,
... pre_group_func_nb=pre_group_func_nb,
... post_group_func_nb=post_group_func_nb,
... pre_segment_func_nb=pre_segment_func_nb,
... pre_segment_args=(size, price, size_type, direction),
... post_segment_func_nb=post_segment_func_nb,
... order_func_nb=order_func_nb,
... order_args=(size, price, size_type, direction, fees, fixed_fees, slippage),
... post_order_func_nb=post_order_func_nb
... )
before simulation
before group 0
before segment 0
creating order 0 at column 0
order status: 0
creating order 1 at column 1
order status: 0
creating order 2 at column 2
order status: 0
after segment 0
before segment 2
creating order 0 at column 1
order status: 0
creating order 1 at column 2
order status: 0
creating order 2 at column 0
order status: 0
after segment 2
before segment 4
creating order 0 at column 0
order status: 0
creating order 1 at column 2
order status: 0
creating order 2 at column 1
order status: 0
after segment 4
after group 0
after simulation
>>> pd.DataFrame.from_records(sim_out.order_records)
id col idx size price fees side
0 0 0 0 7.626262 4.375232 1.033367 0
1 1 0 2 5.210115 1.524275 1.007942 0
2 2 0 4 7.899568 8.483492 1.067016 1
3 0 1 0 3.488053 9.565985 1.033367 0
4 1 1 2 0.920352 8.786790 1.008087 1
5 2 1 4 10.713236 2.913963 1.031218 0
6 0 2 0 3.972040 7.595533 1.030170 0
7 1 2 2 0.448747 6.403625 1.002874 1
8 2 2 4 12.378281 2.639061 1.032667 0
>>> col_map = vbt.rec_nb.col_map_nb(sim_out.order_records['col'], target_shape[1])
>>> asset_flow = vbt.pf_nb.asset_flow_nb(target_shape, sim_out.order_records, col_map)
>>> assets = vbt.pf_nb.assets_nb(asset_flow)
>>> asset_value = vbt.pf_nb.asset_value_nb(close, assets)
>>> vbt.Scatter(data=asset_value).fig.show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
Note that the last order in a group with cash sharing is always disadvantaged
as it has a bit less funds than the previous orders due to costs, which are not
included when valuating the group.
"""
check_group_lens_nb(group_lens, target_shape[1])
init_cash_ = to_1d_array_nb(np.asarray(init_cash))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
segment_mask_ = to_2d_array_nb(np.asarray(segment_mask))
open_ = to_2d_array_nb(np.asarray(open))
high_ = to_2d_array_nb(np.asarray(high))
low_ = to_2d_array_nb(np.asarray(low))
close_ = to_2d_array_nb(np.asarray(close))
bm_close_ = to_2d_array_nb(np.asarray(bm_close))
order_records, log_records = prepare_records_nb(
target_shape=target_shape,
max_order_records=max_order_records,
max_log_records=max_log_records,
)
last_cash = prepare_last_cash_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
)
last_position = prepare_last_position_nb(
target_shape=target_shape,
init_position=init_position_,
)
last_value = prepare_last_value_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
)
last_pos_info = prepare_last_pos_info_nb(
target_shape,
init_position=init_position_,
init_price=init_price_,
fill_pos_info=fill_pos_info,
)
last_cash_deposits = np.full_like(last_cash, 0.0)
last_val_price = np.full_like(last_position, np.nan)
last_debt = np.full_like(last_position, 0.0)
last_locked_cash = np.full_like(last_position, 0.0)
last_free_cash = last_cash.copy()
prev_close_value = last_value.copy()
last_return = np.full_like(last_cash, np.nan)
order_counts = np.full(target_shape[1], 0, dtype=int_)
log_counts = np.full(target_shape[1], 0, dtype=int_)
temp_call_seq = np.empty(target_shape[1], dtype=int_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(target_shape[0], len(group_lens)),
sim_start=sim_start,
sim_end=sim_end,
)
# Call function before the simulation
pre_sim_ctx = SimulationContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
)
pre_sim_out = pre_sim_func_nb(pre_sim_ctx, *pre_sim_args)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
group_len = to_col - from_col
# Call function before the group
pre_group_ctx = GroupContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
)
pre_group_out = pre_group_func_nb(pre_group_ctx, *pre_sim_out, *pre_group_args)
_sim_start = sim_start_[group]
_sim_end = sim_end_[group]
for i in range(_sim_start, _sim_end):
if call_seq is None:
for c in range(group_len):
temp_call_seq[c] = c
call_seq_now = temp_call_seq[:group_len]
else:
call_seq_now = call_seq[i, from_col:to_col]
if track_value:
# Update valuation price using current open
for col in range(from_col, to_col):
_open = flex_select_nb(open_, i, col)
if not np.isnan(_open) or not ffill_val_price:
last_val_price[col] = _open
# Update previous value, current value, and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(prev_close_value[group], last_value[group])
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(prev_close_value[col], last_value[col])
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
is_segment_active = flex_select_nb(segment_mask_, i, group)
if call_pre_segment or is_segment_active:
# Call function before the segment
pre_seg_ctx = SegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=call_seq_now,
)
pre_segment_out = pre_segment_func_nb(pre_seg_ctx, *pre_group_out, *pre_segment_args)
# Add cash
if cash_sharing:
_cash_deposits = flex_select_nb(cash_deposits_, i, group)
last_cash[group] += _cash_deposits
last_free_cash[group] += _cash_deposits
last_cash_deposits[group] = _cash_deposits
else:
for col in range(from_col, to_col):
_cash_deposits = flex_select_nb(cash_deposits_, i, col)
last_cash[col] += _cash_deposits
last_free_cash[col] += _cash_deposits
last_cash_deposits[col] = _cash_deposits
if track_value:
# Update value and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(
prev_close_value[group],
last_value[group] - last_cash_deposits[group],
)
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(
prev_close_value[col],
last_value[col] - last_cash_deposits[col],
)
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
if is_segment_active:
for k in range(group_len):
if cash_sharing:
c = call_seq_now[k]
if c >= group_len:
raise ValueError("Call index out of bounds of the group")
else:
c = k
col = from_col + c
# Get current values
position_now = last_position[col]
debt_now = last_debt[col]
locked_cash_now = last_locked_cash[col]
val_price_now = last_val_price[col]
pos_info_now = last_pos_info[col]
if cash_sharing:
cash_now = last_cash[group]
free_cash_now = last_free_cash[group]
value_now = last_value[group]
return_now = last_return[group]
cash_deposits_now = last_cash_deposits[group]
else:
cash_now = last_cash[col]
free_cash_now = last_free_cash[col]
value_now = last_value[col]
return_now = last_return[col]
cash_deposits_now = last_cash_deposits[col]
# Generate the next order
order_ctx = OrderContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=call_seq_now,
col=col,
call_idx=k,
cash_now=cash_now,
position_now=position_now,
debt_now=debt_now,
locked_cash_now=locked_cash_now,
free_cash_now=free_cash_now,
val_price_now=val_price_now,
value_now=value_now,
return_now=return_now,
pos_info_now=pos_info_now,
)
order = order_func_nb(order_ctx, *pre_segment_out, *order_args)
if not track_value:
if (
order.size_type == SizeType.Value
or order.size_type == SizeType.TargetValue
or order.size_type == SizeType.TargetPercent
):
raise ValueError("Cannot use size type that depends on not tracked value")
# Process the order
price_area = PriceArea(
open=flex_select_nb(open_, i, col),
high=flex_select_nb(high_, i, col),
low=flex_select_nb(low_, i, col),
close=flex_select_nb(close_, i, col),
)
exec_state = ExecState(
cash=cash_now,
position=position_now,
debt=debt_now,
locked_cash=locked_cash_now,
free_cash=free_cash_now,
val_price=val_price_now,
value=value_now,
)
order_result, new_exec_state = process_order_nb(
group=group,
col=col,
i=i,
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
)
# Update execution state
cash_now = new_exec_state.cash
position_now = new_exec_state.position
debt_now = new_exec_state.debt
locked_cash_now = new_exec_state.locked_cash
free_cash_now = new_exec_state.free_cash
if track_value:
val_price_now = new_exec_state.val_price
value_now = new_exec_state.value
if cash_sharing:
return_now = returns_nb_.get_return_nb(
prev_close_value[group],
value_now - cash_deposits_now,
)
else:
return_now = returns_nb_.get_return_nb(prev_close_value[col], value_now - cash_deposits_now)
# Now becomes last
last_position[col] = position_now
last_debt[col] = debt_now
last_locked_cash[col] = locked_cash_now
if cash_sharing:
last_cash[group] = cash_now
last_free_cash[group] = free_cash_now
else:
last_cash[col] = cash_now
last_free_cash[col] = free_cash_now
if track_value:
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
if cash_sharing:
last_value[group] = value_now
last_return[group] = return_now
else:
last_value[col] = value_now
last_return[col] = return_now
# Update position record
if fill_pos_info:
if order_result.status == OrderStatus.Filled:
if order_counts[col] > 0:
order_id = order_records["id"][order_counts[col] - 1, col]
else:
order_id = -1
update_pos_info_nb(
pos_info_now,
i,
col,
exec_state.position,
position_now,
order_result,
order_id,
)
# Post-order callback
post_order_ctx = PostOrderContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=call_seq_now,
col=col,
call_idx=k,
cash_before=exec_state.cash,
position_before=exec_state.position,
debt_before=exec_state.debt,
locked_cash_before=exec_state.locked_cash,
free_cash_before=exec_state.free_cash,
val_price_before=exec_state.val_price,
value_before=exec_state.value,
order_result=order_result,
cash_now=cash_now,
position_now=position_now,
debt_now=debt_now,
locked_cash_now=locked_cash_now,
free_cash_now=free_cash_now,
val_price_now=val_price_now,
value_now=value_now,
return_now=return_now,
pos_info_now=pos_info_now,
)
post_order_func_nb(post_order_ctx, *pre_segment_out, *post_order_args)
# NOTE: Regardless of segment_mask, we still need to update stats to be accessed by future rows
# Add earnings in cash
for col in range(from_col, to_col):
_cash_earnings = flex_select_nb(cash_earnings_, i, col)
if cash_sharing:
last_cash[group] += _cash_earnings
last_free_cash[group] += _cash_earnings
else:
last_cash[col] += _cash_earnings
last_free_cash[col] += _cash_earnings
if track_value:
# Update valuation price using current close
for col in range(from_col, to_col):
_close = flex_select_nb(close_, i, col)
if not np.isnan(_close) or not ffill_val_price:
last_val_price[col] = _close
# Update previous value, current value, and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(
prev_close_value[group],
last_value[group] - last_cash_deposits[group],
)
prev_close_value[group] = last_value[group]
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(
prev_close_value[col],
last_value[col] - last_cash_deposits[col],
)
prev_close_value[col] = last_value[col]
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
if call_post_segment or is_segment_active:
# Call function before the segment
post_seg_ctx = SegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=call_seq_now,
)
post_segment_func_nb(post_seg_ctx, *pre_group_out, *post_segment_args)
if i >= sim_end_[group] - 1:
break
# Call function after the group
post_group_ctx = GroupContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
)
post_group_func_nb(post_group_ctx, *pre_sim_out, *post_group_args)
# Call function after the simulation
post_sim_ctx = SimulationContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
)
post_sim_func_nb(post_sim_ctx, *post_sim_args)
sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start_,
sim_end=sim_end_,
allow_none=True,
)
return prepare_sim_out_nb(
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
call_seq=call_seq,
in_outputs=in_outputs,
sim_start=sim_start_out,
sim_end=sim_end_out,
)
# %
PreRowFuncT = tp.Callable[[RowContext, tp.VarArg()], tp.Args]
PostRowFuncT = tp.Callable[[RowContext, tp.VarArg()], None]
# %
# %
# %
# @register_jitted
# def pre_row_func_nb(
# c: RowContext,
# *args,
# ) -> tp.Args:
# """Custom row pre-processing function."""
# return args
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def post_row_func_nb(
# c: RowContext,
# *args,
# ) -> None:
# """Custom row post-processing function."""
# return None
#
#
# %
# %
# %
# %
# %
# import vectorbtpro as vbt
# from vectorbtpro.portfolio.nb.from_order_func import *
# %? import_lines
#
#
# %
# %? blocks[pre_sim_func_nb_block]
# % blocks["pre_sim_func_nb"]
# %? blocks[post_sim_func_nb_block]
# % blocks["post_sim_func_nb"]
# %? blocks[pre_row_func_nb_block]
# % blocks["pre_row_func_nb"]
# %? blocks[post_row_func_nb_block]
# % blocks["post_row_func_nb"]
# %? blocks[pre_segment_func_nb_block]
# % blocks["pre_segment_func_nb"]
# %? blocks[post_segment_func_nb_block]
# % blocks["post_segment_func_nb"]
# %? blocks[order_func_nb_block]
# % blocks["order_func_nb"]
# %? blocks[post_order_func_nb_block]
# % blocks["post_order_func_nb"]
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
cash_sharing=None,
call_seq=base_ch.array_gl_slicer,
init_cash=RepFunc(portfolio_ch.get_init_cash_slicer),
init_position=base_ch.flex_1d_array_gl_slicer,
init_price=base_ch.flex_1d_array_gl_slicer,
cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer),
cash_earnings=base_ch.flex_array_gl_slicer,
segment_mask=base_ch.FlexArraySlicer(axis=1),
call_pre_segment=None,
call_post_segment=None,
pre_sim_func_nb=None, # % None
pre_sim_args=ch.ArgsTaker(),
post_sim_func_nb=None, # % None
post_sim_args=ch.ArgsTaker(),
pre_row_func_nb=None, # % None
pre_row_args=ch.ArgsTaker(),
post_row_func_nb=None, # % None
post_row_args=ch.ArgsTaker(),
pre_segment_func_nb=None, # % None
pre_segment_args=ch.ArgsTaker(),
post_segment_func_nb=None, # % None
post_segment_args=ch.ArgsTaker(),
order_func_nb=None, # % None
order_args=ch.ArgsTaker(),
post_order_func_nb=None, # % None
post_order_args=ch.ArgsTaker(),
index=None,
freq=None,
open=base_ch.flex_array_gl_slicer,
high=base_ch.flex_array_gl_slicer,
low=base_ch.flex_array_gl_slicer,
close=base_ch.flex_array_gl_slicer,
bm_close=base_ch.flex_array_gl_slicer,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
ffill_val_price=None,
update_value=None,
fill_pos_info=None,
track_value=None,
max_order_records=None,
max_log_records=None,
in_outputs=ch.ArgsTaker(),
),
**portfolio_ch.merge_sim_outs_config,
setup_id=None, # %? line.replace("None", task_id)
)
@register_jitted(
tags={"can_parallel"},
cache=False, # % line.replace("False", "True")
task_id_or_func=None, # %? line.replace("None", task_id)
)
def from_order_func_rw_nb( # %? line.replace("from_order_func_rw_nb", new_func_name)
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
call_seq: tp.Optional[tp.Array2d] = None,
init_cash: tp.FlexArray1dLike = 100.0,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_earnings: tp.FlexArray2dLike = 0.0,
segment_mask: tp.FlexArray2dLike = True,
call_pre_segment: bool = False,
call_post_segment: bool = False,
pre_sim_func_nb: PreSimFuncT = no_pre_func_nb, # % None
pre_sim_args: tp.Args = (),
post_sim_func_nb: PostSimFuncT = no_post_func_nb, # % None
post_sim_args: tp.Args = (),
pre_row_func_nb: PreRowFuncT = no_pre_func_nb, # % None
pre_row_args: tp.Args = (),
post_row_func_nb: PostRowFuncT = no_post_func_nb, # % None
post_row_args: tp.Args = (),
pre_segment_func_nb: PreSegmentFuncT = no_pre_func_nb, # % None
pre_segment_args: tp.Args = (),
post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None
post_segment_args: tp.Args = (),
order_func_nb: OrderFuncT = no_order_func_nb, # % None
order_args: tp.Args = (),
post_order_func_nb: PostOrderFuncT = no_post_func_nb, # % None
post_order_args: tp.Args = (),
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
open: tp.FlexArray2dLike = np.nan,
high: tp.FlexArray2dLike = np.nan,
low: tp.FlexArray2dLike = np.nan,
close: tp.FlexArray2dLike = np.nan,
bm_close: tp.FlexArray2dLike = np.nan,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
ffill_val_price: bool = True,
update_value: bool = False,
fill_pos_info: bool = True,
track_value: bool = True,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
in_outputs: tp.Optional[tp.NamedTuple] = None,
) -> SimulationOutput:
"""Same as `from_order_func_nb`, but iterates in row-major order.
Row-major order means processing the entire row with all groups/columns before moving to the next one.
The main difference is that instead of `pre_group_func_nb` it now exposes `pre_row_func_nb`,
which is executed per entire row. It must accept `vectorbtpro.portfolio.enums.RowContext`.
!!! note
Function `pre_row_func_nb` is only called if there is at least on active segment in
the row. Functions `pre_segment_func_nb` and `order_func_nb` are only called if their
segment is active. If the main task of `pre_row_func_nb` is to activate/deactivate segments,
all segments must be activated by default to allow `pre_row_func_nb` to be called.
!!! warning
You can only safely access data points that are to the left of the current group and
rows that are to the top of the current row.
Call hierarchy:
```plaintext
1. pre_sim_out = pre_sim_func_nb(SimulationContext, *pre_sim_args)
2. pre_row_out = pre_row_func_nb(RowContext, *pre_sim_out, *pre_row_args)
3. if call_pre_segment or segment_mask: pre_segment_out = pre_segment_func_nb(SegmentContext, *pre_row_out, *pre_segment_args)
4. if segment_mask: order = order_func_nb(OrderContext, *pre_segment_out, *order_args)
5. if order: post_order_func_nb(PostOrderContext, *pre_segment_out, *post_order_args)
...
6. if call_post_segment or segment_mask: post_segment_func_nb(SegmentContext, *pre_row_out, *post_segment_args)
...
7. post_row_func_nb(RowContext, *pre_sim_out, *post_row_args)
...
8. post_sim_func_nb(SimulationContext, *post_sim_args)
```
Let's illustrate the same example as in `from_order_func_nb` but adapted for this function:
{: loading=lazy style="width:800px;" }
Usage:
* Running the same example as in `from_order_func_nb` but adapted for this function:
```pycon
>>> @njit
... def pre_row_func_nb(c, order_value_out):
... print('\\tbefore row', c.i)
... # Forward down the stack
... return (order_value_out,)
>>> @njit
... def post_row_func_nb(c, order_value_out):
... print('\\tafter row', c.i)
... return None
>>> sim_out = vbt.pf_nb.from_order_func_rw_nb(
... target_shape,
... group_lens,
... cash_sharing,
... segment_mask=segment_mask,
... pre_sim_func_nb=pre_sim_func_nb,
... post_sim_func_nb=post_sim_func_nb,
... pre_row_func_nb=pre_row_func_nb,
... post_row_func_nb=post_row_func_nb,
... pre_segment_func_nb=pre_segment_func_nb,
... pre_segment_args=(size, price, size_type, direction),
... post_segment_func_nb=post_segment_func_nb,
... order_func_nb=order_func_nb,
... order_args=(size, price, size_type, direction, fees, fixed_fees, slippage),
... post_order_func_nb=post_order_func_nb
... )
before simulation
before row 0
before segment 0
creating order 0 at column 0
order status: 0
creating order 1 at column 1
order status: 0
creating order 2 at column 2
order status: 0
after segment 0
after row 0
before row 1
after row 1
before row 2
before segment 2
creating order 0 at column 1
order status: 0
creating order 1 at column 2
order status: 0
creating order 2 at column 0
order status: 0
after segment 2
after row 2
before row 3
after row 3
before row 4
before segment 4
creating order 0 at column 0
order status: 0
creating order 1 at column 2
order status: 0
creating order 2 at column 1
order status: 0
after segment 4
after row 4
after simulation
```
"""
check_group_lens_nb(group_lens, target_shape[1])
init_cash_ = to_1d_array_nb(np.asarray(init_cash))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
segment_mask_ = to_2d_array_nb(np.asarray(segment_mask))
open_ = to_2d_array_nb(np.asarray(open))
high_ = to_2d_array_nb(np.asarray(high))
low_ = to_2d_array_nb(np.asarray(low))
close_ = to_2d_array_nb(np.asarray(close))
bm_close_ = to_2d_array_nb(np.asarray(bm_close))
order_records, log_records = prepare_records_nb(
target_shape=target_shape,
max_order_records=max_order_records,
max_log_records=max_log_records,
)
last_cash = prepare_last_cash_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
)
last_position = prepare_last_position_nb(
target_shape=target_shape,
init_position=init_position_,
)
last_value = prepare_last_value_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
)
last_pos_info = prepare_last_pos_info_nb(
target_shape,
init_position=init_position_,
init_price=init_price_,
fill_pos_info=fill_pos_info,
)
last_cash_deposits = np.full_like(last_cash, 0.0)
last_val_price = np.full_like(last_position, np.nan)
last_debt = np.full_like(last_position, 0.0)
last_locked_cash = np.full_like(last_position, 0.0)
last_free_cash = last_cash.copy()
prev_close_value = last_value.copy()
last_return = np.full_like(last_cash, np.nan)
order_counts = np.full(target_shape[1], 0, dtype=int_)
log_counts = np.full(target_shape[1], 0, dtype=int_)
temp_call_seq = np.empty(target_shape[1], dtype=int_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(target_shape[0], len(group_lens)),
sim_start=sim_start,
sim_end=sim_end,
)
# Call function before the simulation
pre_sim_ctx = SimulationContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
)
pre_sim_out = pre_sim_func_nb(pre_sim_ctx, *pre_sim_args)
_sim_start = sim_start_.min()
_sim_end = sim_end_.max()
for i in range(_sim_start, _sim_end):
# Call function before the row
pre_row_ctx = RowContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
i=i,
)
pre_row_out = pre_row_func_nb(pre_row_ctx, *pre_sim_out, *pre_row_args)
for group in range(len(group_lens)):
if i < sim_start_[group] or i >= sim_end_[group]:
continue
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
group_len = to_col - from_col
if call_seq is None:
for c in range(group_len):
temp_call_seq[c] = c
call_seq_now = temp_call_seq[:group_len]
else:
call_seq_now = call_seq[i, from_col:to_col]
if track_value:
# Update valuation price using current open
for col in range(from_col, to_col):
_open = flex_select_nb(open_, i, col)
if not np.isnan(_open) or not ffill_val_price:
last_val_price[col] = _open
# Update previous value, current value, and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(prev_close_value[group], last_value[group])
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(prev_close_value[col], last_value[col])
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
is_segment_active = flex_select_nb(segment_mask_, i, group)
if call_pre_segment or is_segment_active:
# Call function before the segment
pre_seg_ctx = SegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=call_seq_now,
)
pre_segment_out = pre_segment_func_nb(pre_seg_ctx, *pre_row_out, *pre_segment_args)
# Add cash
if cash_sharing:
_cash_deposits = flex_select_nb(cash_deposits_, i, group)
last_cash[group] += _cash_deposits
last_free_cash[group] += _cash_deposits
last_cash_deposits[group] = _cash_deposits
else:
for col in range(from_col, to_col):
_cash_deposits = flex_select_nb(cash_deposits_, i, col)
last_cash[col] += _cash_deposits
last_free_cash[col] += _cash_deposits
last_cash_deposits[col] = _cash_deposits
if track_value:
# Update value and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(
prev_close_value[group],
last_value[group] - last_cash_deposits[group],
)
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(
prev_close_value[col],
last_value[col] - last_cash_deposits[col],
)
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
if is_segment_active:
for k in range(group_len):
if cash_sharing:
c = call_seq_now[k]
if c >= group_len:
raise ValueError("Call index out of bounds of the group")
else:
c = k
col = from_col + c
# Get current values
position_now = last_position[col]
debt_now = last_debt[col]
locked_cash_now = last_locked_cash[col]
val_price_now = last_val_price[col]
pos_info_now = last_pos_info[col]
if cash_sharing:
cash_now = last_cash[group]
free_cash_now = last_free_cash[group]
value_now = last_value[group]
return_now = last_return[group]
cash_deposits_now = last_cash_deposits[group]
else:
cash_now = last_cash[col]
free_cash_now = last_free_cash[col]
value_now = last_value[col]
return_now = last_return[col]
cash_deposits_now = last_cash_deposits[col]
# Generate the next order
order_ctx = OrderContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=call_seq_now,
col=col,
call_idx=k,
cash_now=cash_now,
position_now=position_now,
debt_now=debt_now,
locked_cash_now=locked_cash_now,
free_cash_now=free_cash_now,
val_price_now=val_price_now,
value_now=value_now,
return_now=return_now,
pos_info_now=pos_info_now,
)
order = order_func_nb(order_ctx, *pre_segment_out, *order_args)
if not track_value:
if (
order.size_type == SizeType.Value
or order.size_type == SizeType.TargetValue
or order.size_type == SizeType.TargetPercent
):
raise ValueError("Cannot use size type that depends on not tracked value")
# Process the order
price_area = PriceArea(
open=flex_select_nb(open_, i, col),
high=flex_select_nb(high_, i, col),
low=flex_select_nb(low_, i, col),
close=flex_select_nb(close_, i, col),
)
exec_state = ExecState(
cash=cash_now,
position=position_now,
debt=debt_now,
locked_cash=locked_cash_now,
free_cash=free_cash_now,
val_price=val_price_now,
value=value_now,
)
order_result, new_exec_state = process_order_nb(
group=group,
col=col,
i=i,
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
)
# Update execution state
cash_now = new_exec_state.cash
position_now = new_exec_state.position
debt_now = new_exec_state.debt
locked_cash_now = new_exec_state.locked_cash
free_cash_now = new_exec_state.free_cash
if track_value:
val_price_now = new_exec_state.val_price
value_now = new_exec_state.value
if cash_sharing:
return_now = returns_nb_.get_return_nb(
prev_close_value[group],
value_now - cash_deposits_now,
)
else:
return_now = returns_nb_.get_return_nb(prev_close_value[col], value_now - cash_deposits_now)
# Now becomes last
last_position[col] = position_now
last_debt[col] = debt_now
last_locked_cash[col] = locked_cash_now
if cash_sharing:
last_cash[group] = cash_now
last_free_cash[group] = free_cash_now
else:
last_cash[col] = cash_now
last_free_cash[col] = free_cash_now
if track_value:
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
if cash_sharing:
last_value[group] = value_now
last_return[group] = return_now
else:
last_value[col] = value_now
last_return[col] = return_now
# Update position record
if fill_pos_info:
if order_result.status == OrderStatus.Filled:
if order_counts[col] > 0:
order_id = order_records["id"][order_counts[col] - 1, col]
else:
order_id = -1
update_pos_info_nb(
pos_info_now,
i,
col,
exec_state.position,
position_now,
order_result,
order_id,
)
# Post-order callback
post_order_ctx = PostOrderContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=call_seq_now,
col=col,
call_idx=k,
cash_before=exec_state.cash,
position_before=exec_state.position,
debt_before=exec_state.debt,
locked_cash_before=exec_state.locked_cash,
free_cash_before=exec_state.free_cash,
val_price_before=exec_state.val_price,
value_before=exec_state.value,
order_result=order_result,
cash_now=cash_now,
position_now=position_now,
debt_now=debt_now,
locked_cash_now=locked_cash_now,
free_cash_now=free_cash_now,
val_price_now=val_price_now,
value_now=value_now,
return_now=return_now,
pos_info_now=pos_info_now,
)
post_order_func_nb(post_order_ctx, *pre_segment_out, *post_order_args)
# NOTE: Regardless of segment_mask, we still need to update stats to be accessed by future rows
# Add earnings in cash
for col in range(from_col, to_col):
_cash_earnings = flex_select_nb(cash_earnings_, i, col)
if cash_sharing:
last_cash[group] += _cash_earnings
last_free_cash[group] += _cash_earnings
else:
last_cash[col] += _cash_earnings
last_free_cash[col] += _cash_earnings
if track_value:
# Update valuation price using current close
for col in range(from_col, to_col):
_close = flex_select_nb(close_, i, col)
if not np.isnan(_close) or not ffill_val_price:
last_val_price[col] = _close
# Update previous value, current value, and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(
prev_close_value[group],
last_value[group] - last_cash_deposits[group],
)
prev_close_value[group] = last_value[group]
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(
prev_close_value[col],
last_value[col] - last_cash_deposits[col],
)
prev_close_value[col] = last_value[col]
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
if call_post_segment or is_segment_active:
# Call function after the segment
post_seg_ctx = SegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=call_seq_now,
)
post_segment_func_nb(post_seg_ctx, *pre_row_out, *post_segment_args)
# Call function after the row
post_row_ctx = RowContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
i=i,
)
post_row_func_nb(post_row_ctx, *pre_sim_out, *post_row_args)
sim_end_reached = True
for group in range(len(group_lens)):
if i < sim_end_[group] - 1:
sim_end_reached = False
break
if sim_end_reached:
break
# Call function after the simulation
post_sim_ctx = SimulationContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=call_seq,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
)
post_sim_func_nb(post_sim_ctx, *post_sim_args)
sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start_,
sim_end=sim_end_,
allow_none=True,
)
return prepare_sim_out_nb(
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
call_seq=call_seq,
in_outputs=in_outputs,
sim_start=sim_start_out,
sim_end=sim_end_out,
)
# %
@register_jitted
def no_flex_order_func_nb(c: FlexOrderContext, *args) -> tp.Tuple[int, Order]:
"""Placeholder flexible order function that returns "break" column and no order."""
return -1, NoOrder
FlexOrderFuncT = tp.Callable[[FlexOrderContext, tp.VarArg()], tp.Tuple[int, Order]]
# %
# %
# %
# @register_jitted
# def flex_order_func_nb(
# c: FlexOrderContext,
# *args,
# ) -> tp.Tuple[int, Order]:
# """Custom flexible order function."""
# return -1, NoOrder
#
#
# %
# %
# %
# %
# %
# import vectorbtpro as vbt
# from vectorbtpro.portfolio.nb.from_order_func import *
# %? import_lines
#
#
# %
# %? blocks[pre_sim_func_nb_block]
# % blocks["pre_sim_func_nb"]
# %? blocks[post_sim_func_nb_block]
# % blocks["post_sim_func_nb"]
# %? blocks[pre_group_func_nb_block]
# % blocks["pre_group_func_nb"]
# %? blocks[post_group_func_nb_block]
# % blocks["post_group_func_nb"]
# %? blocks[pre_segment_func_nb_block]
# % blocks["pre_segment_func_nb"]
# %? blocks[post_segment_func_nb_block]
# % blocks["post_segment_func_nb"]
# %? blocks[flex_order_func_nb_block]
# % blocks["flex_order_func_nb"]
# %? blocks[post_order_func_nb_block]
# % blocks["post_order_func_nb"]
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
cash_sharing=None,
init_cash=RepFunc(portfolio_ch.get_init_cash_slicer),
init_position=base_ch.flex_1d_array_gl_slicer,
init_price=base_ch.flex_1d_array_gl_slicer,
cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer),
cash_earnings=base_ch.flex_array_gl_slicer,
segment_mask=base_ch.FlexArraySlicer(axis=1),
call_pre_segment=None,
call_post_segment=None,
pre_sim_func_nb=None, # % None
pre_sim_args=ch.ArgsTaker(),
post_sim_func_nb=None, # % None
post_sim_args=ch.ArgsTaker(),
pre_group_func_nb=None, # % None
pre_group_args=ch.ArgsTaker(),
post_group_func_nb=None, # % None
post_group_args=ch.ArgsTaker(),
pre_segment_func_nb=None, # % None
pre_segment_args=ch.ArgsTaker(),
post_segment_func_nb=None, # % None
post_segment_args=ch.ArgsTaker(),
flex_order_func_nb=None, # % None
flex_order_args=ch.ArgsTaker(),
post_order_func_nb=None, # % None
post_order_args=ch.ArgsTaker(),
index=None,
freq=None,
open=base_ch.flex_array_gl_slicer,
high=base_ch.flex_array_gl_slicer,
low=base_ch.flex_array_gl_slicer,
close=base_ch.flex_array_gl_slicer,
bm_close=base_ch.flex_array_gl_slicer,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
ffill_val_price=None,
update_value=None,
fill_pos_info=None,
track_value=None,
max_order_records=None,
max_log_records=None,
in_outputs=ch.ArgsTaker(),
),
**portfolio_ch.merge_sim_outs_config,
setup_id=None, # %? line.replace("None", task_id)
)
@register_jitted(
tags={"can_parallel"},
cache=False, # % line.replace("False", "True")
task_id_or_func=None, # %? line.replace("None", task_id)
)
def from_flex_order_func_nb( # %? line.replace("from_flex_order_func_nb", new_func_name)
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
init_cash: tp.FlexArray1dLike = 100.0,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_earnings: tp.FlexArray2dLike = 0.0,
segment_mask: tp.FlexArray2dLike = True,
call_pre_segment: bool = False,
call_post_segment: bool = False,
pre_sim_func_nb: PreSimFuncT = no_pre_func_nb, # % None
pre_sim_args: tp.Args = (),
post_sim_func_nb: PostSimFuncT = no_post_func_nb, # % None
post_sim_args: tp.Args = (),
pre_group_func_nb: PreGroupFuncT = no_pre_func_nb, # % None
pre_group_args: tp.Args = (),
post_group_func_nb: PostGroupFuncT = no_post_func_nb, # % None
post_group_args: tp.Args = (),
pre_segment_func_nb: PreSegmentFuncT = no_pre_func_nb, # % None
pre_segment_args: tp.Args = (),
post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None
post_segment_args: tp.Args = (),
flex_order_func_nb: FlexOrderFuncT = no_flex_order_func_nb, # % None
flex_order_args: tp.Args = (),
post_order_func_nb: PostOrderFuncT = no_post_func_nb, # % None
post_order_args: tp.Args = (),
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
open: tp.FlexArray2dLike = np.nan,
high: tp.FlexArray2dLike = np.nan,
low: tp.FlexArray2dLike = np.nan,
close: tp.FlexArray2dLike = np.nan,
bm_close: tp.FlexArray2dLike = np.nan,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
ffill_val_price: bool = True,
update_value: bool = False,
fill_pos_info: bool = True,
track_value: bool = True,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
in_outputs: tp.Optional[tp.NamedTuple] = None,
) -> SimulationOutput:
"""Same as `from_order_func_nb`, but with no predefined call sequence.
In contrast to `order_func_nb` in`from_order_func_nb`, `post_order_func_nb` is a segment-level order function
that returns a column along with the order, and gets repeatedly called until some condition is met.
This allows multiple orders to be issued within a single element and in an arbitrary order.
The order function must accept `vectorbtpro.portfolio.enums.FlexOrderContext`, unpacked tuple from
`pre_segment_func_nb`, and `*flex_order_args`. Must return column and `vectorbtpro.portfolio.enums.Order`.
To break out of the loop, return column of -1.
!!! note
Since one element can now accommodate multiple orders, you may run into "order_records index out of range"
exception. In this case, you must increase `max_order_records`. This cannot be done automatically and
dynamically to avoid performance degradation.
Call hierarchy:
```plaintext
1. pre_sim_out = pre_sim_func_nb(SimulationContext, *pre_sim_args)
2. pre_group_out = pre_group_func_nb(GroupContext, *pre_sim_out, *pre_group_args)
3. if call_pre_segment or segment_mask: pre_segment_out = pre_segment_func_nb(SegmentContext, *pre_group_out, *pre_segment_args)
while col != -1:
4. if segment_mask: col, order = flex_order_func_nb(FlexOrderContext, *pre_segment_out, *flex_order_args)
5. if order: post_order_func_nb(PostOrderContext, *pre_segment_out, *post_order_args)
...
6. if call_post_segment or segment_mask: post_segment_func_nb(SegmentContext, *pre_group_out, *post_segment_args)
...
7. post_group_func_nb(GroupContext, *pre_sim_out, *post_group_args)
...
8. post_sim_func_nb(SimulationContext, *post_sim_args)
```
Let's illustrate the same example as in `from_order_func_nb` but adapted for this function:
{: loading=lazy style="width:800px;" }
Usage:
* The same example as in `from_order_func_nb`:
```pycon
>>> from vectorbtpro import *
>>> @njit
... def pre_sim_func_nb(c):
... # Create temporary arrays and pass them down the stack
... print('before simulation')
... order_value_out = np.empty(c.target_shape[1], dtype=float_)
... call_seq_out = np.empty(c.target_shape[1], dtype=int_)
... return (order_value_out, call_seq_out)
>>> @njit
... def pre_group_func_nb(c, order_value_out, call_seq_out):
... print('\\tbefore group', c.group)
... return (order_value_out, call_seq_out)
>>> @njit
... def pre_segment_func_nb(c, order_value_out, call_seq_out, size, price, size_type, direction):
... print('\\t\\tbefore segment', c.i)
... for col in range(c.from_col, c.to_col):
... # Here we use order price for group valuation
... c.last_val_price[col] = vbt.pf_nb.select_from_col_nb(c, col, price)
...
... # Same as for from_order_func_nb, but since we don't have a predefined c.call_seq_now anymore,
... # we need to store our new call sequence somewhere else
... call_seq_out[:] = np.arange(c.group_len)
... vbt.pf_nb.sort_call_seq_out_nb(
... c,
... size,
... size_type,
... direction,
... order_value_out[c.from_col:c.to_col],
... call_seq_out[c.from_col:c.to_col]
... )
...
... # Forward the sorted call sequence
... return (call_seq_out,)
>>> @njit
... def flex_order_func_nb(c, call_seq_out, size, price, size_type, direction, fees, fixed_fees, slippage):
... if c.call_idx < c.group_len:
... col = c.from_col + call_seq_out[c.call_idx]
... print('\\t\\t\\tcreating order', c.call_idx, 'at column', col)
... # # Create and return an order
... return col, vbt.pf_nb.order_nb(
... size=vbt.pf_nb.select_from_col_nb(c, col, size),
... price=vbt.pf_nb.select_from_col_nb(c, col, price),
... size_type=vbt.pf_nb.select_from_col_nb(c, col, size_type),
... direction=vbt.pf_nb.select_from_col_nb(c, col, direction),
... fees=vbt.pf_nb.select_from_col_nb(c, col, fees),
... fixed_fees=vbt.pf_nb.select_from_col_nb(c, col, fixed_fees),
... slippage=vbt.pf_nb.select_from_col_nb(c, col, slippage)
... )
... # All columns already processed -> break the loop
... print('\\t\\t\\tbreaking out of the loop')
... return -1, vbt.pf_nb.order_nothing_nb()
>>> @njit
... def post_order_func_nb(c, call_seq_out):
... print('\\t\\t\\t\\torder status:', c.order_result.status)
... return None
>>> @njit
... def post_segment_func_nb(c, order_value_out, call_seq_out):
... print('\\t\\tafter segment', c.i)
... return None
>>> @njit
... def post_group_func_nb(c, order_value_out, call_seq_out):
... print('\\tafter group', c.group)
... return None
>>> @njit
... def post_sim_func_nb(c):
... print('after simulation')
... return None
>>> target_shape = (5, 3)
>>> np.random.seed(42)
>>> group_lens = np.array([3]) # one group of three columns
>>> cash_sharing = True
>>> segment_mask = np.array([True, False, True, False, True])[:, None]
>>> price = close = np.random.uniform(1, 10, size=target_shape)
>>> size = np.array([[1 / target_shape[1]]]) # custom flexible arrays must be 2-dim
>>> size_type = np.array([[vbt.pf_enums.SizeType.TargetPercent]])
>>> direction = np.array([[vbt.pf_enums.Direction.LongOnly]])
>>> fees = np.array([[0.001]])
>>> fixed_fees = np.array([[1.]])
>>> slippage = np.array([[0.001]])
>>> sim_out = vbt.pf_nb.from_flex_order_func_nb(
... target_shape,
... group_lens,
... cash_sharing,
... segment_mask=segment_mask,
... pre_sim_func_nb=pre_sim_func_nb,
... post_sim_func_nb=post_sim_func_nb,
... pre_group_func_nb=pre_group_func_nb,
... post_group_func_nb=post_group_func_nb,
... pre_segment_func_nb=pre_segment_func_nb,
... pre_segment_args=(size, price, size_type, direction),
... post_segment_func_nb=post_segment_func_nb,
... flex_order_func_nb=flex_order_func_nb,
... flex_order_args=(size, price, size_type, direction, fees, fixed_fees, slippage),
... post_order_func_nb=post_order_func_nb
... )
before simulation
before group 0
before segment 0
creating order 0 at column 0
order status: 0
creating order 1 at column 1
order status: 0
creating order 2 at column 2
order status: 0
breaking out of the loop
after segment 0
before segment 2
creating order 0 at column 1
order status: 0
creating order 1 at column 2
order status: 0
creating order 2 at column 0
order status: 0
breaking out of the loop
after segment 2
before segment 4
creating order 0 at column 0
order status: 0
creating order 1 at column 2
order status: 0
creating order 2 at column 1
order status: 0
breaking out of the loop
after segment 4
after group 0
after simulation
```
"""
check_group_lens_nb(group_lens, target_shape[1])
init_cash_ = to_1d_array_nb(np.asarray(init_cash))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
segment_mask_ = to_2d_array_nb(np.asarray(segment_mask))
open_ = to_2d_array_nb(np.asarray(open))
high_ = to_2d_array_nb(np.asarray(high))
low_ = to_2d_array_nb(np.asarray(low))
close_ = to_2d_array_nb(np.asarray(close))
bm_close_ = to_2d_array_nb(np.asarray(bm_close))
order_records, log_records = prepare_records_nb(
target_shape=target_shape,
max_order_records=max_order_records,
max_log_records=max_log_records,
)
last_cash = prepare_last_cash_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
)
last_position = prepare_last_position_nb(
target_shape=target_shape,
init_position=init_position_,
)
last_value = prepare_last_value_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
)
last_pos_info = prepare_last_pos_info_nb(
target_shape,
init_position=init_position_,
init_price=init_price_,
fill_pos_info=fill_pos_info,
)
last_cash_deposits = np.full_like(last_cash, 0.0)
last_val_price = np.full_like(last_position, np.nan)
last_debt = np.full_like(last_position, 0.0)
last_locked_cash = np.full_like(last_position, 0.0)
last_free_cash = last_cash.copy()
prev_close_value = last_value.copy()
last_return = np.full_like(last_cash, np.nan)
order_counts = np.full(target_shape[1], 0, dtype=int_)
log_counts = np.full(target_shape[1], 0, dtype=int_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(target_shape[0], len(group_lens)),
sim_start=sim_start,
sim_end=sim_end,
)
# Call function before the simulation
pre_sim_ctx = SimulationContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
)
pre_sim_out = pre_sim_func_nb(pre_sim_ctx, *pre_sim_args)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
group_len = to_col - from_col
# Call function before the group
pre_group_ctx = GroupContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
)
pre_group_out = pre_group_func_nb(pre_group_ctx, *pre_sim_out, *pre_group_args)
_sim_start = sim_start_[group]
_sim_end = sim_end_[group]
for i in range(_sim_start, _sim_end):
if track_value:
# Update valuation price using current open
for col in range(from_col, to_col):
_open = flex_select_nb(open_, i, col)
if not np.isnan(_open) or not ffill_val_price:
last_val_price[col] = _open
# Update previous value, current value, and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(prev_close_value[group], last_value[group])
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(prev_close_value[col], last_value[col])
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
is_segment_active = flex_select_nb(segment_mask_, i, group)
if call_pre_segment or is_segment_active:
# Call function before the segment
pre_seg_ctx = SegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=None,
)
pre_segment_out = pre_segment_func_nb(pre_seg_ctx, *pre_group_out, *pre_segment_args)
# Add cash
if cash_sharing:
_cash_deposits = flex_select_nb(cash_deposits_, i, group)
last_cash[group] += _cash_deposits
last_free_cash[group] += _cash_deposits
last_cash_deposits[group] = _cash_deposits
else:
for col in range(from_col, to_col):
_cash_deposits = flex_select_nb(cash_deposits_, i, col)
last_cash[col] += _cash_deposits
last_free_cash[col] += _cash_deposits
last_cash_deposits[col] = _cash_deposits
if track_value:
# Update value and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(
prev_close_value[group],
last_value[group] - last_cash_deposits[group],
)
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(
prev_close_value[col],
last_value[col] - last_cash_deposits[col],
)
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
is_segment_active = flex_select_nb(segment_mask_, i, group)
if is_segment_active:
call_idx = -1
while True:
call_idx += 1
# Generate the next order
flex_order_ctx = FlexOrderContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=None,
call_idx=call_idx,
)
col, order = flex_order_func_nb(flex_order_ctx, *pre_segment_out, *flex_order_args)
if col == -1:
break
if col < from_col or col >= to_col:
raise ValueError("Column out of bounds of the group")
if not track_value:
if (
order.size_type == SizeType.Value
or order.size_type == SizeType.TargetValue
or order.size_type == SizeType.TargetPercent
):
raise ValueError("Cannot use size type that depends on not tracked value")
# Get current values
position_now = last_position[col]
debt_now = last_debt[col]
locked_cash_now = last_locked_cash[col]
val_price_now = last_val_price[col]
pos_info_now = last_pos_info[col]
if cash_sharing:
cash_now = last_cash[group]
free_cash_now = last_free_cash[group]
value_now = last_value[group]
return_now = last_return[group]
cash_deposits_now = last_cash_deposits[group]
else:
cash_now = last_cash[col]
free_cash_now = last_free_cash[col]
value_now = last_value[col]
return_now = last_return[col]
cash_deposits_now = last_cash_deposits[col]
# Process the order
price_area = PriceArea(
open=flex_select_nb(open_, i, col),
high=flex_select_nb(high_, i, col),
low=flex_select_nb(low_, i, col),
close=flex_select_nb(close_, i, col),
)
exec_state = ExecState(
cash=cash_now,
position=position_now,
debt=debt_now,
locked_cash=locked_cash_now,
free_cash=free_cash_now,
val_price=val_price_now,
value=value_now,
)
order_result, new_exec_state = process_order_nb(
group=group,
col=col,
i=i,
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
)
# Update execution state
cash_now = new_exec_state.cash
position_now = new_exec_state.position
debt_now = new_exec_state.debt
locked_cash_now = new_exec_state.locked_cash
free_cash_now = new_exec_state.free_cash
if track_value:
val_price_now = new_exec_state.val_price
value_now = new_exec_state.value
if cash_sharing:
return_now = returns_nb_.get_return_nb(
prev_close_value[group],
value_now - cash_deposits_now,
)
else:
return_now = returns_nb_.get_return_nb(prev_close_value[col], value_now - cash_deposits_now)
# Now becomes last
last_position[col] = position_now
last_debt[col] = debt_now
last_locked_cash[col] = locked_cash_now
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
if cash_sharing:
last_cash[group] = cash_now
last_free_cash[group] = free_cash_now
last_value[group] = value_now
last_return[group] = return_now
else:
last_cash[col] = cash_now
last_free_cash[col] = free_cash_now
last_value[col] = value_now
last_return[col] = return_now
if track_value:
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
if cash_sharing:
last_value[group] = value_now
last_return[group] = return_now
else:
last_value[col] = value_now
last_return[col] = return_now
# Update position record
if fill_pos_info:
if order_result.status == OrderStatus.Filled:
if order_counts[col] > 0:
order_id = order_records["id"][order_counts[col] - 1, col]
else:
order_id = -1
update_pos_info_nb(
pos_info_now,
i,
col,
exec_state.position,
position_now,
order_result,
order_id,
)
# Post-order callback
post_order_ctx = PostOrderContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=None,
col=col,
call_idx=call_idx,
cash_before=exec_state.cash,
position_before=exec_state.position,
debt_before=exec_state.debt,
locked_cash_before=exec_state.locked_cash,
free_cash_before=exec_state.free_cash,
val_price_before=exec_state.val_price,
value_before=exec_state.value,
order_result=order_result,
cash_now=cash_now,
position_now=position_now,
debt_now=debt_now,
locked_cash_now=locked_cash_now,
free_cash_now=free_cash_now,
val_price_now=val_price_now,
value_now=value_now,
return_now=return_now,
pos_info_now=pos_info_now,
)
post_order_func_nb(post_order_ctx, *pre_segment_out, *post_order_args)
# NOTE: Regardless of segment_mask, we still need to update stats to be accessed by future rows
# Add earnings in cash
for col in range(from_col, to_col):
_cash_earnings = flex_select_nb(cash_earnings_, i, col)
if cash_sharing:
last_cash[group] += _cash_earnings
last_free_cash[group] += _cash_earnings
else:
last_cash[col] += _cash_earnings
last_free_cash[col] += _cash_earnings
if track_value:
# Update valuation price using current close
for col in range(from_col, to_col):
_close = flex_select_nb(close_, i, col)
if not np.isnan(_close) or not ffill_val_price:
last_val_price[col] = _close
# Update previous value, current value, and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(
prev_close_value[group],
last_value[group] - last_cash_deposits[group],
)
prev_close_value[group] = last_value[group]
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(
prev_close_value[col],
last_value[col] - last_cash_deposits[col],
)
prev_close_value[col] = last_value[col]
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
if call_post_segment or is_segment_active:
# Call function before the segment
post_seg_ctx = SegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=None,
)
post_segment_func_nb(post_seg_ctx, *pre_group_out, *post_segment_args)
if i >= sim_end_[group] - 1:
break
# Call function after the group
post_group_ctx = GroupContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
)
post_group_func_nb(post_group_ctx, *pre_sim_out, *post_group_args)
# Call function after the simulation
post_sim_ctx = SimulationContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
)
post_sim_func_nb(post_sim_ctx, *post_sim_args)
sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start_,
sim_end=sim_end_,
allow_none=True,
)
return prepare_sim_out_nb(
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
call_seq=None,
in_outputs=in_outputs,
sim_start=sim_start_out,
sim_end=sim_end_out,
)
# %
# %
# %
# import vectorbtpro as vbt
# from vectorbtpro.portfolio.nb.from_order_func import *
# %? import_lines
#
#
# %
# %? blocks[pre_sim_func_nb_block]
# % blocks["pre_sim_func_nb"]
# %? blocks[post_sim_func_nb_block]
# % blocks["post_sim_func_nb"]
# %? blocks[pre_row_func_nb_block]
# % blocks["pre_row_func_nb"]
# %? blocks[post_row_func_nb_block]
# % blocks["post_row_func_nb"]
# %? blocks[pre_segment_func_nb_block]
# % blocks["pre_segment_func_nb"]
# %? blocks[post_segment_func_nb_block]
# % blocks["post_segment_func_nb"]
# %? blocks[flex_order_func_nb_block]
# % blocks["flex_order_func_nb"]
# %? blocks[post_order_func_nb_block]
# % blocks["post_order_func_nb"]
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
cash_sharing=None,
init_cash=RepFunc(portfolio_ch.get_init_cash_slicer),
init_position=base_ch.flex_1d_array_gl_slicer,
init_price=base_ch.flex_1d_array_gl_slicer,
cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer),
cash_earnings=base_ch.flex_array_gl_slicer,
segment_mask=base_ch.FlexArraySlicer(axis=1),
call_pre_segment=None,
call_post_segment=None,
pre_sim_func_nb=None, # % None
pre_sim_args=ch.ArgsTaker(),
post_sim_func_nb=None, # % None
post_sim_args=ch.ArgsTaker(),
pre_row_func_nb=None, # % None
pre_row_args=ch.ArgsTaker(),
post_row_func_nb=None, # % None
post_row_args=ch.ArgsTaker(),
pre_segment_func_nb=None, # % None
pre_segment_args=ch.ArgsTaker(),
post_segment_func_nb=None, # % None
post_segment_args=ch.ArgsTaker(),
flex_order_func_nb=None, # % None
flex_order_args=ch.ArgsTaker(),
post_order_func_nb=None, # % None
post_order_args=ch.ArgsTaker(),
index=None,
freq=None,
open=base_ch.flex_array_gl_slicer,
high=base_ch.flex_array_gl_slicer,
low=base_ch.flex_array_gl_slicer,
close=base_ch.flex_array_gl_slicer,
bm_close=base_ch.flex_array_gl_slicer,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
ffill_val_price=None,
update_value=None,
fill_pos_info=None,
track_value=None,
max_order_records=None,
max_log_records=None,
in_outputs=ch.ArgsTaker(),
),
**portfolio_ch.merge_sim_outs_config,
setup_id=None, # %? line.replace("None", task_id)
)
@register_jitted(
tags={"can_parallel"},
cache=False, # % line.replace("False", "True")
task_id_or_func=None, # %? line.replace("None", task_id)
)
def from_flex_order_func_rw_nb( # %? line.replace("from_flex_order_func_rw_nb", new_func_name)
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
init_cash: tp.FlexArray1dLike = 100.0,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_earnings: tp.FlexArray2dLike = 0.0,
segment_mask: tp.FlexArray2dLike = True,
call_pre_segment: bool = False,
call_post_segment: bool = False,
pre_sim_func_nb: PreSimFuncT = no_pre_func_nb, # % None
pre_sim_args: tp.Args = (),
post_sim_func_nb: PostSimFuncT = no_post_func_nb, # % None
post_sim_args: tp.Args = (),
pre_row_func_nb: PreRowFuncT = no_pre_func_nb, # % None
pre_row_args: tp.Args = (),
post_row_func_nb: PostRowFuncT = no_post_func_nb, # % None
post_row_args: tp.Args = (),
pre_segment_func_nb: PreSegmentFuncT = no_pre_func_nb, # % None
pre_segment_args: tp.Args = (),
post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None
post_segment_args: tp.Args = (),
flex_order_func_nb: FlexOrderFuncT = no_flex_order_func_nb, # % None
flex_order_args: tp.Args = (),
post_order_func_nb: PostOrderFuncT = no_post_func_nb, # % None
post_order_args: tp.Args = (),
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
open: tp.FlexArray2dLike = np.nan,
high: tp.FlexArray2dLike = np.nan,
low: tp.FlexArray2dLike = np.nan,
close: tp.FlexArray2dLike = np.nan,
bm_close: tp.FlexArray2dLike = np.nan,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
ffill_val_price: bool = True,
update_value: bool = False,
fill_pos_info: bool = True,
track_value: bool = True,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
in_outputs: tp.Optional[tp.NamedTuple] = None,
) -> SimulationOutput:
"""Same as `from_flex_order_func_nb`, but iterates using row-major order, with the rows
changing fastest, and the columns/groups changing slowest.
Call hierarchy:
```plaintext
1. pre_sim_out = pre_sim_func_nb(SimulationContext, *pre_sim_args)
2. pre_row_out = pre_row_func_nb(RowContext, *pre_sim_out, *pre_row_args)
3. if call_pre_segment or segment_mask: pre_segment_out = pre_segment_func_nb(SegmentContext, *pre_row_out, *pre_segment_args)
while col != -1:
4. if segment_mask: col, order = flex_order_func_nb(FlexOrderContext, *pre_segment_out, *flex_order_args)
5. if order: post_order_func_nb(PostOrderContext, *pre_segment_out, *post_order_args)
...
6. if call_post_segment or segment_mask: post_segment_func_nb(SegmentContext, *pre_row_out, *post_segment_args)
...
7. post_row_func_nb(RowContext, *pre_sim_out, *post_row_args)
...
8. post_sim_func_nb(SimulationContext, *post_sim_args)
```
Let's illustrate the same example as in `from_order_func_nb` but adapted for this function:
```pycon
>>> @njit
... def pre_row_func_nb(c, order_value_out, call_seq_out):
... print('\\tbefore row', c.i)
... return (order_value_out, call_seq_out)
>>> @njit
... def post_row_func_nb(c, order_value_out, call_seq_out):
... print('\\tafter row', c.i)
... return None
>>> sim_out = vbt.pf_nb.from_flex_order_func_rw_nb(
... target_shape,
... group_lens,
... cash_sharing,
... segment_mask=segment_mask,
... pre_sim_func_nb=pre_sim_func_nb,
... post_sim_func_nb=post_sim_func_nb,
... pre_row_func_nb=pre_row_func_nb,
... post_row_func_nb=post_row_func_nb,
... pre_segment_func_nb=pre_segment_func_nb,
... pre_segment_args=(size, price, size_type, direction),
... post_segment_func_nb=post_segment_func_nb,
... flex_order_func_nb=flex_order_func_nb,
... flex_order_args=(size, price, size_type, direction, fees, fixed_fees, slippage),
... post_order_func_nb=post_order_func_nb
... )
```
{: loading=lazy style="width:800px;" }
"""
check_group_lens_nb(group_lens, target_shape[1])
init_cash_ = to_1d_array_nb(np.asarray(init_cash))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
segment_mask_ = to_2d_array_nb(np.asarray(segment_mask))
open_ = to_2d_array_nb(np.asarray(open))
high_ = to_2d_array_nb(np.asarray(high))
low_ = to_2d_array_nb(np.asarray(low))
close_ = to_2d_array_nb(np.asarray(close))
bm_close_ = to_2d_array_nb(np.asarray(bm_close))
order_records, log_records = prepare_records_nb(
target_shape=target_shape,
max_order_records=max_order_records,
max_log_records=max_log_records,
)
last_cash = prepare_last_cash_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
)
last_position = prepare_last_position_nb(
target_shape=target_shape,
init_position=init_position_,
)
last_value = prepare_last_value_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
)
last_pos_info = prepare_last_pos_info_nb(
target_shape,
init_position=init_position_,
init_price=init_price_,
fill_pos_info=fill_pos_info,
)
last_cash_deposits = np.full_like(last_cash, 0.0)
last_val_price = np.full_like(last_position, np.nan)
last_debt = np.full_like(last_position, 0.0)
last_locked_cash = np.full_like(last_position, 0.0)
last_free_cash = last_cash.copy()
prev_close_value = last_value.copy()
last_return = np.full_like(last_cash, np.nan)
order_counts = np.full(target_shape[1], 0, dtype=int_)
log_counts = np.full(target_shape[1], 0, dtype=int_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(target_shape[0], len(group_lens)),
sim_start=sim_start,
sim_end=sim_end,
)
# Call function before the simulation
pre_sim_ctx = SimulationContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
)
pre_sim_out = pre_sim_func_nb(pre_sim_ctx, *pre_sim_args)
_sim_start = sim_start_.min()
_sim_end = sim_end_.max()
for i in range(_sim_start, _sim_end):
# Call function before the row
pre_row_ctx = RowContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
i=i,
)
pre_row_out = pre_row_func_nb(pre_row_ctx, *pre_sim_out, *pre_row_args)
for group in range(len(group_lens)):
if i < sim_start_[group] or i >= sim_end_[group]:
continue
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
group_len = to_col - from_col
if track_value:
# Update valuation price using current open
for col in range(from_col, to_col):
_open = flex_select_nb(open_, i, col)
if not np.isnan(_open) or not ffill_val_price:
last_val_price[col] = _open
# Update previous value, current value, and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(prev_close_value[group], last_value[group])
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(prev_close_value[col], last_value[col])
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
is_segment_active = flex_select_nb(segment_mask_, i, group)
if call_pre_segment or is_segment_active:
# Call function before the segment
pre_seg_ctx = SegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=None,
)
pre_segment_out = pre_segment_func_nb(pre_seg_ctx, *pre_row_out, *pre_segment_args)
# Add cash
if cash_sharing:
_cash_deposits = flex_select_nb(cash_deposits_, i, group)
last_cash[group] += _cash_deposits
last_free_cash[group] += _cash_deposits
last_cash_deposits[group] = _cash_deposits
else:
for col in range(from_col, to_col):
_cash_deposits = flex_select_nb(cash_deposits_, i, col)
last_cash[col] += _cash_deposits
last_free_cash[col] += _cash_deposits
last_cash_deposits[col] = _cash_deposits
if track_value:
# Update value and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(
prev_close_value[group],
last_value[group] - last_cash_deposits[group],
)
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(
prev_close_value[col],
last_value[col] - last_cash_deposits[col],
)
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
is_segment_active = flex_select_nb(segment_mask_, i, group)
if is_segment_active:
call_idx = -1
while True:
call_idx += 1
# Generate the next order
flex_order_ctx = FlexOrderContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=None,
call_idx=call_idx,
)
col, order = flex_order_func_nb(flex_order_ctx, *pre_segment_out, *flex_order_args)
if col == -1:
break
if col < from_col or col >= to_col:
raise ValueError("Column out of bounds of the group")
if not track_value:
if (
order.size_type == SizeType.Value
or order.size_type == SizeType.TargetValue
or order.size_type == SizeType.TargetPercent
):
raise ValueError("Cannot use size type that depends on not tracked value")
# Get current values
position_now = last_position[col]
debt_now = last_debt[col]
locked_cash_now = last_locked_cash[col]
val_price_now = last_val_price[col]
pos_info_now = last_pos_info[col]
if cash_sharing:
cash_now = last_cash[group]
free_cash_now = last_free_cash[group]
value_now = last_value[group]
return_now = last_return[group]
cash_deposits_now = last_cash_deposits[group]
else:
cash_now = last_cash[col]
free_cash_now = last_free_cash[col]
value_now = last_value[col]
return_now = last_return[col]
cash_deposits_now = last_cash_deposits[col]
# Process the order
price_area = PriceArea(
open=flex_select_nb(open_, i, col),
high=flex_select_nb(high_, i, col),
low=flex_select_nb(low_, i, col),
close=flex_select_nb(close_, i, col),
)
exec_state = ExecState(
cash=cash_now,
position=position_now,
debt=debt_now,
locked_cash=locked_cash_now,
free_cash=free_cash_now,
val_price=val_price_now,
value=value_now,
)
order_result, new_exec_state = process_order_nb(
group=group,
col=col,
i=i,
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
)
# Update execution state
cash_now = new_exec_state.cash
position_now = new_exec_state.position
debt_now = new_exec_state.debt
locked_cash_now = new_exec_state.locked_cash
free_cash_now = new_exec_state.free_cash
if track_value:
val_price_now = new_exec_state.val_price
value_now = new_exec_state.value
if cash_sharing:
return_now = returns_nb_.get_return_nb(
prev_close_value[group],
value_now - cash_deposits_now,
)
else:
return_now = returns_nb_.get_return_nb(prev_close_value[col], value_now - cash_deposits_now)
# Now becomes last
last_position[col] = position_now
last_debt[col] = debt_now
last_locked_cash[col] = locked_cash_now
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
if cash_sharing:
last_cash[group] = cash_now
last_free_cash[group] = free_cash_now
last_value[group] = value_now
last_return[group] = return_now
else:
last_cash[col] = cash_now
last_free_cash[col] = free_cash_now
last_value[col] = value_now
last_return[col] = return_now
if track_value:
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
if cash_sharing:
last_value[group] = value_now
last_return[group] = return_now
else:
last_value[col] = value_now
last_return[col] = return_now
# Update position record
if fill_pos_info:
if order_result.status == OrderStatus.Filled:
if order_counts[col] > 0:
order_id = order_records["id"][order_counts[col] - 1, col]
else:
order_id = -1
update_pos_info_nb(
pos_info_now,
i,
col,
exec_state.position,
position_now,
order_result,
order_id,
)
# Post-order callback
post_order_ctx = PostOrderContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=None,
col=col,
call_idx=call_idx,
cash_before=exec_state.cash,
position_before=exec_state.position,
debt_before=exec_state.debt,
locked_cash_before=exec_state.locked_cash,
free_cash_before=exec_state.free_cash,
val_price_before=exec_state.val_price,
value_before=exec_state.value,
order_result=order_result,
cash_now=cash_now,
position_now=position_now,
debt_now=debt_now,
locked_cash_now=locked_cash_now,
free_cash_now=free_cash_now,
val_price_now=val_price_now,
value_now=value_now,
return_now=return_now,
pos_info_now=pos_info_now,
)
post_order_func_nb(post_order_ctx, *pre_segment_out, *post_order_args)
# NOTE: Regardless of segment_mask, we still need to update stats to be accessed by future rows
# Add earnings in cash
for col in range(from_col, to_col):
_cash_earnings = flex_select_nb(cash_earnings_, i, col)
if cash_sharing:
last_cash[group] += _cash_earnings
last_free_cash[group] += _cash_earnings
else:
last_cash[col] += _cash_earnings
last_free_cash[col] += _cash_earnings
if track_value:
# Update valuation price using current close
for col in range(from_col, to_col):
_close = flex_select_nb(close_, i, col)
if not np.isnan(_close) or not ffill_val_price:
last_val_price[col] = _close
# Update previous value, current value, and return
if cash_sharing:
last_value[group] = calc_group_value_nb(
from_col,
to_col,
last_cash[group],
last_position,
last_val_price,
)
last_return[group] = returns_nb_.get_return_nb(
prev_close_value[group],
last_value[group] - last_cash_deposits[group],
)
prev_close_value[group] = last_value[group]
else:
for col in range(from_col, to_col):
if last_position[col] == 0:
last_value[col] = last_cash[col]
else:
last_value[col] = last_cash[col] + last_position[col] * last_val_price[col]
last_return[col] = returns_nb_.get_return_nb(
prev_close_value[col],
last_value[col] - last_cash_deposits[col],
)
prev_close_value[col] = last_value[col]
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Is this segment active?
if call_post_segment or is_segment_active:
# Call function after the segment
post_seg_ctx = SegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
call_seq_now=None,
)
post_segment_func_nb(post_seg_ctx, *pre_row_out, *post_segment_args)
# Call function after the row
post_row_ctx = RowContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
i=i,
)
post_row_func_nb(post_row_ctx, *pre_sim_out, *post_row_args)
sim_end_reached = True
for group in range(len(group_lens)):
if i < sim_end_[group] - 1:
sim_end_reached = False
break
if sim_end_reached:
break
# Call function after the simulation
post_sim_ctx = SimulationContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
call_seq=None,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
segment_mask=segment_mask_,
call_pre_segment=call_pre_segment,
call_post_segment=call_post_segment,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
bm_close=bm_close_,
ffill_val_price=ffill_val_price,
update_value=update_value,
fill_pos_info=fill_pos_info,
track_value=track_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
sim_start=sim_start_,
sim_end=sim_end_,
)
post_sim_func_nb(post_sim_ctx, *post_sim_args)
sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start_,
sim_end=sim_end_,
allow_none=True,
)
return prepare_sim_out_nb(
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
cash_deposits=cash_deposits_,
cash_earnings=cash_earnings_,
call_seq=None,
in_outputs=in_outputs,
sim_start=sim_start_out,
sim_end=sim_end_out,
)
# %
@register_jitted
def set_val_price_nb(c: SegmentContext, val_price: tp.FlexArray2d, price: tp.FlexArray2d) -> None:
"""Override valuation price in a context.
Allows specifying a valuation price of positive infinity (takes the current price)
and negative infinity (takes the latest valuation price)."""
for col in range(c.from_col, c.to_col):
_val_price = select_from_col_nb(c, col, val_price)
if np.isinf(_val_price):
if _val_price > 0:
_price = select_from_col_nb(c, col, price)
if np.isinf(_price):
if _price > 0:
_price = select_from_col_nb(c, col, c.close)
else:
_price = select_from_col_nb(c, col, c.open)
_val_price = _price
else:
_val_price = c.last_val_price[col]
if not np.isnan(_val_price) or not c.ffill_val_price:
c.last_val_price[col] = _val_price
# %
@register_jitted
def def_pre_segment_func_nb( # % line.replace("def_pre_segment_func_nb", "pre_segment_func_nb")
c: SegmentContext,
val_price: tp.FlexArray2d,
price: tp.FlexArray2d,
size: tp.FlexArray2d,
size_type: tp.FlexArray2d,
direction: tp.FlexArray2d,
auto_call_seq: bool,
) -> tp.Args:
"""Pre-segment function that overrides the valuation price and optionally sorts the call sequence."""
set_val_price_nb(c, val_price, price)
if auto_call_seq:
order_value_out = np.empty(c.group_len, dtype=float_)
sort_call_seq_nb(c, size, size_type, direction, order_value_out)
return ()
# %
# %
@register_jitted
def def_order_func_nb( # % line.replace("def_order_func_nb", "order_func_nb")
c: OrderContext,
size: tp.FlexArray2d,
price: tp.FlexArray2d,
size_type: tp.FlexArray2d,
direction: tp.FlexArray2d,
fees: tp.FlexArray2d,
fixed_fees: tp.FlexArray2d,
slippage: tp.FlexArray2d,
min_size: tp.FlexArray2d,
max_size: tp.FlexArray2d,
size_granularity: tp.FlexArray2d,
leverage: tp.FlexArray2d,
leverage_mode: tp.FlexArray2d,
reject_prob: tp.FlexArray2d,
price_area_vio_mode: tp.FlexArray2d,
allow_partial: tp.FlexArray2d,
raise_reject: tp.FlexArray2d,
log: tp.FlexArray2d,
) -> tp.Tuple[int, Order]:
"""Order function that creates an order based on default information."""
return order_nb(
size=select_nb(c, size),
price=select_nb(c, price),
size_type=select_nb(c, size_type),
direction=select_nb(c, direction),
fees=select_nb(c, fees),
fixed_fees=select_nb(c, fixed_fees),
slippage=select_nb(c, slippage),
min_size=select_nb(c, min_size),
max_size=select_nb(c, max_size),
size_granularity=select_nb(c, size_granularity),
leverage=select_nb(c, leverage),
leverage_mode=select_nb(c, leverage_mode),
reject_prob=select_nb(c, reject_prob),
price_area_vio_mode=select_nb(c, price_area_vio_mode),
allow_partial=select_nb(c, allow_partial),
raise_reject=select_nb(c, raise_reject),
log=select_nb(c, log),
)
# %
# %
@register_jitted
def def_flex_pre_segment_func_nb( # % line.replace("def_flex_pre_segment_func_nb", "pre_segment_func_nb")
c: SegmentContext,
val_price: tp.FlexArray2d,
price: tp.FlexArray2d,
size: tp.FlexArray2d,
size_type: tp.FlexArray2d,
direction: tp.FlexArray2d,
auto_call_seq: bool,
) -> tp.Args:
"""Flexible pre-segment function that overrides the valuation price and optionally sorts the call sequence."""
set_val_price_nb(c, val_price, price)
call_seq_out = np.arange(c.group_len)
if auto_call_seq:
order_value_out = np.empty(c.group_len, dtype=float_)
sort_call_seq_out_nb(c, size, size_type, direction, order_value_out, call_seq_out)
return (call_seq_out,)
# %
# %
@register_jitted
def def_flex_order_func_nb( # % line.replace("def_flex_order_func_nb", "flex_order_func_nb")
c: FlexOrderContext,
call_seq_now: tp.Array1d,
size: tp.FlexArray2d,
price: tp.FlexArray2d,
size_type: tp.FlexArray2d,
direction: tp.FlexArray2d,
fees: tp.FlexArray2d,
fixed_fees: tp.FlexArray2d,
slippage: tp.FlexArray2d,
min_size: tp.FlexArray2d,
max_size: tp.FlexArray2d,
size_granularity: tp.FlexArray2d,
leverage: tp.FlexArray2d,
leverage_mode: tp.FlexArray2d,
reject_prob: tp.FlexArray2d,
price_area_vio_mode: tp.FlexArray2d,
allow_partial: tp.FlexArray2d,
raise_reject: tp.FlexArray2d,
log: tp.FlexArray2d,
) -> tp.Tuple[int, Order]:
"""Flexible order function that creates an order based on default information."""
if c.call_idx < c.group_len:
col = c.from_col + call_seq_now[c.call_idx]
order = order_nb(
size=select_from_col_nb(c, col, size),
price=select_from_col_nb(c, col, price),
size_type=select_from_col_nb(c, col, size_type),
direction=select_from_col_nb(c, col, direction),
fees=select_from_col_nb(c, col, fees),
fixed_fees=select_from_col_nb(c, col, fixed_fees),
slippage=select_from_col_nb(c, col, slippage),
min_size=select_from_col_nb(c, col, min_size),
max_size=select_from_col_nb(c, col, max_size),
size_granularity=select_from_col_nb(c, col, size_granularity),
leverage=select_from_col_nb(c, col, leverage),
leverage_mode=select_from_col_nb(c, col, leverage_mode),
reject_prob=select_from_col_nb(c, col, reject_prob),
price_area_vio_mode=select_from_col_nb(c, col, price_area_vio_mode),
allow_partial=select_from_col_nb(c, col, allow_partial),
raise_reject=select_from_col_nb(c, col, raise_reject),
log=select_from_col_nb(c, col, log),
)
return col, order
return -1, order_nothing_nb()
# %
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for portfolio simulation based on orders."""
from numba import prange
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb
from vectorbtpro.portfolio import chunking as portfolio_ch
from vectorbtpro.portfolio.nb.core import *
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.returns.nb import get_return_nb
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.array_ import insert_argsort_nb
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
open=base_ch.flex_array_gl_slicer,
high=base_ch.flex_array_gl_slicer,
low=base_ch.flex_array_gl_slicer,
close=base_ch.flex_array_gl_slicer,
init_cash=base_ch.FlexArraySlicer(),
init_position=base_ch.flex_1d_array_gl_slicer,
init_price=base_ch.flex_1d_array_gl_slicer,
cash_deposits=base_ch.FlexArraySlicer(axis=1),
cash_earnings=base_ch.flex_array_gl_slicer,
cash_dividends=base_ch.flex_array_gl_slicer,
size=base_ch.flex_array_gl_slicer,
price=base_ch.flex_array_gl_slicer,
size_type=base_ch.flex_array_gl_slicer,
direction=base_ch.flex_array_gl_slicer,
fees=base_ch.flex_array_gl_slicer,
fixed_fees=base_ch.flex_array_gl_slicer,
slippage=base_ch.flex_array_gl_slicer,
min_size=base_ch.flex_array_gl_slicer,
max_size=base_ch.flex_array_gl_slicer,
size_granularity=base_ch.flex_array_gl_slicer,
leverage=base_ch.flex_array_gl_slicer,
leverage_mode=base_ch.flex_array_gl_slicer,
reject_prob=base_ch.flex_array_gl_slicer,
price_area_vio_mode=base_ch.flex_array_gl_slicer,
allow_partial=base_ch.flex_array_gl_slicer,
raise_reject=base_ch.flex_array_gl_slicer,
log=base_ch.flex_array_gl_slicer,
val_price=base_ch.flex_array_gl_slicer,
from_ago=base_ch.flex_array_gl_slicer,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
call_seq=base_ch.array_gl_slicer,
auto_call_seq=None,
ffill_val_price=None,
update_value=None,
save_state=None,
save_value=None,
save_returns=None,
skip_empty=None,
max_order_records=None,
max_log_records=None,
),
**portfolio_ch.merge_sim_outs_config,
)
@register_jitted(cache=True, tags={"can_parallel"})
def from_orders_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
open: tp.FlexArray2dLike = np.nan,
high: tp.FlexArray2dLike = np.nan,
low: tp.FlexArray2dLike = np.nan,
close: tp.FlexArray2dLike = np.nan,
init_cash: tp.FlexArray1dLike = 100.0,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_earnings: tp.FlexArray2dLike = 0.0,
cash_dividends: tp.FlexArray2dLike = 0.0,
size: tp.FlexArray2dLike = np.inf,
price: tp.FlexArray2dLike = np.inf,
size_type: tp.FlexArray2dLike = SizeType.Amount,
direction: tp.FlexArray2dLike = Direction.Both,
fees: tp.FlexArray2dLike = 0.0,
fixed_fees: tp.FlexArray2dLike = 0.0,
slippage: tp.FlexArray2dLike = 0.0,
min_size: tp.FlexArray2dLike = np.nan,
max_size: tp.FlexArray2dLike = np.nan,
size_granularity: tp.FlexArray2dLike = np.nan,
leverage: tp.FlexArray2dLike = 1.0,
leverage_mode: tp.FlexArray2dLike = LeverageMode.Lazy,
reject_prob: tp.FlexArray2dLike = 0.0,
price_area_vio_mode: tp.FlexArray2dLike = PriceAreaVioMode.Ignore,
allow_partial: tp.FlexArray2dLike = True,
raise_reject: tp.FlexArray2dLike = False,
log: tp.FlexArray2dLike = False,
val_price: tp.FlexArray2dLike = np.inf,
from_ago: tp.FlexArray2dLike = 0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
call_seq: tp.Optional[tp.Array2d] = None,
auto_call_seq: bool = False,
ffill_val_price: bool = True,
update_value: bool = False,
save_state: bool = False,
save_value: bool = False,
save_returns: bool = False,
skip_empty: bool = True,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
) -> SimulationOutput:
"""Creates on order out of each element.
Iterates in the column-major order. Utilizes flexible broadcasting.
!!! note
Should be only grouped if cash sharing is enabled.
Single value must be passed as a 0-dim array (for example, by using `np.asarray(value)`).
Usage:
* Buy and hold using all cash and closing price (default):
```pycon
>>> from vectorbtpro import *
>>> from vectorbtpro.records.nb import col_map_nb
>>> from vectorbtpro.portfolio.nb import from_orders_nb, asset_flow_nb
>>> close = np.array([1, 2, 3, 4, 5])[:, None]
>>> sim_out = from_orders_nb(
... target_shape=close.shape,
... group_lens=np.array([1]),
... call_seq=np.full(close.shape, 0),
... close=close
... )
>>> col_map = col_map_nb(sim_out.order_records['col'], close.shape[1])
>>> asset_flow = asset_flow_nb(close.shape, sim_out.order_records, col_map)
>>> asset_flow
array([[100.],
[ 0.],
[ 0.],
[ 0.],
[ 0.]])
```
"""
check_group_lens_nb(group_lens, target_shape[1])
cash_sharing = is_grouped_nb(group_lens)
open_ = to_2d_array_nb(np.asarray(open))
high_ = to_2d_array_nb(np.asarray(high))
low_ = to_2d_array_nb(np.asarray(low))
close_ = to_2d_array_nb(np.asarray(close))
init_cash_ = to_1d_array_nb(np.asarray(init_cash))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
cash_dividends_ = to_2d_array_nb(np.asarray(cash_dividends))
size_ = to_2d_array_nb(np.asarray(size))
price_ = to_2d_array_nb(np.asarray(price))
size_type_ = to_2d_array_nb(np.asarray(size_type))
direction_ = to_2d_array_nb(np.asarray(direction))
fees_ = to_2d_array_nb(np.asarray(fees))
fixed_fees_ = to_2d_array_nb(np.asarray(fixed_fees))
slippage_ = to_2d_array_nb(np.asarray(slippage))
min_size_ = to_2d_array_nb(np.asarray(min_size))
max_size_ = to_2d_array_nb(np.asarray(max_size))
size_granularity_ = to_2d_array_nb(np.asarray(size_granularity))
leverage_ = to_2d_array_nb(np.asarray(leverage))
leverage_mode_ = to_2d_array_nb(np.asarray(leverage_mode))
reject_prob_ = to_2d_array_nb(np.asarray(reject_prob))
price_area_vio_mode_ = to_2d_array_nb(np.asarray(price_area_vio_mode))
allow_partial_ = to_2d_array_nb(np.asarray(allow_partial))
raise_reject_ = to_2d_array_nb(np.asarray(raise_reject))
log_ = to_2d_array_nb(np.asarray(log))
val_price_ = to_2d_array_nb(np.asarray(val_price))
from_ago_ = to_2d_array_nb(np.asarray(from_ago))
order_records, log_records = prepare_records_nb(
target_shape=target_shape,
max_order_records=max_order_records,
max_log_records=max_log_records,
)
order_counts = np.full(target_shape[1], 0, dtype=int_)
log_counts = np.full(target_shape[1], 0, dtype=int_)
last_cash = prepare_last_cash_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
)
last_position = prepare_last_position_nb(
target_shape=target_shape,
init_position=init_position_,
)
last_value = prepare_last_value_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
)
last_cash_deposits = np.full_like(last_cash, 0.0)
last_val_price = np.full_like(last_position, np.nan)
last_debt = np.full(target_shape[1], 0.0, dtype=float_)
last_locked_cash = np.full(target_shape[1], 0.0, dtype=float_)
last_free_cash = last_cash.copy()
prev_close_value = last_value.copy()
last_return = np.full_like(last_cash, np.nan)
track_cash_deposits = (cash_deposits_.size == 1 and cash_deposits_[0, 0] != 0) or cash_deposits_.size > 1
if track_cash_deposits:
cash_deposits_out = np.full((target_shape[0], len(group_lens)), 0.0, dtype=float_)
else:
cash_deposits_out = np.full((1, 1), 0.0, dtype=float_)
track_cash_earnings = (cash_earnings_.size == 1 and cash_earnings_[0, 0] != 0) or cash_earnings_.size > 1
track_cash_dividends = (cash_dividends_.size == 1 and cash_dividends_[0, 0] != 0) or cash_dividends_.size > 1
track_cash_earnings = track_cash_earnings or track_cash_dividends
if track_cash_earnings:
cash_earnings_out = np.full(target_shape, 0.0, dtype=float_)
else:
cash_earnings_out = np.full((1, 1), 0.0, dtype=float_)
if save_state:
cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
position = np.full(target_shape, np.nan, dtype=float_)
debt = np.full(target_shape, np.nan, dtype=float_)
locked_cash = np.full(target_shape, np.nan, dtype=float_)
free_cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
cash = np.full((0, 0), np.nan, dtype=float_)
position = np.full((0, 0), np.nan, dtype=float_)
debt = np.full((0, 0), np.nan, dtype=float_)
locked_cash = np.full((0, 0), np.nan, dtype=float_)
free_cash = np.full((0, 0), np.nan, dtype=float_)
if save_value:
value = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
value = np.full((0, 0), np.nan, dtype=float_)
if save_returns:
returns = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
returns = np.full((0, 0), np.nan, dtype=float_)
in_outputs = FOInOutputs(
cash=cash,
position=position,
debt=debt,
locked_cash=locked_cash,
free_cash=free_cash,
value=value,
returns=returns,
)
temp_call_seq = np.empty(target_shape[1], dtype=int_)
temp_order_value = np.empty(target_shape[1], dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(target_shape[0], len(group_lens)),
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
group_len = to_col - from_col
_sim_start = sim_start_[group]
_sim_end = sim_end_[group]
for i in range(_sim_start, _sim_end):
# Add cash
_cash_deposits = flex_select_nb(cash_deposits_, i, group)
if _cash_deposits < 0:
_cash_deposits = max(_cash_deposits, -last_cash[group])
last_cash[group] += _cash_deposits
last_free_cash[group] += _cash_deposits
last_cash_deposits[group] = _cash_deposits
if track_cash_deposits:
cash_deposits_out[i, group] += _cash_deposits
skip = skip_empty
if skip:
for c in range(group_len):
col = from_col + c
_i = i - abs(flex_select_nb(from_ago_, i, col))
if _i < 0:
continue
if flex_select_nb(log_, i, col):
skip = False
break
if not np.isnan(flex_select_nb(size_, _i, col)):
if not np.isnan(flex_select_nb(price_, _i, col)):
skip = False
break
if not skip or ffill_val_price:
for c in range(group_len):
col = from_col + c
# Update valuation price using current open
_open = flex_select_nb(open_, i, col)
if not np.isnan(_open) or not ffill_val_price:
last_val_price[col] = _open
# Resolve valuation price
_val_price = flex_select_nb(val_price_, i, col)
if np.isinf(_val_price):
if _val_price > 0:
_i = i - abs(flex_select_nb(from_ago_, i, col))
if _i < 0:
_price = np.nan
else:
_price = flex_select_nb(price_, _i, col)
if np.isinf(_price):
if _price > 0:
_price = flex_select_nb(close_, i, col)
else:
_price = _open
_val_price = _price
else:
_val_price = last_val_price[col]
if not np.isnan(_val_price) or not ffill_val_price:
last_val_price[col] = _val_price
if not skip:
# Update value and return
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - _cash_deposits,
)
if cash_sharing:
# Dynamically sort by order value -> selling comes first to release funds early
if call_seq is None:
for c in range(group_len):
temp_call_seq[c] = c
call_seq_now = temp_call_seq[:group_len]
else:
call_seq_now = call_seq[i, from_col:to_col]
if auto_call_seq:
# Same as sort_by_order_value_ctx_nb but with flexible indexing
for c in range(group_len):
col = from_col + c
exec_state = ExecState(
cash=last_cash[group] if cash_sharing else last_cash[col],
position=last_position[col],
debt=last_debt[col],
locked_cash=last_locked_cash[col],
free_cash=last_free_cash[group] if cash_sharing else last_free_cash[col],
val_price=last_val_price[col],
value=last_value[group] if cash_sharing else last_value[col],
)
_i = i - abs(flex_select_nb(from_ago_, i, col))
if _i < 0:
temp_order_value[c] = 0.0
else:
temp_order_value[c] = approx_order_value_nb(
exec_state,
flex_select_nb(size_, _i, col),
flex_select_nb(size_type_, _i, col),
flex_select_nb(direction_, _i, col),
)
if call_seq_now[c] != c:
raise ValueError("Call sequence must follow CallSeqType.Default")
# Sort by order value
insert_argsort_nb(temp_order_value[:group_len], call_seq_now)
for k in range(group_len):
if cash_sharing:
c = call_seq_now[k]
if c >= group_len:
raise ValueError("Call index out of bounds of the group")
else:
c = k
col = from_col + c
# Get current values per column
position_now = last_position[col]
debt_now = last_debt[col]
locked_cash_now = last_locked_cash[col]
val_price_now = last_val_price[col]
cash_now = last_cash[group]
free_cash_now = last_free_cash[group]
value_now = last_value[group]
return_now = last_return[group]
# Generate the next order
_i = i - abs(flex_select_nb(from_ago_, i, col))
if _i < 0:
continue
order = order_nb(
size=flex_select_nb(size_, _i, col),
price=flex_select_nb(price_, _i, col),
size_type=flex_select_nb(size_type_, _i, col),
direction=flex_select_nb(direction_, _i, col),
fees=flex_select_nb(fees_, _i, col),
fixed_fees=flex_select_nb(fixed_fees_, _i, col),
slippage=flex_select_nb(slippage_, _i, col),
min_size=flex_select_nb(min_size_, _i, col),
max_size=flex_select_nb(max_size_, _i, col),
size_granularity=flex_select_nb(size_granularity_, _i, col),
leverage=flex_select_nb(leverage_, _i, col),
leverage_mode=flex_select_nb(leverage_mode_, _i, col),
reject_prob=flex_select_nb(reject_prob_, _i, col),
price_area_vio_mode=flex_select_nb(price_area_vio_mode_, _i, col),
allow_partial=flex_select_nb(allow_partial_, _i, col),
raise_reject=flex_select_nb(raise_reject_, _i, col),
log=flex_select_nb(log_, _i, col),
)
# Process the order
price_area = PriceArea(
open=flex_select_nb(open_, i, col),
high=flex_select_nb(high_, i, col),
low=flex_select_nb(low_, i, col),
close=flex_select_nb(close_, i, col),
)
exec_state = ExecState(
cash=cash_now,
position=position_now,
debt=debt_now,
locked_cash=locked_cash_now,
free_cash=free_cash_now,
val_price=val_price_now,
value=value_now,
)
order_result, new_exec_state = process_order_nb(
group=group,
col=col,
i=i,
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
)
# Update execution state
cash_now = new_exec_state.cash
position_now = new_exec_state.position
debt_now = new_exec_state.debt
locked_cash_now = new_exec_state.locked_cash
free_cash_now = new_exec_state.free_cash
val_price_now = new_exec_state.val_price
value_now = new_exec_state.value
# Now becomes last
last_position[col] = position_now
last_debt[col] = debt_now
last_locked_cash[col] = locked_cash_now
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
last_cash[group] = cash_now
last_free_cash[group] = free_cash_now
last_value[group] = value_now
last_return[group] = return_now
for col in range(from_col, to_col):
# Update valuation price using current close
_close = flex_select_nb(close_, i, col)
if not np.isnan(_close) or not ffill_val_price:
last_val_price[col] = _close
_cash_earnings = flex_select_nb(cash_earnings_, i, col)
_cash_dividends = flex_select_nb(cash_dividends_, i, col)
_cash_earnings += _cash_dividends * last_position[col]
last_cash[group] += _cash_earnings
last_free_cash[group] += _cash_earnings
if track_cash_earnings:
cash_earnings_out[i, col] += _cash_earnings
if save_state:
position[i, col] = last_position[col]
debt[i, col] = last_debt[col]
locked_cash[i, col] = last_locked_cash[col]
cash[i, group] = last_cash[group]
free_cash[i, group] = last_free_cash[group]
# Update value and return
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - _cash_deposits,
)
prev_close_value[group] = last_value[group]
if save_value:
in_outputs.value[i, group] = last_value[group]
if save_returns:
in_outputs.returns[i, group] = last_return[group]
sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start_,
sim_end=sim_end_,
allow_none=True,
)
return prepare_sim_out_nb(
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
cash_deposits=cash_deposits_out,
cash_earnings=cash_earnings_out,
call_seq=call_seq,
in_outputs=in_outputs,
sim_start=sim_start_out,
sim_end=sim_end_out,
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for portfolio simulation based on signals."""
from numba import prange
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb
from vectorbtpro.generic.enums import BarZone
from vectorbtpro.portfolio import chunking as portfolio_ch
from vectorbtpro.portfolio.nb.core import *
from vectorbtpro.portfolio.nb.from_order_func import no_post_func_nb
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.returns.nb import get_return_nb
from vectorbtpro.signals.enums import StopType
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.array_ import insert_argsort_nb
from vectorbtpro.utils.math_ import is_less_nb
from vectorbtpro.utils.template import RepFunc
@register_jitted(cache=True)
def resolve_pending_conflict_nb(
is_pending_long: bool,
is_user_long: bool,
upon_adj_conflict: int,
upon_opp_conflict: int,
) -> tp.Tuple[bool, bool]:
"""Resolve any conflict between a pending signal and a user-defined signal.
Returns whether to keep the pending signal and execute the user signal."""
if (is_pending_long and is_user_long) or (not is_pending_long and not is_user_long):
if upon_adj_conflict == PendingConflictMode.KeepIgnore:
return True, False
if upon_adj_conflict == PendingConflictMode.KeepExecute:
return True, True
if upon_adj_conflict == PendingConflictMode.CancelIgnore:
return False, False
if upon_adj_conflict == PendingConflictMode.CancelExecute:
return False, True
raise ValueError("Invalid PendingConflictMode option")
if upon_opp_conflict == PendingConflictMode.KeepIgnore:
return True, False
if upon_opp_conflict == PendingConflictMode.KeepExecute:
return True, True
if upon_opp_conflict == PendingConflictMode.CancelIgnore:
return False, False
if upon_opp_conflict == PendingConflictMode.CancelExecute:
return False, True
raise ValueError("Invalid PendingConflictMode option")
@register_jitted(cache=True)
def generate_stop_signal_nb(
position_now: float,
stop_exit_type: int,
accumulate: int = False,
) -> tp.Tuple[bool, bool, bool, bool, int]:
"""Generate stop signal and change accumulation if needed."""
is_long_entry = False
is_long_exit = False
is_short_entry = False
is_short_exit = False
if position_now > 0:
if stop_exit_type == StopExitType.Close:
is_long_exit = True
accumulate = AccumulationMode.Disabled
elif stop_exit_type == StopExitType.CloseReduce:
is_long_exit = True
elif stop_exit_type == StopExitType.Reverse:
is_short_entry = True
accumulate = AccumulationMode.Disabled
elif stop_exit_type == StopExitType.ReverseReduce:
is_short_entry = True
else:
raise ValueError("Invalid StopExitType option")
elif position_now < 0:
if stop_exit_type == StopExitType.Close:
is_short_exit = True
accumulate = AccumulationMode.Disabled
elif stop_exit_type == StopExitType.CloseReduce:
is_short_exit = True
elif stop_exit_type == StopExitType.Reverse:
is_long_entry = True
accumulate = AccumulationMode.Disabled
elif stop_exit_type == StopExitType.ReverseReduce:
is_long_entry = True
else:
raise ValueError("Invalid StopExitType option")
return is_long_entry, is_long_exit, is_short_entry, is_short_exit, accumulate
@register_jitted(cache=True)
def resolve_signal_conflict_nb(
position_now: float,
is_entry: bool,
is_exit: bool,
direction: int,
conflict_mode: int,
) -> tp.Tuple[bool, bool]:
"""Resolve any conflict between an entry and an exit."""
if is_entry and is_exit:
# Conflict
if conflict_mode == ConflictMode.Entry:
# Ignore exit signal
is_exit = False
elif conflict_mode == ConflictMode.Exit:
# Ignore entry signal
is_entry = False
elif conflict_mode == ConflictMode.Adjacent:
# Take the signal adjacent to the position we are in
if position_now == 0:
# Cannot decide -> ignore
is_entry = False
is_exit = False
else:
if direction == Direction.Both:
if position_now > 0:
is_exit = False
elif position_now < 0:
is_entry = False
else:
is_exit = False
elif conflict_mode == ConflictMode.Opposite:
# Take the signal opposite to the position we are in
if position_now == 0:
if direction == Direction.Both:
# Cannot decide -> ignore
is_entry = False
is_exit = False
else:
is_exit = False
else:
if direction == Direction.Both:
if position_now > 0:
is_entry = False
elif position_now < 0:
is_exit = False
else:
is_entry = False
elif conflict_mode == ConflictMode.Ignore:
is_entry = False
is_exit = False
else:
raise ValueError("Invalid ConflictMode option")
return is_entry, is_exit
@register_jitted(cache=True)
def resolve_dir_conflict_nb(
position_now: float,
is_long_entry: bool,
is_short_entry: bool,
upon_dir_conflict: int,
) -> tp.Tuple[bool, bool]:
"""Resolve any direction conflict between a long entry and a short entry."""
if is_long_entry and is_short_entry:
if upon_dir_conflict == DirectionConflictMode.Long:
is_short_entry = False
elif upon_dir_conflict == DirectionConflictMode.Short:
is_long_entry = False
elif upon_dir_conflict == DirectionConflictMode.Adjacent:
if position_now > 0:
is_short_entry = False
elif position_now < 0:
is_long_entry = False
else:
is_long_entry = False
is_short_entry = False
elif upon_dir_conflict == DirectionConflictMode.Opposite:
if position_now > 0:
is_long_entry = False
elif position_now < 0:
is_short_entry = False
else:
is_long_entry = False
is_short_entry = False
elif upon_dir_conflict == DirectionConflictMode.Ignore:
is_long_entry = False
is_short_entry = False
else:
raise ValueError("Invalid DirectionConflictMode option")
return is_long_entry, is_short_entry
@register_jitted(cache=True)
def resolve_opposite_entry_nb(
position_now: float,
is_long_entry: bool,
is_long_exit: bool,
is_short_entry: bool,
is_short_exit: bool,
upon_opposite_entry: int,
accumulate: int,
) -> tp.Tuple[bool, bool, bool, bool, int]:
"""Resolve opposite entry."""
if position_now > 0 and is_short_entry:
if upon_opposite_entry == OppositeEntryMode.Ignore:
is_short_entry = False
elif upon_opposite_entry == OppositeEntryMode.Close:
is_short_entry = False
is_long_exit = True
accumulate = AccumulationMode.Disabled
elif upon_opposite_entry == OppositeEntryMode.CloseReduce:
is_short_entry = False
is_long_exit = True
elif upon_opposite_entry == OppositeEntryMode.Reverse:
accumulate = AccumulationMode.Disabled
elif upon_opposite_entry == OppositeEntryMode.ReverseReduce:
pass
else:
raise ValueError("Invalid OppositeEntryMode option")
if position_now < 0 and is_long_entry:
if upon_opposite_entry == OppositeEntryMode.Ignore:
is_long_entry = False
elif upon_opposite_entry == OppositeEntryMode.Close:
is_long_entry = False
is_short_exit = True
accumulate = AccumulationMode.Disabled
elif upon_opposite_entry == OppositeEntryMode.CloseReduce:
is_long_entry = False
is_short_exit = True
elif upon_opposite_entry == OppositeEntryMode.Reverse:
accumulate = AccumulationMode.Disabled
elif upon_opposite_entry == OppositeEntryMode.ReverseReduce:
pass
else:
raise ValueError("Invalid OppositeEntryMode option")
return is_long_entry, is_long_exit, is_short_entry, is_short_exit, accumulate
@register_jitted(cache=True)
def signal_to_size_nb(
position_now: float,
val_price_now: float,
value_now: float,
is_long_entry: bool,
is_long_exit: bool,
is_short_entry: bool,
is_short_exit: bool,
size: float,
size_type: int,
accumulate: int,
) -> tp.Tuple[float, int, int]:
"""Translate direction-aware signals into size, size type, and direction."""
if (
accumulate != AccumulationMode.Disabled
and accumulate != AccumulationMode.Both
and accumulate != AccumulationMode.AddOnly
and accumulate != AccumulationMode.RemoveOnly
):
raise ValueError("Invalid AccumulationMode option")
def _check_size_type(_size_type):
if (
_size_type == SizeType.TargetAmount
or _size_type == SizeType.TargetValue
or _size_type == SizeType.TargetPercent
or _size_type == SizeType.TargetPercent100
):
raise ValueError("Target size types are not supported")
if is_less_nb(size, 0):
raise ValueError("Negative size is not allowed. Please express direction using signals.")
if size_type == SizeType.Percent100:
size /= 100
size_type = SizeType.Percent
if size_type == SizeType.ValuePercent100:
size /= 100
size_type = SizeType.ValuePercent
if size_type == SizeType.ValuePercent:
size *= value_now
size_type = SizeType.Value
order_size = np.nan
direction = Direction.Both
abs_position_now = abs(position_now)
if position_now > 0:
# We're in a long position
if is_short_entry:
_check_size_type(size_type)
if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.RemoveOnly:
# Decrease the position
order_size = -size
else:
# Reverse the position
if not np.isnan(size):
if size_type == SizeType.Percent:
order_size = -size
else:
order_size = -abs_position_now
if size_type == SizeType.Value:
order_size -= size / val_price_now
else:
order_size -= size
size_type = SizeType.Amount
elif is_long_exit:
direction = Direction.LongOnly
if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.RemoveOnly:
# Decrease the position
_check_size_type(size_type)
order_size = -size
else:
# Close the position
order_size = -abs_position_now
size_type = SizeType.Amount
elif is_long_entry:
_check_size_type(size_type)
direction = Direction.LongOnly
if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.AddOnly:
# Increase the position
order_size = size
elif position_now < 0:
# We're in a short position
if is_long_entry:
_check_size_type(size_type)
if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.RemoveOnly:
# Decrease the position
order_size = size
else:
# Reverse the position
if not np.isnan(size):
if size_type == SizeType.Percent:
order_size = size
else:
order_size = abs_position_now
if size_type == SizeType.Value:
order_size += size / val_price_now
else:
order_size += size
size_type = SizeType.Amount
elif is_short_exit:
direction = Direction.ShortOnly
if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.RemoveOnly:
# Decrease the position
_check_size_type(size_type)
order_size = size
else:
# Close the position
order_size = abs_position_now
size_type = SizeType.Amount
elif is_short_entry:
_check_size_type(size_type)
direction = Direction.ShortOnly
if accumulate == AccumulationMode.Both or accumulate == AccumulationMode.AddOnly:
# Increase the position
order_size = -size
else:
_check_size_type(size_type)
if is_long_entry:
# Open long position
order_size = size
elif is_short_entry:
# Open short position
order_size = -size
if direction == Direction.ShortOnly:
order_size = -order_size
return order_size, size_type, direction
@register_jitted(cache=True)
def prepare_fs_records_nb(
target_shape: tp.Shape,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
) -> tp.Tuple[tp.RecordArray2d, tp.RecordArray2d]:
"""Prepare from-signals records."""
if max_order_records is None:
order_records = np.empty((target_shape[0], target_shape[1]), dtype=fs_order_dt)
else:
order_records = np.empty((max_order_records, target_shape[1]), dtype=fs_order_dt)
if max_log_records is None:
log_records = np.empty((target_shape[0], target_shape[1]), dtype=log_dt)
else:
log_records = np.empty((max_log_records, target_shape[1]), dtype=log_dt)
return order_records, log_records
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
open=base_ch.flex_array_gl_slicer,
high=base_ch.flex_array_gl_slicer,
low=base_ch.flex_array_gl_slicer,
close=base_ch.flex_array_gl_slicer,
init_cash=base_ch.FlexArraySlicer(),
init_position=base_ch.flex_1d_array_gl_slicer,
init_price=base_ch.flex_1d_array_gl_slicer,
cash_deposits=base_ch.FlexArraySlicer(axis=1),
cash_earnings=base_ch.flex_array_gl_slicer,
cash_dividends=base_ch.flex_array_gl_slicer,
long_entries=base_ch.flex_array_gl_slicer,
long_exits=base_ch.flex_array_gl_slicer,
short_entries=base_ch.flex_array_gl_slicer,
short_exits=base_ch.flex_array_gl_slicer,
size=base_ch.flex_array_gl_slicer,
price=base_ch.flex_array_gl_slicer,
size_type=base_ch.flex_array_gl_slicer,
fees=base_ch.flex_array_gl_slicer,
fixed_fees=base_ch.flex_array_gl_slicer,
slippage=base_ch.flex_array_gl_slicer,
min_size=base_ch.flex_array_gl_slicer,
max_size=base_ch.flex_array_gl_slicer,
size_granularity=base_ch.flex_array_gl_slicer,
leverage=base_ch.flex_array_gl_slicer,
leverage_mode=base_ch.flex_array_gl_slicer,
reject_prob=base_ch.flex_array_gl_slicer,
price_area_vio_mode=base_ch.flex_array_gl_slicer,
allow_partial=base_ch.flex_array_gl_slicer,
raise_reject=base_ch.flex_array_gl_slicer,
log=base_ch.flex_array_gl_slicer,
val_price=base_ch.flex_array_gl_slicer,
accumulate=base_ch.flex_array_gl_slicer,
upon_long_conflict=base_ch.flex_array_gl_slicer,
upon_short_conflict=base_ch.flex_array_gl_slicer,
upon_dir_conflict=base_ch.flex_array_gl_slicer,
upon_opposite_entry=base_ch.flex_array_gl_slicer,
from_ago=base_ch.flex_array_gl_slicer,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
call_seq=base_ch.array_gl_slicer,
auto_call_seq=None,
ffill_val_price=None,
update_value=None,
save_state=None,
save_value=None,
save_returns=None,
skip_empty=None,
max_order_records=None,
max_log_records=None,
),
**portfolio_ch.merge_sim_outs_config,
)
@register_jitted(cache=True, tags={"can_parallel"})
def from_basic_signals_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
open: tp.FlexArray2dLike = np.nan,
high: tp.FlexArray2dLike = np.nan,
low: tp.FlexArray2dLike = np.nan,
close: tp.FlexArray2dLike = np.nan,
init_cash: tp.FlexArray1dLike = 100.0,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_earnings: tp.FlexArray2dLike = 0.0,
cash_dividends: tp.FlexArray2dLike = 0.0,
long_entries: tp.FlexArray2dLike = False,
long_exits: tp.FlexArray2dLike = False,
short_entries: tp.FlexArray2dLike = False,
short_exits: tp.FlexArray2dLike = False,
size: tp.FlexArray2dLike = np.inf,
price: tp.FlexArray2dLike = np.inf,
size_type: tp.FlexArray2dLike = SizeType.Amount,
fees: tp.FlexArray2dLike = 0.0,
fixed_fees: tp.FlexArray2dLike = 0.0,
slippage: tp.FlexArray2dLike = 0.0,
min_size: tp.FlexArray2dLike = np.nan,
max_size: tp.FlexArray2dLike = np.nan,
size_granularity: tp.FlexArray2dLike = np.nan,
leverage: tp.FlexArray2dLike = 1.0,
leverage_mode: tp.FlexArray2dLike = LeverageMode.Lazy,
reject_prob: tp.FlexArray2dLike = 0.0,
price_area_vio_mode: tp.FlexArray2dLike = PriceAreaVioMode.Ignore,
allow_partial: tp.FlexArray2dLike = True,
raise_reject: tp.FlexArray2dLike = False,
log: tp.FlexArray2dLike = False,
val_price: tp.FlexArray2dLike = np.inf,
accumulate: tp.FlexArray2dLike = AccumulationMode.Disabled,
upon_long_conflict: tp.FlexArray2dLike = ConflictMode.Ignore,
upon_short_conflict: tp.FlexArray2dLike = ConflictMode.Ignore,
upon_dir_conflict: tp.FlexArray2dLike = DirectionConflictMode.Ignore,
upon_opposite_entry: tp.FlexArray2dLike = OppositeEntryMode.ReverseReduce,
from_ago: tp.FlexArray2dLike = 0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
call_seq: tp.Optional[tp.Array2d] = None,
auto_call_seq: bool = False,
ffill_val_price: bool = True,
update_value: bool = False,
save_state: bool = False,
save_value: bool = False,
save_returns: bool = False,
skip_empty: bool = True,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
) -> SimulationOutput:
"""Simulate given basic signals (no limit or stop orders).
Iterates in the column-major order. Utilizes flexible broadcasting.
!!! note
Should be only grouped if cash sharing is enabled.
"""
check_group_lens_nb(group_lens, target_shape[1])
cash_sharing = is_grouped_nb(group_lens)
open_ = to_2d_array_nb(np.asarray(open))
high_ = to_2d_array_nb(np.asarray(high))
low_ = to_2d_array_nb(np.asarray(low))
close_ = to_2d_array_nb(np.asarray(close))
init_cash_ = to_1d_array_nb(np.asarray(init_cash))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
cash_dividends_ = to_2d_array_nb(np.asarray(cash_dividends))
long_entries_ = to_2d_array_nb(np.asarray(long_entries))
long_exits_ = to_2d_array_nb(np.asarray(long_exits))
short_entries_ = to_2d_array_nb(np.asarray(short_entries))
short_exits_ = to_2d_array_nb(np.asarray(short_exits))
size_ = to_2d_array_nb(np.asarray(size))
price_ = to_2d_array_nb(np.asarray(price))
size_type_ = to_2d_array_nb(np.asarray(size_type))
fees_ = to_2d_array_nb(np.asarray(fees))
fixed_fees_ = to_2d_array_nb(np.asarray(fixed_fees))
slippage_ = to_2d_array_nb(np.asarray(slippage))
min_size_ = to_2d_array_nb(np.asarray(min_size))
max_size_ = to_2d_array_nb(np.asarray(max_size))
size_granularity_ = to_2d_array_nb(np.asarray(size_granularity))
leverage_ = to_2d_array_nb(np.asarray(leverage))
leverage_mode_ = to_2d_array_nb(np.asarray(leverage_mode))
reject_prob_ = to_2d_array_nb(np.asarray(reject_prob))
price_area_vio_mode_ = to_2d_array_nb(np.asarray(price_area_vio_mode))
allow_partial_ = to_2d_array_nb(np.asarray(allow_partial))
raise_reject_ = to_2d_array_nb(np.asarray(raise_reject))
log_ = to_2d_array_nb(np.asarray(log))
val_price_ = to_2d_array_nb(np.asarray(val_price))
accumulate_ = to_2d_array_nb(np.asarray(accumulate))
upon_long_conflict_ = to_2d_array_nb(np.asarray(upon_long_conflict))
upon_short_conflict_ = to_2d_array_nb(np.asarray(upon_short_conflict))
upon_dir_conflict_ = to_2d_array_nb(np.asarray(upon_dir_conflict))
upon_opposite_entry_ = to_2d_array_nb(np.asarray(upon_opposite_entry))
from_ago_ = to_2d_array_nb(np.asarray(from_ago))
order_records, log_records = prepare_fs_records_nb(
target_shape=target_shape,
max_order_records=max_order_records,
max_log_records=max_log_records,
)
order_counts = np.full(target_shape[1], 0, dtype=int_)
log_counts = np.full(target_shape[1], 0, dtype=int_)
last_cash = prepare_last_cash_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
)
last_position = prepare_last_position_nb(
target_shape=target_shape,
init_position=init_position_,
)
last_value = prepare_last_value_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
)
last_cash_deposits = np.full_like(last_cash, 0.0)
last_val_price = np.full_like(last_position, np.nan)
last_debt = np.full(target_shape[1], 0.0, dtype=float_)
last_locked_cash = np.full(target_shape[1], 0.0, dtype=float_)
last_free_cash = last_cash.copy()
prev_close_value = last_value.copy()
last_return = np.full_like(last_cash, np.nan)
track_cash_deposits = (cash_deposits_.size == 1 and cash_deposits_[0, 0] != 0) or cash_deposits_.size > 1
if track_cash_deposits:
cash_deposits_out = np.full((target_shape[0], len(group_lens)), 0.0, dtype=float_)
else:
cash_deposits_out = np.full((1, 1), 0.0, dtype=float_)
track_cash_earnings = (cash_earnings_.size == 1 and cash_earnings_[0, 0] != 0) or cash_earnings_.size > 1
track_cash_dividends = (cash_dividends_.size == 1 and cash_dividends_[0, 0] != 0) or cash_dividends_.size > 1
track_cash_earnings = track_cash_earnings or track_cash_dividends
if track_cash_earnings:
cash_earnings_out = np.full(target_shape, 0.0, dtype=float_)
else:
cash_earnings_out = np.full((1, 1), 0.0, dtype=float_)
if save_state:
cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
position = np.full(target_shape, np.nan, dtype=float_)
debt = np.full(target_shape, np.nan, dtype=float_)
locked_cash = np.full(target_shape, np.nan, dtype=float_)
free_cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
cash = np.full((0, 0), np.nan, dtype=float_)
position = np.full((0, 0), np.nan, dtype=float_)
debt = np.full((0, 0), np.nan, dtype=float_)
locked_cash = np.full((0, 0), np.nan, dtype=float_)
free_cash = np.full((0, 0), np.nan, dtype=float_)
if save_value:
value = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
value = np.full((0, 0), np.nan, dtype=float_)
if save_returns:
returns = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
returns = np.full((0, 0), np.nan, dtype=float_)
in_outputs = FSInOutputs(
cash=cash,
position=position,
debt=debt,
locked_cash=locked_cash,
free_cash=free_cash,
value=value,
returns=returns,
)
last_signal = np.empty(target_shape[1], dtype=int_)
main_info = np.empty(target_shape[1], dtype=main_info_dt)
temp_call_seq = np.empty(target_shape[1], dtype=int_)
temp_sort_by = np.empty(target_shape[1], dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(target_shape[0], len(group_lens)),
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
group_len = to_col - from_col
_sim_start = sim_start_[group]
_sim_end = sim_end_[group]
for i in range(_sim_start, _sim_end):
# Add cash
_cash_deposits = flex_select_nb(cash_deposits_, i, group)
if _cash_deposits < 0:
_cash_deposits = max(_cash_deposits, -last_cash[group])
last_cash[group] += _cash_deposits
last_free_cash[group] += _cash_deposits
last_cash_deposits[group] = _cash_deposits
if track_cash_deposits:
cash_deposits_out[i, group] += _cash_deposits
for c in range(group_len):
col = from_col + c
# Update valuation price using current open
_open = flex_select_nb(open_, i, col)
if not np.isnan(_open) or not ffill_val_price:
last_val_price[col] = _open
# Resolve valuation price
_val_price = flex_select_nb(val_price_, i, col)
if np.isinf(_val_price):
if _val_price > 0:
_i = i - abs(flex_select_nb(from_ago_, i, col))
if _i < 0:
_price = np.nan
else:
_price = flex_select_nb(price_, _i, col)
if np.isinf(_price):
if _price > 0:
_price = flex_select_nb(close_, i, col)
else:
_price = _open
_val_price = _price
else:
_val_price = last_val_price[col]
if not np.isnan(_val_price) or not ffill_val_price:
last_val_price[col] = _val_price
# Update value and return
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - _cash_deposits,
)
# Get signals
skip = skip_empty
for c in range(group_len):
col = from_col + c
if flex_select_nb(log_, i, col):
skip = False
_i = i - abs(flex_select_nb(from_ago_, i, col))
if _i < 0:
is_long_entry = False
is_long_exit = False
is_short_entry = False
is_short_exit = False
else:
is_long_entry = flex_select_nb(long_entries_, _i, col)
is_long_exit = flex_select_nb(long_exits_, _i, col)
is_short_entry = flex_select_nb(short_entries_, _i, col)
is_short_exit = flex_select_nb(short_exits_, _i, col)
# Pack signals into a single integer
last_signal[col] = (
(is_long_entry << 4) | (is_long_exit << 3) | (is_short_entry << 2) | (is_short_exit << 1)
)
if last_signal[col] > 0:
skip = False
if not skip:
# Get size and value of each order
for c in range(group_len):
col = from_col + c
# Set defaults
main_info["bar_zone"][col] = -1
main_info["signal_idx"][col] = -1
main_info["creation_idx"][col] = -1
main_info["idx"][col] = i
main_info["val_price"][col] = np.nan
main_info["price"][col] = np.nan
main_info["size"][col] = np.nan
main_info["size_type"][col] = -1
main_info["direction"][col] = -1
main_info["type"][col] = OrderType.Market
main_info["stop_type"][col] = -1
temp_sort_by[col] = 0.0
# Unpack a single integer into signals
is_long_entry = (last_signal[col] >> 4) & 1
is_long_exit = (last_signal[col] >> 3) & 1
is_short_entry = (last_signal[col] >> 2) & 1
is_short_exit = (last_signal[col] >> 1) & 1
# Resolve the current bar
_i = i - abs(flex_select_nb(from_ago_, i, col))
_open = flex_select_nb(open_, i, col)
_high = flex_select_nb(high_, i, col)
_low = flex_select_nb(low_, i, col)
_close = flex_select_nb(close_, i, col)
_high, _low = resolve_hl_nb(
open=_open,
high=_high,
low=_low,
close=_close,
)
# Process user signal
if _i >= 0:
_accumulate = flex_select_nb(accumulate_, _i, col)
if is_long_entry or is_short_entry:
# Resolve any single-direction conflicts
_upon_long_conflict = flex_select_nb(upon_long_conflict_, _i, col)
is_long_entry, is_long_exit = resolve_signal_conflict_nb(
position_now=last_position[col],
is_entry=is_long_entry,
is_exit=is_long_exit,
direction=Direction.LongOnly,
conflict_mode=_upon_long_conflict,
)
_upon_short_conflict = flex_select_nb(upon_short_conflict_, _i, col)
is_short_entry, is_short_exit = resolve_signal_conflict_nb(
position_now=last_position[col],
is_entry=is_short_entry,
is_exit=is_short_exit,
direction=Direction.ShortOnly,
conflict_mode=_upon_short_conflict,
)
# Resolve any multi-direction conflicts
_upon_dir_conflict = flex_select_nb(upon_dir_conflict_, _i, col)
is_long_entry, is_short_entry = resolve_dir_conflict_nb(
position_now=last_position[col],
is_long_entry=is_long_entry,
is_short_entry=is_short_entry,
upon_dir_conflict=_upon_dir_conflict,
)
# Resolve an opposite entry
_upon_opposite_entry = flex_select_nb(upon_opposite_entry_, _i, col)
(
is_long_entry,
is_long_exit,
is_short_entry,
is_short_exit,
_accumulate,
) = resolve_opposite_entry_nb(
position_now=last_position[col],
is_long_entry=is_long_entry,
is_long_exit=is_long_exit,
is_short_entry=is_short_entry,
is_short_exit=is_short_exit,
upon_opposite_entry=_upon_opposite_entry,
accumulate=_accumulate,
)
# Resolve the price
_price = flex_select_nb(price_, _i, col)
# Convert both signals to size (direction-aware), size type, and direction
_size, _size_type, _direction = signal_to_size_nb(
position_now=last_position[col],
val_price_now=last_val_price[col],
value_now=last_value[group],
is_long_entry=is_long_entry,
is_long_exit=is_long_exit,
is_short_entry=is_short_entry,
is_short_exit=is_short_exit,
size=flex_select_nb(size_, _i, col),
size_type=flex_select_nb(size_type_, _i, col),
accumulate=_accumulate,
)
# Execute user signal
if np.isinf(_price):
if _price > 0:
main_info["bar_zone"][col] = BarZone.Close
else:
main_info["bar_zone"][col] = BarZone.Open
else:
main_info["bar_zone"][col] = BarZone.Middle
main_info["signal_idx"][col] = _i
main_info["creation_idx"][col] = i
main_info["idx"][col] = _i
main_info["val_price"][col] = last_val_price[col]
main_info["price"][col] = _price
main_info["size"][col] = _size
main_info["size_type"][col] = _size_type
main_info["direction"][col] = _direction
skip = skip_empty
if skip:
for col in range(from_col, to_col):
if flex_select_nb(log_, i, col):
skip = False
break
if not np.isnan(main_info["size"][col]):
skip = False
break
if not skip:
# Check bar zone and update valuation price
bar_zone = -1
same_bar_zone = True
same_timing = True
for c in range(group_len):
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
if bar_zone == -1:
bar_zone = main_info["bar_zone"][col]
if main_info["bar_zone"][col] != bar_zone:
same_bar_zone = False
same_timing = False
if main_info["bar_zone"][col] == BarZone.Middle:
same_timing = False
_val_price = main_info["val_price"][col]
if not np.isnan(_val_price) or not ffill_val_price:
last_val_price[col] = _val_price
if cash_sharing:
# Dynamically sort by order value -> selling comes first to release funds early
if call_seq is None:
for c in range(group_len):
temp_call_seq[c] = c
call_seq_now = temp_call_seq[:group_len]
else:
call_seq_now = call_seq[i, from_col:to_col]
if auto_call_seq:
# Sort by order value
if not same_timing:
raise ValueError("Cannot sort orders by value if they are executed at different times")
for c in range(group_len):
if call_seq_now[c] != c:
raise ValueError("Call sequence must follow CallSeqType.Default")
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
# Approximate order value
exec_state = ExecState(
cash=last_cash[group] if cash_sharing else last_cash[col],
position=last_position[col],
debt=last_debt[col],
locked_cash=last_locked_cash[col],
free_cash=last_free_cash[group] if cash_sharing else last_free_cash[col],
val_price=last_val_price[col],
value=last_value[group] if cash_sharing else last_value[col],
)
temp_sort_by[c] = approx_order_value_nb(
exec_state=exec_state,
size=main_info["size"][col],
size_type=main_info["size_type"][col],
direction=main_info["direction"][col],
)
insert_argsort_nb(temp_sort_by[:group_len], call_seq_now)
else:
if not same_bar_zone:
# Sort by bar zone
for c in range(group_len):
if call_seq_now[c] != c:
raise ValueError("Call sequence must follow CallSeqType.Default")
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
temp_sort_by[c] = main_info["bar_zone"][col]
insert_argsort_nb(temp_sort_by[:group_len], call_seq_now)
for k in range(group_len):
if cash_sharing:
c = call_seq_now[k]
if c >= group_len:
raise ValueError("Call index out of bounds of the group")
else:
c = k
col = from_col + c
if skip_empty and np.isnan(main_info["size"][col]): # shortcut
continue
# Get current values per column
position_now = last_position[col]
debt_now = last_debt[col]
locked_cash_now = last_locked_cash[col]
val_price_now = last_val_price[col]
cash_now = last_cash[group]
free_cash_now = last_free_cash[group]
value_now = last_value[group]
return_now = last_return[group]
# Generate the next order
_i = main_info["idx"][col]
if main_info["type"][col] == OrderType.Limit:
_slippage = 0.0
else:
_slippage = float(flex_select_nb(slippage_, _i, col))
_min_size = flex_select_nb(min_size_, _i, col)
_max_size = flex_select_nb(max_size_, _i, col)
_size_type = flex_select_nb(size_type_, _i, col)
if _size_type != main_info["size_type"][col]:
if not np.isnan(_min_size):
_min_size, _ = resolve_size_nb(
size=_min_size,
size_type=_size_type,
position=position_now,
val_price=val_price_now,
value=value_now,
target_size_type=main_info["size_type"][col],
as_requirement=True,
)
if not np.isnan(_max_size):
_max_size, _ = resolve_size_nb(
size=_max_size,
size_type=_size_type,
position=position_now,
val_price=val_price_now,
value=value_now,
target_size_type=main_info["size_type"][col],
as_requirement=True,
)
order = order_nb(
size=main_info["size"][col],
price=main_info["price"][col],
size_type=main_info["size_type"][col],
direction=main_info["direction"][col],
fees=flex_select_nb(fees_, _i, col),
fixed_fees=flex_select_nb(fixed_fees_, _i, col),
slippage=_slippage,
min_size=_min_size,
max_size=_max_size,
size_granularity=flex_select_nb(size_granularity_, _i, col),
leverage=flex_select_nb(leverage_, _i, col),
leverage_mode=flex_select_nb(leverage_mode_, _i, col),
reject_prob=flex_select_nb(reject_prob_, _i, col),
price_area_vio_mode=flex_select_nb(price_area_vio_mode_, _i, col),
allow_partial=flex_select_nb(allow_partial_, _i, col),
raise_reject=flex_select_nb(raise_reject_, _i, col),
log=flex_select_nb(log_, _i, col),
)
# Process the order
price_area = PriceArea(
open=flex_select_nb(open_, i, col),
high=flex_select_nb(high_, i, col),
low=flex_select_nb(low_, i, col),
close=flex_select_nb(close_, i, col),
)
exec_state = ExecState(
cash=cash_now,
position=position_now,
debt=debt_now,
locked_cash=locked_cash_now,
free_cash=free_cash_now,
val_price=val_price_now,
value=value_now,
)
order_result, new_exec_state = process_order_nb(
group=group,
col=col,
i=i,
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
)
# Append more order information
if order_result.status == OrderStatus.Filled and order_counts[col] >= 1:
order_records["signal_idx"][order_counts[col] - 1, col] = main_info["signal_idx"][col]
order_records["creation_idx"][order_counts[col] - 1, col] = main_info["creation_idx"][col]
order_records["type"][order_counts[col] - 1, col] = main_info["type"][col]
order_records["stop_type"][order_counts[col] - 1, col] = main_info["stop_type"][col]
# Update execution state
cash_now = new_exec_state.cash
position_now = new_exec_state.position
debt_now = new_exec_state.debt
locked_cash_now = new_exec_state.locked_cash
free_cash_now = new_exec_state.free_cash
val_price_now = new_exec_state.val_price
value_now = new_exec_state.value
# Now becomes last
last_position[col] = position_now
last_debt[col] = debt_now
last_locked_cash[col] = locked_cash_now
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
last_cash[group] = cash_now
last_free_cash[group] = free_cash_now
last_value[group] = value_now
last_return[group] = return_now
for col in range(from_col, to_col):
# Update valuation price using current close
_close = flex_select_nb(close_, i, col)
if not np.isnan(_close) or not ffill_val_price:
last_val_price[col] = _close
_cash_earnings = flex_select_nb(cash_earnings_, i, col)
_cash_dividends = flex_select_nb(cash_dividends_, i, col)
_cash_earnings += _cash_dividends * last_position[col]
last_cash[group] += _cash_earnings
last_free_cash[group] += _cash_earnings
if track_cash_earnings:
cash_earnings_out[i, col] += _cash_earnings
if save_state:
position[i, col] = last_position[col]
debt[i, col] = last_debt[col]
locked_cash[i, col] = last_locked_cash[col]
cash[i, group] = last_cash[group]
free_cash[i, group] = last_free_cash[group]
# Update value and return
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - _cash_deposits,
)
prev_close_value[group] = last_value[group]
if save_value:
in_outputs.value[i, group] = last_value[group]
if save_returns:
in_outputs.returns[i, group] = last_return[group]
sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start_,
sim_end=sim_end_,
allow_none=True,
)
return prepare_sim_out_nb(
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
cash_deposits=cash_deposits_out,
cash_earnings=cash_earnings_out,
call_seq=call_seq,
in_outputs=in_outputs,
sim_start=sim_start_out,
sim_end=sim_end_out,
)
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
index=None,
freq=None,
open=base_ch.flex_array_gl_slicer,
high=base_ch.flex_array_gl_slicer,
low=base_ch.flex_array_gl_slicer,
close=base_ch.flex_array_gl_slicer,
init_cash=base_ch.FlexArraySlicer(),
init_position=base_ch.flex_1d_array_gl_slicer,
init_price=base_ch.flex_1d_array_gl_slicer,
cash_deposits=base_ch.FlexArraySlicer(axis=1),
cash_earnings=base_ch.flex_array_gl_slicer,
cash_dividends=base_ch.flex_array_gl_slicer,
long_entries=base_ch.flex_array_gl_slicer,
long_exits=base_ch.flex_array_gl_slicer,
short_entries=base_ch.flex_array_gl_slicer,
short_exits=base_ch.flex_array_gl_slicer,
size=base_ch.flex_array_gl_slicer,
price=base_ch.flex_array_gl_slicer,
size_type=base_ch.flex_array_gl_slicer,
fees=base_ch.flex_array_gl_slicer,
fixed_fees=base_ch.flex_array_gl_slicer,
slippage=base_ch.flex_array_gl_slicer,
min_size=base_ch.flex_array_gl_slicer,
max_size=base_ch.flex_array_gl_slicer,
size_granularity=base_ch.flex_array_gl_slicer,
leverage=base_ch.flex_array_gl_slicer,
leverage_mode=base_ch.flex_array_gl_slicer,
reject_prob=base_ch.flex_array_gl_slicer,
price_area_vio_mode=base_ch.flex_array_gl_slicer,
allow_partial=base_ch.flex_array_gl_slicer,
raise_reject=base_ch.flex_array_gl_slicer,
log=base_ch.flex_array_gl_slicer,
val_price=base_ch.flex_array_gl_slicer,
accumulate=base_ch.flex_array_gl_slicer,
upon_long_conflict=base_ch.flex_array_gl_slicer,
upon_short_conflict=base_ch.flex_array_gl_slicer,
upon_dir_conflict=base_ch.flex_array_gl_slicer,
upon_opposite_entry=base_ch.flex_array_gl_slicer,
order_type=base_ch.flex_array_gl_slicer,
limit_delta=base_ch.flex_array_gl_slicer,
limit_tif=base_ch.flex_array_gl_slicer,
limit_expiry=base_ch.flex_array_gl_slicer,
limit_reverse=base_ch.flex_array_gl_slicer,
limit_order_price=base_ch.flex_array_gl_slicer,
upon_adj_limit_conflict=base_ch.flex_array_gl_slicer,
upon_opp_limit_conflict=base_ch.flex_array_gl_slicer,
use_stops=None,
stop_ladder=None,
sl_stop=base_ch.flex_array_gl_slicer,
tsl_stop=base_ch.flex_array_gl_slicer,
tsl_th=base_ch.flex_array_gl_slicer,
tp_stop=base_ch.flex_array_gl_slicer,
td_stop=base_ch.flex_array_gl_slicer,
dt_stop=base_ch.flex_array_gl_slicer,
stop_entry_price=base_ch.flex_array_gl_slicer,
stop_exit_price=base_ch.flex_array_gl_slicer,
stop_exit_type=base_ch.flex_array_gl_slicer,
stop_order_type=base_ch.flex_array_gl_slicer,
stop_limit_delta=base_ch.flex_array_gl_slicer,
upon_stop_update=base_ch.flex_array_gl_slicer,
upon_adj_stop_conflict=base_ch.flex_array_gl_slicer,
upon_opp_stop_conflict=base_ch.flex_array_gl_slicer,
delta_format=base_ch.flex_array_gl_slicer,
time_delta_format=base_ch.flex_array_gl_slicer,
from_ago=base_ch.flex_array_gl_slicer,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
call_seq=base_ch.array_gl_slicer,
auto_call_seq=None,
ffill_val_price=None,
update_value=None,
save_state=None,
save_value=None,
save_returns=None,
skip_empty=None,
max_order_records=None,
max_log_records=None,
),
**portfolio_ch.merge_sim_outs_config,
)
@register_jitted(cache=True, tags={"can_parallel"})
def from_signals_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
open: tp.FlexArray2dLike = np.nan,
high: tp.FlexArray2dLike = np.nan,
low: tp.FlexArray2dLike = np.nan,
close: tp.FlexArray2dLike = np.nan,
init_cash: tp.FlexArray1dLike = 100.0,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_earnings: tp.FlexArray2dLike = 0.0,
cash_dividends: tp.FlexArray2dLike = 0.0,
long_entries: tp.FlexArray2dLike = False,
long_exits: tp.FlexArray2dLike = False,
short_entries: tp.FlexArray2dLike = False,
short_exits: tp.FlexArray2dLike = False,
size: tp.FlexArray2dLike = np.inf,
price: tp.FlexArray2dLike = np.inf,
size_type: tp.FlexArray2dLike = SizeType.Amount,
fees: tp.FlexArray2dLike = 0.0,
fixed_fees: tp.FlexArray2dLike = 0.0,
slippage: tp.FlexArray2dLike = 0.0,
min_size: tp.FlexArray2dLike = np.nan,
max_size: tp.FlexArray2dLike = np.nan,
size_granularity: tp.FlexArray2dLike = np.nan,
leverage: tp.FlexArray2dLike = 1.0,
leverage_mode: tp.FlexArray2dLike = LeverageMode.Lazy,
reject_prob: tp.FlexArray2dLike = 0.0,
price_area_vio_mode: tp.FlexArray2dLike = PriceAreaVioMode.Ignore,
allow_partial: tp.FlexArray2dLike = True,
raise_reject: tp.FlexArray2dLike = False,
log: tp.FlexArray2dLike = False,
val_price: tp.FlexArray2dLike = np.inf,
accumulate: tp.FlexArray2dLike = AccumulationMode.Disabled,
upon_long_conflict: tp.FlexArray2dLike = ConflictMode.Ignore,
upon_short_conflict: tp.FlexArray2dLike = ConflictMode.Ignore,
upon_dir_conflict: tp.FlexArray2dLike = DirectionConflictMode.Ignore,
upon_opposite_entry: tp.FlexArray2dLike = OppositeEntryMode.ReverseReduce,
order_type: tp.FlexArray2dLike = OrderType.Market,
limit_delta: tp.FlexArray2dLike = np.nan,
limit_tif: tp.FlexArray2dLike = -1,
limit_expiry: tp.FlexArray2dLike = -1,
limit_reverse: tp.FlexArray2dLike = False,
limit_order_price: tp.FlexArray2dLike = LimitOrderPrice.Limit,
upon_adj_limit_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepIgnore,
upon_opp_limit_conflict: tp.FlexArray2dLike = PendingConflictMode.CancelExecute,
use_stops: bool = True,
stop_ladder: int = StopLadderMode.Disabled,
sl_stop: tp.FlexArray2dLike = np.nan,
tsl_stop: tp.FlexArray2dLike = np.nan,
tsl_th: tp.FlexArray2dLike = np.nan,
tp_stop: tp.FlexArray2dLike = np.nan,
td_stop: tp.FlexArray2dLike = -1,
dt_stop: tp.FlexArray2dLike = -1,
stop_entry_price: tp.FlexArray2dLike = StopEntryPrice.Close,
stop_exit_price: tp.FlexArray2dLike = StopExitPrice.Stop,
stop_exit_type: tp.FlexArray2dLike = StopExitType.Close,
stop_order_type: tp.FlexArray2dLike = OrderType.Market,
stop_limit_delta: tp.FlexArray2dLike = np.nan,
upon_stop_update: tp.FlexArray2dLike = StopUpdateMode.Keep,
upon_adj_stop_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepExecute,
upon_opp_stop_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepExecute,
delta_format: tp.FlexArray2dLike = DeltaFormat.Percent,
time_delta_format: tp.FlexArray2dLike = TimeDeltaFormat.Index,
from_ago: tp.FlexArray2dLike = 0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
call_seq: tp.Optional[tp.Array2d] = None,
auto_call_seq: bool = False,
ffill_val_price: bool = True,
update_value: bool = False,
save_state: bool = False,
save_value: bool = False,
save_returns: bool = False,
skip_empty: bool = True,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
) -> SimulationOutput:
"""Simulate given signals.
Iterates in the column-major order. Utilizes flexible broadcasting.
!!! note
Should be only grouped if cash sharing is enabled.
"""
check_group_lens_nb(group_lens, target_shape[1])
cash_sharing = is_grouped_nb(group_lens)
open_ = to_2d_array_nb(np.asarray(open))
high_ = to_2d_array_nb(np.asarray(high))
low_ = to_2d_array_nb(np.asarray(low))
close_ = to_2d_array_nb(np.asarray(close))
init_cash_ = to_1d_array_nb(np.asarray(init_cash))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
cash_dividends_ = to_2d_array_nb(np.asarray(cash_dividends))
long_entries_ = to_2d_array_nb(np.asarray(long_entries))
long_exits_ = to_2d_array_nb(np.asarray(long_exits))
short_entries_ = to_2d_array_nb(np.asarray(short_entries))
short_exits_ = to_2d_array_nb(np.asarray(short_exits))
size_ = to_2d_array_nb(np.asarray(size))
price_ = to_2d_array_nb(np.asarray(price))
size_type_ = to_2d_array_nb(np.asarray(size_type))
fees_ = to_2d_array_nb(np.asarray(fees))
fixed_fees_ = to_2d_array_nb(np.asarray(fixed_fees))
slippage_ = to_2d_array_nb(np.asarray(slippage))
min_size_ = to_2d_array_nb(np.asarray(min_size))
max_size_ = to_2d_array_nb(np.asarray(max_size))
size_granularity_ = to_2d_array_nb(np.asarray(size_granularity))
leverage_ = to_2d_array_nb(np.asarray(leverage))
leverage_mode_ = to_2d_array_nb(np.asarray(leverage_mode))
reject_prob_ = to_2d_array_nb(np.asarray(reject_prob))
price_area_vio_mode_ = to_2d_array_nb(np.asarray(price_area_vio_mode))
allow_partial_ = to_2d_array_nb(np.asarray(allow_partial))
raise_reject_ = to_2d_array_nb(np.asarray(raise_reject))
log_ = to_2d_array_nb(np.asarray(log))
val_price_ = to_2d_array_nb(np.asarray(val_price))
accumulate_ = to_2d_array_nb(np.asarray(accumulate))
upon_long_conflict_ = to_2d_array_nb(np.asarray(upon_long_conflict))
upon_short_conflict_ = to_2d_array_nb(np.asarray(upon_short_conflict))
upon_dir_conflict_ = to_2d_array_nb(np.asarray(upon_dir_conflict))
upon_opposite_entry_ = to_2d_array_nb(np.asarray(upon_opposite_entry))
order_type_ = to_2d_array_nb(np.asarray(order_type))
limit_delta_ = to_2d_array_nb(np.asarray(limit_delta))
limit_tif_ = to_2d_array_nb(np.asarray(limit_tif))
limit_expiry_ = to_2d_array_nb(np.asarray(limit_expiry))
limit_reverse_ = to_2d_array_nb(np.asarray(limit_reverse))
limit_order_price_ = to_2d_array_nb(np.asarray(limit_order_price))
upon_adj_limit_conflict_ = to_2d_array_nb(np.asarray(upon_adj_limit_conflict))
upon_opp_limit_conflict_ = to_2d_array_nb(np.asarray(upon_opp_limit_conflict))
sl_stop_ = to_2d_array_nb(np.asarray(sl_stop))
tsl_stop_ = to_2d_array_nb(np.asarray(tsl_stop))
tsl_th_ = to_2d_array_nb(np.asarray(tsl_th))
tp_stop_ = to_2d_array_nb(np.asarray(tp_stop))
td_stop_ = to_2d_array_nb(np.asarray(td_stop))
dt_stop_ = to_2d_array_nb(np.asarray(dt_stop))
stop_entry_price_ = to_2d_array_nb(np.asarray(stop_entry_price))
stop_exit_price_ = to_2d_array_nb(np.asarray(stop_exit_price))
stop_exit_type_ = to_2d_array_nb(np.asarray(stop_exit_type))
stop_order_type_ = to_2d_array_nb(np.asarray(stop_order_type))
stop_limit_delta_ = to_2d_array_nb(np.asarray(stop_limit_delta))
upon_stop_update_ = to_2d_array_nb(np.asarray(upon_stop_update))
upon_adj_stop_conflict_ = to_2d_array_nb(np.asarray(upon_adj_stop_conflict))
upon_opp_stop_conflict_ = to_2d_array_nb(np.asarray(upon_opp_stop_conflict))
delta_format_ = to_2d_array_nb(np.asarray(delta_format))
time_delta_format_ = to_2d_array_nb(np.asarray(time_delta_format))
from_ago_ = to_2d_array_nb(np.asarray(from_ago))
n_sl_steps = sl_stop_.shape[0]
n_tsl_steps = tsl_stop_.shape[0]
n_tp_steps = tp_stop_.shape[0]
n_td_steps = td_stop_.shape[0]
n_dt_steps = dt_stop_.shape[0]
order_records, log_records = prepare_fs_records_nb(
target_shape=target_shape,
max_order_records=max_order_records,
max_log_records=max_log_records,
)
order_counts = np.full(target_shape[1], 0, dtype=int_)
log_counts = np.full(target_shape[1], 0, dtype=int_)
last_cash = prepare_last_cash_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
)
last_position = prepare_last_position_nb(
target_shape=target_shape,
init_position=init_position_,
)
last_value = prepare_last_value_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
)
last_cash_deposits = np.full_like(last_cash, 0.0)
last_val_price = np.full_like(last_position, np.nan)
last_debt = np.full(target_shape[1], 0.0, dtype=float_)
last_locked_cash = np.full(target_shape[1], 0.0, dtype=float_)
last_free_cash = last_cash.copy()
prev_close_value = last_value.copy()
last_return = np.full_like(last_cash, np.nan)
track_cash_deposits = (cash_deposits_.size == 1 and cash_deposits_[0, 0] != 0) or cash_deposits_.size > 1
if track_cash_deposits:
cash_deposits_out = np.full((target_shape[0], len(group_lens)), 0.0, dtype=float_)
else:
cash_deposits_out = np.full((1, 1), 0.0, dtype=float_)
track_cash_earnings = (cash_earnings_.size == 1 and cash_earnings_[0, 0] != 0) or cash_earnings_.size > 1
track_cash_dividends = (cash_dividends_.size == 1 and cash_dividends_[0, 0] != 0) or cash_dividends_.size > 1
track_cash_earnings = track_cash_earnings or track_cash_dividends
if track_cash_earnings:
cash_earnings_out = np.full(target_shape, 0.0, dtype=float_)
else:
cash_earnings_out = np.full((1, 1), 0.0, dtype=float_)
if save_state:
cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
position = np.full(target_shape, np.nan, dtype=float_)
debt = np.full(target_shape, np.nan, dtype=float_)
locked_cash = np.full(target_shape, np.nan, dtype=float_)
free_cash = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
cash = np.full((0, 0), np.nan, dtype=float_)
position = np.full((0, 0), np.nan, dtype=float_)
debt = np.full((0, 0), np.nan, dtype=float_)
locked_cash = np.full((0, 0), np.nan, dtype=float_)
free_cash = np.full((0, 0), np.nan, dtype=float_)
if save_value:
value = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
value = np.full((0, 0), np.nan, dtype=float_)
if save_returns:
returns = np.full((target_shape[0], len(group_lens)), np.nan, dtype=float_)
else:
returns = np.full((0, 0), np.nan, dtype=float_)
in_outputs = FSInOutputs(
cash=cash,
position=position,
debt=debt,
locked_cash=locked_cash,
free_cash=free_cash,
value=value,
returns=returns,
)
last_limit_info = np.empty(target_shape[1], dtype=limit_info_dt)
last_limit_info["signal_idx"][:] = -1
last_limit_info["creation_idx"][:] = -1
last_limit_info["init_idx"][:] = -1
last_limit_info["init_price"][:] = np.nan
last_limit_info["init_size"][:] = np.nan
last_limit_info["init_size_type"][:] = -1
last_limit_info["init_direction"][:] = -1
last_limit_info["init_stop_type"][:] = -1
last_limit_info["delta"][:] = np.nan
last_limit_info["delta_format"][:] = -1
last_limit_info["tif"][:] = -1
last_limit_info["expiry"][:] = -1
last_limit_info["time_delta_format"][:] = -1
last_limit_info["reverse"][:] = False
last_limit_info["order_price"][:] = np.nan
if use_stops:
last_sl_info = np.empty(target_shape[1], dtype=sl_info_dt)
last_sl_info["init_idx"][:] = -1
last_sl_info["init_price"][:] = np.nan
last_sl_info["init_position"][:] = np.nan
last_sl_info["stop"][:] = np.nan
last_sl_info["exit_price"][:] = -1
last_sl_info["exit_size"][:] = np.nan
last_sl_info["exit_size_type"][:] = -1
last_sl_info["exit_type"][:] = -1
last_sl_info["order_type"][:] = -1
last_sl_info["limit_delta"][:] = np.nan
last_sl_info["delta_format"][:] = -1
last_sl_info["ladder"][:] = -1
last_sl_info["step"][:] = -1
last_sl_info["step_idx"][:] = -1
last_tsl_info = np.empty(target_shape[1], dtype=tsl_info_dt)
last_tsl_info["init_idx"][:] = -1
last_tsl_info["init_price"][:] = np.nan
last_tsl_info["init_position"][:] = np.nan
last_tsl_info["peak_idx"][:] = -1
last_tsl_info["peak_price"][:] = np.nan
last_tsl_info["stop"][:] = np.nan
last_tsl_info["th"][:] = np.nan
last_tsl_info["exit_price"][:] = -1
last_tsl_info["exit_size"][:] = np.nan
last_tsl_info["exit_size_type"][:] = -1
last_tsl_info["exit_type"][:] = -1
last_tsl_info["order_type"][:] = -1
last_tsl_info["limit_delta"][:] = np.nan
last_tsl_info["delta_format"][:] = -1
last_tsl_info["ladder"][:] = -1
last_tsl_info["step"][:] = -1
last_tsl_info["step_idx"][:] = -1
last_tp_info = np.empty(target_shape[1], dtype=tp_info_dt)
last_tp_info["init_idx"][:] = -1
last_tp_info["init_price"][:] = np.nan
last_tp_info["init_position"][:] = np.nan
last_tp_info["stop"][:] = np.nan
last_tp_info["exit_price"][:] = -1
last_tp_info["exit_size"][:] = np.nan
last_tp_info["exit_size_type"][:] = -1
last_tp_info["exit_type"][:] = -1
last_tp_info["order_type"][:] = -1
last_tp_info["limit_delta"][:] = np.nan
last_tp_info["delta_format"][:] = -1
last_tp_info["ladder"][:] = -1
last_tp_info["step"][:] = -1
last_tp_info["step_idx"][:] = -1
last_td_info = np.empty(target_shape[1], dtype=time_info_dt)
last_td_info["init_idx"][:] = -1
last_td_info["init_position"][:] = np.nan
last_td_info["stop"][:] = -1
last_td_info["exit_price"][:] = -1
last_td_info["exit_size"][:] = np.nan
last_td_info["exit_size_type"][:] = -1
last_td_info["exit_type"][:] = -1
last_td_info["order_type"][:] = -1
last_td_info["limit_delta"][:] = np.nan
last_td_info["delta_format"][:] = -1
last_td_info["time_delta_format"][:] = -1
last_td_info["ladder"][:] = -1
last_td_info["step"][:] = -1
last_td_info["step_idx"][:] = -1
last_dt_info = np.empty(target_shape[1], dtype=time_info_dt)
last_dt_info["init_idx"][:] = -1
last_dt_info["init_position"][:] = np.nan
last_dt_info["stop"][:] = -1
last_dt_info["exit_price"][:] = -1
last_dt_info["exit_size"][:] = np.nan
last_dt_info["exit_size_type"][:] = -1
last_dt_info["exit_type"][:] = -1
last_dt_info["order_type"][:] = -1
last_dt_info["limit_delta"][:] = np.nan
last_dt_info["delta_format"][:] = -1
last_dt_info["time_delta_format"][:] = -1
last_dt_info["ladder"][:] = -1
last_dt_info["step"][:] = -1
last_dt_info["step_idx"][:] = -1
else:
last_sl_info = np.empty(0, dtype=sl_info_dt)
last_tsl_info = np.empty(0, dtype=tsl_info_dt)
last_tp_info = np.empty(0, dtype=tp_info_dt)
last_td_info = np.empty(0, dtype=time_info_dt)
last_dt_info = np.empty(0, dtype=time_info_dt)
last_signal = np.empty(target_shape[1], dtype=int_)
main_info = np.empty(target_shape[1], dtype=main_info_dt)
temp_call_seq = np.empty(target_shape[1], dtype=int_)
temp_sort_by = np.empty(target_shape[1], dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(target_shape[0], len(group_lens)),
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
group_len = to_col - from_col
_sim_start = sim_start_[group]
_sim_end = sim_end_[group]
for i in range(_sim_start, _sim_end):
# Add cash
_cash_deposits = flex_select_nb(cash_deposits_, i, group)
if _cash_deposits < 0:
_cash_deposits = max(_cash_deposits, -last_cash[group])
last_cash[group] += _cash_deposits
last_free_cash[group] += _cash_deposits
last_cash_deposits[group] = _cash_deposits
if track_cash_deposits:
cash_deposits_out[i, group] += _cash_deposits
for c in range(group_len):
col = from_col + c
# Update valuation price using current open
_open = flex_select_nb(open_, i, col)
if not np.isnan(_open) or not ffill_val_price:
last_val_price[col] = _open
# Update value and return
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - _cash_deposits,
)
# Get signals
skip = skip_empty
for c in range(group_len):
col = from_col + c
if flex_select_nb(log_, i, col):
skip = False
_i = i - abs(flex_select_nb(from_ago_, i, col))
if _i < 0:
is_long_entry = False
is_long_exit = False
is_short_entry = False
is_short_exit = False
else:
is_long_entry = flex_select_nb(long_entries_, _i, col)
is_long_exit = flex_select_nb(long_exits_, _i, col)
is_short_entry = flex_select_nb(short_entries_, _i, col)
is_short_exit = flex_select_nb(short_exits_, _i, col)
limit_signal = is_limit_active_nb(
init_idx=last_limit_info["init_idx"][col],
init_price=last_limit_info["init_price"][col],
)
if not use_stops:
sl_stop_signal = False
tsl_stop_signal = False
tp_stop_signal = False
td_stop_signal = False
dt_stop_signal = False
else:
sl_stop_signal = is_stop_active_nb(
init_idx=last_sl_info["init_idx"][col],
stop=last_sl_info["stop"][col],
)
tsl_stop_signal = is_stop_active_nb(
init_idx=last_tsl_info["init_idx"][col],
stop=last_tsl_info["stop"][col],
)
tp_stop_signal = is_stop_active_nb(
init_idx=last_tp_info["init_idx"][col],
stop=last_tp_info["stop"][col],
)
td_stop_signal = is_time_stop_active_nb(
init_idx=last_td_info["init_idx"][col],
stop=last_td_info["stop"][col],
)
dt_stop_signal = is_time_stop_active_nb(
init_idx=last_dt_info["init_idx"][col],
stop=last_dt_info["stop"][col],
)
# Pack signals into a single integer
last_signal[col] = (
(is_long_entry << 10)
| (is_long_exit << 9)
| (is_short_entry << 8)
| (is_short_exit << 7)
| (limit_signal << 6)
| (sl_stop_signal << 5)
| (tsl_stop_signal << 4)
| (tp_stop_signal << 3)
| (td_stop_signal << 2)
| (dt_stop_signal << 1)
)
if last_signal[col] > 0:
skip = False
if not skip:
# Get size and value of each order
for c in range(group_len):
col = from_col + c
# Set defaults
main_info["bar_zone"][col] = -1
main_info["signal_idx"][col] = -1
main_info["creation_idx"][col] = -1
main_info["idx"][col] = i
main_info["val_price"][col] = np.nan
main_info["price"][col] = np.nan
main_info["size"][col] = np.nan
main_info["size_type"][col] = -1
main_info["direction"][col] = -1
main_info["type"][col] = -1
main_info["stop_type"][col] = -1
temp_sort_by[col] = 0.0
# Unpack a single integer into signals
is_long_entry = (last_signal[col] >> 10) & 1
is_long_exit = (last_signal[col] >> 9) & 1
is_short_entry = (last_signal[col] >> 8) & 1
is_short_exit = (last_signal[col] >> 7) & 1
limit_signal = (last_signal[col] >> 6) & 1
sl_stop_signal = (last_signal[col] >> 5) & 1
tsl_stop_signal = (last_signal[col] >> 4) & 1
tp_stop_signal = (last_signal[col] >> 3) & 1
td_stop_signal = (last_signal[col] >> 2) & 1
dt_stop_signal = (last_signal[col] >> 1) & 1
any_user_signal = is_long_entry or is_long_exit or is_short_entry or is_short_exit
any_limit_signal = limit_signal
any_stop_signal = (
sl_stop_signal or tsl_stop_signal or tp_stop_signal or td_stop_signal or dt_stop_signal
)
# Set initial info
exec_limit_set = False
exec_limit_set_on_open = False
exec_limit_set_on_close = False
exec_limit_signal_i = -1
exec_limit_creation_i = -1
exec_limit_init_i = -1
exec_limit_val_price = np.nan
exec_limit_price = np.nan
exec_limit_size = np.nan
exec_limit_size_type = -1
exec_limit_direction = -1
exec_limit_stop_type = -1
exec_limit_bar_zone = -1
exec_stop_set = False
exec_stop_set_on_open = False
exec_stop_set_on_close = False
exec_stop_init_i = -1
exec_stop_val_price = np.nan
exec_stop_price = np.nan
exec_stop_size = np.nan
exec_stop_size_type = -1
exec_stop_direction = -1
exec_stop_type = -1
exec_stop_stop_type = -1
exec_stop_delta = np.nan
exec_stop_delta_format = -1
exec_stop_make_limit = False
exec_stop_bar_zone = -1
user_on_open = False
user_on_close = False
exec_user_set = False
exec_user_val_price = np.nan
exec_user_price = np.nan
exec_user_size = np.nan
exec_user_size_type = -1
exec_user_direction = -1
exec_user_type = -1
exec_user_stop_type = -1
exec_user_make_limit = False
exec_user_bar_zone = -1
# Resolve the current bar
_i = i - abs(flex_select_nb(from_ago_, i, col))
_open = flex_select_nb(open_, i, col)
_high = flex_select_nb(high_, i, col)
_low = flex_select_nb(low_, i, col)
_close = flex_select_nb(close_, i, col)
_high, _low = resolve_hl_nb(
open=_open,
high=_high,
low=_low,
close=_close,
)
# Process the limit signal
if any_limit_signal:
# Check whether the limit price was hit
_signal_i = last_limit_info["signal_idx"][col]
_creation_i = last_limit_info["creation_idx"][col]
_init_i = last_limit_info["init_idx"][col]
_price = last_limit_info["init_price"][col]
_size = last_limit_info["init_size"][col]
_size_type = last_limit_info["init_size_type"][col]
_direction = last_limit_info["init_direction"][col]
_stop_type = last_limit_info["init_stop_type"][col]
_delta = last_limit_info["delta"][col]
_delta_format = last_limit_info["delta_format"][col]
_tif = last_limit_info["tif"][col]
_expiry = last_limit_info["expiry"][col]
_time_delta_format = last_limit_info["time_delta_format"][col]
_reverse = last_limit_info["reverse"][col]
_order_price = last_limit_info["order_price"][col]
limit_expired_on_open, limit_expired = check_limit_expired_nb(
creation_idx=_creation_i,
i=i,
tif=_tif,
expiry=_expiry,
time_delta_format=_time_delta_format,
index=index,
freq=freq,
)
limit_price, limit_hit_on_open, limit_hit = check_limit_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
price=_price,
size=_size,
direction=_direction,
limit_delta=_delta,
delta_format=_delta_format,
limit_reverse=_reverse,
can_use_ohlc=True,
check_open=True,
hard_limit=_order_price == LimitOrderPrice.HardLimit,
)
# Resolve the price
limit_price = resolve_limit_order_price_nb(
limit_price=limit_price,
close=_close,
limit_order_price=_order_price,
)
if limit_expired_on_open or (not limit_hit_on_open and limit_expired):
# Expired limit signal
any_limit_signal = False
last_limit_info["signal_idx"][col] = -1
last_limit_info["creation_idx"][col] = -1
last_limit_info["init_idx"][col] = -1
last_limit_info["init_price"][col] = np.nan
last_limit_info["init_size"][col] = np.nan
last_limit_info["init_size_type"][col] = -1
last_limit_info["init_direction"][col] = -1
last_limit_info["delta"][col] = np.nan
last_limit_info["delta_format"][col] = -1
last_limit_info["tif"][col] = -1
last_limit_info["expiry"][col] = -1
last_limit_info["time_delta_format"][col] = -1
last_limit_info["reverse"][col] = False
last_limit_info["order_price"][col] = np.nan
else:
# Save info
if limit_hit:
# Executable limit signal
exec_limit_set = True
exec_limit_set_on_open = limit_hit_on_open
exec_limit_set_on_close = _order_price == LimitOrderPrice.Close
exec_limit_signal_i = _signal_i
exec_limit_creation_i = _creation_i
exec_limit_init_i = _init_i
if np.isinf(limit_price) and limit_price > 0:
exec_limit_val_price = _close
elif np.isinf(limit_price) and limit_price < 0:
exec_limit_val_price = _open
else:
exec_limit_val_price = limit_price
exec_limit_price = limit_price
exec_limit_size = _size
exec_limit_size_type = _size_type
exec_limit_direction = _direction
exec_limit_stop_type = _stop_type
# Process the stop signal
if any_stop_signal:
# Check SL
sl_stop_price, sl_stop_hit_on_open, sl_stop_hit = np.nan, False, False
if sl_stop_signal:
# Check against high and low
sl_stop_price, sl_stop_hit_on_open, sl_stop_hit = check_stop_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
is_position_long=last_position[col] > 0,
init_price=last_sl_info["init_price"][col],
stop=last_sl_info["stop"][col],
delta_format=last_sl_info["delta_format"][col],
hit_below=True,
hard_stop=last_sl_info["exit_price"][col] == StopExitPrice.HardStop,
)
# Check TSL and TTP
tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = np.nan, False, False
if tsl_stop_signal:
# Update peak price using open
if last_position[col] > 0:
if _open > last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col] + _open - last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _open
else:
if _open < last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col] + _open - last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _open
# Check threshold against previous bars and open
if np.isnan(last_tsl_info["th"][col]):
th_hit = True
else:
th_hit = check_tsl_th_hit_nb(
is_position_long=last_position[col] > 0,
init_price=last_tsl_info["init_price"][col],
peak_price=last_tsl_info["peak_price"][col],
threshold=last_tsl_info["th"][col],
delta_format=last_tsl_info["delta_format"][col],
)
if th_hit:
tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = check_stop_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
is_position_long=last_position[col] > 0,
init_price=last_tsl_info["peak_price"][col],
stop=last_tsl_info["stop"][col],
delta_format=last_tsl_info["delta_format"][col],
hit_below=True,
hard_stop=last_tsl_info["exit_price"][col] == StopExitPrice.HardStop,
)
# Update peak price using full bar
if last_position[col] > 0:
if _high > last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col] + _high - last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _high
else:
if _low < last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col] + _low - last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _low
if not tsl_stop_hit:
# Check threshold against full bar
if not th_hit:
if np.isnan(last_tsl_info["th"][col]):
th_hit = True
else:
th_hit = check_tsl_th_hit_nb(
is_position_long=last_position[col] > 0,
init_price=last_tsl_info["init_price"][col],
peak_price=last_tsl_info["peak_price"][col],
threshold=last_tsl_info["th"][col],
delta_format=last_tsl_info["delta_format"][col],
)
if th_hit:
# Check threshold against close
tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = check_stop_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
is_position_long=last_position[col] > 0,
init_price=last_tsl_info["peak_price"][col],
stop=last_tsl_info["stop"][col],
delta_format=last_tsl_info["delta_format"][col],
hit_below=True,
can_use_ohlc=False,
hard_stop=last_tsl_info["exit_price"][col] == StopExitPrice.HardStop,
)
# Check TP
tp_stop_price, tp_stop_hit_on_open, tp_stop_hit = np.nan, False, False
if tp_stop_signal:
tp_stop_price, tp_stop_hit_on_open, tp_stop_hit = check_stop_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
is_position_long=last_position[col] > 0,
init_price=last_tp_info["init_price"][col],
stop=last_tp_info["stop"][col],
delta_format=last_tp_info["delta_format"][col],
hit_below=False,
hard_stop=last_tp_info["exit_price"][col] == StopExitPrice.HardStop,
)
# Check TD
td_stop_price, td_stop_hit_on_open, td_stop_hit = np.nan, False, False
if td_stop_signal:
td_stop_hit_on_open, td_stop_hit = check_td_stop_hit_nb(
init_idx=last_td_info["init_idx"][col],
i=i,
stop=last_td_info["stop"][col],
time_delta_format=last_td_info["time_delta_format"][col],
index=index,
freq=freq,
)
if np.isnan(_open):
td_stop_hit_on_open = False
if td_stop_hit_on_open:
td_stop_price = _open
else:
td_stop_price = _close
# Check DT
dt_stop_price, dt_stop_hit_on_open, dt_stop_hit = np.nan, False, False
if dt_stop_signal:
dt_stop_hit_on_open, dt_stop_hit = check_dt_stop_hit_nb(
i=i,
stop=last_dt_info["stop"][col],
time_delta_format=last_dt_info["time_delta_format"][col],
index=index,
freq=freq,
)
if np.isnan(_open):
dt_stop_hit_on_open = False
if dt_stop_hit_on_open:
dt_stop_price = _open
else:
dt_stop_price = _close
# Resolve the stop signal
sl_hit = False
tsl_hit = False
tp_hit = False
td_hit = False
dt_hit = False
if sl_stop_hit_on_open:
sl_hit = True
elif tsl_stop_hit_on_open:
tsl_hit = True
elif tp_stop_hit_on_open:
tp_hit = True
elif td_stop_hit_on_open:
td_hit = True
elif dt_stop_hit_on_open:
dt_hit = True
elif sl_stop_hit:
sl_hit = True
elif tsl_stop_hit:
tsl_hit = True
elif tp_stop_hit:
tp_hit = True
elif td_stop_hit:
td_hit = True
elif dt_stop_hit:
dt_hit = True
if sl_hit:
stop_price, stop_hit_on_open, stop_hit = sl_stop_price, sl_stop_hit_on_open, sl_stop_hit
_stop_type = StopType.SL
_init_i = last_sl_info["init_idx"][col]
_stop_exit_price = last_sl_info["exit_price"][col]
_stop_exit_size = last_sl_info["exit_size"][col]
_stop_exit_size_type = last_sl_info["exit_size_type"][col]
_stop_exit_type = last_sl_info["exit_type"][col]
_stop_order_type = last_sl_info["order_type"][col]
_limit_delta = last_sl_info["limit_delta"][col]
_delta_format = last_sl_info["delta_format"][col]
_ladder = last_sl_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_sl_info["step"][col]
if step < n_sl_steps:
_stop_exit_size = get_stop_ladder_exit_size_nb(
stop_=sl_stop_,
step=step,
col=col,
init_price=last_sl_info["init_price"][col],
init_position=last_sl_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
delta_format=last_sl_info["delta_format"][col],
hit_below=True,
)
_stop_exit_size_type = SizeType.Amount
elif tsl_hit:
stop_price, stop_hit_on_open, stop_hit = (
tsl_stop_price,
tsl_stop_hit_on_open,
tsl_stop_hit,
)
if np.isnan(last_tsl_info["th"][col]):
_stop_type = StopType.TSL
else:
_stop_type = StopType.TTP
_init_i = last_tsl_info["init_idx"][col]
_stop_exit_price = last_tsl_info["exit_price"][col]
_stop_exit_size = last_tsl_info["exit_size"][col]
_stop_exit_size_type = last_tsl_info["exit_size_type"][col]
_stop_exit_type = last_tsl_info["exit_type"][col]
_stop_order_type = last_tsl_info["order_type"][col]
_limit_delta = last_tsl_info["limit_delta"][col]
_delta_format = last_tsl_info["delta_format"][col]
_ladder = last_tsl_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_tsl_info["step"][col]
if step < n_tsl_steps:
_stop_exit_size = get_stop_ladder_exit_size_nb(
stop_=tsl_stop_,
step=step,
col=col,
init_price=last_tsl_info["init_price"][col],
init_position=last_tsl_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
delta_format=last_tsl_info["delta_format"][col],
hit_below=True,
)
_stop_exit_size_type = SizeType.Amount
elif tp_hit:
stop_price, stop_hit_on_open, stop_hit = tp_stop_price, tp_stop_hit_on_open, tp_stop_hit
_stop_type = StopType.TP
_init_i = last_tp_info["init_idx"][col]
_stop_exit_price = last_tp_info["exit_price"][col]
_stop_exit_size = last_tp_info["exit_size"][col]
_stop_exit_size_type = last_tp_info["exit_size_type"][col]
_stop_exit_type = last_tp_info["exit_type"][col]
_stop_order_type = last_tp_info["order_type"][col]
_limit_delta = last_tp_info["limit_delta"][col]
_delta_format = last_tp_info["delta_format"][col]
_ladder = last_tp_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_tp_info["step"][col]
if step < n_tp_steps:
_stop_exit_size = get_stop_ladder_exit_size_nb(
stop_=tp_stop_,
step=step,
col=col,
init_price=last_tp_info["init_price"][col],
init_position=last_tp_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
delta_format=last_tp_info["delta_format"][col],
hit_below=True,
)
_stop_exit_size_type = SizeType.Amount
elif td_hit:
stop_price, stop_hit_on_open, stop_hit = td_stop_price, td_stop_hit_on_open, td_stop_hit
_stop_type = StopType.TD
_init_i = last_td_info["init_idx"][col]
_stop_exit_price = last_td_info["exit_price"][col]
_stop_exit_size = last_td_info["exit_size"][col]
_stop_exit_size_type = last_td_info["exit_size_type"][col]
_stop_exit_type = last_td_info["exit_type"][col]
_stop_order_type = last_td_info["order_type"][col]
_limit_delta = last_td_info["limit_delta"][col]
_delta_format = last_td_info["delta_format"][col]
_ladder = last_td_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_td_info["step"][col]
if step < n_td_steps:
_stop_exit_size = get_time_stop_ladder_exit_size_nb(
stop_=td_stop_,
step=step,
col=col,
init_idx=last_td_info["init_idx"][col],
init_position=last_td_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
time_delta_format=last_td_info["time_delta_format"][col],
index=index,
)
_stop_exit_size_type = SizeType.Amount
elif dt_hit:
stop_price, stop_hit_on_open, stop_hit = dt_stop_price, dt_stop_hit_on_open, dt_stop_hit
_stop_type = StopType.DT
_init_i = last_dt_info["init_idx"][col]
_stop_exit_price = last_dt_info["exit_price"][col]
_stop_exit_size = last_dt_info["exit_size"][col]
_stop_exit_size_type = last_dt_info["exit_size_type"][col]
_stop_exit_type = last_dt_info["exit_type"][col]
_stop_order_type = last_dt_info["order_type"][col]
_limit_delta = last_dt_info["limit_delta"][col]
_delta_format = last_dt_info["delta_format"][col]
_ladder = last_dt_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_dt_info["step"][col]
if step < n_dt_steps:
_stop_exit_size = get_time_stop_ladder_exit_size_nb(
stop_=dt_stop_,
step=step,
col=col,
init_idx=last_dt_info["init_idx"][col],
init_position=last_dt_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
time_delta_format=last_dt_info["time_delta_format"][col],
index=index,
)
_stop_exit_size_type = SizeType.Amount
else:
stop_price, stop_hit_on_open, stop_hit = np.nan, False, False
if stop_hit:
# Stop price was hit
# Resolve the final stop signal
_accumulate = flex_select_nb(accumulate_, i, col)
_size = flex_select_nb(size_, i, col)
_size_type = flex_select_nb(size_type_, i, col)
if not np.isnan(_stop_exit_size):
_accumulate = True
if _stop_exit_type == StopExitType.Close:
_stop_exit_type = StopExitType.CloseReduce
_size = _stop_exit_size
if _stop_exit_size_type != -1:
_size_type = _stop_exit_size_type
(
stop_is_long_entry,
stop_is_long_exit,
stop_is_short_entry,
stop_is_short_exit,
_accumulate,
) = generate_stop_signal_nb(
position_now=last_position[col],
stop_exit_type=_stop_exit_type,
accumulate=_accumulate,
)
# Resolve the price
_price = resolve_stop_exit_price_nb(
stop_price=stop_price,
close=_close,
stop_exit_price=_stop_exit_price,
)
# Convert both signals to size (direction-aware), size type, and direction
_size, _size_type, _direction = signal_to_size_nb(
position_now=last_position[col],
val_price_now=_price,
value_now=last_value[group],
is_long_entry=stop_is_long_entry,
is_long_exit=stop_is_long_exit,
is_short_entry=stop_is_short_entry,
is_short_exit=stop_is_short_exit,
size=_size,
size_type=_size_type,
accumulate=_accumulate,
)
if not np.isnan(_size):
# Executable stop signal
can_execute = True
if _stop_order_type == OrderType.Limit:
# Use close to check whether the limit price was hit
if _stop_exit_price == StopExitPrice.Close:
# Cannot place a limit order at the close price and execute right away
can_execute = False
if can_execute:
limit_price, _, can_execute = check_limit_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
price=_price,
size=_size,
direction=_direction,
limit_delta=_limit_delta,
delta_format=_delta_format,
limit_reverse=False,
can_use_ohlc=stop_hit_on_open,
check_open=False,
hard_limit=False,
)
if can_execute:
_price = limit_price
# Save info
exec_stop_set = True
exec_stop_set_on_open = stop_hit_on_open
exec_stop_set_on_close = _stop_exit_price == StopExitPrice.Close
exec_stop_init_i = _init_i
if np.isinf(_price) and _price > 0:
exec_stop_val_price = _close
elif np.isinf(_price) and _price < 0:
exec_stop_val_price = _open
else:
exec_stop_val_price = _price
exec_stop_price = _price
exec_stop_size = _size
exec_stop_size_type = _size_type
exec_stop_direction = _direction
exec_stop_type = _stop_order_type
exec_stop_stop_type = _stop_type
exec_stop_delta = _limit_delta
exec_stop_delta_format = _delta_format
exec_stop_make_limit = not can_execute
# Process user signal
if any_user_signal:
if _i < 0:
_price = np.nan
_size = np.nan
_size_type = -1
_direction = -1
else:
_accumulate = flex_select_nb(accumulate_, _i, col)
if is_long_entry or is_short_entry:
# Resolve any single-direction conflicts
_upon_long_conflict = flex_select_nb(upon_long_conflict_, _i, col)
is_long_entry, is_long_exit = resolve_signal_conflict_nb(
position_now=last_position[col],
is_entry=is_long_entry,
is_exit=is_long_exit,
direction=Direction.LongOnly,
conflict_mode=_upon_long_conflict,
)
_upon_short_conflict = flex_select_nb(upon_short_conflict_, _i, col)
is_short_entry, is_short_exit = resolve_signal_conflict_nb(
position_now=last_position[col],
is_entry=is_short_entry,
is_exit=is_short_exit,
direction=Direction.ShortOnly,
conflict_mode=_upon_short_conflict,
)
# Resolve any multi-direction conflicts
_upon_dir_conflict = flex_select_nb(upon_dir_conflict_, _i, col)
is_long_entry, is_short_entry = resolve_dir_conflict_nb(
position_now=last_position[col],
is_long_entry=is_long_entry,
is_short_entry=is_short_entry,
upon_dir_conflict=_upon_dir_conflict,
)
# Resolve an opposite entry
_upon_opposite_entry = flex_select_nb(upon_opposite_entry_, _i, col)
(
is_long_entry,
is_long_exit,
is_short_entry,
is_short_exit,
_accumulate,
) = resolve_opposite_entry_nb(
position_now=last_position[col],
is_long_entry=is_long_entry,
is_long_exit=is_long_exit,
is_short_entry=is_short_entry,
is_short_exit=is_short_exit,
upon_opposite_entry=_upon_opposite_entry,
accumulate=_accumulate,
)
# Resolve the price
_price = flex_select_nb(price_, _i, col)
# Convert both signals to size (direction-aware), size type, and direction
_val_price = flex_select_nb(val_price_, i, col)
if np.isinf(_val_price) and _val_price > 0:
if np.isinf(_price) and _price > 0:
_val_price = _close
elif np.isinf(_price) and _price < 0:
_val_price = _open
else:
_val_price = _price
elif np.isnan(_val_price) or (np.isinf(_val_price) and _val_price < 0):
_val_price = last_val_price[col]
_size, _size_type, _direction = signal_to_size_nb(
position_now=last_position[col],
val_price_now=_val_price,
value_now=last_value[group],
is_long_entry=is_long_entry,
is_long_exit=is_long_exit,
is_short_entry=is_short_entry,
is_short_exit=is_short_exit,
size=flex_select_nb(size_, _i, col),
size_type=flex_select_nb(size_type_, _i, col),
accumulate=_accumulate,
)
if np.isinf(_price):
if _price > 0:
user_on_close = True
else:
user_on_open = True
if not np.isnan(_size):
# Executable user signal
can_execute = True
_order_type = flex_select_nb(order_type_, _i, col)
if _order_type == OrderType.Limit:
# Use close to check whether the limit price was hit
can_use_ohlc = False
if np.isinf(_price):
if _price > 0:
# Cannot place a limit order at the close price and execute right away
_price = _close
can_execute = False
else:
can_use_ohlc = True
_price = _open
if can_execute:
_limit_delta = flex_select_nb(limit_delta_, _i, col)
_delta_format = flex_select_nb(delta_format_, _i, col)
_limit_reverse = flex_select_nb(limit_reverse_, _i, col)
limit_price, _, can_execute = check_limit_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
price=_price,
size=_size,
direction=_direction,
limit_delta=_limit_delta,
delta_format=_delta_format,
limit_reverse=_limit_reverse,
can_use_ohlc=can_use_ohlc,
check_open=False,
hard_limit=False,
)
if can_execute:
_price = limit_price
# Save info
exec_user_set = True
exec_user_val_price = _val_price
exec_user_price = _price
exec_user_size = _size
exec_user_size_type = _size_type
exec_user_direction = _direction
exec_user_type = _order_type
exec_user_stop_type = -1
exec_user_make_limit = not can_execute
if (
exec_limit_set
or exec_stop_set
or exec_user_set
or ((any_limit_signal or any_stop_signal) and any_user_signal)
):
# Choose the main executable signal
# Priority: limit -> stop -> user
# Check whether the main signal comes on open
keep_limit = True
keep_stop = True
execute_limit = False
execute_stop = False
execute_user = False
if exec_limit_set_on_open:
keep_limit = False
keep_stop = False
execute_limit = True
if exec_limit_set_on_close:
exec_limit_bar_zone = BarZone.Close
else:
exec_limit_bar_zone = BarZone.Open
elif exec_stop_set_on_open:
keep_limit = False
keep_stop = _ladder
execute_stop = True
if exec_stop_set_on_close:
exec_stop_bar_zone = BarZone.Close
else:
exec_stop_bar_zone = BarZone.Open
elif any_user_signal and user_on_open:
execute_user = True
if any_limit_signal and (execute_user or not exec_user_set):
stop_size = get_diraware_size_nb(
size=last_limit_info["init_size"][col],
direction=last_limit_info["init_direction"][col],
)
keep_limit, execute_user = resolve_pending_conflict_nb(
is_pending_long=stop_size >= 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col),
)
if any_stop_signal and (execute_user or not exec_user_set):
keep_stop, execute_user = resolve_pending_conflict_nb(
is_pending_long=last_position[col] < 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col),
)
if not exec_user_set:
execute_user = False
if execute_user:
exec_user_bar_zone = BarZone.Open
if not execute_limit and not execute_stop and not execute_user:
# Check whether the main signal comes in the middle of the bar
if exec_limit_set and not exec_limit_set_on_open and keep_limit:
keep_limit = False
keep_stop = False
execute_limit = True
exec_limit_bar_zone = BarZone.Middle
elif (
exec_stop_set and not exec_stop_set_on_open and not exec_stop_set_on_close and keep_stop
):
keep_limit = False
keep_stop = _ladder
execute_stop = True
exec_stop_bar_zone = BarZone.Middle
elif any_user_signal and not user_on_open and not user_on_close:
execute_user = True
if any_limit_signal and keep_limit and (execute_user or not exec_user_set):
stop_size = get_diraware_size_nb(
size=last_limit_info["init_size"][col],
direction=last_limit_info["init_direction"][col],
)
keep_limit, execute_user = resolve_pending_conflict_nb(
is_pending_long=stop_size >= 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col),
)
if any_stop_signal and keep_stop and (execute_user or not exec_user_set):
keep_stop, execute_user = resolve_pending_conflict_nb(
is_pending_long=last_position[col] < 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col),
)
if not exec_user_set:
execute_user = False
if execute_user:
exec_user_bar_zone = BarZone.Middle
if not execute_limit and not execute_stop and not execute_user:
# Check whether the main signal comes on close
if exec_stop_set_on_close and keep_stop:
keep_limit = False
keep_stop = _ladder
execute_stop = True
exec_stop_bar_zone = BarZone.Close
elif any_user_signal and user_on_close:
execute_user = True
if any_limit_signal and keep_limit and (execute_user or not exec_user_set):
stop_size = get_diraware_size_nb(
size=last_limit_info["init_size"][col],
direction=last_limit_info["init_direction"][col],
)
keep_limit, execute_user = resolve_pending_conflict_nb(
is_pending_long=stop_size >= 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col),
)
if any_stop_signal and keep_stop and (execute_user or not exec_user_set):
keep_stop, execute_user = resolve_pending_conflict_nb(
is_pending_long=last_position[col] < 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col),
)
if not exec_user_set:
execute_user = False
if execute_user:
exec_user_bar_zone = BarZone.Close
# Process the limit signal
if execute_limit:
# Execute the signal
main_info["bar_zone"][col] = exec_limit_bar_zone
main_info["signal_idx"][col] = exec_limit_signal_i
main_info["creation_idx"][col] = exec_limit_creation_i
main_info["idx"][col] = exec_limit_init_i
main_info["val_price"][col] = exec_limit_val_price
main_info["price"][col] = exec_limit_price
main_info["size"][col] = exec_limit_size
main_info["size_type"][col] = exec_limit_size_type
main_info["direction"][col] = exec_limit_direction
main_info["type"][col] = OrderType.Limit
main_info["stop_type"][col] = exec_limit_stop_type
if execute_limit or (any_limit_signal and not keep_limit):
# Clear the pending info
any_limit_signal = False
last_limit_info["signal_idx"][col] = -1
last_limit_info["creation_idx"][col] = -1
last_limit_info["init_idx"][col] = -1
last_limit_info["init_price"][col] = np.nan
last_limit_info["init_size"][col] = np.nan
last_limit_info["init_size_type"][col] = -1
last_limit_info["init_direction"][col] = -1
last_limit_info["init_stop_type"][col] = -1
last_limit_info["delta"][col] = np.nan
last_limit_info["delta_format"][col] = -1
last_limit_info["tif"][col] = -1
last_limit_info["expiry"][col] = -1
last_limit_info["time_delta_format"][col] = -1
last_limit_info["reverse"][col] = False
last_limit_info["order_price"][col] = np.nan
# Process the stop signal
if execute_stop:
# Execute the signal
if exec_stop_make_limit:
if any_limit_signal:
raise ValueError("Only one active limit signal is allowed at a time")
_limit_tif = flex_select_nb(limit_tif_, i, col)
_limit_expiry = flex_select_nb(limit_expiry_, i, col)
_time_delta_format = flex_select_nb(time_delta_format_, i, col)
_limit_order_price = flex_select_nb(limit_order_price_, i, col)
last_limit_info["signal_idx"][col] = exec_stop_init_i
last_limit_info["creation_idx"][col] = i
last_limit_info["init_idx"][col] = i
last_limit_info["init_price"][col] = exec_stop_price
last_limit_info["init_size"][col] = exec_stop_size
last_limit_info["init_size_type"][col] = exec_stop_size_type
last_limit_info["init_direction"][col] = exec_stop_direction
last_limit_info["init_stop_type"][col] = exec_stop_stop_type
last_limit_info["delta"][col] = exec_stop_delta
last_limit_info["delta_format"][col] = exec_stop_delta_format
last_limit_info["tif"][col] = _limit_tif
last_limit_info["expiry"][col] = _limit_expiry
last_limit_info["time_delta_format"][col] = _time_delta_format
last_limit_info["reverse"][col] = False
last_limit_info["order_price"][col] = _limit_order_price
else:
main_info["bar_zone"][col] = exec_stop_bar_zone
main_info["signal_idx"][col] = exec_stop_init_i
main_info["creation_idx"][col] = i
main_info["idx"][col] = i
main_info["val_price"][col] = exec_stop_val_price
main_info["price"][col] = exec_stop_price
main_info["size"][col] = exec_stop_size
main_info["size_type"][col] = exec_stop_size_type
main_info["direction"][col] = exec_stop_direction
main_info["type"][col] = exec_stop_type
main_info["stop_type"][col] = exec_stop_stop_type
if any_stop_signal and not keep_stop:
# Clear the pending info
any_stop_signal = False
last_sl_info["init_idx"][col] = -1
last_sl_info["init_price"][col] = np.nan
last_sl_info["init_position"][col] = np.nan
last_sl_info["stop"][col] = np.nan
last_sl_info["exit_price"][col] = -1
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
last_sl_info["exit_type"][col] = -1
last_sl_info["order_type"][col] = -1
last_sl_info["limit_delta"][col] = np.nan
last_sl_info["delta_format"][col] = -1
last_sl_info["ladder"][col] = -1
last_sl_info["step"][col] = -1
last_sl_info["step_idx"][col] = -1
last_tsl_info["init_idx"][col] = -1
last_tsl_info["init_price"][col] = np.nan
last_tsl_info["init_position"][col] = np.nan
last_tsl_info["peak_idx"][col] = -1
last_tsl_info["peak_price"][col] = np.nan
last_tsl_info["stop"][col] = np.nan
last_tsl_info["th"][col] = np.nan
last_tsl_info["exit_price"][col] = -1
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
last_tsl_info["exit_type"][col] = -1
last_tsl_info["order_type"][col] = -1
last_tsl_info["limit_delta"][col] = np.nan
last_tsl_info["delta_format"][col] = -1
last_tsl_info["ladder"][col] = -1
last_tsl_info["step"][col] = -1
last_tsl_info["step_idx"][col] = -1
last_tp_info["init_idx"][col] = -1
last_tp_info["init_price"][col] = np.nan
last_tp_info["init_position"][col] = np.nan
last_tp_info["stop"][col] = np.nan
last_tp_info["exit_price"][col] = -1
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
last_tp_info["exit_type"][col] = -1
last_tp_info["order_type"][col] = -1
last_tp_info["limit_delta"][col] = np.nan
last_tp_info["delta_format"][col] = -1
last_tp_info["ladder"][col] = -1
last_tp_info["step"][col] = -1
last_tp_info["step_idx"][col] = -1
last_td_info["init_idx"][col] = -1
last_td_info["init_position"][col] = np.nan
last_td_info["stop"][col] = -1
last_td_info["exit_price"][col] = -1
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
last_td_info["exit_type"][col] = -1
last_td_info["order_type"][col] = -1
last_td_info["limit_delta"][col] = np.nan
last_td_info["delta_format"][col] = -1
last_td_info["time_delta_format"][col] = -1
last_td_info["ladder"][col] = -1
last_td_info["step"][col] = -1
last_td_info["step_idx"][col] = -1
last_dt_info["init_idx"][col] = -1
last_dt_info["init_position"][col] = np.nan
last_dt_info["stop"][col] = -1
last_dt_info["exit_price"][col] = -1
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
last_dt_info["exit_type"][col] = -1
last_dt_info["order_type"][col] = -1
last_dt_info["limit_delta"][col] = np.nan
last_dt_info["delta_format"][col] = -1
last_dt_info["time_delta_format"][col] = -1
last_dt_info["ladder"][col] = -1
last_dt_info["step"][col] = -1
last_dt_info["step_idx"][col] = -1
# Process the user signal
if execute_user:
# Execute the signal
if _i >= 0:
if exec_user_make_limit:
if any_limit_signal:
raise ValueError("Only one active limit signal is allowed at a time")
_limit_delta = flex_select_nb(limit_delta_, _i, col)
_delta_format = flex_select_nb(delta_format_, _i, col)
_limit_tif = flex_select_nb(limit_tif_, _i, col)
_limit_expiry = flex_select_nb(limit_expiry_, _i, col)
_time_delta_format = flex_select_nb(time_delta_format_, _i, col)
_limit_reverse = flex_select_nb(limit_reverse_, _i, col)
_limit_order_price = flex_select_nb(limit_order_price_, _i, col)
last_limit_info["signal_idx"][col] = _i
last_limit_info["creation_idx"][col] = i
last_limit_info["init_idx"][col] = _i
last_limit_info["init_price"][col] = exec_user_price
last_limit_info["init_size"][col] = exec_user_size
last_limit_info["init_size_type"][col] = exec_user_size_type
last_limit_info["init_direction"][col] = exec_user_direction
last_limit_info["init_stop_type"][col] = -1
last_limit_info["delta"][col] = _limit_delta
last_limit_info["delta_format"][col] = _delta_format
last_limit_info["tif"][col] = _limit_tif
last_limit_info["expiry"][col] = _limit_expiry
last_limit_info["time_delta_format"][col] = _time_delta_format
last_limit_info["reverse"][col] = _limit_reverse
last_limit_info["order_price"][col] = _limit_order_price
else:
main_info["bar_zone"][col] = exec_user_bar_zone
main_info["signal_idx"][col] = _i
main_info["creation_idx"][col] = i
main_info["idx"][col] = _i
main_info["val_price"][col] = exec_user_val_price
main_info["price"][col] = exec_user_price
main_info["size"][col] = exec_user_size
main_info["size_type"][col] = exec_user_size_type
main_info["direction"][col] = exec_user_direction
main_info["type"][col] = exec_user_type
main_info["stop_type"][col] = exec_user_stop_type
skip = skip_empty
if skip:
for col in range(from_col, to_col):
if flex_select_nb(log_, i, col):
skip = False
break
if not np.isnan(main_info["size"][col]):
skip = False
break
if not skip:
# Check bar zone and update valuation price
bar_zone = -1
same_bar_zone = True
same_timing = True
for c in range(group_len):
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
if bar_zone == -1:
bar_zone = main_info["bar_zone"][col]
if main_info["bar_zone"][col] != bar_zone:
same_bar_zone = False
same_timing = False
if main_info["bar_zone"][col] == BarZone.Middle:
same_timing = False
_val_price = main_info["val_price"][col]
if not np.isnan(_val_price) or not ffill_val_price:
last_val_price[col] = _val_price
if cash_sharing:
# Dynamically sort by order value -> selling comes first to release funds early
if call_seq is None:
for c in range(group_len):
temp_call_seq[c] = c
call_seq_now = temp_call_seq[:group_len]
else:
call_seq_now = call_seq[i, from_col:to_col]
if auto_call_seq:
# Sort by order value
if not same_timing:
raise ValueError("Cannot sort orders by value if they are executed at different times")
for c in range(group_len):
if call_seq_now[c] != c:
raise ValueError("Call sequence must follow CallSeqType.Default")
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
# Approximate order value
exec_state = ExecState(
cash=last_cash[group] if cash_sharing else last_cash[col],
position=last_position[col],
debt=last_debt[col],
locked_cash=last_locked_cash[col],
free_cash=last_free_cash[group] if cash_sharing else last_free_cash[col],
val_price=last_val_price[col],
value=last_value[group] if cash_sharing else last_value[col],
)
temp_sort_by[c] = approx_order_value_nb(
exec_state=exec_state,
size=main_info["size"][col],
size_type=main_info["size_type"][col],
direction=main_info["direction"][col],
)
insert_argsort_nb(temp_sort_by[:group_len], call_seq_now)
else:
if not same_bar_zone:
# Sort by bar zone
for c in range(group_len):
if call_seq_now[c] != c:
raise ValueError("Call sequence must follow CallSeqType.Default")
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
temp_sort_by[c] = main_info["bar_zone"][col]
insert_argsort_nb(temp_sort_by[:group_len], call_seq_now)
for k in range(group_len):
if cash_sharing:
c = call_seq_now[k]
if c >= group_len:
raise ValueError("Call index out of bounds of the group")
else:
c = k
col = from_col + c
if skip_empty and np.isnan(main_info["size"][col]): # shortcut
continue
# Get current values per column
position_now = last_position[col]
debt_now = last_debt[col]
locked_cash_now = last_locked_cash[col]
val_price_now = last_val_price[col]
cash_now = last_cash[group]
free_cash_now = last_free_cash[group]
value_now = last_value[group]
return_now = last_return[group]
# Generate the next order
_i = main_info["idx"][col]
if main_info["type"][col] == OrderType.Limit:
_slippage = 0.0
else:
_slippage = float(flex_select_nb(slippage_, _i, col))
_min_size = flex_select_nb(min_size_, _i, col)
_max_size = flex_select_nb(max_size_, _i, col)
_size_type = flex_select_nb(size_type_, _i, col)
if _size_type != main_info["size_type"][col]:
if not np.isnan(_min_size):
_min_size, _ = resolve_size_nb(
size=_min_size,
size_type=_size_type,
position=position_now,
val_price=val_price_now,
value=value_now,
target_size_type=main_info["size_type"][col],
as_requirement=True,
)
if not np.isnan(_max_size):
_max_size, _ = resolve_size_nb(
size=_max_size,
size_type=_size_type,
position=position_now,
val_price=val_price_now,
value=value_now,
target_size_type=main_info["size_type"][col],
as_requirement=True,
)
order = order_nb(
size=main_info["size"][col],
price=main_info["price"][col],
size_type=main_info["size_type"][col],
direction=main_info["direction"][col],
fees=flex_select_nb(fees_, _i, col),
fixed_fees=flex_select_nb(fixed_fees_, _i, col),
slippage=_slippage,
min_size=_min_size,
max_size=_max_size,
size_granularity=flex_select_nb(size_granularity_, _i, col),
leverage=flex_select_nb(leverage_, _i, col),
leverage_mode=flex_select_nb(leverage_mode_, _i, col),
reject_prob=flex_select_nb(reject_prob_, _i, col),
price_area_vio_mode=flex_select_nb(price_area_vio_mode_, _i, col),
allow_partial=flex_select_nb(allow_partial_, _i, col),
raise_reject=flex_select_nb(raise_reject_, _i, col),
log=flex_select_nb(log_, _i, col),
)
# Process the order
price_area = PriceArea(
open=flex_select_nb(open_, i, col),
high=flex_select_nb(high_, i, col),
low=flex_select_nb(low_, i, col),
close=flex_select_nb(close_, i, col),
)
exec_state = ExecState(
cash=cash_now,
position=position_now,
debt=debt_now,
locked_cash=locked_cash_now,
free_cash=free_cash_now,
val_price=val_price_now,
value=value_now,
)
order_result, new_exec_state = process_order_nb(
group=group,
col=col,
i=i,
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
)
# Append more order information
if order_result.status == OrderStatus.Filled and order_counts[col] >= 1:
order_records["signal_idx"][order_counts[col] - 1, col] = main_info["signal_idx"][col]
order_records["creation_idx"][order_counts[col] - 1, col] = main_info["creation_idx"][col]
order_records["type"][order_counts[col] - 1, col] = main_info["type"][col]
order_records["stop_type"][order_counts[col] - 1, col] = main_info["stop_type"][col]
# Update execution state
cash_now = new_exec_state.cash
position_now = new_exec_state.position
debt_now = new_exec_state.debt
locked_cash_now = new_exec_state.locked_cash
free_cash_now = new_exec_state.free_cash
val_price_now = new_exec_state.val_price
value_now = new_exec_state.value
if use_stops:
# Update stop price
if position_now == 0:
# Not in position anymore -> clear stops (irrespective of order success)
last_sl_info["init_idx"][col] = -1
last_sl_info["init_price"][col] = np.nan
last_sl_info["init_position"][col] = np.nan
last_sl_info["stop"][col] = np.nan
last_sl_info["exit_price"][col] = -1
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
last_sl_info["exit_type"][col] = -1
last_sl_info["order_type"][col] = -1
last_sl_info["limit_delta"][col] = np.nan
last_sl_info["delta_format"][col] = -1
last_sl_info["ladder"][col] = -1
last_sl_info["step"][col] = -1
last_sl_info["step_idx"][col] = -1
last_tsl_info["init_idx"][col] = -1
last_tsl_info["init_price"][col] = np.nan
last_tsl_info["init_position"][col] = np.nan
last_tsl_info["peak_idx"][col] = -1
last_tsl_info["peak_price"][col] = np.nan
last_tsl_info["stop"][col] = np.nan
last_tsl_info["th"][col] = np.nan
last_tsl_info["exit_price"][col] = -1
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
last_tsl_info["exit_type"][col] = -1
last_tsl_info["order_type"][col] = -1
last_tsl_info["limit_delta"][col] = np.nan
last_tsl_info["delta_format"][col] = -1
last_tsl_info["ladder"][col] = -1
last_tsl_info["step"][col] = -1
last_tsl_info["step_idx"][col] = -1
last_tp_info["init_idx"][col] = -1
last_tp_info["init_price"][col] = np.nan
last_tp_info["init_position"][col] = np.nan
last_tp_info["stop"][col] = np.nan
last_tp_info["exit_price"][col] = -1
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
last_tp_info["exit_type"][col] = -1
last_tp_info["order_type"][col] = -1
last_tp_info["limit_delta"][col] = np.nan
last_tp_info["delta_format"][col] = -1
last_tp_info["ladder"][col] = -1
last_tp_info["step"][col] = -1
last_tp_info["step_idx"][col] = -1
last_td_info["init_idx"][col] = -1
last_td_info["init_position"][col] = np.nan
last_td_info["stop"][col] = -1
last_td_info["exit_price"][col] = -1
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
last_td_info["exit_type"][col] = -1
last_td_info["order_type"][col] = -1
last_td_info["limit_delta"][col] = np.nan
last_td_info["delta_format"][col] = -1
last_td_info["time_delta_format"][col] = -1
last_td_info["ladder"][col] = -1
last_td_info["step"][col] = -1
last_td_info["step_idx"][col] = -1
last_dt_info["init_idx"][col] = -1
last_dt_info["init_position"][col] = np.nan
last_dt_info["stop"][col] = -1
last_dt_info["exit_price"][col] = -1
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
last_dt_info["exit_type"][col] = -1
last_dt_info["order_type"][col] = -1
last_dt_info["limit_delta"][col] = np.nan
last_dt_info["delta_format"][col] = -1
last_dt_info["time_delta_format"][col] = -1
last_dt_info["ladder"][col] = -1
last_dt_info["step"][col] = -1
last_dt_info["step_idx"][col] = -1
else:
if main_info["stop_type"][col] == StopType.SL:
if last_sl_info["ladder"][col]:
step = last_sl_info["step"][col] + 1
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
if stop_ladder and last_sl_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_sl_steps:
last_sl_info["stop"][col] = flex_select_nb(sl_stop_, step, col)
last_sl_info["step"][col] = step
last_sl_info["step_idx"][col] = i
else:
last_sl_info["stop"][col] = np.nan
last_sl_info["step"][col] = -1
last_sl_info["step_idx"][col] = -1
else:
last_sl_info["stop"][col] = np.nan
last_sl_info["step"][col] = step
last_sl_info["step_idx"][col] = i
elif (
main_info["stop_type"][col] == StopType.TSL
or main_info["stop_type"][col] == StopType.TTP
):
if last_tsl_info["ladder"][col]:
step = last_tsl_info["step"][col] + 1
last_tsl_info["step"][col] = step
last_tsl_info["step_idx"][col] = i
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
if stop_ladder and last_tsl_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_tsl_steps:
last_tsl_info["stop"][col] = flex_select_nb(tsl_stop_, step, col)
last_tsl_info["step"][col] = step
last_tsl_info["step_idx"][col] = i
else:
last_tsl_info["stop"][col] = np.nan
last_tsl_info["step"][col] = -1
last_tsl_info["step_idx"][col] = -1
else:
last_tsl_info["stop"][col] = np.nan
last_tsl_info["step"][col] = step
last_tsl_info["step_idx"][col] = i
elif main_info["stop_type"][col] == StopType.TP:
if last_tp_info["ladder"][col]:
step = last_tp_info["step"][col] + 1
last_tp_info["step"][col] = step
last_tp_info["step_idx"][col] = i
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
if stop_ladder and last_tp_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_tp_steps:
last_tp_info["stop"][col] = flex_select_nb(tp_stop_, step, col)
last_tp_info["step"][col] = step
last_tp_info["step_idx"][col] = i
else:
last_tp_info["stop"][col] = np.nan
last_tp_info["step"][col] = -1
last_tp_info["step_idx"][col] = -1
else:
last_tp_info["stop"][col] = np.nan
last_tp_info["step"][col] = step
last_tp_info["step_idx"][col] = i
elif main_info["stop_type"][col] == StopType.TD:
if last_td_info["ladder"][col]:
step = last_td_info["step"][col] + 1
last_td_info["step"][col] = step
last_td_info["step_idx"][col] = i
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
if stop_ladder and last_td_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_td_steps:
last_td_info["stop"][col] = flex_select_nb(td_stop_, step, col)
last_td_info["step"][col] = step
last_td_info["step_idx"][col] = i
else:
last_td_info["stop"][col] = -1
last_td_info["step"][col] = -1
last_td_info["step_idx"][col] = -1
else:
last_td_info["stop"][col] = -1
last_td_info["step"][col] = step
last_td_info["step_idx"][col] = i
elif main_info["stop_type"][col] == StopType.DT:
if last_dt_info["ladder"][col]:
step = last_dt_info["step"][col] + 1
last_dt_info["step"][col] = step
last_dt_info["step_idx"][col] = i
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
if stop_ladder and last_dt_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_dt_steps:
last_dt_info["stop"][col] = flex_select_nb(dt_stop_, step, col)
last_dt_info["step"][col] = step
last_dt_info["step_idx"][col] = i
else:
last_dt_info["stop"][col] = -1
last_dt_info["step"][col] = -1
last_dt_info["step_idx"][col] = -1
else:
last_dt_info["stop"][col] = -1
last_dt_info["step"][col] = step
last_dt_info["step_idx"][col] = i
if order_result.status == OrderStatus.Filled and position_now != 0:
# Order filled and in position -> possibly set stops
_price = main_info["price"][col]
_stop_entry_price = flex_select_nb(stop_entry_price_, i, col)
if _stop_entry_price < 0:
if _stop_entry_price == StopEntryPrice.ValPrice:
new_init_price = val_price_now
can_use_ohlc = False
elif _stop_entry_price == StopEntryPrice.Price:
new_init_price = order.price
can_use_ohlc = np.isinf(_price) and _price < 0
if np.isinf(new_init_price):
if new_init_price > 0:
new_init_price = flex_select_nb(close_, i, col)
else:
new_init_price = flex_select_nb(open_, i, col)
elif _stop_entry_price == StopEntryPrice.FillPrice:
new_init_price = order_result.price
can_use_ohlc = np.isinf(_price) and _price < 0
elif _stop_entry_price == StopEntryPrice.Open:
new_init_price = flex_select_nb(open_, i, col)
can_use_ohlc = True
elif _stop_entry_price == StopEntryPrice.Close:
new_init_price = flex_select_nb(close_, i, col)
can_use_ohlc = False
else:
raise ValueError("Invalid StopEntryPrice option")
else:
new_init_price = _stop_entry_price
can_use_ohlc = False
if stop_ladder:
_sl_stop = flex_select_nb(sl_stop_, 0, col)
_tsl_stop = flex_select_nb(tsl_stop_, 0, col)
_tp_stop = flex_select_nb(tp_stop_, 0, col)
_td_stop = flex_select_nb(td_stop_, 0, col)
_dt_stop = flex_select_nb(dt_stop_, 0, col)
else:
_sl_stop = flex_select_nb(sl_stop_, i, col)
_tsl_stop = flex_select_nb(tsl_stop_, i, col)
_tp_stop = flex_select_nb(tp_stop_, i, col)
_td_stop = flex_select_nb(td_stop_, i, col)
_dt_stop = flex_select_nb(dt_stop_, i, col)
_tsl_th = flex_select_nb(tsl_th_, i, col)
_stop_exit_price = flex_select_nb(stop_exit_price_, i, col)
_stop_exit_type = flex_select_nb(stop_exit_type_, i, col)
_stop_order_type = flex_select_nb(stop_order_type_, i, col)
_stop_limit_delta = flex_select_nb(stop_limit_delta_, i, col)
_delta_format = flex_select_nb(delta_format_, i, col)
_time_delta_format = flex_select_nb(time_delta_format_, i, col)
tsl_updated = False
if exec_state.position == 0 or np.sign(position_now) != np.sign(exec_state.position):
# Position opened/reversed -> set stops
last_sl_info["init_idx"][col] = i
last_sl_info["init_price"][col] = new_init_price
last_sl_info["init_position"][col] = position_now
last_sl_info["stop"][col] = _sl_stop
last_sl_info["exit_price"][col] = _stop_exit_price
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
last_sl_info["exit_type"][col] = _stop_exit_type
last_sl_info["order_type"][col] = _stop_order_type
last_sl_info["limit_delta"][col] = _stop_limit_delta
last_sl_info["delta_format"][col] = _delta_format
last_sl_info["ladder"][col] = stop_ladder
last_sl_info["step"][col] = 0
last_sl_info["step_idx"][col] = i
tsl_updated = True
last_tsl_info["init_idx"][col] = i
last_tsl_info["init_price"][col] = new_init_price
last_tsl_info["init_position"][col] = position_now
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = new_init_price
last_tsl_info["stop"][col] = _tsl_stop
last_tsl_info["th"][col] = _tsl_th
last_tsl_info["exit_price"][col] = _stop_exit_price
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
last_tsl_info["exit_type"][col] = _stop_exit_type
last_tsl_info["order_type"][col] = _stop_order_type
last_tsl_info["limit_delta"][col] = _stop_limit_delta
last_tsl_info["delta_format"][col] = _delta_format
last_tsl_info["ladder"][col] = stop_ladder
last_tsl_info["step"][col] = 0
last_tsl_info["step_idx"][col] = i
last_tp_info["init_idx"][col] = i
last_tp_info["init_price"][col] = new_init_price
last_tp_info["init_position"][col] = position_now
last_tp_info["stop"][col] = _tp_stop
last_tp_info["exit_price"][col] = _stop_exit_price
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
last_tp_info["exit_type"][col] = _stop_exit_type
last_tp_info["order_type"][col] = _stop_order_type
last_tp_info["limit_delta"][col] = _stop_limit_delta
last_tp_info["delta_format"][col] = _delta_format
last_tp_info["ladder"][col] = stop_ladder
last_tp_info["step"][col] = 0
last_tp_info["step_idx"][col] = i
last_td_info["init_idx"][col] = i
last_td_info["init_position"][col] = position_now
last_td_info["stop"][col] = _td_stop
last_td_info["exit_price"][col] = _stop_exit_price
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
last_td_info["exit_type"][col] = _stop_exit_type
last_td_info["order_type"][col] = _stop_order_type
last_td_info["limit_delta"][col] = _stop_limit_delta
last_td_info["delta_format"][col] = _delta_format
last_td_info["time_delta_format"][col] = _time_delta_format
last_td_info["ladder"][col] = stop_ladder
last_td_info["step"][col] = 0
last_td_info["step_idx"][col] = i
last_dt_info["init_idx"][col] = i
last_dt_info["init_position"][col] = position_now
last_dt_info["stop"][col] = _dt_stop
last_dt_info["exit_price"][col] = _stop_exit_price
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
last_dt_info["exit_type"][col] = _stop_exit_type
last_dt_info["order_type"][col] = _stop_order_type
last_dt_info["limit_delta"][col] = _stop_limit_delta
last_dt_info["delta_format"][col] = _delta_format
last_dt_info["time_delta_format"][col] = _time_delta_format
last_dt_info["ladder"][col] = stop_ladder
last_dt_info["step"][col] = 0
last_dt_info["step_idx"][col] = i
elif abs(position_now) > abs(exec_state.position):
# Position increased -> keep/override stops
_upon_stop_update = flex_select_nb(upon_stop_update_, i, col)
if should_update_stop_nb(new_stop=_sl_stop, upon_stop_update=_upon_stop_update):
last_sl_info["init_idx"][col] = i
last_sl_info["init_price"][col] = new_init_price
last_sl_info["init_position"][col] = position_now
last_sl_info["stop"][col] = _sl_stop
last_sl_info["exit_price"][col] = _stop_exit_price
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
last_sl_info["exit_type"][col] = _stop_exit_type
last_sl_info["order_type"][col] = _stop_order_type
last_sl_info["limit_delta"][col] = _stop_limit_delta
last_sl_info["delta_format"][col] = _delta_format
last_sl_info["ladder"][col] = stop_ladder
last_sl_info["step"][col] = 0
last_sl_info["step_idx"][col] = i
if should_update_stop_nb(new_stop=_tsl_stop, upon_stop_update=_upon_stop_update):
tsl_updated = True
last_tsl_info["init_idx"][col] = i
last_tsl_info["init_price"][col] = new_init_price
last_tsl_info["init_position"][col] = position_now
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = new_init_price
last_tsl_info["stop"][col] = _tsl_stop
last_tsl_info["th"][col] = _tsl_th
last_tsl_info["exit_price"][col] = _stop_exit_price
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
last_tsl_info["exit_type"][col] = _stop_exit_type
last_tsl_info["order_type"][col] = _stop_order_type
last_tsl_info["limit_delta"][col] = _stop_limit_delta
last_tsl_info["delta_format"][col] = _delta_format
last_tsl_info["ladder"][col] = stop_ladder
last_tsl_info["step"][col] = 0
last_tsl_info["step_idx"][col] = i
if should_update_stop_nb(new_stop=_tp_stop, upon_stop_update=_upon_stop_update):
last_tp_info["init_idx"][col] = i
last_tp_info["init_price"][col] = new_init_price
last_tp_info["init_position"][col] = position_now
last_tp_info["stop"][col] = _tp_stop
last_tp_info["exit_price"][col] = _stop_exit_price
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
last_tp_info["exit_type"][col] = _stop_exit_type
last_tp_info["order_type"][col] = _stop_order_type
last_tp_info["limit_delta"][col] = _stop_limit_delta
last_tp_info["delta_format"][col] = _delta_format
last_tp_info["ladder"][col] = stop_ladder
last_tp_info["step"][col] = 0
last_tp_info["step_idx"][col] = i
if should_update_time_stop_nb(
new_stop=_td_stop, upon_stop_update=_upon_stop_update
):
last_td_info["init_idx"][col] = i
last_td_info["init_position"][col] = position_now
last_td_info["stop"][col] = _td_stop
last_td_info["exit_price"][col] = _stop_exit_price
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
last_td_info["exit_type"][col] = _stop_exit_type
last_td_info["order_type"][col] = _stop_order_type
last_td_info["limit_delta"][col] = _stop_limit_delta
last_td_info["delta_format"][col] = _delta_format
last_td_info["time_delta_format"][col] = _time_delta_format
last_td_info["ladder"][col] = stop_ladder
last_td_info["step"][col] = 0
last_td_info["step_idx"][col] = i
if should_update_time_stop_nb(
new_stop=_dt_stop, upon_stop_update=_upon_stop_update
):
last_dt_info["init_idx"][col] = i
last_dt_info["init_position"][col] = position_now
last_dt_info["stop"][col] = _dt_stop
last_dt_info["exit_price"][col] = _stop_exit_price
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
last_dt_info["exit_type"][col] = _stop_exit_type
last_dt_info["order_type"][col] = _stop_order_type
last_dt_info["limit_delta"][col] = _stop_limit_delta
last_dt_info["delta_format"][col] = _delta_format
last_dt_info["time_delta_format"][col] = _time_delta_format
last_dt_info["ladder"][col] = stop_ladder
last_dt_info["step"][col] = 0
last_dt_info["step_idx"][col] = i
if tsl_updated:
# Update highest/lowest price
if can_use_ohlc:
_open = flex_select_nb(open_, i, col)
_high = flex_select_nb(high_, i, col)
_low = flex_select_nb(low_, i, col)
_close = flex_select_nb(close_, i, col)
_high, _low = resolve_hl_nb(
open=_open,
high=_high,
low=_low,
close=_close,
)
else:
_open = np.nan
_high = _low = _close = flex_select_nb(close_, i, col)
if tsl_updated:
if position_now > 0:
if _high > last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col]
+ _high
- last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _high
elif position_now < 0:
if _low < last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col]
+ _low
- last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _low
# Now becomes last
last_position[col] = position_now
last_debt[col] = debt_now
last_locked_cash[col] = locked_cash_now
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
last_cash[group] = cash_now
last_free_cash[group] = free_cash_now
last_value[group] = value_now
last_return[group] = return_now
for col in range(from_col, to_col):
# Update valuation price using current close
_close = flex_select_nb(close_, i, col)
if not np.isnan(_close) or not ffill_val_price:
last_val_price[col] = _close
_cash_earnings = flex_select_nb(cash_earnings_, i, col)
_cash_dividends = flex_select_nb(cash_dividends_, i, col)
_cash_earnings += _cash_dividends * last_position[col]
last_cash[group] += _cash_earnings
last_free_cash[group] += _cash_earnings
if track_cash_earnings:
cash_earnings_out[i, col] += _cash_earnings
if save_state:
position[i, col] = last_position[col]
debt[i, col] = last_debt[col]
locked_cash[i, col] = last_locked_cash[col]
cash[i, group] = last_cash[group]
free_cash[i, group] = last_free_cash[group]
# Update value and return
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - _cash_deposits,
)
prev_close_value[group] = last_value[group]
if save_value:
in_outputs.value[i, group] = last_value[group]
if save_returns:
in_outputs.returns[i, group] = last_return[group]
sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start_,
sim_end=sim_end_,
allow_none=True,
)
return prepare_sim_out_nb(
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
cash_deposits=cash_deposits_out,
cash_earnings=cash_earnings_out,
call_seq=call_seq,
in_outputs=in_outputs,
sim_start=sim_start_out,
sim_end=sim_end_out,
)
@register_jitted(cache=True)
def init_FSInOutputs_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool = False,
save_state: bool = True,
save_value: bool = True,
save_returns: bool = True,
):
"""Initialize `vectorbtpro.portfolio.enums.FSInOutputs`."""
if save_state:
position = np.full(target_shape, np.nan)
debt = np.full(target_shape, np.nan)
locked_cash = np.full(target_shape, np.nan)
if cash_sharing:
cash = np.full((target_shape[0], len(group_lens)), np.nan)
free_cash = np.full((target_shape[0], len(group_lens)), np.nan)
else:
cash = np.full(target_shape, np.nan)
free_cash = np.full(target_shape, np.nan)
else:
position = np.full((0, 0), np.nan)
debt = np.full((0, 0), np.nan)
locked_cash = np.full((0, 0), np.nan)
cash = np.full((0, 0), np.nan)
free_cash = np.full((0, 0), np.nan)
if save_value:
if cash_sharing:
value = np.full((target_shape[0], len(group_lens)), np.nan)
else:
value = np.full(target_shape, np.nan)
else:
value = np.full((0, 0), np.nan)
if save_returns:
if cash_sharing:
returns = np.full((target_shape[0], len(group_lens)), np.nan)
else:
returns = np.full(target_shape, np.nan)
else:
returns = np.full((0, 0), np.nan)
return FSInOutputs(
position=position,
debt=debt,
locked_cash=locked_cash,
cash=cash,
free_cash=free_cash,
value=value,
returns=returns,
)
@register_jitted
def no_signal_func_nb(c: SignalContext, *args) -> tp.Tuple[bool, bool, bool, bool]:
"""Placeholder `signal_func_nb` that returns no signal."""
return False, False, False, False
SignalFuncT = tp.Callable[[SignalContext, tp.VarArg()], tp.Tuple[bool, bool, bool, bool]]
PostSignalFuncT = tp.Callable[[PostSignalContext, tp.VarArg()], None]
PostSegmentFuncT = tp.Callable[[SignalSegmentContext, tp.VarArg()], None]
# %
# %
# %
# @register_jitted
# def signal_func_nb(
# c: SignalContext,
# ) -> tp.Tuple[bool, bool, bool, bool]:
# """Custom signal function."""
# return False, False, False, False
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def post_signal_func_nb(
# c: PostSignalContext,
# ) -> None:
# """Custom post-signal function."""
# return None
#
#
# %
# %
# %
# %
# %
# %
# @register_jitted
# def post_segment_func_nb(
# c: SignalSegmentContext,
# ) -> None:
# """Custom post-segment function."""
# return None
#
#
# %
# %
# %
# %
# %
# import vectorbtpro as vbt
# from vectorbtpro.portfolio.nb.from_signals import *
# %? import_lines
#
#
# %
# %? blocks[signal_func_nb_block]
# % blocks["signal_func_nb"]
# %? blocks[post_signal_func_nb_block]
# % blocks["post_signal_func_nb"]
# %? blocks[post_segment_func_nb_block]
# % blocks["post_segment_func_nb"]
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
target_shape=base_ch.shape_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
cash_sharing=None,
index=None,
freq=None,
open=base_ch.flex_array_gl_slicer,
high=base_ch.flex_array_gl_slicer,
low=base_ch.flex_array_gl_slicer,
close=base_ch.flex_array_gl_slicer,
init_cash=RepFunc(portfolio_ch.get_init_cash_slicer),
init_position=base_ch.flex_1d_array_gl_slicer,
init_price=base_ch.flex_1d_array_gl_slicer,
cash_deposits=RepFunc(portfolio_ch.get_cash_deposits_slicer),
cash_earnings=base_ch.flex_array_gl_slicer,
cash_dividends=base_ch.flex_array_gl_slicer,
signal_func_nb=None, # % None
signal_args=ch.ArgsTaker(),
post_signal_func_nb=None, # % None
post_signal_args=ch.ArgsTaker(),
post_segment_func_nb=None, # % None
post_segment_args=ch.ArgsTaker(),
size=base_ch.flex_array_gl_slicer,
price=base_ch.flex_array_gl_slicer,
size_type=base_ch.flex_array_gl_slicer,
fees=base_ch.flex_array_gl_slicer,
fixed_fees=base_ch.flex_array_gl_slicer,
slippage=base_ch.flex_array_gl_slicer,
min_size=base_ch.flex_array_gl_slicer,
max_size=base_ch.flex_array_gl_slicer,
size_granularity=base_ch.flex_array_gl_slicer,
leverage=base_ch.flex_array_gl_slicer,
leverage_mode=base_ch.flex_array_gl_slicer,
reject_prob=base_ch.flex_array_gl_slicer,
price_area_vio_mode=base_ch.flex_array_gl_slicer,
allow_partial=base_ch.flex_array_gl_slicer,
raise_reject=base_ch.flex_array_gl_slicer,
log=base_ch.flex_array_gl_slicer,
val_price=base_ch.flex_array_gl_slicer,
accumulate=base_ch.flex_array_gl_slicer,
upon_long_conflict=base_ch.flex_array_gl_slicer,
upon_short_conflict=base_ch.flex_array_gl_slicer,
upon_dir_conflict=base_ch.flex_array_gl_slicer,
upon_opposite_entry=base_ch.flex_array_gl_slicer,
order_type=base_ch.flex_array_gl_slicer,
limit_delta=base_ch.flex_array_gl_slicer,
limit_tif=base_ch.flex_array_gl_slicer,
limit_expiry=base_ch.flex_array_gl_slicer,
limit_reverse=base_ch.flex_array_gl_slicer,
limit_order_price=base_ch.flex_array_gl_slicer,
upon_adj_limit_conflict=base_ch.flex_array_gl_slicer,
upon_opp_limit_conflict=base_ch.flex_array_gl_slicer,
use_stops=None,
stop_ladder=None,
sl_stop=base_ch.flex_array_gl_slicer,
tsl_stop=base_ch.flex_array_gl_slicer,
tsl_th=base_ch.flex_array_gl_slicer,
tp_stop=base_ch.flex_array_gl_slicer,
td_stop=base_ch.flex_array_gl_slicer,
dt_stop=base_ch.flex_array_gl_slicer,
stop_entry_price=base_ch.flex_array_gl_slicer,
stop_exit_price=base_ch.flex_array_gl_slicer,
stop_exit_type=base_ch.flex_array_gl_slicer,
stop_order_type=base_ch.flex_array_gl_slicer,
stop_limit_delta=base_ch.flex_array_gl_slicer,
upon_stop_update=base_ch.flex_array_gl_slicer,
upon_adj_stop_conflict=base_ch.flex_array_gl_slicer,
upon_opp_stop_conflict=base_ch.flex_array_gl_slicer,
delta_format=base_ch.flex_array_gl_slicer,
time_delta_format=base_ch.flex_array_gl_slicer,
from_ago=base_ch.flex_array_gl_slicer,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
call_seq=base_ch.array_gl_slicer,
auto_call_seq=None,
ffill_val_price=None,
update_value=None,
fill_pos_info=None,
skip_empty=None,
max_order_records=None,
max_log_records=None,
in_outputs=ch.ArgsTaker(),
),
**portfolio_ch.merge_sim_outs_config,
setup_id=None, # %? line.replace("None", task_id)
)
@register_jitted(
tags={"can_parallel"},
cache=False, # % line.replace("False", "True")
task_id_or_func=None, # %? line.replace("None", task_id)
)
def from_signal_func_nb( # %? line.replace("from_signal_func_nb", new_func_name)
target_shape: tp.Shape,
group_lens: tp.GroupLens,
cash_sharing: bool,
index: tp.Optional[tp.Array1d] = None,
freq: tp.Optional[int] = None,
open: tp.FlexArray2dLike = np.nan,
high: tp.FlexArray2dLike = np.nan,
low: tp.FlexArray2dLike = np.nan,
close: tp.FlexArray2dLike = np.nan,
init_cash: tp.FlexArray1dLike = 100.0,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
cash_deposits: tp.FlexArray2dLike = 0.0,
cash_earnings: tp.FlexArray2dLike = 0.0,
cash_dividends: tp.FlexArray2dLike = 0.0,
signal_func_nb: SignalFuncT = no_signal_func_nb, # % None
signal_args: tp.ArgsLike = (),
post_signal_func_nb: PostSignalFuncT = no_post_func_nb, # % None
post_signal_args: tp.ArgsLike = (),
post_segment_func_nb: PostSegmentFuncT = no_post_func_nb, # % None
post_segment_args: tp.ArgsLike = (),
size: tp.FlexArray2dLike = np.inf,
price: tp.FlexArray2dLike = np.inf,
size_type: tp.FlexArray2dLike = SizeType.Amount,
fees: tp.FlexArray2dLike = 0.0,
fixed_fees: tp.FlexArray2dLike = 0.0,
slippage: tp.FlexArray2dLike = 0.0,
min_size: tp.FlexArray2dLike = np.nan,
max_size: tp.FlexArray2dLike = np.nan,
size_granularity: tp.FlexArray2dLike = np.nan,
leverage: tp.FlexArray2dLike = 1.0,
leverage_mode: tp.FlexArray2dLike = LeverageMode.Lazy,
reject_prob: tp.FlexArray2dLike = 0.0,
price_area_vio_mode: tp.FlexArray2dLike = PriceAreaVioMode.Ignore,
allow_partial: tp.FlexArray2dLike = True,
raise_reject: tp.FlexArray2dLike = False,
log: tp.FlexArray2dLike = False,
val_price: tp.FlexArray2dLike = np.inf,
accumulate: tp.FlexArray2dLike = AccumulationMode.Disabled,
upon_long_conflict: tp.FlexArray2dLike = ConflictMode.Ignore,
upon_short_conflict: tp.FlexArray2dLike = ConflictMode.Ignore,
upon_dir_conflict: tp.FlexArray2dLike = DirectionConflictMode.Ignore,
upon_opposite_entry: tp.FlexArray2dLike = OppositeEntryMode.ReverseReduce,
order_type: tp.FlexArray2dLike = OrderType.Market,
limit_delta: tp.FlexArray2dLike = np.nan,
limit_tif: tp.FlexArray2dLike = -1,
limit_expiry: tp.FlexArray2dLike = -1,
limit_reverse: tp.FlexArray2dLike = False,
limit_order_price: tp.FlexArray2dLike = LimitOrderPrice.Limit,
upon_adj_limit_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepIgnore,
upon_opp_limit_conflict: tp.FlexArray2dLike = PendingConflictMode.CancelExecute,
use_stops: bool = True,
stop_ladder: int = StopLadderMode.Disabled,
sl_stop: tp.FlexArray2dLike = np.nan,
tsl_stop: tp.FlexArray2dLike = np.nan,
tsl_th: tp.FlexArray2dLike = np.nan,
tp_stop: tp.FlexArray2dLike = np.nan,
td_stop: tp.FlexArray2dLike = -1,
dt_stop: tp.FlexArray2dLike = -1,
stop_entry_price: tp.FlexArray2dLike = StopEntryPrice.Close,
stop_exit_price: tp.FlexArray2dLike = StopExitPrice.Stop,
stop_exit_type: tp.FlexArray2dLike = StopExitType.Close,
stop_order_type: tp.FlexArray2dLike = OrderType.Market,
stop_limit_delta: tp.FlexArray2dLike = np.nan,
upon_stop_update: tp.FlexArray2dLike = StopUpdateMode.Keep,
upon_adj_stop_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepExecute,
upon_opp_stop_conflict: tp.FlexArray2dLike = PendingConflictMode.KeepExecute,
delta_format: tp.FlexArray2dLike = DeltaFormat.Percent,
time_delta_format: tp.FlexArray2dLike = TimeDeltaFormat.Index,
from_ago: tp.FlexArray2dLike = 0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
call_seq: tp.Optional[tp.Array2d] = None,
auto_call_seq: bool = False,
ffill_val_price: bool = True,
update_value: bool = False,
fill_pos_info: bool = True,
skip_empty: bool = True,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = 0,
in_outputs: tp.Optional[tp.NamedTuple] = None,
) -> SimulationOutput:
"""Simulate given a signal function.
Iterates in the column-major order. Utilizes flexible broadcasting.
`signal_func_nb` is a user-defined signal generation function that is called at each row and column
(= element). It must accept the context of the type `vectorbtpro.portfolio.enums.SignalContext`
and return 4 signals: long entry, long exit, short entry, and short exit.
`post_signal_func_nb` is a user-defined post-signal function that is called after an order has been processed.
It must accept the context of the type `vectorbtpro.portfolio.enums.PostSignalContext` and return nothing.
`post_segment_func_nb` is a user-defined post-segment function that is called after each row and group
(= segment). It must accept the context of the type `vectorbtpro.portfolio.enums.SignalSegmentContext`
and return nothing.
"""
check_group_lens_nb(group_lens, target_shape[1])
open_ = to_2d_array_nb(np.asarray(open))
high_ = to_2d_array_nb(np.asarray(high))
low_ = to_2d_array_nb(np.asarray(low))
close_ = to_2d_array_nb(np.asarray(close))
init_cash_ = to_1d_array_nb(np.asarray(init_cash))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
cash_deposits_ = to_2d_array_nb(np.asarray(cash_deposits))
cash_earnings_ = to_2d_array_nb(np.asarray(cash_earnings))
cash_dividends_ = to_2d_array_nb(np.asarray(cash_dividends))
size_ = to_2d_array_nb(np.asarray(size))
price_ = to_2d_array_nb(np.asarray(price))
size_type_ = to_2d_array_nb(np.asarray(size_type))
fees_ = to_2d_array_nb(np.asarray(fees))
fixed_fees_ = to_2d_array_nb(np.asarray(fixed_fees))
slippage_ = to_2d_array_nb(np.asarray(slippage))
min_size_ = to_2d_array_nb(np.asarray(min_size))
max_size_ = to_2d_array_nb(np.asarray(max_size))
size_granularity_ = to_2d_array_nb(np.asarray(size_granularity))
leverage_ = to_2d_array_nb(np.asarray(leverage))
leverage_mode_ = to_2d_array_nb(np.asarray(leverage_mode))
reject_prob_ = to_2d_array_nb(np.asarray(reject_prob))
price_area_vio_mode_ = to_2d_array_nb(np.asarray(price_area_vio_mode))
allow_partial_ = to_2d_array_nb(np.asarray(allow_partial))
raise_reject_ = to_2d_array_nb(np.asarray(raise_reject))
log_ = to_2d_array_nb(np.asarray(log))
val_price_ = to_2d_array_nb(np.asarray(val_price))
accumulate_ = to_2d_array_nb(np.asarray(accumulate))
upon_long_conflict_ = to_2d_array_nb(np.asarray(upon_long_conflict))
upon_short_conflict_ = to_2d_array_nb(np.asarray(upon_short_conflict))
upon_dir_conflict_ = to_2d_array_nb(np.asarray(upon_dir_conflict))
upon_opposite_entry_ = to_2d_array_nb(np.asarray(upon_opposite_entry))
order_type_ = to_2d_array_nb(np.asarray(order_type))
limit_delta_ = to_2d_array_nb(np.asarray(limit_delta))
limit_tif_ = to_2d_array_nb(np.asarray(limit_tif))
limit_expiry_ = to_2d_array_nb(np.asarray(limit_expiry))
limit_reverse_ = to_2d_array_nb(np.asarray(limit_reverse))
limit_order_price_ = to_2d_array_nb(np.asarray(limit_order_price))
upon_adj_limit_conflict_ = to_2d_array_nb(np.asarray(upon_adj_limit_conflict))
upon_opp_limit_conflict_ = to_2d_array_nb(np.asarray(upon_opp_limit_conflict))
sl_stop_ = to_2d_array_nb(np.asarray(sl_stop))
tsl_stop_ = to_2d_array_nb(np.asarray(tsl_stop))
tsl_th_ = to_2d_array_nb(np.asarray(tsl_th))
tp_stop_ = to_2d_array_nb(np.asarray(tp_stop))
td_stop_ = to_2d_array_nb(np.asarray(td_stop))
dt_stop_ = to_2d_array_nb(np.asarray(dt_stop))
stop_entry_price_ = to_2d_array_nb(np.asarray(stop_entry_price))
stop_exit_price_ = to_2d_array_nb(np.asarray(stop_exit_price))
stop_exit_type_ = to_2d_array_nb(np.asarray(stop_exit_type))
stop_order_type_ = to_2d_array_nb(np.asarray(stop_order_type))
stop_limit_delta_ = to_2d_array_nb(np.asarray(stop_limit_delta))
upon_stop_update_ = to_2d_array_nb(np.asarray(upon_stop_update))
upon_adj_stop_conflict_ = to_2d_array_nb(np.asarray(upon_adj_stop_conflict))
upon_opp_stop_conflict_ = to_2d_array_nb(np.asarray(upon_opp_stop_conflict))
delta_format_ = to_2d_array_nb(np.asarray(delta_format))
time_delta_format_ = to_2d_array_nb(np.asarray(time_delta_format))
from_ago_ = to_2d_array_nb(np.asarray(from_ago))
n_sl_steps = sl_stop_.shape[0]
n_tsl_steps = tsl_stop_.shape[0]
n_tp_steps = tp_stop_.shape[0]
n_td_steps = td_stop_.shape[0]
n_dt_steps = dt_stop_.shape[0]
order_records, log_records = prepare_fs_records_nb(
target_shape=target_shape,
max_order_records=max_order_records,
max_log_records=max_log_records,
)
order_counts = np.full(target_shape[1], 0, dtype=int_)
log_counts = np.full(target_shape[1], 0, dtype=int_)
last_cash = prepare_last_cash_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
)
last_position = prepare_last_position_nb(
target_shape=target_shape,
init_position=init_position_,
)
last_value = prepare_last_value_nb(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
)
last_pos_info = prepare_last_pos_info_nb(
target_shape,
init_position=init_position_,
init_price=init_price_,
fill_pos_info=fill_pos_info,
)
last_cash_deposits = np.full_like(last_cash, 0.0)
last_val_price = np.full_like(last_position, np.nan)
last_debt = np.full(target_shape[1], 0.0, dtype=float_)
last_locked_cash = np.full(target_shape[1], 0.0, dtype=float_)
last_free_cash = last_cash.copy()
prev_close_value = last_value.copy()
last_return = np.full_like(last_cash, np.nan)
track_cash_deposits = (cash_deposits_.size == 1 and cash_deposits_[0, 0] != 0) or cash_deposits_.size > 1
if track_cash_deposits:
cash_deposits_out = np.full((target_shape[0], len(group_lens)), 0.0, dtype=float_)
else:
cash_deposits_out = np.full((1, 1), 0.0, dtype=float_)
track_cash_earnings = (cash_earnings_.size == 1 and cash_earnings_[0, 0] != 0) or cash_earnings_.size > 1
track_cash_dividends = (cash_dividends_.size == 1 and cash_dividends_[0, 0] != 0) or cash_dividends_.size > 1
track_cash_earnings = track_cash_earnings or track_cash_dividends
if track_cash_earnings:
cash_earnings_out = np.full(target_shape, 0.0, dtype=float_)
else:
cash_earnings_out = np.full((1, 1), 0.0, dtype=float_)
last_limit_info = np.empty(target_shape[1], dtype=limit_info_dt)
last_limit_info["signal_idx"][:] = -1
last_limit_info["creation_idx"][:] = -1
last_limit_info["init_idx"][:] = -1
last_limit_info["init_price"][:] = np.nan
last_limit_info["init_size"][:] = np.nan
last_limit_info["init_size_type"][:] = -1
last_limit_info["init_direction"][:] = -1
last_limit_info["init_stop_type"][:] = -1
last_limit_info["delta"][:] = np.nan
last_limit_info["delta_format"][:] = -1
last_limit_info["tif"][:] = -1
last_limit_info["expiry"][:] = -1
last_limit_info["time_delta_format"][:] = -1
last_limit_info["reverse"][:] = False
last_limit_info["order_price"][:] = np.nan
if use_stops:
last_sl_info = np.empty(target_shape[1], dtype=sl_info_dt)
last_sl_info["init_idx"][:] = -1
last_sl_info["init_price"][:] = np.nan
last_sl_info["init_position"][:] = np.nan
last_sl_info["stop"][:] = np.nan
last_sl_info["exit_price"][:] = -1
last_sl_info["exit_size"][:] = np.nan
last_sl_info["exit_size_type"][:] = -1
last_sl_info["exit_type"][:] = -1
last_sl_info["order_type"][:] = -1
last_sl_info["limit_delta"][:] = np.nan
last_sl_info["delta_format"][:] = -1
last_sl_info["ladder"][:] = -1
last_sl_info["step"][:] = -1
last_sl_info["step_idx"][:] = -1
last_tsl_info = np.empty(target_shape[1], dtype=tsl_info_dt)
last_tsl_info["init_idx"][:] = -1
last_tsl_info["init_price"][:] = np.nan
last_tsl_info["init_position"][:] = np.nan
last_tsl_info["peak_idx"][:] = -1
last_tsl_info["peak_price"][:] = np.nan
last_tsl_info["stop"][:] = np.nan
last_tsl_info["th"][:] = np.nan
last_tsl_info["exit_price"][:] = -1
last_tsl_info["exit_size"][:] = np.nan
last_tsl_info["exit_size_type"][:] = -1
last_tsl_info["exit_type"][:] = -1
last_tsl_info["order_type"][:] = -1
last_tsl_info["limit_delta"][:] = np.nan
last_tsl_info["delta_format"][:] = -1
last_tsl_info["ladder"][:] = -1
last_tsl_info["step"][:] = -1
last_tsl_info["step_idx"][:] = -1
last_tp_info = np.empty(target_shape[1], dtype=tp_info_dt)
last_tp_info["init_idx"][:] = -1
last_tp_info["init_price"][:] = np.nan
last_tp_info["init_position"][:] = np.nan
last_tp_info["stop"][:] = np.nan
last_tp_info["exit_price"][:] = -1
last_tp_info["exit_size"][:] = np.nan
last_tp_info["exit_size_type"][:] = -1
last_tp_info["exit_type"][:] = -1
last_tp_info["order_type"][:] = -1
last_tp_info["limit_delta"][:] = np.nan
last_tp_info["delta_format"][:] = -1
last_tp_info["ladder"][:] = -1
last_tp_info["step"][:] = -1
last_tp_info["step_idx"][:] = -1
last_td_info = np.empty(target_shape[1], dtype=time_info_dt)
last_td_info["init_idx"][:] = -1
last_td_info["init_position"][:] = np.nan
last_td_info["stop"][:] = -1
last_td_info["exit_price"][:] = -1
last_td_info["exit_size"][:] = np.nan
last_td_info["exit_size_type"][:] = -1
last_td_info["exit_type"][:] = -1
last_td_info["order_type"][:] = -1
last_td_info["limit_delta"][:] = np.nan
last_td_info["delta_format"][:] = -1
last_td_info["time_delta_format"][:] = -1
last_td_info["ladder"][:] = -1
last_td_info["step"][:] = -1
last_td_info["step_idx"][:] = -1
last_dt_info = np.empty(target_shape[1], dtype=time_info_dt)
last_dt_info["init_idx"][:] = -1
last_dt_info["init_position"][:] = np.nan
last_dt_info["stop"][:] = -1
last_dt_info["exit_price"][:] = -1
last_dt_info["exit_size"][:] = np.nan
last_dt_info["exit_size_type"][:] = -1
last_dt_info["exit_type"][:] = -1
last_dt_info["order_type"][:] = -1
last_dt_info["limit_delta"][:] = np.nan
last_dt_info["delta_format"][:] = -1
last_dt_info["time_delta_format"][:] = -1
last_dt_info["ladder"][:] = -1
last_dt_info["step"][:] = -1
last_dt_info["step_idx"][:] = -1
else:
last_sl_info = np.empty(0, dtype=sl_info_dt)
last_tsl_info = np.empty(0, dtype=tsl_info_dt)
last_tp_info = np.empty(0, dtype=tp_info_dt)
last_td_info = np.empty(0, dtype=time_info_dt)
last_dt_info = np.empty(0, dtype=time_info_dt)
last_signal = np.empty(target_shape[1], dtype=int_)
main_info = np.empty(target_shape[1], dtype=main_info_dt)
temp_call_seq = np.empty(target_shape[1], dtype=int_)
temp_sort_by = np.empty(target_shape[1], dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(target_shape[0], len(group_lens)),
sim_start=sim_start,
sim_end=sim_end,
)
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
group_len = to_col - from_col
_sim_start = sim_start_[group]
_sim_end = sim_end_[group]
for i in range(_sim_start, _sim_end):
# Add cash
if cash_sharing:
_cash_deposits = flex_select_nb(cash_deposits_, i, group)
if _cash_deposits < 0:
_cash_deposits = max(_cash_deposits, -last_cash[group])
last_cash[group] += _cash_deposits
last_free_cash[group] += _cash_deposits
last_cash_deposits[group] = _cash_deposits
if track_cash_deposits:
cash_deposits_out[i, group] += _cash_deposits
else:
for col in range(from_col, to_col):
_cash_deposits = flex_select_nb(cash_deposits_, i, col)
if _cash_deposits < 0:
_cash_deposits = max(_cash_deposits, -last_cash[col])
last_cash[col] += _cash_deposits
last_free_cash[col] += _cash_deposits
last_cash_deposits[col] = _cash_deposits
if track_cash_deposits:
cash_deposits_out[i, col] += _cash_deposits
# Update valuation price using current open
for c in range(group_len):
col = from_col + c
_open = flex_select_nb(open_, i, col)
if not np.isnan(_open) or not ffill_val_price:
last_val_price[col] = _open
# Update value and return
if cash_sharing:
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - last_cash_deposits[group],
)
else:
for col in range(from_col, to_col):
group_value = last_cash[col]
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[col] = group_value
last_return[col] = get_return_nb(
input_value=prev_close_value[col],
output_value=last_value[col] - last_cash_deposits[col],
)
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Get signals
skip = skip_empty
for c in range(group_len):
col = from_col + c
signal_ctx = SignalContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
track_cash_deposits=track_cash_deposits,
cash_deposits_out=cash_deposits_out,
track_cash_earnings=track_cash_earnings,
cash_earnings_out=cash_earnings_out,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
last_limit_info=last_limit_info,
last_sl_info=last_sl_info,
last_tsl_info=last_tsl_info,
last_tp_info=last_tp_info,
last_td_info=last_td_info,
last_dt_info=last_dt_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
col=col,
)
is_long_entry, is_long_exit, is_short_entry, is_short_exit = signal_func_nb(signal_ctx, *signal_args)
# Update limit and stop prices
_i = i - abs(flex_select_nb(from_ago_, i, col))
if _i < 0:
_price = np.nan
else:
_price = flex_select_nb(price_, _i, col)
last_limit_info["init_price"][col] = resolve_dyn_limit_price_nb(
val_price=last_val_price[col],
price=_price,
limit_price=last_limit_info["init_price"][col],
)
last_sl_info["init_price"][col] = resolve_dyn_stop_entry_price_nb(
val_price=last_val_price[col],
price=_price,
stop_entry_price=last_sl_info["init_price"][col],
)
last_tsl_info["init_price"][col] = resolve_dyn_stop_entry_price_nb(
val_price=last_val_price[col],
price=_price,
stop_entry_price=last_tsl_info["init_price"][col],
)
last_tsl_info["peak_price"][col] = resolve_dyn_stop_entry_price_nb(
val_price=last_val_price[col],
price=_price,
stop_entry_price=last_tsl_info["peak_price"][col],
)
last_tp_info["init_price"][col] = resolve_dyn_stop_entry_price_nb(
val_price=last_val_price[col],
price=_price,
stop_entry_price=last_tp_info["init_price"][col],
)
limit_signal = is_limit_active_nb(
init_idx=last_limit_info["init_idx"][col],
init_price=last_limit_info["init_price"][col],
)
if not use_stops:
sl_stop_signal = False
tsl_stop_signal = False
tp_stop_signal = False
td_stop_signal = False
dt_stop_signal = False
else:
sl_stop_signal = is_stop_active_nb(
init_idx=last_sl_info["init_idx"][col],
stop=last_sl_info["stop"][col],
)
tsl_stop_signal = is_stop_active_nb(
init_idx=last_tsl_info["init_idx"][col],
stop=last_tsl_info["stop"][col],
)
tp_stop_signal = is_stop_active_nb(
init_idx=last_tp_info["init_idx"][col],
stop=last_tp_info["stop"][col],
)
td_stop_signal = is_time_stop_active_nb(
init_idx=last_td_info["init_idx"][col],
stop=last_td_info["stop"][col],
)
dt_stop_signal = is_time_stop_active_nb(
init_idx=last_dt_info["init_idx"][col],
stop=last_dt_info["stop"][col],
)
# Pack signals into a single integer
last_signal[col] = (
(is_long_entry << 10)
| (is_long_exit << 9)
| (is_short_entry << 8)
| (is_short_exit << 7)
| (limit_signal << 6)
| (sl_stop_signal << 5)
| (tsl_stop_signal << 4)
| (tp_stop_signal << 3)
| (td_stop_signal << 2)
| (dt_stop_signal << 1)
)
if last_signal[col] > 0:
skip = False
if not skip:
# Update value and return
if cash_sharing:
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - last_cash_deposits[group],
)
else:
for col in range(from_col, to_col):
group_value = last_cash[col]
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[col] = group_value
last_return[col] = get_return_nb(
input_value=prev_close_value[col],
output_value=last_value[col] - last_cash_deposits[col],
)
# Get size and value of each order
for c in range(group_len):
col = from_col + c
# Set defaults
main_info["bar_zone"][col] = -1
main_info["signal_idx"][col] = -1
main_info["creation_idx"][col] = -1
main_info["idx"][col] = i
main_info["val_price"][col] = np.nan
main_info["price"][col] = np.nan
main_info["size"][col] = np.nan
main_info["size_type"][col] = -1
main_info["direction"][col] = -1
main_info["type"][col] = -1
main_info["stop_type"][col] = -1
temp_sort_by[col] = 0.0
# Unpack a single integer into signals
is_long_entry = (last_signal[col] >> 10) & 1
is_long_exit = (last_signal[col] >> 9) & 1
is_short_entry = (last_signal[col] >> 8) & 1
is_short_exit = (last_signal[col] >> 7) & 1
limit_signal = (last_signal[col] >> 6) & 1
sl_stop_signal = (last_signal[col] >> 5) & 1
tsl_stop_signal = (last_signal[col] >> 4) & 1
tp_stop_signal = (last_signal[col] >> 3) & 1
td_stop_signal = (last_signal[col] >> 2) & 1
dt_stop_signal = (last_signal[col] >> 1) & 1
any_user_signal = is_long_entry or is_long_exit or is_short_entry or is_short_exit
any_limit_signal = limit_signal
any_stop_signal = (
sl_stop_signal or tsl_stop_signal or tp_stop_signal or td_stop_signal or dt_stop_signal
)
# Set initial info
exec_limit_set = False
exec_limit_set_on_open = False
exec_limit_set_on_close = False
exec_limit_signal_i = -1
exec_limit_creation_i = -1
exec_limit_init_i = -1
exec_limit_val_price = np.nan
exec_limit_price = np.nan
exec_limit_size = np.nan
exec_limit_size_type = -1
exec_limit_direction = -1
exec_limit_stop_type = -1
exec_limit_bar_zone = -1
exec_stop_set = False
exec_stop_set_on_open = False
exec_stop_set_on_close = False
exec_stop_init_i = -1
exec_stop_val_price = np.nan
exec_stop_price = np.nan
exec_stop_size = np.nan
exec_stop_size_type = -1
exec_stop_direction = -1
exec_stop_type = -1
exec_stop_stop_type = -1
exec_stop_delta = np.nan
exec_stop_delta_format = -1
exec_stop_make_limit = False
exec_stop_bar_zone = -1
user_on_open = False
user_on_close = False
exec_user_set = False
exec_user_val_price = np.nan
exec_user_price = np.nan
exec_user_size = np.nan
exec_user_size_type = -1
exec_user_direction = -1
exec_user_type = -1
exec_user_stop_type = -1
exec_user_make_limit = False
exec_user_bar_zone = -1
# Resolve the current bar
_i = i - abs(flex_select_nb(from_ago_, i, col))
_open = flex_select_nb(open_, i, col)
_high = flex_select_nb(high_, i, col)
_low = flex_select_nb(low_, i, col)
_close = flex_select_nb(close_, i, col)
_high, _low = resolve_hl_nb(
open=_open,
high=_high,
low=_low,
close=_close,
)
# Process the limit signal
if any_limit_signal:
# Check whether the limit price was hit
_signal_i = last_limit_info["signal_idx"][col]
_creation_i = last_limit_info["creation_idx"][col]
_init_i = last_limit_info["init_idx"][col]
_price = last_limit_info["init_price"][col]
_size = last_limit_info["init_size"][col]
_size_type = last_limit_info["init_size_type"][col]
_direction = last_limit_info["init_direction"][col]
_stop_type = last_limit_info["init_stop_type"][col]
_delta = last_limit_info["delta"][col]
_delta_format = last_limit_info["delta_format"][col]
_tif = last_limit_info["tif"][col]
_expiry = last_limit_info["expiry"][col]
_time_delta_format = last_limit_info["time_delta_format"][col]
_reverse = last_limit_info["reverse"][col]
_order_price = last_limit_info["order_price"][col]
limit_expired_on_open, limit_expired = check_limit_expired_nb(
creation_idx=_creation_i,
i=i,
tif=_tif,
expiry=_expiry,
time_delta_format=_time_delta_format,
index=index,
freq=freq,
)
limit_price, limit_hit_on_open, limit_hit = check_limit_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
price=_price,
size=_size,
direction=_direction,
limit_delta=_delta,
delta_format=_delta_format,
limit_reverse=_reverse,
can_use_ohlc=True,
check_open=True,
hard_limit=_order_price == LimitOrderPrice.HardLimit,
)
# Resolve the price
limit_price = resolve_limit_order_price_nb(
limit_price=limit_price,
close=_close,
limit_order_price=_order_price,
)
if limit_expired_on_open or (not limit_hit_on_open and limit_expired):
# Expired limit signal
any_limit_signal = False
last_limit_info["signal_idx"][col] = -1
last_limit_info["creation_idx"][col] = -1
last_limit_info["init_idx"][col] = -1
last_limit_info["init_price"][col] = np.nan
last_limit_info["init_size"][col] = np.nan
last_limit_info["init_size_type"][col] = -1
last_limit_info["init_direction"][col] = -1
last_limit_info["delta"][col] = np.nan
last_limit_info["delta_format"][col] = -1
last_limit_info["tif"][col] = -1
last_limit_info["expiry"][col] = -1
last_limit_info["time_delta_format"][col] = -1
last_limit_info["reverse"][col] = False
last_limit_info["order_price"][col] = np.nan
else:
# Save info
if limit_hit:
# Executable limit signal
exec_limit_set = True
exec_limit_set_on_open = limit_hit_on_open
exec_limit_set_on_close = _order_price == LimitOrderPrice.Close
exec_limit_signal_i = _signal_i
exec_limit_creation_i = _creation_i
exec_limit_init_i = _init_i
if np.isinf(limit_price) and limit_price > 0:
exec_limit_val_price = _close
elif np.isinf(limit_price) and limit_price < 0:
exec_limit_val_price = _open
else:
exec_limit_val_price = limit_price
exec_limit_price = limit_price
exec_limit_size = _size
exec_limit_size_type = _size_type
exec_limit_direction = _direction
exec_limit_stop_type = _stop_type
# Process the stop signal
if any_stop_signal:
# Check SL
sl_stop_price, sl_stop_hit_on_open, sl_stop_hit = np.nan, False, False
if sl_stop_signal:
# Check against high and low
sl_stop_price, sl_stop_hit_on_open, sl_stop_hit = check_stop_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
is_position_long=last_position[col] > 0,
init_price=last_sl_info["init_price"][col],
stop=last_sl_info["stop"][col],
delta_format=last_sl_info["delta_format"][col],
hit_below=True,
hard_stop=last_sl_info["exit_price"][col] == StopExitPrice.HardStop,
)
# Check TSL and TTP
tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = np.nan, False, False
if tsl_stop_signal:
# Update peak price using open
if last_position[col] > 0:
if _open > last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col] + _open - last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _open
else:
if _open < last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col] + _open - last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _open
# Check threshold against previous bars and open
if np.isnan(last_tsl_info["th"][col]):
th_hit = True
else:
th_hit = check_tsl_th_hit_nb(
is_position_long=last_position[col] > 0,
init_price=last_tsl_info["init_price"][col],
peak_price=last_tsl_info["peak_price"][col],
threshold=last_tsl_info["th"][col],
delta_format=last_tsl_info["delta_format"][col],
)
if th_hit:
tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = check_stop_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
is_position_long=last_position[col] > 0,
init_price=last_tsl_info["peak_price"][col],
stop=last_tsl_info["stop"][col],
delta_format=last_tsl_info["delta_format"][col],
hit_below=True,
hard_stop=last_tsl_info["exit_price"][col] == StopExitPrice.HardStop,
)
# Update peak price using full bar
if last_position[col] > 0:
if _high > last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col] + _high - last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _high
else:
if _low < last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col] + _low - last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _low
if not tsl_stop_hit:
# Check threshold against full bar
if not th_hit:
if np.isnan(last_tsl_info["th"][col]):
th_hit = True
else:
th_hit = check_tsl_th_hit_nb(
is_position_long=last_position[col] > 0,
init_price=last_tsl_info["init_price"][col],
peak_price=last_tsl_info["peak_price"][col],
threshold=last_tsl_info["th"][col],
delta_format=last_tsl_info["delta_format"][col],
)
if th_hit:
# Check threshold against close
tsl_stop_price, tsl_stop_hit_on_open, tsl_stop_hit = check_stop_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
is_position_long=last_position[col] > 0,
init_price=last_tsl_info["peak_price"][col],
stop=last_tsl_info["stop"][col],
delta_format=last_tsl_info["delta_format"][col],
hit_below=True,
can_use_ohlc=False,
hard_stop=last_tsl_info["exit_price"][col] == StopExitPrice.HardStop,
)
# Check TP
tp_stop_price, tp_stop_hit_on_open, tp_stop_hit = np.nan, False, False
if tp_stop_signal:
tp_stop_price, tp_stop_hit_on_open, tp_stop_hit = check_stop_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
is_position_long=last_position[col] > 0,
init_price=last_tp_info["init_price"][col],
stop=last_tp_info["stop"][col],
delta_format=last_tp_info["delta_format"][col],
hit_below=False,
hard_stop=last_tp_info["exit_price"][col] == StopExitPrice.HardStop,
)
# Check TD
td_stop_price, td_stop_hit_on_open, td_stop_hit = np.nan, False, False
if td_stop_signal:
td_stop_hit_on_open, td_stop_hit = check_td_stop_hit_nb(
init_idx=last_td_info["init_idx"][col],
i=i,
stop=last_td_info["stop"][col],
time_delta_format=last_td_info["time_delta_format"][col],
index=index,
freq=freq,
)
if np.isnan(_open):
td_stop_hit_on_open = False
if td_stop_hit_on_open:
td_stop_price = _open
else:
td_stop_price = _close
# Check DT
dt_stop_price, dt_stop_hit_on_open, dt_stop_hit = np.nan, False, False
if dt_stop_signal:
dt_stop_hit_on_open, dt_stop_hit = check_dt_stop_hit_nb(
i=i,
stop=last_dt_info["stop"][col],
time_delta_format=last_dt_info["time_delta_format"][col],
index=index,
freq=freq,
)
if np.isnan(_open):
dt_stop_hit_on_open = False
if dt_stop_hit_on_open:
dt_stop_price = _open
else:
dt_stop_price = _close
# Resolve the stop signal
sl_hit = False
tsl_hit = False
tp_hit = False
td_hit = False
dt_hit = False
if sl_stop_hit_on_open:
sl_hit = True
elif tsl_stop_hit_on_open:
tsl_hit = True
elif tp_stop_hit_on_open:
tp_hit = True
elif td_stop_hit_on_open:
td_hit = True
elif dt_stop_hit_on_open:
dt_hit = True
elif sl_stop_hit:
sl_hit = True
elif tsl_stop_hit:
tsl_hit = True
elif tp_stop_hit:
tp_hit = True
elif td_stop_hit:
td_hit = True
elif dt_stop_hit:
dt_hit = True
if sl_hit:
stop_price, stop_hit_on_open, stop_hit = sl_stop_price, sl_stop_hit_on_open, sl_stop_hit
_stop_type = StopType.SL
_init_i = last_sl_info["init_idx"][col]
_stop_exit_price = last_sl_info["exit_price"][col]
_stop_exit_size = last_sl_info["exit_size"][col]
_stop_exit_size_type = last_sl_info["exit_size_type"][col]
_stop_exit_type = last_sl_info["exit_type"][col]
_stop_order_type = last_sl_info["order_type"][col]
_limit_delta = last_sl_info["limit_delta"][col]
_delta_format = last_sl_info["delta_format"][col]
_ladder = last_sl_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_sl_info["step"][col]
if step < n_sl_steps:
_stop_exit_size = get_stop_ladder_exit_size_nb(
stop_=sl_stop_,
step=step,
col=col,
init_price=last_sl_info["init_price"][col],
init_position=last_sl_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
delta_format=last_sl_info["delta_format"][col],
hit_below=True,
)
_stop_exit_size_type = SizeType.Amount
elif tsl_hit:
stop_price, stop_hit_on_open, stop_hit = (
tsl_stop_price,
tsl_stop_hit_on_open,
tsl_stop_hit,
)
if np.isnan(last_tsl_info["th"][col]):
_stop_type = StopType.TSL
else:
_stop_type = StopType.TTP
_init_i = last_tsl_info["init_idx"][col]
_stop_exit_price = last_tsl_info["exit_price"][col]
_stop_exit_size = last_tsl_info["exit_size"][col]
_stop_exit_size_type = last_tsl_info["exit_size_type"][col]
_stop_exit_type = last_tsl_info["exit_type"][col]
_stop_order_type = last_tsl_info["order_type"][col]
_limit_delta = last_tsl_info["limit_delta"][col]
_delta_format = last_tsl_info["delta_format"][col]
_ladder = last_tsl_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_tsl_info["step"][col]
if step < n_tsl_steps:
_stop_exit_size = get_stop_ladder_exit_size_nb(
stop_=tsl_stop_,
step=step,
col=col,
init_price=last_tsl_info["init_price"][col],
init_position=last_tsl_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
delta_format=last_tsl_info["delta_format"][col],
hit_below=True,
)
_stop_exit_size_type = SizeType.Amount
elif tp_hit:
stop_price, stop_hit_on_open, stop_hit = tp_stop_price, tp_stop_hit_on_open, tp_stop_hit
_stop_type = StopType.TP
_init_i = last_tp_info["init_idx"][col]
_stop_exit_price = last_tp_info["exit_price"][col]
_stop_exit_size = last_tp_info["exit_size"][col]
_stop_exit_size_type = last_tp_info["exit_size_type"][col]
_stop_exit_type = last_tp_info["exit_type"][col]
_stop_order_type = last_tp_info["order_type"][col]
_limit_delta = last_tp_info["limit_delta"][col]
_delta_format = last_tp_info["delta_format"][col]
_ladder = last_tp_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_tp_info["step"][col]
if step < n_tp_steps:
_stop_exit_size = get_stop_ladder_exit_size_nb(
stop_=tp_stop_,
step=step,
col=col,
init_price=last_tp_info["init_price"][col],
init_position=last_tp_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
delta_format=last_tp_info["delta_format"][col],
hit_below=True,
)
_stop_exit_size_type = SizeType.Amount
elif td_hit:
stop_price, stop_hit_on_open, stop_hit = td_stop_price, td_stop_hit_on_open, td_stop_hit
_stop_type = StopType.TD
_init_i = last_td_info["init_idx"][col]
_stop_exit_price = last_td_info["exit_price"][col]
_stop_exit_size = last_td_info["exit_size"][col]
_stop_exit_size_type = last_td_info["exit_size_type"][col]
_stop_exit_type = last_td_info["exit_type"][col]
_stop_order_type = last_td_info["order_type"][col]
_limit_delta = last_td_info["limit_delta"][col]
_delta_format = last_td_info["delta_format"][col]
_ladder = last_td_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_td_info["step"][col]
if step < n_td_steps:
_stop_exit_size = get_time_stop_ladder_exit_size_nb(
stop_=td_stop_,
step=step,
col=col,
init_idx=last_td_info["init_idx"][col],
init_position=last_td_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
time_delta_format=last_td_info["time_delta_format"][col],
index=index,
)
_stop_exit_size_type = SizeType.Amount
elif dt_hit:
stop_price, stop_hit_on_open, stop_hit = dt_stop_price, dt_stop_hit_on_open, dt_stop_hit
_stop_type = StopType.DT
_init_i = last_dt_info["init_idx"][col]
_stop_exit_price = last_dt_info["exit_price"][col]
_stop_exit_size = last_dt_info["exit_size"][col]
_stop_exit_size_type = last_dt_info["exit_size_type"][col]
_stop_exit_type = last_dt_info["exit_type"][col]
_stop_order_type = last_dt_info["order_type"][col]
_limit_delta = last_dt_info["limit_delta"][col]
_delta_format = last_dt_info["delta_format"][col]
_ladder = last_dt_info["ladder"][col]
if np.isnan(_stop_exit_size):
if stop_ladder and _ladder and _ladder != StopLadderMode.Dynamic:
step = last_dt_info["step"][col]
if step < n_dt_steps:
_stop_exit_size = get_time_stop_ladder_exit_size_nb(
stop_=dt_stop_,
step=step,
col=col,
init_idx=last_dt_info["init_idx"][col],
init_position=last_dt_info["init_position"][col],
position_now=last_position[col],
ladder=_ladder,
time_delta_format=last_dt_info["time_delta_format"][col],
index=index,
)
_stop_exit_size_type = SizeType.Amount
else:
stop_price, stop_hit_on_open, stop_hit = np.nan, False, False
if stop_hit:
# Stop price was hit
# Resolve the final stop signal
_accumulate = flex_select_nb(accumulate_, i, col)
_size = flex_select_nb(size_, i, col)
_size_type = flex_select_nb(size_type_, i, col)
if not np.isnan(_stop_exit_size):
_accumulate = True
if _stop_exit_type == StopExitType.Close:
_stop_exit_type = StopExitType.CloseReduce
_size = _stop_exit_size
if _stop_exit_size_type != -1:
_size_type = _stop_exit_size_type
(
stop_is_long_entry,
stop_is_long_exit,
stop_is_short_entry,
stop_is_short_exit,
_accumulate,
) = generate_stop_signal_nb(
position_now=last_position[col],
stop_exit_type=_stop_exit_type,
accumulate=_accumulate,
)
# Resolve the price
_price = resolve_stop_exit_price_nb(
stop_price=stop_price,
close=_close,
stop_exit_price=_stop_exit_price,
)
# Convert both signals to size (direction-aware), size type, and direction
_size, _size_type, _direction = signal_to_size_nb(
position_now=last_position[col],
val_price_now=_price,
value_now=last_value[group],
is_long_entry=stop_is_long_entry,
is_long_exit=stop_is_long_exit,
is_short_entry=stop_is_short_entry,
is_short_exit=stop_is_short_exit,
size=_size,
size_type=_size_type,
accumulate=_accumulate,
)
if not np.isnan(_size):
# Executable stop signal
can_execute = True
if _stop_order_type == OrderType.Limit:
# Use close to check whether the limit price was hit
if _stop_exit_price == StopExitPrice.Close:
# Cannot place a limit order at the close price and execute right away
can_execute = False
if can_execute:
limit_price, _, can_execute = check_limit_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
price=_price,
size=_size,
direction=_direction,
limit_delta=_limit_delta,
delta_format=_delta_format,
limit_reverse=False,
can_use_ohlc=stop_hit_on_open,
check_open=False,
hard_limit=False,
)
if can_execute:
_price = limit_price
# Save info
exec_stop_set = True
exec_stop_set_on_open = stop_hit_on_open
exec_stop_set_on_close = _stop_exit_price == StopExitPrice.Close
exec_stop_init_i = _init_i
if np.isinf(_price) and _price > 0:
exec_stop_val_price = _close
elif np.isinf(_price) and _price < 0:
exec_stop_val_price = _open
else:
exec_stop_val_price = _price
exec_stop_price = _price
exec_stop_size = _size
exec_stop_size_type = _size_type
exec_stop_direction = _direction
exec_stop_type = _stop_order_type
exec_stop_stop_type = _stop_type
exec_stop_delta = _limit_delta
exec_stop_delta_format = _delta_format
exec_stop_make_limit = not can_execute
# Process user signal
if any_user_signal:
if _i < 0:
_price = np.nan
_size = np.nan
_size_type = -1
_direction = -1
else:
_accumulate = flex_select_nb(accumulate_, _i, col)
if is_long_entry or is_short_entry:
# Resolve any single-direction conflicts
_upon_long_conflict = flex_select_nb(upon_long_conflict_, _i, col)
is_long_entry, is_long_exit = resolve_signal_conflict_nb(
position_now=last_position[col],
is_entry=is_long_entry,
is_exit=is_long_exit,
direction=Direction.LongOnly,
conflict_mode=_upon_long_conflict,
)
_upon_short_conflict = flex_select_nb(upon_short_conflict_, _i, col)
is_short_entry, is_short_exit = resolve_signal_conflict_nb(
position_now=last_position[col],
is_entry=is_short_entry,
is_exit=is_short_exit,
direction=Direction.ShortOnly,
conflict_mode=_upon_short_conflict,
)
# Resolve any multi-direction conflicts
_upon_dir_conflict = flex_select_nb(upon_dir_conflict_, _i, col)
is_long_entry, is_short_entry = resolve_dir_conflict_nb(
position_now=last_position[col],
is_long_entry=is_long_entry,
is_short_entry=is_short_entry,
upon_dir_conflict=_upon_dir_conflict,
)
# Resolve an opposite entry
_upon_opposite_entry = flex_select_nb(upon_opposite_entry_, _i, col)
(
is_long_entry,
is_long_exit,
is_short_entry,
is_short_exit,
_accumulate,
) = resolve_opposite_entry_nb(
position_now=last_position[col],
is_long_entry=is_long_entry,
is_long_exit=is_long_exit,
is_short_entry=is_short_entry,
is_short_exit=is_short_exit,
upon_opposite_entry=_upon_opposite_entry,
accumulate=_accumulate,
)
# Resolve the price
_price = flex_select_nb(price_, _i, col)
# Convert both signals to size (direction-aware), size type, and direction
_val_price = flex_select_nb(val_price_, i, col)
if np.isinf(_val_price) and _val_price > 0:
if np.isinf(_price) and _price > 0:
_val_price = _close
elif np.isinf(_price) and _price < 0:
_val_price = _open
else:
_val_price = _price
elif np.isnan(_val_price) or (np.isinf(_val_price) and _val_price < 0):
_val_price = last_val_price[col]
_size, _size_type, _direction = signal_to_size_nb(
position_now=last_position[col],
val_price_now=_val_price,
value_now=last_value[group],
is_long_entry=is_long_entry,
is_long_exit=is_long_exit,
is_short_entry=is_short_entry,
is_short_exit=is_short_exit,
size=flex_select_nb(size_, _i, col),
size_type=flex_select_nb(size_type_, _i, col),
accumulate=_accumulate,
)
if np.isinf(_price):
if _price > 0:
user_on_close = True
else:
user_on_open = True
if not np.isnan(_size):
# Executable user signal
can_execute = True
_order_type = flex_select_nb(order_type_, _i, col)
if _order_type == OrderType.Limit:
# Use close to check whether the limit price was hit
can_use_ohlc = False
if np.isinf(_price):
if _price > 0:
# Cannot place a limit order at the close price and execute right away
_price = _close
can_execute = False
else:
can_use_ohlc = True
_price = _open
if can_execute:
_limit_delta = flex_select_nb(limit_delta_, _i, col)
_delta_format = flex_select_nb(delta_format_, _i, col)
_limit_reverse = flex_select_nb(limit_reverse_, _i, col)
limit_price, _, can_execute = check_limit_hit_nb(
open=_open,
high=_high,
low=_low,
close=_close,
price=_price,
size=_size,
direction=_direction,
limit_delta=_limit_delta,
delta_format=_delta_format,
limit_reverse=_limit_reverse,
can_use_ohlc=can_use_ohlc,
check_open=False,
hard_limit=False,
)
if can_execute:
_price = limit_price
# Save info
exec_user_set = True
exec_user_val_price = _val_price
exec_user_price = _price
exec_user_size = _size
exec_user_size_type = _size_type
exec_user_direction = _direction
exec_user_type = _order_type
exec_user_stop_type = -1
exec_user_make_limit = not can_execute
if (
exec_limit_set
or exec_stop_set
or exec_user_set
or ((any_limit_signal or any_stop_signal) and any_user_signal)
):
# Choose the main executable signal
# Priority: limit -> stop -> user
# Check whether the main signal comes on open
keep_limit = True
keep_stop = True
execute_limit = False
execute_stop = False
execute_user = False
if exec_limit_set_on_open:
keep_limit = False
keep_stop = False
execute_limit = True
if exec_limit_set_on_close:
exec_limit_bar_zone = BarZone.Close
else:
exec_limit_bar_zone = BarZone.Open
elif exec_stop_set_on_open:
keep_limit = False
keep_stop = _ladder
execute_stop = True
if exec_stop_set_on_close:
exec_stop_bar_zone = BarZone.Close
else:
exec_stop_bar_zone = BarZone.Open
elif any_user_signal and user_on_open:
execute_user = True
if any_limit_signal and (execute_user or not exec_user_set):
stop_size = get_diraware_size_nb(
size=last_limit_info["init_size"][col],
direction=last_limit_info["init_direction"][col],
)
keep_limit, execute_user = resolve_pending_conflict_nb(
is_pending_long=stop_size >= 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col),
)
if any_stop_signal and (execute_user or not exec_user_set):
keep_stop, execute_user = resolve_pending_conflict_nb(
is_pending_long=last_position[col] < 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col),
)
if not exec_user_set:
execute_user = False
if execute_user:
exec_user_bar_zone = BarZone.Open
if not execute_limit and not execute_stop and not execute_user:
# Check whether the main signal comes in the middle of the bar
if exec_limit_set and not exec_limit_set_on_open and keep_limit:
keep_limit = False
keep_stop = False
execute_limit = True
exec_limit_bar_zone = BarZone.Middle
elif (
exec_stop_set and not exec_stop_set_on_open and not exec_stop_set_on_close and keep_stop
):
keep_limit = False
keep_stop = _ladder
execute_stop = True
exec_stop_bar_zone = BarZone.Middle
elif any_user_signal and not user_on_open and not user_on_close:
execute_user = True
if any_limit_signal and keep_limit and (execute_user or not exec_user_set):
stop_size = get_diraware_size_nb(
size=last_limit_info["init_size"][col],
direction=last_limit_info["init_direction"][col],
)
keep_limit, execute_user = resolve_pending_conflict_nb(
is_pending_long=stop_size >= 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col),
)
if any_stop_signal and keep_stop and (execute_user or not exec_user_set):
keep_stop, execute_user = resolve_pending_conflict_nb(
is_pending_long=last_position[col] < 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col),
)
if not exec_user_set:
execute_user = False
if execute_user:
exec_user_bar_zone = BarZone.Middle
if not execute_limit and not execute_stop and not execute_user:
# Check whether the main signal comes on close
if exec_stop_set_on_close and keep_stop:
keep_limit = False
keep_stop = _ladder
execute_stop = True
exec_stop_bar_zone = BarZone.Close
elif any_user_signal and user_on_close:
execute_user = True
if any_limit_signal and keep_limit and (execute_user or not exec_user_set):
stop_size = get_diraware_size_nb(
size=last_limit_info["init_size"][col],
direction=last_limit_info["init_direction"][col],
)
keep_limit, execute_user = resolve_pending_conflict_nb(
is_pending_long=stop_size >= 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_limit_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_limit_conflict_, i, col),
)
if any_stop_signal and keep_stop and (execute_user or not exec_user_set):
keep_stop, execute_user = resolve_pending_conflict_nb(
is_pending_long=last_position[col] < 0,
is_user_long=is_long_entry or is_short_exit,
upon_adj_conflict=flex_select_nb(upon_adj_stop_conflict_, i, col),
upon_opp_conflict=flex_select_nb(upon_opp_stop_conflict_, i, col),
)
if not exec_user_set:
execute_user = False
if execute_user:
exec_user_bar_zone = BarZone.Close
# Process the limit signal
if execute_limit:
# Execute the signal
main_info["bar_zone"][col] = exec_limit_bar_zone
main_info["signal_idx"][col] = exec_limit_signal_i
main_info["creation_idx"][col] = exec_limit_creation_i
main_info["idx"][col] = exec_limit_init_i
main_info["val_price"][col] = exec_limit_val_price
main_info["price"][col] = exec_limit_price
main_info["size"][col] = exec_limit_size
main_info["size_type"][col] = exec_limit_size_type
main_info["direction"][col] = exec_limit_direction
main_info["type"][col] = OrderType.Limit
main_info["stop_type"][col] = exec_limit_stop_type
if execute_limit or (any_limit_signal and not keep_limit):
# Clear the pending info
any_limit_signal = False
last_limit_info["signal_idx"][col] = -1
last_limit_info["creation_idx"][col] = -1
last_limit_info["init_idx"][col] = -1
last_limit_info["init_price"][col] = np.nan
last_limit_info["init_size"][col] = np.nan
last_limit_info["init_size_type"][col] = -1
last_limit_info["init_direction"][col] = -1
last_limit_info["init_stop_type"][col] = -1
last_limit_info["delta"][col] = np.nan
last_limit_info["delta_format"][col] = -1
last_limit_info["tif"][col] = -1
last_limit_info["expiry"][col] = -1
last_limit_info["time_delta_format"][col] = -1
last_limit_info["reverse"][col] = False
last_limit_info["order_price"][col] = np.nan
# Process the stop signal
if execute_stop:
# Execute the signal
if exec_stop_make_limit:
if any_limit_signal:
raise ValueError("Only one active limit signal is allowed at a time")
_limit_tif = flex_select_nb(limit_tif_, i, col)
_limit_expiry = flex_select_nb(limit_expiry_, i, col)
_time_delta_format = flex_select_nb(time_delta_format_, i, col)
_limit_order_price = flex_select_nb(limit_order_price_, i, col)
last_limit_info["signal_idx"][col] = exec_stop_init_i
last_limit_info["creation_idx"][col] = i
last_limit_info["init_idx"][col] = i
last_limit_info["init_price"][col] = exec_stop_price
last_limit_info["init_size"][col] = exec_stop_size
last_limit_info["init_size_type"][col] = exec_stop_size_type
last_limit_info["init_direction"][col] = exec_stop_direction
last_limit_info["init_stop_type"][col] = exec_stop_stop_type
last_limit_info["delta"][col] = exec_stop_delta
last_limit_info["delta_format"][col] = exec_stop_delta_format
last_limit_info["tif"][col] = _limit_tif
last_limit_info["expiry"][col] = _limit_expiry
last_limit_info["time_delta_format"][col] = _time_delta_format
last_limit_info["reverse"][col] = False
last_limit_info["order_price"][col] = _limit_order_price
else:
main_info["bar_zone"][col] = exec_stop_bar_zone
main_info["signal_idx"][col] = exec_stop_init_i
main_info["creation_idx"][col] = i
main_info["idx"][col] = i
main_info["val_price"][col] = exec_stop_val_price
main_info["price"][col] = exec_stop_price
main_info["size"][col] = exec_stop_size
main_info["size_type"][col] = exec_stop_size_type
main_info["direction"][col] = exec_stop_direction
main_info["type"][col] = exec_stop_type
main_info["stop_type"][col] = exec_stop_stop_type
if any_stop_signal and not keep_stop:
# Clear the pending info
any_stop_signal = False
last_sl_info["init_idx"][col] = -1
last_sl_info["init_price"][col] = np.nan
last_sl_info["init_position"][col] = np.nan
last_sl_info["stop"][col] = np.nan
last_sl_info["exit_price"][col] = -1
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
last_sl_info["exit_type"][col] = -1
last_sl_info["order_type"][col] = -1
last_sl_info["limit_delta"][col] = np.nan
last_sl_info["delta_format"][col] = -1
last_sl_info["ladder"][col] = -1
last_sl_info["step"][col] = -1
last_sl_info["step_idx"][col] = -1
last_tsl_info["init_idx"][col] = -1
last_tsl_info["init_price"][col] = np.nan
last_tsl_info["init_position"][col] = np.nan
last_tsl_info["peak_idx"][col] = -1
last_tsl_info["peak_price"][col] = np.nan
last_tsl_info["stop"][col] = np.nan
last_tsl_info["th"][col] = np.nan
last_tsl_info["exit_price"][col] = -1
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
last_tsl_info["exit_type"][col] = -1
last_tsl_info["order_type"][col] = -1
last_tsl_info["limit_delta"][col] = np.nan
last_tsl_info["delta_format"][col] = -1
last_tsl_info["ladder"][col] = -1
last_tsl_info["step"][col] = -1
last_tsl_info["step_idx"][col] = -1
last_tp_info["init_idx"][col] = -1
last_tp_info["init_price"][col] = np.nan
last_tp_info["init_position"][col] = np.nan
last_tp_info["stop"][col] = np.nan
last_tp_info["exit_price"][col] = -1
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
last_tp_info["exit_type"][col] = -1
last_tp_info["order_type"][col] = -1
last_tp_info["limit_delta"][col] = np.nan
last_tp_info["delta_format"][col] = -1
last_tp_info["ladder"][col] = -1
last_tp_info["step"][col] = -1
last_tp_info["step_idx"][col] = -1
last_td_info["init_idx"][col] = -1
last_td_info["init_position"][col] = np.nan
last_td_info["stop"][col] = -1
last_td_info["exit_price"][col] = -1
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
last_td_info["exit_type"][col] = -1
last_td_info["order_type"][col] = -1
last_td_info["limit_delta"][col] = np.nan
last_td_info["delta_format"][col] = -1
last_td_info["time_delta_format"][col] = -1
last_td_info["ladder"][col] = -1
last_td_info["step"][col] = -1
last_td_info["step_idx"][col] = -1
last_dt_info["init_idx"][col] = -1
last_dt_info["init_position"][col] = np.nan
last_dt_info["stop"][col] = -1
last_dt_info["exit_price"][col] = -1
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
last_dt_info["exit_type"][col] = -1
last_dt_info["order_type"][col] = -1
last_dt_info["limit_delta"][col] = np.nan
last_dt_info["delta_format"][col] = -1
last_dt_info["time_delta_format"][col] = -1
last_dt_info["ladder"][col] = -1
last_dt_info["step"][col] = -1
last_dt_info["step_idx"][col] = -1
# Process the user signal
if execute_user:
# Execute the signal
if _i >= 0:
if exec_user_make_limit:
if any_limit_signal:
raise ValueError("Only one active limit signal is allowed at a time")
_limit_delta = flex_select_nb(limit_delta_, _i, col)
_delta_format = flex_select_nb(delta_format_, _i, col)
_limit_tif = flex_select_nb(limit_tif_, _i, col)
_limit_expiry = flex_select_nb(limit_expiry_, _i, col)
_time_delta_format = flex_select_nb(time_delta_format_, _i, col)
_limit_reverse = flex_select_nb(limit_reverse_, _i, col)
_limit_order_price = flex_select_nb(limit_order_price_, _i, col)
last_limit_info["signal_idx"][col] = _i
last_limit_info["creation_idx"][col] = i
last_limit_info["init_idx"][col] = _i
last_limit_info["init_price"][col] = exec_user_price
last_limit_info["init_size"][col] = exec_user_size
last_limit_info["init_size_type"][col] = exec_user_size_type
last_limit_info["init_direction"][col] = exec_user_direction
last_limit_info["init_stop_type"][col] = -1
last_limit_info["delta"][col] = _limit_delta
last_limit_info["delta_format"][col] = _delta_format
last_limit_info["tif"][col] = _limit_tif
last_limit_info["expiry"][col] = _limit_expiry
last_limit_info["time_delta_format"][col] = _time_delta_format
last_limit_info["reverse"][col] = _limit_reverse
last_limit_info["order_price"][col] = _limit_order_price
else:
main_info["bar_zone"][col] = exec_user_bar_zone
main_info["signal_idx"][col] = _i
main_info["creation_idx"][col] = i
main_info["idx"][col] = _i
main_info["val_price"][col] = exec_user_val_price
main_info["price"][col] = exec_user_price
main_info["size"][col] = exec_user_size
main_info["size_type"][col] = exec_user_size_type
main_info["direction"][col] = exec_user_direction
main_info["type"][col] = exec_user_type
main_info["stop_type"][col] = exec_user_stop_type
skip = skip_empty
if skip:
for col in range(from_col, to_col):
if flex_select_nb(log_, i, col):
skip = False
break
if not np.isnan(main_info["size"][col]):
skip = False
break
if not skip:
# Check bar zone and update valuation price
bar_zone = -1
same_bar_zone = True
same_timing = True
for c in range(group_len):
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
if bar_zone == -1:
bar_zone = main_info["bar_zone"][col]
if main_info["bar_zone"][col] != bar_zone:
same_bar_zone = False
same_timing = False
if main_info["bar_zone"][col] == BarZone.Middle:
same_timing = False
_val_price = main_info["val_price"][col]
if not np.isnan(_val_price) or not ffill_val_price:
last_val_price[col] = _val_price
if cash_sharing:
# Dynamically sort by order value -> selling comes first to release funds early
if call_seq is None:
for c in range(group_len):
temp_call_seq[c] = c
call_seq_now = temp_call_seq[:group_len]
else:
call_seq_now = call_seq[i, from_col:to_col]
if auto_call_seq:
# Sort by order value
if not same_timing:
raise ValueError("Cannot sort orders by value if they are executed at different times")
for c in range(group_len):
if call_seq_now[c] != c:
raise ValueError("Call sequence must follow CallSeqType.Default")
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
# Approximate order value
exec_state = ExecState(
cash=last_cash[group] if cash_sharing else last_cash[col],
position=last_position[col],
debt=last_debt[col],
locked_cash=last_locked_cash[col],
free_cash=last_free_cash[group] if cash_sharing else last_free_cash[col],
val_price=last_val_price[col],
value=last_value[group] if cash_sharing else last_value[col],
)
temp_sort_by[c] = approx_order_value_nb(
exec_state=exec_state,
size=main_info["size"][col],
size_type=main_info["size_type"][col],
direction=main_info["direction"][col],
)
insert_argsort_nb(temp_sort_by[:group_len], call_seq_now)
else:
if not same_bar_zone:
# Sort by bar zone
for c in range(group_len):
if call_seq_now[c] != c:
raise ValueError("Call sequence must follow CallSeqType.Default")
col = from_col + c
if np.isnan(main_info["size"][col]):
continue
temp_sort_by[c] = main_info["bar_zone"][col]
insert_argsort_nb(temp_sort_by[:group_len], call_seq_now)
for k in range(group_len):
if cash_sharing:
c = call_seq_now[k]
if c >= group_len:
raise ValueError("Call index out of bounds of the group")
else:
c = k
col = from_col + c
if skip_empty and np.isnan(main_info["size"][col]): # shortcut
continue
# Get current values per column
position_before = position_now = last_position[col]
debt_before = debt_now = last_debt[col]
locked_cash_before = locked_cash_now = last_locked_cash[col]
val_price_before = val_price_now = last_val_price[col]
cash_before = cash_now = last_cash[group] if cash_sharing else last_cash[col]
free_cash_before = free_cash_now = (
last_free_cash[group] if cash_sharing else last_free_cash[col]
)
value_before = value_now = last_value[group] if cash_sharing else last_value[col]
return_before = return_now = last_return[group] if cash_sharing else last_return[col]
# Generate the next order
_i = main_info["idx"][col]
if main_info["type"][col] == OrderType.Limit:
_slippage = 0.0
else:
_slippage = float(flex_select_nb(slippage_, _i, col))
_min_size = flex_select_nb(min_size_, _i, col)
_max_size = flex_select_nb(max_size_, _i, col)
_size_type = flex_select_nb(size_type_, _i, col)
if _size_type != main_info["size_type"][col]:
if not np.isnan(_min_size):
_min_size, _ = resolve_size_nb(
size=_min_size,
size_type=_size_type,
position=position_now,
val_price=val_price_now,
value=value_now,
target_size_type=main_info["size_type"][col],
as_requirement=True,
)
if not np.isnan(_max_size):
_max_size, _ = resolve_size_nb(
size=_max_size,
size_type=_size_type,
position=position_now,
val_price=val_price_now,
value=value_now,
target_size_type=main_info["size_type"][col],
as_requirement=True,
)
order = order_nb(
size=main_info["size"][col],
price=main_info["price"][col],
size_type=main_info["size_type"][col],
direction=main_info["direction"][col],
fees=flex_select_nb(fees_, _i, col),
fixed_fees=flex_select_nb(fixed_fees_, _i, col),
slippage=_slippage,
min_size=_min_size,
max_size=_max_size,
size_granularity=flex_select_nb(size_granularity_, _i, col),
leverage=flex_select_nb(leverage_, _i, col),
leverage_mode=flex_select_nb(leverage_mode_, _i, col),
reject_prob=flex_select_nb(reject_prob_, _i, col),
price_area_vio_mode=flex_select_nb(price_area_vio_mode_, _i, col),
allow_partial=flex_select_nb(allow_partial_, _i, col),
raise_reject=flex_select_nb(raise_reject_, _i, col),
log=flex_select_nb(log_, _i, col),
)
# Process the order
price_area = PriceArea(
open=flex_select_nb(open_, i, col),
high=flex_select_nb(high_, i, col),
low=flex_select_nb(low_, i, col),
close=flex_select_nb(close_, i, col),
)
exec_state = ExecState(
cash=cash_now,
position=position_now,
debt=debt_now,
locked_cash=locked_cash_now,
free_cash=free_cash_now,
val_price=val_price_now,
value=value_now,
)
order_result, new_exec_state = process_order_nb(
group=group,
col=col,
i=i,
exec_state=exec_state,
order=order,
price_area=price_area,
update_value=update_value,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
)
# Append more order information
if order_result.status == OrderStatus.Filled and order_counts[col] >= 1:
order_records["signal_idx"][order_counts[col] - 1, col] = main_info["signal_idx"][col]
order_records["creation_idx"][order_counts[col] - 1, col] = main_info["creation_idx"][col]
order_records["type"][order_counts[col] - 1, col] = main_info["type"][col]
order_records["stop_type"][order_counts[col] - 1, col] = main_info["stop_type"][col]
# Update execution state
cash_now = new_exec_state.cash
position_now = new_exec_state.position
debt_now = new_exec_state.debt
locked_cash_now = new_exec_state.locked_cash
free_cash_now = new_exec_state.free_cash
val_price_now = new_exec_state.val_price
value_now = new_exec_state.value
# Update position record
if fill_pos_info:
if order_result.status == OrderStatus.Filled:
if order_counts[col] > 0:
order_id = order_records["id"][order_counts[col] - 1, col]
else:
order_id = -1
update_pos_info_nb(
last_pos_info[col],
i,
col,
exec_state.position,
position_now,
order_result,
order_id,
)
if use_stops:
# Update stop price
if position_now == 0:
# Not in position anymore -> clear stops (irrespective of order success)
last_sl_info["init_idx"][col] = -1
last_sl_info["init_price"][col] = np.nan
last_sl_info["init_position"][col] = np.nan
last_sl_info["stop"][col] = np.nan
last_sl_info["exit_price"][col] = -1
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
last_sl_info["exit_type"][col] = -1
last_sl_info["order_type"][col] = -1
last_sl_info["limit_delta"][col] = np.nan
last_sl_info["delta_format"][col] = -1
last_sl_info["ladder"][col] = -1
last_sl_info["step"][col] = -1
last_sl_info["step_idx"][col] = -1
last_tsl_info["init_idx"][col] = -1
last_tsl_info["init_price"][col] = np.nan
last_tsl_info["init_position"][col] = np.nan
last_tsl_info["peak_idx"][col] = -1
last_tsl_info["peak_price"][col] = np.nan
last_tsl_info["stop"][col] = np.nan
last_tsl_info["th"][col] = np.nan
last_tsl_info["exit_price"][col] = -1
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
last_tsl_info["exit_type"][col] = -1
last_tsl_info["order_type"][col] = -1
last_tsl_info["limit_delta"][col] = np.nan
last_tsl_info["delta_format"][col] = -1
last_tsl_info["ladder"][col] = -1
last_tsl_info["step"][col] = -1
last_tsl_info["step_idx"][col] = -1
last_tp_info["init_idx"][col] = -1
last_tp_info["init_price"][col] = np.nan
last_tp_info["init_position"][col] = np.nan
last_tp_info["stop"][col] = np.nan
last_tp_info["exit_price"][col] = -1
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
last_tp_info["exit_type"][col] = -1
last_tp_info["order_type"][col] = -1
last_tp_info["limit_delta"][col] = np.nan
last_tp_info["delta_format"][col] = -1
last_tp_info["ladder"][col] = -1
last_tp_info["step"][col] = -1
last_tp_info["step_idx"][col] = -1
last_td_info["init_idx"][col] = -1
last_td_info["init_position"][col] = np.nan
last_td_info["stop"][col] = -1
last_td_info["exit_price"][col] = -1
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
last_td_info["exit_type"][col] = -1
last_td_info["order_type"][col] = -1
last_td_info["limit_delta"][col] = np.nan
last_td_info["delta_format"][col] = -1
last_td_info["time_delta_format"][col] = -1
last_td_info["ladder"][col] = -1
last_td_info["step"][col] = -1
last_td_info["step_idx"][col] = -1
last_dt_info["init_idx"][col] = -1
last_dt_info["init_position"][col] = np.nan
last_dt_info["stop"][col] = -1
last_dt_info["exit_price"][col] = -1
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
last_dt_info["exit_type"][col] = -1
last_dt_info["order_type"][col] = -1
last_dt_info["limit_delta"][col] = np.nan
last_dt_info["delta_format"][col] = -1
last_dt_info["time_delta_format"][col] = -1
last_dt_info["ladder"][col] = -1
last_dt_info["step"][col] = -1
last_dt_info["step_idx"][col] = -1
else:
if main_info["stop_type"][col] == StopType.SL:
if last_sl_info["ladder"][col]:
step = last_sl_info["step"][col] + 1
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
if stop_ladder and last_sl_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_sl_steps:
last_sl_info["stop"][col] = flex_select_nb(sl_stop_, step, col)
last_sl_info["step"][col] = step
last_sl_info["step_idx"][col] = i
else:
last_sl_info["stop"][col] = np.nan
last_sl_info["step"][col] = -1
last_sl_info["step_idx"][col] = -1
else:
last_sl_info["stop"][col] = np.nan
last_sl_info["step"][col] = step
last_sl_info["step_idx"][col] = i
elif (
main_info["stop_type"][col] == StopType.TSL
or main_info["stop_type"][col] == StopType.TTP
):
if last_tsl_info["ladder"][col]:
step = last_tsl_info["step"][col] + 1
last_tsl_info["step"][col] = step
last_tsl_info["step_idx"][col] = i
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
if stop_ladder and last_tsl_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_tsl_steps:
last_tsl_info["stop"][col] = flex_select_nb(tsl_stop_, step, col)
last_tsl_info["step"][col] = step
last_tsl_info["step_idx"][col] = i
else:
last_tsl_info["stop"][col] = np.nan
last_tsl_info["step"][col] = -1
last_tsl_info["step_idx"][col] = -1
else:
last_tsl_info["stop"][col] = np.nan
last_tsl_info["step"][col] = step
last_tsl_info["step_idx"][col] = i
elif main_info["stop_type"][col] == StopType.TP:
if last_tp_info["ladder"][col]:
step = last_tp_info["step"][col] + 1
last_tp_info["step"][col] = step
last_tp_info["step_idx"][col] = i
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
if stop_ladder and last_tp_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_tp_steps:
last_tp_info["stop"][col] = flex_select_nb(tp_stop_, step, col)
last_tp_info["step"][col] = step
last_tp_info["step_idx"][col] = i
else:
last_tp_info["stop"][col] = np.nan
last_tp_info["step"][col] = -1
last_tp_info["step_idx"][col] = -1
else:
last_tp_info["stop"][col] = np.nan
last_tp_info["step"][col] = step
last_tp_info["step_idx"][col] = i
elif main_info["stop_type"][col] == StopType.TD:
if last_td_info["ladder"][col]:
step = last_td_info["step"][col] + 1
last_td_info["step"][col] = step
last_td_info["step_idx"][col] = i
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
if stop_ladder and last_td_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_td_steps:
last_td_info["stop"][col] = flex_select_nb(td_stop_, step, col)
last_td_info["step"][col] = step
last_td_info["step_idx"][col] = i
else:
last_td_info["stop"][col] = -1
last_td_info["step"][col] = -1
last_td_info["step_idx"][col] = -1
else:
last_td_info["stop"][col] = -1
last_td_info["step"][col] = step
last_td_info["step_idx"][col] = i
elif main_info["stop_type"][col] == StopType.DT:
if last_dt_info["ladder"][col]:
step = last_dt_info["step"][col] + 1
last_dt_info["step"][col] = step
last_dt_info["step_idx"][col] = i
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
if stop_ladder and last_dt_info["ladder"][col] != StopLadderMode.Dynamic:
if step < n_dt_steps:
last_dt_info["stop"][col] = flex_select_nb(dt_stop_, step, col)
last_dt_info["step"][col] = step
last_dt_info["step_idx"][col] = i
else:
last_dt_info["stop"][col] = -1
last_dt_info["step"][col] = -1
last_dt_info["step_idx"][col] = -1
else:
last_dt_info["stop"][col] = -1
last_dt_info["step"][col] = step
last_dt_info["step_idx"][col] = i
if order_result.status == OrderStatus.Filled and position_now != 0:
# Order filled and in position -> possibly set stops
_price = main_info["price"][col]
_stop_entry_price = flex_select_nb(stop_entry_price_, i, col)
if _stop_entry_price < 0:
if _stop_entry_price == StopEntryPrice.ValPrice:
new_init_price = val_price_now
can_use_ohlc = False
elif _stop_entry_price == StopEntryPrice.Price:
new_init_price = order.price
can_use_ohlc = np.isinf(_price) and _price < 0
if np.isinf(new_init_price):
if new_init_price > 0:
new_init_price = flex_select_nb(close_, i, col)
else:
new_init_price = flex_select_nb(open_, i, col)
elif _stop_entry_price == StopEntryPrice.FillPrice:
new_init_price = order_result.price
can_use_ohlc = np.isinf(_price) and _price < 0
elif _stop_entry_price == StopEntryPrice.Open:
new_init_price = flex_select_nb(open_, i, col)
can_use_ohlc = True
elif _stop_entry_price == StopEntryPrice.Close:
new_init_price = flex_select_nb(close_, i, col)
can_use_ohlc = False
else:
raise ValueError("Invalid StopEntryPrice option")
else:
new_init_price = _stop_entry_price
can_use_ohlc = False
if stop_ladder:
_sl_stop = flex_select_nb(sl_stop_, 0, col)
_tsl_stop = flex_select_nb(tsl_stop_, 0, col)
_tp_stop = flex_select_nb(tp_stop_, 0, col)
_td_stop = flex_select_nb(td_stop_, 0, col)
_dt_stop = flex_select_nb(dt_stop_, 0, col)
else:
_sl_stop = flex_select_nb(sl_stop_, i, col)
_tsl_stop = flex_select_nb(tsl_stop_, i, col)
_tp_stop = flex_select_nb(tp_stop_, i, col)
_td_stop = flex_select_nb(td_stop_, i, col)
_dt_stop = flex_select_nb(dt_stop_, i, col)
_tsl_th = flex_select_nb(tsl_th_, i, col)
_stop_exit_price = flex_select_nb(stop_exit_price_, i, col)
_stop_exit_type = flex_select_nb(stop_exit_type_, i, col)
_stop_order_type = flex_select_nb(stop_order_type_, i, col)
_stop_limit_delta = flex_select_nb(stop_limit_delta_, i, col)
_delta_format = flex_select_nb(delta_format_, i, col)
_time_delta_format = flex_select_nb(time_delta_format_, i, col)
tsl_updated = False
if exec_state.position == 0 or np.sign(position_now) != np.sign(exec_state.position):
# Position opened/reversed -> set stops
last_sl_info["init_idx"][col] = i
last_sl_info["init_price"][col] = new_init_price
last_sl_info["init_position"][col] = position_now
last_sl_info["stop"][col] = _sl_stop
last_sl_info["exit_price"][col] = _stop_exit_price
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
last_sl_info["exit_type"][col] = _stop_exit_type
last_sl_info["order_type"][col] = _stop_order_type
last_sl_info["limit_delta"][col] = _stop_limit_delta
last_sl_info["delta_format"][col] = _delta_format
last_sl_info["ladder"][col] = stop_ladder
last_sl_info["step"][col] = 0
last_sl_info["step_idx"][col] = i
tsl_updated = True
last_tsl_info["init_idx"][col] = i
last_tsl_info["init_price"][col] = new_init_price
last_tsl_info["init_position"][col] = position_now
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = new_init_price
last_tsl_info["stop"][col] = _tsl_stop
last_tsl_info["th"][col] = _tsl_th
last_tsl_info["exit_price"][col] = _stop_exit_price
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
last_tsl_info["exit_type"][col] = _stop_exit_type
last_tsl_info["order_type"][col] = _stop_order_type
last_tsl_info["limit_delta"][col] = _stop_limit_delta
last_tsl_info["delta_format"][col] = _delta_format
last_tsl_info["ladder"][col] = stop_ladder
last_tsl_info["step"][col] = 0
last_tsl_info["step_idx"][col] = i
last_tp_info["init_idx"][col] = i
last_tp_info["init_price"][col] = new_init_price
last_tp_info["init_position"][col] = position_now
last_tp_info["stop"][col] = _tp_stop
last_tp_info["exit_price"][col] = _stop_exit_price
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
last_tp_info["exit_type"][col] = _stop_exit_type
last_tp_info["order_type"][col] = _stop_order_type
last_tp_info["limit_delta"][col] = _stop_limit_delta
last_tp_info["delta_format"][col] = _delta_format
last_tp_info["ladder"][col] = stop_ladder
last_tp_info["step"][col] = 0
last_tp_info["step_idx"][col] = i
last_td_info["init_idx"][col] = i
last_td_info["init_position"][col] = position_now
last_td_info["stop"][col] = _td_stop
last_td_info["exit_price"][col] = _stop_exit_price
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
last_td_info["exit_type"][col] = _stop_exit_type
last_td_info["order_type"][col] = _stop_order_type
last_td_info["limit_delta"][col] = _stop_limit_delta
last_td_info["delta_format"][col] = _delta_format
last_td_info["time_delta_format"][col] = _time_delta_format
last_td_info["ladder"][col] = stop_ladder
last_td_info["step"][col] = 0
last_td_info["step_idx"][col] = i
last_dt_info["init_idx"][col] = i
last_dt_info["init_position"][col] = position_now
last_dt_info["stop"][col] = _dt_stop
last_dt_info["exit_price"][col] = _stop_exit_price
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
last_dt_info["exit_type"][col] = _stop_exit_type
last_dt_info["order_type"][col] = _stop_order_type
last_dt_info["limit_delta"][col] = _stop_limit_delta
last_dt_info["delta_format"][col] = _delta_format
last_dt_info["time_delta_format"][col] = _time_delta_format
last_dt_info["ladder"][col] = stop_ladder
last_dt_info["step"][col] = 0
last_dt_info["step_idx"][col] = i
elif abs(position_now) > abs(exec_state.position):
# Position increased -> keep/override stops
_upon_stop_update = flex_select_nb(upon_stop_update_, i, col)
if should_update_stop_nb(new_stop=_sl_stop, upon_stop_update=_upon_stop_update):
last_sl_info["init_idx"][col] = i
last_sl_info["init_price"][col] = new_init_price
last_sl_info["init_position"][col] = position_now
last_sl_info["stop"][col] = _sl_stop
last_sl_info["exit_price"][col] = _stop_exit_price
last_sl_info["exit_size"][col] = np.nan
last_sl_info["exit_size_type"][col] = -1
last_sl_info["exit_type"][col] = _stop_exit_type
last_sl_info["order_type"][col] = _stop_order_type
last_sl_info["limit_delta"][col] = _stop_limit_delta
last_sl_info["delta_format"][col] = _delta_format
last_sl_info["ladder"][col] = stop_ladder
last_sl_info["step"][col] = 0
last_sl_info["step_idx"][col] = i
if should_update_stop_nb(new_stop=_tsl_stop, upon_stop_update=_upon_stop_update):
tsl_updated = True
last_tsl_info["init_idx"][col] = i
last_tsl_info["init_price"][col] = new_init_price
last_tsl_info["init_position"][col] = position_now
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = new_init_price
last_tsl_info["stop"][col] = _tsl_stop
last_tsl_info["th"][col] = _tsl_th
last_tsl_info["exit_price"][col] = _stop_exit_price
last_tsl_info["exit_size"][col] = np.nan
last_tsl_info["exit_size_type"][col] = -1
last_tsl_info["exit_type"][col] = _stop_exit_type
last_tsl_info["order_type"][col] = _stop_order_type
last_tsl_info["limit_delta"][col] = _stop_limit_delta
last_tsl_info["delta_format"][col] = _delta_format
last_tsl_info["ladder"][col] = stop_ladder
last_tsl_info["step"][col] = 0
last_tsl_info["step_idx"][col] = i
if should_update_stop_nb(new_stop=_tp_stop, upon_stop_update=_upon_stop_update):
last_tp_info["init_idx"][col] = i
last_tp_info["init_price"][col] = new_init_price
last_tp_info["init_position"][col] = position_now
last_tp_info["stop"][col] = _tp_stop
last_tp_info["exit_price"][col] = _stop_exit_price
last_tp_info["exit_size"][col] = np.nan
last_tp_info["exit_size_type"][col] = -1
last_tp_info["exit_type"][col] = _stop_exit_type
last_tp_info["order_type"][col] = _stop_order_type
last_tp_info["limit_delta"][col] = _stop_limit_delta
last_tp_info["delta_format"][col] = _delta_format
last_tp_info["ladder"][col] = stop_ladder
last_tp_info["step"][col] = 0
last_tp_info["step_idx"][col] = i
if should_update_time_stop_nb(
new_stop=_td_stop, upon_stop_update=_upon_stop_update
):
last_td_info["init_idx"][col] = i
last_td_info["init_position"][col] = position_now
last_td_info["stop"][col] = _td_stop
last_td_info["exit_price"][col] = _stop_exit_price
last_td_info["exit_size"][col] = np.nan
last_td_info["exit_size_type"][col] = -1
last_td_info["exit_type"][col] = _stop_exit_type
last_td_info["order_type"][col] = _stop_order_type
last_td_info["limit_delta"][col] = _stop_limit_delta
last_td_info["delta_format"][col] = _delta_format
last_td_info["time_delta_format"][col] = _time_delta_format
last_td_info["ladder"][col] = stop_ladder
last_td_info["step"][col] = 0
last_td_info["step_idx"][col] = i
if should_update_time_stop_nb(
new_stop=_dt_stop, upon_stop_update=_upon_stop_update
):
last_dt_info["init_idx"][col] = i
last_dt_info["init_position"][col] = position_now
last_dt_info["stop"][col] = _dt_stop
last_dt_info["exit_price"][col] = _stop_exit_price
last_dt_info["exit_size"][col] = np.nan
last_dt_info["exit_size_type"][col] = -1
last_dt_info["exit_type"][col] = _stop_exit_type
last_dt_info["order_type"][col] = _stop_order_type
last_dt_info["limit_delta"][col] = _stop_limit_delta
last_dt_info["delta_format"][col] = _delta_format
last_dt_info["time_delta_format"][col] = _time_delta_format
last_dt_info["ladder"][col] = stop_ladder
last_dt_info["step"][col] = 0
last_dt_info["step_idx"][col] = i
if tsl_updated:
# Update highest/lowest price
if can_use_ohlc:
_open = flex_select_nb(open_, i, col)
_high = flex_select_nb(high_, i, col)
_low = flex_select_nb(low_, i, col)
_close = flex_select_nb(close_, i, col)
_high, _low = resolve_hl_nb(
open=_open,
high=_high,
low=_low,
close=_close,
)
else:
_open = np.nan
_high = _low = _close = flex_select_nb(close_, i, col)
if tsl_updated:
if position_now > 0:
if _high > last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col]
+ _high
- last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _high
elif position_now < 0:
if _low < last_tsl_info["peak_price"][col]:
if last_tsl_info["delta_format"][col] == DeltaFormat.Target:
last_tsl_info["stop"][col] = (
last_tsl_info["stop"][col]
+ _low
- last_tsl_info["peak_price"][col]
)
last_tsl_info["peak_idx"][col] = i
last_tsl_info["peak_price"][col] = _low
# Now becomes last
last_position[col] = position_now
last_debt[col] = debt_now
last_locked_cash[col] = locked_cash_now
if not np.isnan(val_price_now) or not ffill_val_price:
last_val_price[col] = val_price_now
if cash_sharing:
last_cash[group] = cash_now
last_free_cash[group] = free_cash_now
last_value[group] = value_now
last_return[group] = return_now
else:
last_cash[col] = cash_now
last_free_cash[col] = free_cash_now
last_value[col] = value_now
last_return[col] = return_now
# Call post-signal function
post_signal_ctx = PostSignalContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
track_cash_deposits=track_cash_deposits,
cash_deposits_out=cash_deposits_out,
track_cash_earnings=track_cash_earnings,
cash_earnings_out=cash_earnings_out,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
last_limit_info=last_limit_info,
last_sl_info=last_sl_info,
last_tsl_info=last_tsl_info,
last_tp_info=last_tp_info,
last_td_info=last_td_info,
last_dt_info=last_dt_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
col=col,
cash_before=cash_before,
position_before=position_before,
debt_before=debt_before,
locked_cash_before=locked_cash_before,
free_cash_before=free_cash_before,
val_price_before=val_price_before,
value_before=value_before,
order_result=order_result,
)
post_signal_func_nb(post_signal_ctx, *post_signal_args)
for col in range(from_col, to_col):
# Update valuation price using current close
_close = flex_select_nb(close_, i, col)
if not np.isnan(_close) or not ffill_val_price:
last_val_price[col] = _close
_cash_earnings = flex_select_nb(cash_earnings_, i, col)
_cash_dividends = flex_select_nb(cash_dividends_, i, col)
_cash_earnings += _cash_dividends * last_position[col]
if cash_sharing:
last_cash[group] += _cash_earnings
last_free_cash[group] += _cash_earnings
else:
last_cash[col] += _cash_earnings
last_free_cash[col] += _cash_earnings
if track_cash_earnings:
cash_earnings_out[i, col] += _cash_earnings
# Update value and return
if cash_sharing:
group_value = last_cash[group]
for col in range(from_col, to_col):
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[group] = group_value
last_return[group] = get_return_nb(
input_value=prev_close_value[group],
output_value=last_value[group] - last_cash_deposits[group],
)
prev_close_value[group] = last_value[group]
else:
for col in range(from_col, to_col):
group_value = last_cash[col]
if last_position[col] != 0:
group_value += last_position[col] * last_val_price[col]
last_value[col] = group_value
last_return[col] = get_return_nb(
input_value=prev_close_value[col],
output_value=last_value[col] - last_cash_deposits[col],
)
prev_close_value[col] = last_value[col]
# Update open position stats
if fill_pos_info:
for col in range(from_col, to_col):
update_open_pos_info_stats_nb(last_pos_info[col], last_position[col], last_val_price[col])
# Call post-segment function
post_segment_ctx = SignalSegmentContext(
target_shape=target_shape,
group_lens=group_lens,
cash_sharing=cash_sharing,
index=index,
freq=freq,
open=open_,
high=high_,
low=low_,
close=close_,
init_cash=init_cash_,
init_position=init_position_,
init_price=init_price_,
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
track_cash_deposits=track_cash_deposits,
cash_deposits_out=cash_deposits_out,
track_cash_earnings=track_cash_earnings,
cash_earnings_out=cash_earnings_out,
in_outputs=in_outputs,
last_cash=last_cash,
last_position=last_position,
last_debt=last_debt,
last_locked_cash=last_locked_cash,
last_free_cash=last_free_cash,
last_val_price=last_val_price,
last_value=last_value,
last_return=last_return,
last_pos_info=last_pos_info,
last_limit_info=last_limit_info,
last_sl_info=last_sl_info,
last_tsl_info=last_tsl_info,
last_tp_info=last_tp_info,
last_td_info=last_td_info,
last_dt_info=last_dt_info,
sim_start=sim_start_,
sim_end=sim_end_,
group=group,
group_len=group_len,
from_col=from_col,
to_col=to_col,
i=i,
)
post_segment_func_nb(post_segment_ctx, *post_segment_args)
if i >= sim_end_[group] - 1:
break
sim_start_out, sim_end_out = generic_nb.resolve_ungrouped_sim_range_nb(
target_shape=target_shape,
group_lens=group_lens,
sim_start=sim_start_,
sim_end=sim_end_,
allow_none=True,
)
return prepare_sim_out_nb(
order_records=order_records,
order_counts=order_counts,
log_records=log_records,
log_counts=log_counts,
cash_deposits=cash_deposits_out,
cash_earnings=cash_earnings_out,
call_seq=call_seq,
in_outputs=in_outputs,
sim_start=sim_start_out,
sim_end=sim_end_out,
)
# %
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
entries=base_ch.FlexArraySlicer(axis=1),
exits=base_ch.FlexArraySlicer(axis=1),
direction=base_ch.FlexArraySlicer(axis=1),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def dir_to_ls_signals_nb(
target_shape: tp.Shape,
entries: tp.FlexArray2d,
exits: tp.FlexArray2d,
direction: tp.FlexArray2d,
) -> tp.Tuple[tp.Array2d, tp.Array2d, tp.Array2d, tp.Array2d]:
"""Convert direction-unaware to direction-aware signals."""
long_entries_out = np.empty(target_shape, dtype=np.bool_)
long_exits_out = np.empty(target_shape, dtype=np.bool_)
short_entries_out = np.empty(target_shape, dtype=np.bool_)
short_exits_out = np.empty(target_shape, dtype=np.bool_)
for col in prange(target_shape[1]):
for i in range(target_shape[0]):
is_entry = flex_select_nb(entries, i, col)
is_exit = flex_select_nb(exits, i, col)
_direction = flex_select_nb(direction, i, col)
if _direction == Direction.LongOnly:
long_entries_out[i, col] = is_entry
long_exits_out[i, col] = is_exit
short_entries_out[i, col] = False
short_exits_out[i, col] = False
elif _direction == Direction.ShortOnly:
long_entries_out[i, col] = False
long_exits_out[i, col] = False
short_entries_out[i, col] = is_entry
short_exits_out[i, col] = is_exit
else:
long_entries_out[i, col] = is_entry
long_exits_out[i, col] = False
short_entries_out[i, col] = is_exit
short_exits_out[i, col] = False
return long_entries_out, long_exits_out, short_entries_out, short_exits_out
AdjustFuncT = tp.Callable[[SignalContext, tp.VarArg()], None]
@register_jitted
def no_adjust_func_nb(c: SignalContext, *args) -> None:
"""Placeholder `adjust_func_nb` that does nothing."""
return None
# %
# %
# %
# @register_jitted
# def adjust_func_nb(
# c: SignalContext,
# ) -> None:
# """Custom adjustment function."""
# return None
#
#
# %
# %
# %
# %
# % blocks["adjust_func_nb"]
@register_jitted
def holding_enex_signal_func_nb( # % line.replace("holding_enex_signal_func_nb", "signal_func_nb")
c: SignalContext,
direction: int,
close_at_end: bool,
adjust_func_nb: AdjustFuncT = no_adjust_func_nb, # % None
adjust_args: tp.Args = (),
) -> tp.Tuple[bool, bool, bool, bool]:
"""`signal_func_nb` that returns direction-aware signals from holding."""
adjust_func_nb(c, *adjust_args)
if c.last_position[c.col] == 0:
if c.order_counts[c.col] == 0:
if direction == Direction.ShortOnly:
return False, False, True, False
return True, False, False, False
elif close_at_end and c.i == c.target_shape[0] - 1:
if c.last_position[c.col] < 0:
return False, False, False, True
return False, True, False, False
return False, False, False, False
# %
# %
# % blocks["adjust_func_nb"]
@register_jitted
def dir_signal_func_nb( # % line.replace("dir_signal_func_nb", "signal_func_nb")
c: SignalContext,
entries: tp.FlexArray2d,
exits: tp.FlexArray2d,
direction: tp.FlexArray2d,
from_ago: tp.FlexArray2d,
adjust_func_nb: AdjustFuncT = no_adjust_func_nb, # % None
adjust_args: tp.Args = (),
) -> tp.Tuple[bool, bool, bool, bool]:
"""`signal_func_nb` that converts entries, exits, and direction into direction-aware signals.
The direction of each pair of signals is taken from `direction` argument:
* True, True, `Direction.LongOnly` -> True, True, False, False
* True, True, `Direction.ShortOnly` -> False, False, True, True
* True, True, `Direction.Both` -> True, False, True, False
Best to use when the direction doesn't change throughout time.
Prior to returning the signals, calls user-defined `adjust_func_nb`, which can be used to adjust
stop values in the context. Must accept `vectorbtpro.portfolio.enums.SignalContext` and `*adjust_args`,
and return nothing."""
adjust_func_nb(c, *adjust_args)
_i = c.i - abs(flex_select_nb(from_ago, c.i, c.col))
if _i < 0:
return False, False, False, False
is_entry = flex_select_nb(entries, _i, c.col)
is_exit = flex_select_nb(exits, _i, c.col)
_direction = flex_select_nb(direction, _i, c.col)
if _direction == Direction.LongOnly:
return is_entry, is_exit, False, False
if _direction == Direction.ShortOnly:
return False, False, is_entry, is_exit
return is_entry, False, is_exit, False
# %
# %
# % blocks["adjust_func_nb"]
@register_jitted
def ls_signal_func_nb( # % line.replace("ls_signal_func_nb", "signal_func_nb")
c: SignalContext,
long_entries: tp.FlexArray2d,
long_exits: tp.FlexArray2d,
short_entries: tp.FlexArray2d,
short_exits: tp.FlexArray2d,
from_ago: tp.FlexArray2d,
adjust_func_nb: AdjustFuncT = no_adjust_func_nb, # % None
adjust_args: tp.Args = (),
) -> tp.Tuple[bool, bool, bool, bool]:
"""`signal_func_nb` that gets an element of direction-aware signals.
The direction is already built into the arrays. Best to use when the direction changes frequently
(for example, if you have one indicator providing long signals and one providing short signals).
Prior to returning the signals, calls user-defined `adjust_func_nb`, which can be used to adjust
stop values in the context. Must accept `vectorbtpro.portfolio.enums.SignalContext` and `*adjust_args`,
and return nothing."""
adjust_func_nb(c, *adjust_args)
_i = c.i - abs(flex_select_nb(from_ago, c.i, c.col))
if _i < 0:
return False, False, False, False
is_long_entry = flex_select_nb(long_entries, _i, c.col)
is_long_exit = flex_select_nb(long_exits, _i, c.col)
is_short_entry = flex_select_nb(short_entries, _i, c.col)
is_short_exit = flex_select_nb(short_exits, _i, c.col)
return is_long_entry, is_long_exit, is_short_entry, is_short_exit
# %
# %
# % blocks["adjust_func_nb"]
@register_jitted
def order_signal_func_nb( # % line.replace("order_signal_func_nb", "signal_func_nb")
c: SignalContext,
size: tp.FlexArray2d,
price: tp.FlexArray2d,
size_type: tp.FlexArray2d,
direction: tp.FlexArray2d,
min_size: tp.FlexArray2d,
max_size: tp.FlexArray2d,
val_price: tp.FlexArray2d,
from_ago: tp.FlexArray2d,
adjust_func_nb: AdjustFuncT = no_adjust_func_nb, # % None
adjust_args: tp.Args = (),
) -> tp.Tuple[bool, bool, bool, bool]:
"""`signal_func_nb` that converts orders into direction-aware signals.
You must ensure that `size`, `size_type`, `min_size`, and `max_size` are writeable non-flexible arrays
and accumulation is enabled."""
adjust_func_nb(c, *adjust_args)
_i = c.i - abs(flex_select_nb(from_ago, c.i, c.col))
if _i < 0:
return False, False, False, False
order_size = float(flex_select_nb(size, _i, c.col))
if np.isnan(order_size):
return False, False, False, False
order_size_type = int(flex_select_nb(size_type, _i, c.col))
order_direction = int(flex_select_nb(direction, _i, c.col))
min_order_size = float(flex_select_nb(min_size, _i, c.col))
max_order_size = float(flex_select_nb(max_size, _i, c.col))
order_size = get_diraware_size_nb(order_size, order_direction)
if (
order_size_type == SizeType.TargetAmount
or order_size_type == SizeType.TargetValue
or order_size_type == SizeType.TargetPercent
or order_size_type == SizeType.TargetPercent100
):
order_price = flex_select_nb(price, _i, c.col)
order_val_price = flex_select_nb(val_price, c.i, c.col)
if np.isinf(order_val_price) and order_val_price > 0:
if np.isinf(order_price) and order_price > 0:
order_val_price = flex_select_nb(c.close, c.i, c.col)
elif np.isinf(order_price) and order_price < 0:
order_val_price = flex_select_nb(c.open, c.i, c.col)
else:
order_val_price = order_price
elif np.isnan(order_val_price) or (np.isinf(order_val_price) and order_val_price < 0):
order_val_price = c.last_val_price[c.col]
order_size, _ = resolve_size_nb(
size=order_size,
size_type=order_size_type,
position=c.last_position[c.col],
val_price=order_val_price,
value=c.last_value[c.group] if c.cash_sharing else c.last_value[c.col],
)
if not np.isnan(min_order_size):
min_order_size, _ = resolve_size_nb(
size=min_order_size,
size_type=order_size_type,
position=c.last_position[c.col],
val_price=order_val_price,
value=c.last_value[c.group] if c.cash_sharing else c.last_value[c.col],
as_requirement=True,
)
if not np.isnan(max_order_size):
max_order_size, _ = resolve_size_nb(
size=max_order_size,
size_type=order_size_type,
position=c.last_position[c.col],
val_price=order_val_price,
value=c.last_value[c.group] if c.cash_sharing else c.last_value[c.col],
as_requirement=True,
)
order_size_type = SizeType.Amount
size[_i, c.col] = abs(order_size)
size_type[_i, c.col] = order_size_type
min_size[_i, c.col] = min_order_size
max_size[_i, c.col] = max_order_size
else:
size[_i, c.col] = abs(order_size)
if order_size > 0:
if c.last_position[c.col] < 0 and order_direction == Direction.ShortOnly:
return False, False, False, True
return True, False, False, False
if order_size < 0:
if c.last_position[c.col] > 0 and order_direction == Direction.LongOnly:
return False, True, False, False
return False, False, True, False
return False, False, False, False
# %
# %
# % blocks["post_segment_func_nb"]
@register_jitted
def save_post_segment_func_nb( # % line.replace("save_post_segment_func_nb", "post_segment_func_nb")
c: SignalSegmentContext,
save_state: bool = True,
save_value: bool = True,
save_returns: bool = True,
) -> None:
"""`post_segment_func_nb` that saves state, value, and returns."""
if save_state:
for col in range(c.from_col, c.to_col):
c.in_outputs.position[c.i, col] = c.last_position[col]
c.in_outputs.debt[c.i, col] = c.last_debt[col]
c.in_outputs.locked_cash[c.i, col] = c.last_locked_cash[col]
if c.cash_sharing:
c.in_outputs.cash[c.i, c.group] = c.last_cash[c.group]
c.in_outputs.free_cash[c.i, c.group] = c.last_free_cash[c.group]
else:
for col in range(c.from_col, c.to_col):
c.in_outputs.cash[c.i, col] = c.last_cash[col]
c.in_outputs.free_cash[c.i, col] = c.last_free_cash[col]
if save_value:
if c.cash_sharing:
c.in_outputs.value[c.i, c.group] = c.last_value[c.group]
else:
for col in range(c.from_col, c.to_col):
c.in_outputs.value[c.i, col] = c.last_value[col]
if save_returns:
if c.cash_sharing:
c.in_outputs.returns[c.i, c.group] = c.last_return[c.group]
else:
for col in range(c.from_col, c.to_col):
c.in_outputs.returns[c.i, col] = c.last_return[col]
# %
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for iterative portfolio simulation."""
from vectorbtpro import _typing as tp
from vectorbtpro.base.flex_indexing import flex_select_nb
from vectorbtpro.generic.nb.iter_ import (
iter_above_nb as _iter_above_nb,
iter_below_nb as _iter_below_nb,
iter_crossed_above_nb as _iter_crossed_above_nb,
iter_crossed_below_nb as _iter_crossed_below_nb,
)
from vectorbtpro.registries.jit_registry import register_jitted
@register_jitted
def select_nb(
c: tp.NamedTuple,
arr: tp.FlexArray2d,
i: tp.Optional[int] = None,
col: tp.Optional[int] = None,
) -> tp.Scalar:
"""Get the current element using flexible indexing.
If any of the arguments are None, will use the respective value from the context."""
if i is None:
_i = c.i
else:
_i = i
if col is None:
_col = c.col
else:
_col = col
return flex_select_nb(arr, _i, _col)
@register_jitted
def select_from_col_nb(
c: tp.NamedTuple,
col: int,
arr: tp.FlexArray2d,
i: tp.Optional[int] = None,
) -> tp.Scalar:
"""Get the current element from a specific column using flexible indexing.
If any of the arguments are None, will use the respective value from the context."""
if i is None:
_i = c.i
else:
_i = i
return flex_select_nb(arr, _i, col)
@register_jitted
def iter_above_nb(
c: tp.NamedTuple,
arr1: tp.FlexArray2d,
arr2: tp.FlexArray2d,
i: tp.Optional[int] = None,
col: tp.Optional[int] = None,
) -> bool:
"""Call `vectorbtpro.generic.nb.iter_.iter_above_nb` on the context.
If any of the arguments are None, will use the respective value from the context."""
if i is None:
_i = c.i
else:
_i = i
if col is None:
_col = c.col
else:
_col = col
return _iter_above_nb(arr1, arr2, _i, _col)
@register_jitted
def iter_below_nb(
c: tp.NamedTuple,
arr1: tp.FlexArray2d,
arr2: tp.FlexArray2d,
i: tp.Optional[int] = None,
col: tp.Optional[int] = None,
) -> bool:
"""Call `vectorbtpro.generic.nb.iter_.iter_below_nb` on the context.
If any of the arguments are None, will use the respective value from the context."""
if i is None:
_i = c.i
else:
_i = i
if col is None:
_col = c.col
else:
_col = col
return _iter_below_nb(arr1, arr2, _i, _col)
@register_jitted
def iter_crossed_above_nb(
c: tp.NamedTuple,
arr1: tp.FlexArray2d,
arr2: tp.FlexArray2d,
i: tp.Optional[int] = None,
col: tp.Optional[int] = None,
) -> bool:
"""Call `vectorbtpro.generic.nb.iter_.iter_crossed_above_nb` on the context.
If any of the arguments are None, will use the respective value from the context."""
if i is None:
_i = c.i
else:
_i = i
if col is None:
_col = c.col
else:
_col = col
return _iter_crossed_above_nb(arr1, arr2, _i, _col)
@register_jitted
def iter_crossed_below_nb(
c: tp.NamedTuple,
arr1: tp.FlexArray2d,
arr2: tp.FlexArray2d,
i: tp.Optional[int] = None,
col: tp.Optional[int] = None,
) -> bool:
"""Call `vectorbtpro.generic.nb.iter_.iter_crossed_below_nb` on the context.
If any of the arguments are None, will use the respective value from the context."""
if i is None:
_i = c.i
else:
_i = i
if col is None:
_col = c.col
else:
_col = col
return _iter_crossed_below_nb(arr1, arr2, _i, _col)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for portfolio records."""
from numba import prange
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.reshaping import to_1d_array_nb, to_2d_array_nb
from vectorbtpro.portfolio.nb.core import *
from vectorbtpro.records import chunking as records_ch
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.math_ import is_close_nb, is_close_or_less_nb, is_less_nb, add_nb
from vectorbtpro.utils.template import Rep
invalid_size_msg = "Encountered an order with size 0 or less"
invalid_price_msg = "Encountered an order with price less than 0"
@register_jitted(cache=True)
def records_within_sim_range_nb(
target_shape: tp.Shape,
records: tp.RecordArray,
col_arr: tp.Array1d,
idx_arr: tp.Array1d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.RecordArray:
"""Return records within simulation range."""
out = np.empty(len(records), dtype=records.dtype)
k = 0
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=target_shape,
sim_start=sim_start,
sim_end=sim_end,
)
for r in range(len(records)):
_sim_start = sim_start_[col_arr[r]]
_sim_end = sim_end_[col_arr[r]]
if _sim_start >= _sim_end:
continue
if _sim_start <= idx_arr[r] < _sim_end:
out[k] = records[r]
k += 1
return out[:k]
@register_jitted(cache=True)
def apply_weights_to_orders_nb(
order_records: tp.RecordArray,
col_arr: tp.Array1d,
weights: tp.Array1d,
) -> tp.RecordArray:
"""Apply weights to order records."""
order_records = order_records.copy()
out = np.empty(len(order_records), dtype=order_records.dtype)
k = 0
for r in range(len(order_records)):
order_record = order_records[r]
col = col_arr[r]
if not np.isnan(weights[col]):
order_record["size"] = weights[col] * order_record["size"]
order_record["fees"] = weights[col] * order_record["fees"]
if order_record["size"] != 0:
out[k] = order_record
k += 1
else:
out[k] = order_record
k += 1
return out[:k]
@register_jitted(cache=True)
def weighted_price_reduce_meta_nb(
idxs: tp.Array1d,
col: int,
size_arr: tp.Array1d,
price_arr: tp.Array1d,
) -> float:
"""Size-weighted price average."""
if len(idxs) == 0:
return np.nan
size_price_sum = 0.0
size_sum = 0.0
for i in range(len(idxs)):
j = idxs[i]
if not np.isnan(size_arr[j]) and not np.isnan(price_arr[j]):
size_price_sum += size_arr[j] * price_arr[j]
size_sum += size_arr[j]
if size_sum == 0:
return np.nan
return size_price_sum / size_sum
@register_jitted(cache=True)
def fill_trade_record_nb(
new_records: tp.RecordArray,
r: int,
col: int,
size: float,
entry_order_id: int,
entry_idx: int,
entry_price: float,
entry_fees: float,
exit_order_id: int,
exit_idx: int,
exit_price: float,
exit_fees: float,
direction: int,
status: int,
parent_id: int,
) -> None:
"""Fill a trade record."""
# Calculate PnL and return
pnl, ret = get_trade_stats_nb(size, entry_price, entry_fees, exit_price, exit_fees, direction)
# Save trade
new_records["id"][r] = r
new_records["col"][r] = col
new_records["size"][r] = size
new_records["entry_order_id"][r] = entry_order_id
new_records["entry_idx"][r] = entry_idx
new_records["entry_price"][r] = entry_price
new_records["entry_fees"][r] = entry_fees
new_records["exit_order_id"][r] = exit_order_id
new_records["exit_idx"][r] = exit_idx
new_records["exit_price"][r] = exit_price
new_records["exit_fees"][r] = exit_fees
new_records["pnl"][r] = pnl
new_records["return"][r] = ret
new_records["direction"][r] = direction
new_records["status"][r] = status
new_records["parent_id"][r] = parent_id
new_records["parent_id"][r] = parent_id
@register_jitted(cache=True)
def fill_entry_trades_in_position_nb(
order_records: tp.RecordArray,
col_map: tp.GroupMap,
col: int,
sim_start: int,
sim_end: int,
first_c: int,
last_c: int,
init_price: float,
first_entry_size: float,
first_entry_fees: float,
exit_idx: int,
exit_size_sum: float,
exit_gross_sum: float,
exit_fees_sum: float,
direction: int,
status: int,
parent_id: int,
new_records: tp.RecordArray,
r: int,
) -> int:
"""Fill entry trades located within a single position.
Returns the next trade id."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
# Iterate over orders located within a single position
for c in range(first_c, last_c + 1):
if c == -1:
entry_order_id = -1
entry_idx = -1
entry_size = first_entry_size
entry_price = init_price
entry_fees = first_entry_fees
else:
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < sim_start or order_record["idx"] >= sim_end:
continue
entry_order_id = order_record["id"]
entry_idx = order_record["idx"]
entry_price = order_record["price"]
order_side = order_record["side"]
# Ignore exit orders
if (direction == TradeDirection.Long and order_side == OrderSide.Sell) or (
direction == TradeDirection.Short and order_side == OrderSide.Buy
):
continue
if c == first_c:
entry_size = first_entry_size
entry_fees = first_entry_fees
else:
entry_size = order_record["size"]
entry_fees = order_record["fees"]
# Take a size-weighted average of exit price
exit_price = exit_gross_sum / exit_size_sum
# Take a fraction of exit fees
size_fraction = entry_size / exit_size_sum
exit_fees = size_fraction * exit_fees_sum
# Fill the record
if status == TradeStatus.Closed:
exit_order_record = order_records[col_idxs[col_start_idxs[col] + last_c]]
if exit_order_record["idx"] < sim_start or exit_order_record["idx"] >= sim_end:
exit_order_id = -1
else:
exit_order_id = exit_order_record["id"]
else:
exit_order_id = -1
fill_trade_record_nb(
new_records,
r,
col,
entry_size,
entry_order_id,
entry_idx,
entry_price,
entry_fees,
exit_order_id,
exit_idx,
exit_price,
exit_fees,
direction,
status,
parent_id,
)
r += 1
return r
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
close=base_ch.FlexArraySlicer(axis=1),
col_map=base_ch.GroupMapSlicer(),
init_position=base_ch.FlexArraySlicer(),
init_price=base_ch.FlexArraySlicer(),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_entry_trades_nb(
order_records: tp.RecordArray,
close: tp.FlexArray2dLike,
col_map: tp.GroupMap,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.RecordArray:
"""Fill entry trade records by aggregating order records.
Entry trade records are buy orders in a long position and sell orders in a short position.
Usage:
```pycon
>>> from vectorbtpro import *
>>> close = order_price = np.array([
... [1, 6],
... [2, 5],
... [3, 4],
... [4, 3],
... [5, 2],
... [6, 1]
... ])
>>> size = np.array([
... [1, -1],
... [0.1, -0.1],
... [-1, 1],
... [-0.1, 0.1],
... [1, -1],
... [-2, 2]
... ])
>>> target_shape = close.shape
>>> group_lens = np.full(target_shape[1], 1)
>>> init_cash = np.full(target_shape[1], 100)
>>> sim_out = vbt.pf_nb.from_orders_nb(
... target_shape,
... group_lens,
... init_cash=init_cash,
... size=size,
... price=close,
... fees=np.asarray([[0.01]]),
... slippage=np.asarray([[0.01]])
... )
>>> col_map = vbt.rec_nb.col_map_nb(sim_out.order_records['col'], target_shape[1])
>>> entry_trade_records = vbt.pf_nb.get_entry_trades_nb(sim_out.order_records, close, col_map)
>>> pd.DataFrame.from_records(entry_trade_records)
id col size entry_order_id entry_idx entry_price entry_fees \\
0 0 0 1.0 0 0 1.01 0.01010
1 1 0 0.1 1 1 2.02 0.00202
2 2 0 1.0 4 4 5.05 0.05050
3 3 0 1.0 5 5 5.94 0.05940
4 0 1 1.0 0 0 5.94 0.05940
5 1 1 0.1 1 1 4.95 0.00495
6 2 1 1.0 4 4 1.98 0.01980
7 3 1 1.0 5 5 1.01 0.01010
exit_order_id exit_idx exit_price exit_fees pnl return \\
0 3 3 3.060000 0.030600 2.009300 1.989406
1 3 3 3.060000 0.003060 0.098920 0.489703
2 5 5 5.940000 0.059400 0.780100 0.154475
3 -1 5 6.000000 0.000000 -0.119400 -0.020101
4 3 3 3.948182 0.039482 1.892936 0.318676
5 3 3 3.948182 0.003948 0.091284 0.184411
6 5 5 1.010000 0.010100 0.940100 0.474798
7 -1 5 1.000000 0.000000 -0.020100 -0.019901
direction status parent_id
0 0 1 0
1 0 1 0
2 0 1 1
3 1 0 2
4 1 1 0
5 1 1 0
6 1 1 1
7 0 0 2
```
"""
close_ = to_2d_array_nb(np.asarray(close))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
max_records = np.max(col_lens) + 1
new_records = np.empty((max_records, len(col_lens)), dtype=trade_dt)
counts = np.full(len(col_lens), 0, dtype=int_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(close_.shape[0], col_lens.shape[0]),
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(col_lens.shape[0]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_init_position = float(flex_select_1d_pc_nb(init_position_, col))
_init_price = float(flex_select_1d_pc_nb(init_price_, col))
if _init_position != 0:
# Prepare initial position
first_c = -1
in_position = True
parent_id = 0
if _init_position >= 0:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = abs(_init_position)
entry_gross_sum = abs(_init_position) * _init_price
entry_fees_sum = 0.0
exit_size_sum = 0.0
exit_gross_sum = 0.0
exit_fees_sum = 0.0
first_entry_size = _init_position
first_entry_fees = 0.0
else:
in_position = False
parent_id = -1
col_len = col_lens[col]
if col_len == 0 and not in_position:
continue
last_id = -1
for c in range(col_len):
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end:
continue
if order_record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = order_record["id"]
order_idx = order_record["idx"]
order_size = order_record["size"]
order_price = order_record["price"]
order_fees = order_record["fees"]
order_side = order_record["side"]
if order_size <= 0.0:
raise ValueError(invalid_size_msg)
if order_price < 0.0:
raise ValueError(invalid_price_msg)
if not in_position:
# New position opened
first_c = c
in_position = True
parent_id += 1
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = 0.0
entry_gross_sum = 0.0
entry_fees_sum = 0.0
exit_size_sum = 0.0
exit_gross_sum = 0.0
exit_fees_sum = 0.0
first_entry_size = order_size
first_entry_fees = order_fees
if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or (
direction == TradeDirection.Short and order_side == OrderSide.Sell
):
# Position increased
entry_size_sum += order_size
entry_gross_sum += order_size * order_price
entry_fees_sum += order_fees
elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or (
direction == TradeDirection.Short and order_side == OrderSide.Buy
):
if is_close_nb(exit_size_sum + order_size, entry_size_sum):
# Position closed
last_c = c
in_position = False
exit_size_sum = entry_size_sum
exit_gross_sum += order_size * order_price
exit_fees_sum += order_fees
# Fill trade records
counts[col] = fill_entry_trades_in_position_nb(
order_records,
col_map,
col,
_sim_start,
_sim_end,
first_c,
last_c,
_init_price,
first_entry_size,
first_entry_fees,
order_idx,
exit_size_sum,
exit_gross_sum,
exit_fees_sum,
direction,
TradeStatus.Closed,
parent_id,
new_records[:, col],
counts[col],
)
elif is_less_nb(exit_size_sum + order_size, entry_size_sum):
# Position decreased
exit_size_sum += order_size
exit_gross_sum += order_size * order_price
exit_fees_sum += order_fees
else:
# Position closed
last_c = c
remaining_size = add_nb(entry_size_sum, -exit_size_sum)
exit_size_sum = entry_size_sum
exit_gross_sum += remaining_size * order_price
exit_fees_sum += remaining_size / order_size * order_fees
# Fill trade records
counts[col] = fill_entry_trades_in_position_nb(
order_records,
col_map,
col,
_sim_start,
_sim_end,
first_c,
last_c,
_init_price,
first_entry_size,
first_entry_fees,
order_idx,
exit_size_sum,
exit_gross_sum,
exit_fees_sum,
direction,
TradeStatus.Closed,
parent_id,
new_records[:, col],
counts[col],
)
# New position opened
first_c = c
parent_id += 1
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = add_nb(order_size, -remaining_size)
entry_gross_sum = entry_size_sum * order_price
entry_fees_sum = entry_size_sum / order_size * order_fees
first_entry_size = entry_size_sum
first_entry_fees = entry_fees_sum
exit_size_sum = 0.0
exit_gross_sum = 0.0
exit_fees_sum = 0.0
if in_position and is_less_nb(exit_size_sum, entry_size_sum):
# Position hasn't been closed
last_c = col_len - 1
remaining_size = add_nb(entry_size_sum, -exit_size_sum)
exit_size_sum = entry_size_sum
last_close = flex_select_nb(close_, _sim_end - 1, col)
if np.isnan(last_close):
for ri in range(_sim_end - 1, -1, -1):
_close = flex_select_nb(close_, ri, col)
if not np.isnan(_close):
last_close = _close
break
exit_gross_sum += remaining_size * last_close
exit_idx = _sim_end - 1
# Fill trade records
counts[col] = fill_entry_trades_in_position_nb(
order_records,
col_map,
col,
_sim_start,
_sim_end,
first_c,
last_c,
_init_price,
first_entry_size,
first_entry_fees,
exit_idx,
exit_size_sum,
exit_gross_sum,
exit_fees_sum,
direction,
TradeStatus.Open,
parent_id,
new_records[:, col],
counts[col],
)
return generic_nb.repartition_nb(new_records, counts)
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
close=base_ch.FlexArraySlicer(axis=1),
col_map=base_ch.GroupMapSlicer(),
init_position=base_ch.FlexArraySlicer(),
init_price=base_ch.FlexArraySlicer(),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_exit_trades_nb(
order_records: tp.RecordArray,
close: tp.FlexArray2dLike,
col_map: tp.GroupMap,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.RecordArray:
"""Fill exit trade records by aggregating order records.
Exit trade records are sell orders in a long position and buy orders in a short position.
Usage:
* Building upon the example in `get_exit_trades_nb`:
```pycon
>>> exit_trade_records = vbt.pf_nb.get_exit_trades_nb(sim_out.order_records, close, col_map)
>>> pd.DataFrame.from_records(exit_trade_records)
id col size entry_order_id entry_idx entry_price entry_fees \\
0 0 0 1.0 0 0 1.101818 0.011018
1 1 0 0.1 0 0 1.101818 0.001102
2 2 0 1.0 4 4 5.050000 0.050500
3 3 0 1.0 5 5 5.940000 0.059400
4 0 1 1.0 0 0 5.850000 0.058500
5 1 1 0.1 0 0 5.850000 0.005850
6 2 1 1.0 4 4 1.980000 0.019800
7 3 1 1.0 5 5 1.010000 0.010100
exit_order_id exit_idx exit_price exit_fees pnl return \\
0 2 2 2.97 0.02970 1.827464 1.658589
1 3 3 3.96 0.00396 0.280756 2.548119
2 5 5 5.94 0.05940 0.780100 0.154475
3 -1 5 6.00 0.00000 -0.119400 -0.020101
4 2 2 4.04 0.04040 1.711100 0.292496
5 3 3 3.03 0.00303 0.273120 0.466872
6 5 5 1.01 0.01010 0.940100 0.474798
7 -1 5 1.00 0.00000 -0.020100 -0.019901
direction status parent_id
0 0 1 0
1 0 1 0
2 0 1 1
3 1 0 2
4 1 1 0
5 1 1 0
6 1 1 1
7 0 0 2
```
"""
close_ = to_2d_array_nb(np.asarray(close))
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
max_records = np.max(col_lens) + 1
new_records = np.empty((max_records, len(col_lens)), dtype=trade_dt)
counts = np.full(len(col_lens), 0, dtype=int_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(close_.shape[0], col_lens.shape[0]),
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(col_lens.shape[0]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_init_position = float(flex_select_1d_pc_nb(init_position_, col))
_init_price = float(flex_select_1d_pc_nb(init_price_, col))
if _init_position != 0:
# Prepare initial position
in_position = True
parent_id = 0
entry_order_id = -1
entry_idx = -1
if _init_position >= 0:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = abs(_init_position)
entry_gross_sum = abs(_init_position) * _init_price
entry_fees_sum = 0.0
else:
in_position = False
parent_id = -1
col_len = col_lens[col]
if col_len == 0 and not in_position:
continue
last_id = -1
for c in range(col_len):
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end:
continue
if order_record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = order_record["id"]
order_idx = order_record["idx"]
order_id = order_record["id"]
order_size = order_record["size"]
order_price = order_record["price"]
order_fees = order_record["fees"]
order_side = order_record["side"]
if order_size <= 0.0:
raise ValueError(invalid_size_msg)
if order_price < 0.0:
raise ValueError(invalid_price_msg)
if not in_position:
# Trade opened
in_position = True
entry_order_id = order_id
entry_idx = order_idx
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
parent_id += 1
entry_size_sum = 0.0
entry_gross_sum = 0.0
entry_fees_sum = 0.0
if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or (
direction == TradeDirection.Short and order_side == OrderSide.Sell
):
# Position increased
entry_size_sum += order_size
entry_gross_sum += order_size * order_price
entry_fees_sum += order_fees
elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or (
direction == TradeDirection.Short and order_side == OrderSide.Buy
):
if is_close_or_less_nb(order_size, entry_size_sum):
# Trade closed
if is_close_nb(order_size, entry_size_sum):
exit_size = entry_size_sum
else:
exit_size = order_size
exit_price = order_price
exit_fees = order_fees
exit_order_id = order_id
exit_idx = order_idx
# Take a size-weighted average of entry price
entry_price = entry_gross_sum / entry_size_sum
# Take a fraction of entry fees
size_fraction = exit_size / entry_size_sum
entry_fees = size_fraction * entry_fees_sum
fill_trade_record_nb(
new_records[:, col],
counts[col],
col,
exit_size,
entry_order_id,
entry_idx,
entry_price,
entry_fees,
exit_order_id,
exit_idx,
exit_price,
exit_fees,
direction,
TradeStatus.Closed,
parent_id,
)
counts[col] += 1
if is_close_nb(order_size, entry_size_sum):
# Position closed
entry_order_id = -1
entry_idx = -1
direction = -1
in_position = False
else:
# Position decreased, previous orders have now less impact
size_fraction = (entry_size_sum - order_size) / entry_size_sum
entry_size_sum *= size_fraction
entry_gross_sum *= size_fraction
entry_fees_sum *= size_fraction
else:
# Trade reversed
# Close current trade
cl_exit_size = entry_size_sum
cl_exit_price = order_price
cl_exit_fees = cl_exit_size / order_size * order_fees
cl_exit_order_id = order_id
cl_exit_idx = order_idx
# Take a size-weighted average of entry price
entry_price = entry_gross_sum / entry_size_sum
# Take a fraction of entry fees
size_fraction = cl_exit_size / entry_size_sum
entry_fees = size_fraction * entry_fees_sum
fill_trade_record_nb(
new_records[:, col],
counts[col],
col,
cl_exit_size,
entry_order_id,
entry_idx,
entry_price,
entry_fees,
cl_exit_order_id,
cl_exit_idx,
cl_exit_price,
cl_exit_fees,
direction,
TradeStatus.Closed,
parent_id,
)
counts[col] += 1
# Open a new trade
entry_size_sum = order_size - cl_exit_size
entry_gross_sum = entry_size_sum * order_price
entry_fees_sum = order_fees - cl_exit_fees
entry_order_id = order_id
entry_idx = order_idx
if direction == TradeDirection.Long:
direction = TradeDirection.Short
else:
direction = TradeDirection.Long
parent_id += 1
if in_position and is_less_nb(-entry_size_sum, 0):
# Trade hasn't been closed
exit_size = entry_size_sum
last_close = flex_select_nb(close_, _sim_end - 1, col)
if np.isnan(last_close):
for ri in range(_sim_end - 1, -1, -1):
_close = flex_select_nb(close_, ri, col)
if not np.isnan(_close):
last_close = _close
break
exit_price = last_close
exit_fees = 0.0
exit_order_id = -1
exit_idx = _sim_end - 1
# Take a size-weighted average of entry price
entry_price = entry_gross_sum / entry_size_sum
# Take a fraction of entry fees
size_fraction = exit_size / entry_size_sum
entry_fees = size_fraction * entry_fees_sum
fill_trade_record_nb(
new_records[:, col],
counts[col],
col,
exit_size,
entry_order_id,
entry_idx,
entry_price,
entry_fees,
exit_order_id,
exit_idx,
exit_price,
exit_fees,
direction,
TradeStatus.Open,
parent_id,
)
counts[col] += 1
return generic_nb.repartition_nb(new_records, counts)
@register_jitted(cache=True)
def fill_position_record_nb(new_records: tp.RecordArray, r: int, trade_records: tp.RecordArray) -> None:
"""Fill a position record by aggregating trade records."""
# Aggregate trades
col = trade_records["col"][0]
size = np.sum(trade_records["size"])
entry_order_id = trade_records["entry_order_id"][0]
entry_idx = trade_records["entry_idx"][0]
entry_price = np.sum(trade_records["size"] * trade_records["entry_price"]) / size
entry_fees = np.sum(trade_records["entry_fees"])
exit_order_id = trade_records["exit_order_id"][-1]
exit_idx = trade_records["exit_idx"][-1]
exit_price = np.sum(trade_records["size"] * trade_records["exit_price"]) / size
exit_fees = np.sum(trade_records["exit_fees"])
direction = trade_records["direction"][-1]
status = trade_records["status"][-1]
pnl, ret = get_trade_stats_nb(size, entry_price, entry_fees, exit_price, exit_fees, direction)
# Save position
new_records["id"][r] = r
new_records["col"][r] = col
new_records["size"][r] = size
new_records["entry_order_id"][r] = entry_order_id
new_records["entry_idx"][r] = entry_idx
new_records["entry_price"][r] = entry_price
new_records["entry_fees"][r] = entry_fees
new_records["exit_order_id"][r] = exit_order_id
new_records["exit_idx"][r] = exit_idx
new_records["exit_price"][r] = exit_price
new_records["exit_fees"][r] = exit_fees
new_records["pnl"][r] = pnl
new_records["return"][r] = ret
new_records["direction"][r] = direction
new_records["status"][r] = status
new_records["parent_id"][r] = r
@register_jitted(cache=True)
def copy_trade_record_nb(new_records: tp.RecordArray, r: int, trade_record: tp.Record) -> None:
"""Copy a trade record."""
new_records["id"][r] = r
new_records["col"][r] = trade_record["col"]
new_records["size"][r] = trade_record["size"]
new_records["entry_order_id"][r] = trade_record["entry_order_id"]
new_records["entry_idx"][r] = trade_record["entry_idx"]
new_records["entry_price"][r] = trade_record["entry_price"]
new_records["entry_fees"][r] = trade_record["entry_fees"]
new_records["exit_order_id"][r] = trade_record["exit_order_id"]
new_records["exit_idx"][r] = trade_record["exit_idx"]
new_records["exit_price"][r] = trade_record["exit_price"]
new_records["exit_fees"][r] = trade_record["exit_fees"]
new_records["pnl"][r] = trade_record["pnl"]
new_records["return"][r] = trade_record["return"]
new_records["direction"][r] = trade_record["direction"]
new_records["status"][r] = trade_record["status"]
new_records["parent_id"][r] = r
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
trade_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_positions_nb(trade_records: tp.RecordArray, col_map: tp.GroupMap) -> tp.RecordArray:
"""Fill position records by aggregating trade records.
Trades can be entry trades, exit trades, and even positions themselves - all will produce the same results.
Usage:
* Building upon the example in `get_exit_trades_nb`:
```pycon
>>> col_map = vbt.rec_nb.col_map_nb(exit_trade_records['col'], target_shape[1])
>>> position_records = vbt.pf_nb.get_positions_nb(exit_trade_records, col_map)
>>> pd.DataFrame.from_records(position_records)
id col size entry_order_id entry_idx entry_price entry_fees \\
0 0 0 1.1 0 0 1.101818 0.01212
1 1 0 1.0 4 4 5.050000 0.05050
2 2 0 1.0 5 5 5.940000 0.05940
3 0 1 1.1 0 0 5.850000 0.06435
4 1 1 1.0 4 4 1.980000 0.01980
5 2 1 1.0 5 5 1.010000 0.01010
exit_order_id exit_idx exit_price exit_fees pnl return \\
0 3 3 3.060000 0.03366 2.10822 1.739455
1 5 5 5.940000 0.05940 0.78010 0.154475
2 -1 5 6.000000 0.00000 -0.11940 -0.020101
3 3 3 3.948182 0.04343 1.98422 0.308348
4 5 5 1.010000 0.01010 0.94010 0.474798
5 -1 5 1.000000 0.00000 -0.02010 -0.019901
direction status parent_id
0 0 1 0
1 0 1 1
2 1 0 2
3 1 1 0
4 1 1 1
5 0 0 2
```
"""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
new_records = np.empty((np.max(col_lens), len(col_lens)), dtype=trade_dt)
counts = np.full(len(col_lens), 0, dtype=int_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
last_id = -1
last_position_id = -1
from_trade_r = -1
for c in range(col_len):
trade_r = col_idxs[col_start_idxs[col] + c]
record = trade_records[trade_r]
if record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = record["id"]
parent_id = record["parent_id"]
if parent_id != last_position_id:
if last_position_id != -1:
if trade_r - from_trade_r > 1:
fill_position_record_nb(new_records[:, col], counts[col], trade_records[from_trade_r:trade_r])
else:
# Speed up
copy_trade_record_nb(new_records[:, col], counts[col], trade_records[from_trade_r])
counts[col] += 1
from_trade_r = trade_r
last_position_id = parent_id
if trade_r - from_trade_r > 0:
fill_position_record_nb(new_records[:, col], counts[col], trade_records[from_trade_r : trade_r + 1])
else:
# Speed up
copy_trade_record_nb(new_records[:, col], counts[col], trade_records[from_trade_r])
counts[col] += 1
return generic_nb.repartition_nb(new_records, counts)
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
close=ch.ArraySlicer(axis=1),
col_map=base_ch.GroupMapSlicer(),
init_position=base_ch.FlexArraySlicer(),
init_price=base_ch.FlexArraySlicer(),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_long_view_orders_nb(
order_records: tp.RecordArray,
close: tp.Array2d,
col_map: tp.GroupMap,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.RecordArray:
"""Get view of orders in long positions only."""
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
order_records = order_records.copy()
out = np.empty(order_records.shape, dtype=order_records.dtype)
r = 0
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(close.shape[0], col_lens.shape[0]),
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(col_lens.shape[0]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_init_position = float(flex_select_1d_pc_nb(init_position_, col))
_init_price = float(flex_select_1d_pc_nb(init_price_, col))
if _init_position != 0:
# Prepare initial position
in_position = True
if _init_position >= 0:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = abs(_init_position)
entry_gross_sum = abs(_init_position) * _init_price
exit_size_sum = 0.0
exit_gross_sum = 0.0
else:
in_position = False
col_len = col_lens[col]
if col_len == 0 and not in_position:
continue
last_id = -1
for c in range(col_len):
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end:
continue
if order_record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = order_record["id"]
order_size = order_record["size"]
order_price = order_record["price"]
order_side = order_record["side"]
order_fees = order_record["fees"]
if order_size <= 0.0:
raise ValueError(invalid_size_msg)
if order_price < 0.0:
raise ValueError(invalid_price_msg)
if not in_position:
# New position opened
in_position = True
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = 0.0
entry_gross_sum = 0.0
exit_size_sum = 0.0
exit_gross_sum = 0.0
if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or (
direction == TradeDirection.Short and order_side == OrderSide.Sell
):
# Position increased
entry_size_sum += order_size
entry_gross_sum += order_size * order_price
if direction == TradeDirection.Long:
out[r] = order_record
r += 1
elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or (
direction == TradeDirection.Short and order_side == OrderSide.Buy
):
if is_close_nb(exit_size_sum + order_size, entry_size_sum):
# Position closed
in_position = False
exit_size_sum = entry_size_sum
exit_gross_sum += order_size * order_price
if direction == TradeDirection.Long:
out[r] = order_record
r += 1
elif is_less_nb(exit_size_sum + order_size, entry_size_sum):
# Position decreased
exit_size_sum += order_size
exit_gross_sum += order_size * order_price
if direction == TradeDirection.Long:
out[r] = order_record
r += 1
else:
# Position closed
remaining_size = add_nb(entry_size_sum, -exit_size_sum)
# New position opened
entry_size_sum = add_nb(order_size, -remaining_size)
entry_gross_sum = entry_size_sum * order_price
exit_size_sum = 0.0
exit_gross_sum = 0.0
if direction == TradeDirection.Long:
out[r] = order_record
out["size"][r] = remaining_size
out["fees"][r] = remaining_size / order_size * order_fees
r += 1
else:
out[r] = order_record
out["size"][r] = entry_size_sum
out["fees"][r] = entry_size_sum / order_size * order_fees
r += 1
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
return out[:r]
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
close=ch.ArraySlicer(axis=1),
col_map=base_ch.GroupMapSlicer(),
init_position=base_ch.FlexArraySlicer(),
init_price=base_ch.FlexArraySlicer(),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_short_view_orders_nb(
order_records: tp.RecordArray,
close: tp.Array2d,
col_map: tp.GroupMap,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.RecordArray:
"""Get view of orders in short positions only."""
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
order_records = order_records.copy()
out = np.empty(order_records.shape, dtype=order_records.dtype)
r = 0
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(close.shape[0], col_lens.shape[0]),
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(col_lens.shape[0]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_init_position = float(flex_select_1d_pc_nb(init_position_, col))
_init_price = float(flex_select_1d_pc_nb(init_price_, col))
if _init_position != 0:
# Prepare initial position
in_position = True
if _init_position >= 0:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = abs(_init_position)
entry_gross_sum = abs(_init_position) * _init_price
exit_size_sum = 0.0
exit_gross_sum = 0.0
else:
in_position = False
col_len = col_lens[col]
if col_len == 0 and not in_position:
continue
last_id = -1
for c in range(col_len):
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end:
continue
if order_record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = order_record["id"]
order_size = order_record["size"]
order_price = order_record["price"]
order_side = order_record["side"]
order_fees = order_record["fees"]
if order_size <= 0.0:
raise ValueError(invalid_size_msg)
if order_price < 0.0:
raise ValueError(invalid_price_msg)
if not in_position:
# New position opened
in_position = True
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = 0.0
entry_gross_sum = 0.0
exit_size_sum = 0.0
exit_gross_sum = 0.0
if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or (
direction == TradeDirection.Short and order_side == OrderSide.Sell
):
# Position increased
entry_size_sum += order_size
entry_gross_sum += order_size * order_price
if direction == TradeDirection.Short:
out[r] = order_record
r += 1
elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or (
direction == TradeDirection.Short and order_side == OrderSide.Buy
):
if is_close_nb(exit_size_sum + order_size, entry_size_sum):
# Position closed
in_position = False
exit_size_sum = entry_size_sum
exit_gross_sum += order_size * order_price
if direction == TradeDirection.Short:
out[r] = order_record
r += 1
elif is_less_nb(exit_size_sum + order_size, entry_size_sum):
# Position decreased
exit_size_sum += order_size
exit_gross_sum += order_size * order_price
if direction == TradeDirection.Short:
out[r] = order_record
r += 1
else:
# Position closed
remaining_size = add_nb(entry_size_sum, -exit_size_sum)
# New position opened
entry_size_sum = add_nb(order_size, -remaining_size)
entry_gross_sum = entry_size_sum * order_price
exit_size_sum = 0.0
exit_gross_sum = 0.0
if direction == TradeDirection.Short:
out[r] = order_record
out["size"][r] = remaining_size
out["fees"][r] = remaining_size / order_size * order_fees
r += 1
else:
out[r] = order_record
out["size"][r] = entry_size_sum
out["fees"][r] = entry_size_sum / order_size * order_fees
r += 1
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
return out[:r]
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
order_records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
close=ch.ArraySlicer(axis=1),
col_map=base_ch.GroupMapSlicer(),
feature=None,
init_position=base_ch.FlexArraySlicer(),
init_price=base_ch.FlexArraySlicer(),
fill_closed_position=None,
fill_exit_price=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def get_position_feature_nb(
order_records: tp.RecordArray,
close: tp.Array2d,
col_map: tp.GroupMap,
feature: int = PositionFeature.EntryPrice,
init_position: tp.FlexArray1dLike = 0.0,
init_price: tp.FlexArray1dLike = np.nan,
fill_closed_position: bool = False,
fill_exit_price: bool = True,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Get the position's feature at each time step.
For the list of supported features see `vectorbtpro.portfolio.enums.PositionFeature`.
If `fill_exit_price` is True and a part of the position is not closed yet, will fill
the exit price as if the part was closed using the current close.
If `fill_closed_position` is True, will forward-fill missing values with the prices of the
previously closed position."""
init_position_ = to_1d_array_nb(np.asarray(init_position))
init_price_ = to_1d_array_nb(np.asarray(init_price))
out = np.full(close.shape, np.nan, dtype=float_)
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=(close.shape[0], col_lens.shape[0]),
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(col_lens.shape[0]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_init_position = float(flex_select_1d_pc_nb(init_position_, col))
_init_price = float(flex_select_1d_pc_nb(init_price_, col))
if _init_position != 0:
# Prepare initial position
in_position = True
was_in_position = True
if _init_position >= 0:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = abs(_init_position)
entry_gross_sum = abs(_init_position) * _init_price
exit_size_sum = 0.0
exit_gross_sum = 0.0
last_order_idx = 0
else:
in_position = False
was_in_position = False
col_len = col_lens[col]
if col_len == 0 and not in_position:
continue
last_id = -1
for c in range(col_len):
order_record = order_records[col_idxs[col_start_idxs[col] + c]]
if order_record["idx"] < _sim_start or order_record["idx"] >= _sim_end:
continue
if order_record["id"] < last_id:
raise ValueError("Ids must come in ascending order per column")
last_id = order_record["id"]
order_idx = order_record["idx"]
order_size = order_record["size"]
order_price = order_record["price"]
order_side = order_record["side"]
if order_size <= 0.0:
raise ValueError(invalid_size_msg)
if order_price < 0.0:
raise ValueError(invalid_price_msg)
if in_position:
if feature == PositionFeature.EntryPrice:
if entry_size_sum != 0:
entry_price = entry_gross_sum / entry_size_sum
out[last_order_idx:order_idx, col] = entry_price
elif feature == PositionFeature.ExitPrice:
if fill_exit_price:
remaining_size = add_nb(entry_size_sum, -exit_size_sum)
for i in range(last_order_idx, order_idx):
open_exit_size_sum = entry_size_sum
open_exit_gross_sum = exit_gross_sum + remaining_size * close[i, col]
exit_price = open_exit_gross_sum / open_exit_size_sum
out[i, col] = exit_price
else:
if exit_size_sum != 0:
exit_price = exit_gross_sum / exit_size_sum
out[last_order_idx:order_idx, col] = exit_price
else:
if was_in_position and fill_closed_position:
if feature == PositionFeature.EntryPrice:
if entry_size_sum != 0:
entry_price = entry_gross_sum / entry_size_sum
out[last_order_idx:order_idx, col] = entry_price
elif feature == PositionFeature.ExitPrice:
if exit_size_sum != 0:
exit_price = exit_gross_sum / exit_size_sum
out[last_order_idx:order_idx, col] = exit_price
# New position opened
in_position = True
was_in_position = True
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = 0.0
entry_gross_sum = 0.0
exit_size_sum = 0.0
exit_gross_sum = 0.0
if (direction == TradeDirection.Long and order_side == OrderSide.Buy) or (
direction == TradeDirection.Short and order_side == OrderSide.Sell
):
# Position increased
entry_size_sum += order_size
entry_gross_sum += order_size * order_price
elif (direction == TradeDirection.Long and order_side == OrderSide.Sell) or (
direction == TradeDirection.Short and order_side == OrderSide.Buy
):
if is_close_nb(exit_size_sum + order_size, entry_size_sum):
# Position closed
in_position = False
exit_size_sum = entry_size_sum
exit_gross_sum += order_size * order_price
elif is_less_nb(exit_size_sum + order_size, entry_size_sum):
# Position decreased
exit_size_sum += order_size
exit_gross_sum += order_size * order_price
else:
# Position closed
remaining_size = add_nb(entry_size_sum, -exit_size_sum)
# New position opened
if order_side == OrderSide.Buy:
direction = TradeDirection.Long
else:
direction = TradeDirection.Short
entry_size_sum = add_nb(order_size, -remaining_size)
entry_gross_sum = entry_size_sum * order_price
exit_size_sum = 0.0
exit_gross_sum = 0.0
last_order_idx = order_idx
if in_position:
if feature == PositionFeature.EntryPrice:
if entry_size_sum != 0:
entry_price = entry_gross_sum / entry_size_sum
out[last_order_idx:_sim_end, col] = entry_price
elif feature == PositionFeature.ExitPrice:
if fill_exit_price:
remaining_size = add_nb(entry_size_sum, -exit_size_sum)
for i in range(last_order_idx, _sim_end):
open_exit_size_sum = entry_size_sum
open_exit_gross_sum = exit_gross_sum + remaining_size * close[i, col]
exit_price = open_exit_gross_sum / open_exit_size_sum
out[i, col] = exit_price
else:
if exit_size_sum != 0:
exit_price = exit_gross_sum / exit_size_sum
out[last_order_idx:_sim_end, col] = exit_price
elif was_in_position and fill_closed_position:
if feature == PositionFeature.EntryPrice:
if entry_size_sum != 0:
entry_price = entry_gross_sum / entry_size_sum
out[last_order_idx:_sim_end, col] = entry_price
elif feature == PositionFeature.ExitPrice:
if exit_size_sum != 0:
exit_price = exit_gross_sum / exit_size_sum
out[last_order_idx:_sim_end, col] = exit_price
return out
@register_jitted(cache=True)
def price_status_nb(
records: tp.RecordArray,
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
) -> tp.Array1d:
"""Return the status of the order's price related to high and low.
See `vectorbtpro.portfolio.enums.OrderPriceStatus`."""
out = np.full(len(records), 0, dtype=int_)
for i in range(len(records)):
order = records[i]
if high is not None:
_high = float(flex_select_nb(high, order["idx"], order["col"]))
else:
_high = np.nan
if low is not None:
_low = flex_select_nb(low, order["idx"], order["col"])
else:
_low = np.nan
if not np.isnan(_high) and order["price"] > _high:
out[i] = OrderPriceStatus.AboveHigh
elif not np.isnan(_low) and order["price"] < _low:
out[i] = OrderPriceStatus.BelowLow
elif np.isnan(_high) or np.isnan(_low):
out[i] = OrderPriceStatus.Unknown
else:
out[i] = OrderPriceStatus.OK
return out
@register_jitted(cache=True)
def trade_winning_streak_nb(records: tp.RecordArray) -> tp.Array1d:
"""Return the current winning streak of each trade."""
out = np.full(len(records), 0, dtype=int_)
curr_rank = 0
for i in range(len(records)):
if records[i]["pnl"] > 0:
curr_rank += 1
else:
curr_rank = 0
out[i] = curr_rank
return out
@register_jitted(cache=True)
def trade_losing_streak_nb(records: tp.RecordArray) -> tp.Array1d:
"""Return the current losing streak of each trade."""
out = np.full(len(records), 0, dtype=int_)
curr_rank = 0
for i in range(len(records)):
if records[i]["pnl"] < 0:
curr_rank += 1
else:
curr_rank = 0
out[i] = curr_rank
return out
@register_jitted(cache=True)
def win_rate_reduce_nb(pnl_arr: tp.Array1d) -> float:
"""Win rate of a PnL array."""
if pnl_arr.shape[0] == 0:
return np.nan
win_count = 0
count = 0
for i in range(len(pnl_arr)):
if not np.isnan(pnl_arr[i]):
count += 1
if pnl_arr[i] > 0:
win_count += 1
if count == 0:
return np.nan
return win_count / count
@register_jitted(cache=True)
def profit_factor_reduce_nb(pnl_arr: tp.Array1d) -> float:
"""Profit factor of a PnL array."""
if pnl_arr.shape[0] == 0:
return np.nan
win_sum = 0
loss_sum = 0
count = 0
for i in range(len(pnl_arr)):
if not np.isnan(pnl_arr[i]):
count += 1
if pnl_arr[i] > 0:
win_sum += pnl_arr[i]
elif pnl_arr[i] < 0:
loss_sum += abs(pnl_arr[i])
if loss_sum == 0:
return np.inf
return win_sum / loss_sum
@register_jitted(cache=True)
def expectancy_reduce_nb(pnl_arr: tp.Array1d) -> float:
"""Expectancy of a PnL array."""
if pnl_arr.shape[0] == 0:
return np.nan
win_count = 0
win_sum = 0
loss_count = 0
loss_sum = 0
count = 0
for i in range(len(pnl_arr)):
if not np.isnan(pnl_arr[i]):
count += 1
if pnl_arr[i] > 0:
win_count += 1
win_sum += pnl_arr[i]
elif pnl_arr[i] < 0:
loss_count += 1
loss_sum += abs(pnl_arr[i])
if count == 0:
return np.nan
win_rate = win_count / count
if win_count == 0:
win_mean = 0.0
else:
win_mean = win_sum / win_count
loss_rate = loss_count / count
if loss_count == 0:
loss_mean = 0.0
else:
loss_mean = loss_sum / loss_count
return win_rate * win_mean - loss_rate * loss_mean
@register_jitted(cache=True)
def sqn_reduce_nb(pnl_arr: tp.Array1d, ddof: int = 0) -> float:
"""SQN of a PnL array."""
count = generic_nb.nancnt_1d_nb(pnl_arr)
mean = np.nanmean(pnl_arr)
std = generic_nb.nanstd_1d_nb(pnl_arr, ddof=ddof)
if std == 0:
return np.nan
return np.sqrt(count) * mean / std
@register_jitted(cache=True)
def trade_best_worst_price_nb(
trade: tp.Record,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
idx_relative: bool = True,
cont_idx: int = -1,
one_iteration: bool = False,
vmin: float = np.nan,
vmax: float = np.nan,
imin: int = -1,
imax: int = -1,
) -> tp.Tuple[float, float, int, int]:
"""Best price, worst price, and their indices during a trade."""
from_i = trade["entry_idx"]
to_i = trade["exit_idx"]
trade_open = trade["status"] == TradeStatus.Open
trade_long = trade["direction"] == TradeDirection.Long
if cont_idx == -1 or cont_idx == from_i:
cont_idx = from_i
vmin = np.nan
vmax = np.nan
imin = -1
imax = -1
else:
if trade_long:
vmin, vmax, imin, imax = vmax, vmin, imax, imin
if idx_relative:
imin = from_i + imin
imax = from_i + imax
for i in range(cont_idx, to_i + 1):
if i == from_i:
if np.isnan(vmin) or trade["entry_price"] < vmin:
vmin = trade["entry_price"]
imin = i
if np.isnan(vmax) or trade["entry_price"] > vmax:
vmax = trade["entry_price"]
imax = i
if i > from_i or entry_price_open:
if open is not None:
_open = flex_select_nb(open, i, trade["col"])
if np.isnan(vmin) or _open < vmin:
vmin = _open
imin = i
if np.isnan(vmax) or _open > vmax:
vmax = _open
imax = i
if (i > from_i or entry_price_open) and (i < to_i or exit_price_close or trade_open):
if low is not None:
_low = flex_select_nb(low, i, trade["col"])
if np.isnan(vmin) or _low < vmin:
vmin = _low
imin = i
if high is not None:
_high = flex_select_nb(high, i, trade["col"])
if np.isnan(vmax) or _high > vmax:
vmax = _high
imax = i
if i < to_i or exit_price_close or trade_open:
_close = flex_select_nb(close, i, trade["col"])
if np.isnan(vmin) or _close < vmin:
vmin = _close
imin = i
if np.isnan(vmax) or _close > vmax:
vmax = _close
imax = i
if max_duration is not None:
if from_i + max_duration == i:
break
if i == to_i:
if np.isnan(vmin) or trade["exit_price"] < vmin:
vmin = trade["exit_price"]
imin = i
if np.isnan(vmax) or trade["exit_price"] > vmax:
vmax = trade["exit_price"]
imax = i
if one_iteration:
break
if idx_relative:
imin = imin - from_i
imax = imax - from_i
if trade_long:
return vmax, vmin, imax, imin
return vmin, vmax, imin, imax
@register_chunkable(
size=ch.ArraySizer(arg_query="records", axis=0),
arg_take_spec=dict(
records=ch.ArraySlicer(axis=0),
open=None,
high=None,
low=None,
close=None,
entry_price_open=None,
exit_price_close=None,
max_duration=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def best_price_nb(
records: tp.RecordArray,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
) -> tp.Array1d:
"""Get best price by applying `trade_best_worst_price_nb` on each trade."""
out = np.empty(len(records), dtype=float_)
for r in prange(len(records)):
trade = records[r]
out[r] = trade_best_worst_price_nb(
trade=trade,
open=open,
high=high,
low=low,
close=close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
)[0]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="records", axis=0),
arg_take_spec=dict(
records=ch.ArraySlicer(axis=0),
open=None,
high=None,
low=None,
close=None,
entry_price_open=None,
exit_price_close=None,
max_duration=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def worst_price_nb(
records: tp.RecordArray,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
) -> tp.Array1d:
"""Get worst price by applying `trade_best_worst_price_nb` on each trade."""
out = np.empty(len(records), dtype=float_)
for r in prange(len(records)):
trade = records[r]
out[r] = trade_best_worst_price_nb(
trade=trade,
open=open,
high=high,
low=low,
close=close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
)[1]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="records", axis=0),
arg_take_spec=dict(
records=ch.ArraySlicer(axis=0),
open=None,
high=None,
low=None,
close=None,
entry_price_open=None,
exit_price_close=None,
max_duration=None,
relative=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def best_price_idx_nb(
records: tp.RecordArray,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
relative: bool = True,
) -> tp.Array1d:
"""Get index of best price by applying `trade_best_worst_price_nb` on each trade."""
out = np.empty(len(records), dtype=float_)
for r in prange(len(records)):
trade = records[r]
out[r] = trade_best_worst_price_nb(
trade=trade,
open=open,
high=high,
low=low,
close=close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
idx_relative=relative,
)[2]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="records", axis=0),
arg_take_spec=dict(
records=ch.ArraySlicer(axis=0),
open=None,
high=None,
low=None,
close=None,
entry_price_open=None,
exit_price_close=None,
max_duration=None,
relative=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def worst_price_idx_nb(
records: tp.RecordArray,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
relative: bool = True,
) -> tp.Array1d:
"""Get worst price by applying `trade_best_worst_price_nb` on each trade."""
out = np.empty(len(records), dtype=float_)
for r in prange(len(records)):
trade = records[r]
out[r] = trade_best_worst_price_nb(
trade=trade,
open=open,
high=high,
low=low,
close=close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
idx_relative=relative,
)[3]
return out
@register_jitted(cache=True, tags={"can_parallel"})
def expanding_best_price_nb(
records: tp.RecordArray,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
) -> tp.Array2d:
"""Get expanding best price of each trade."""
if max_duration is None:
_max_duration = 0
for r in range(len(records)):
trade = records[r]
trade_duration = trade["exit_idx"] - trade["entry_idx"]
if trade_duration > _max_duration:
_max_duration = trade_duration
else:
_max_duration = max_duration
out = np.full((_max_duration + 1, len(records)), np.nan, dtype=float_)
for r in prange(len(records)):
trade = records[r]
from_i = trade["entry_idx"]
to_i = trade["exit_idx"]
vmin = np.nan
vmax = np.nan
imin = -1
imax = -1
for i in range(from_i, to_i + 1):
vmin, vmax, imin, imax = trade_best_worst_price_nb(
trade=trade,
open=open,
high=high,
low=low,
close=close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
cont_idx=i,
one_iteration=True,
vmin=vmin,
vmax=vmax,
imin=imin,
imax=imax,
)
out[i - from_i, r] = vmin
if max_duration is not None:
if from_i + max_duration == i:
break
return out
@register_jitted(cache=True, tags={"can_parallel"})
def expanding_worst_price_nb(
records: tp.RecordArray,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
) -> tp.Array2d:
"""Get expanding worst price of each trade."""
if max_duration is None:
_max_duration = 0
for r in range(len(records)):
trade = records[r]
trade_duration = trade["exit_idx"] - trade["entry_idx"]
if trade_duration > _max_duration:
_max_duration = trade_duration
else:
_max_duration = max_duration
out = np.full((_max_duration + 1, len(records)), np.nan, dtype=float_)
for r in prange(len(records)):
trade = records[r]
from_i = trade["entry_idx"]
to_i = trade["exit_idx"]
vmin = np.nan
vmax = np.nan
imin = -1
imax = -1
for i in range(from_i, to_i + 1):
vmin, vmax, imin, imax = trade_best_worst_price_nb(
trade=trade,
open=open,
high=high,
low=low,
close=close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
cont_idx=i,
one_iteration=True,
vmin=vmin,
vmax=vmax,
imin=imin,
imax=imax,
)
out[i - from_i, r] = vmax
if max_duration is not None:
if from_i + max_duration == i:
break
return out
@register_jitted(cache=True)
def trade_mfe_nb(
size: float,
direction: int,
entry_price: float,
best_price: float,
use_returns: bool = False,
) -> float:
"""Compute Maximum Favorable Excursion (MFE)."""
if direction == TradeDirection.Long:
if use_returns:
return (best_price - entry_price) / entry_price
return (best_price - entry_price) * size
if use_returns:
return (entry_price - best_price) / best_price
return (entry_price - best_price) * size
@register_chunkable(
size=ch.ArraySizer(arg_query="size", axis=0),
arg_take_spec=dict(
size=ch.ArraySlicer(axis=0),
direction=ch.ArraySlicer(axis=0),
entry_price=ch.ArraySlicer(axis=0),
best_price=ch.ArraySlicer(axis=0),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def mfe_nb(
size: tp.Array1d,
direction: tp.Array1d,
entry_price: tp.Array1d,
best_price: tp.Array1d,
use_returns: bool = False,
) -> tp.Array1d:
"""Apply `trade_mfe_nb` on each trade."""
out = np.empty(size.shape[0], dtype=float_)
for r in prange(size.shape[0]):
out[r] = trade_mfe_nb(
size=size[r],
direction=direction[r],
entry_price=entry_price[r],
best_price=best_price[r],
use_returns=use_returns,
)
return out
@register_jitted(cache=True)
def trade_mae_nb(
size: float,
direction: int,
entry_price: float,
worst_price: float,
use_returns: bool = False,
) -> float:
"""Compute Maximum Adverse Excursion (MAE)."""
if direction == TradeDirection.Long:
if use_returns:
return (worst_price - entry_price) / entry_price
return (worst_price - entry_price) * size
if use_returns:
return (entry_price - worst_price) / worst_price
return (entry_price - worst_price) * size
@register_chunkable(
size=ch.ArraySizer(arg_query="size", axis=0),
arg_take_spec=dict(
size=ch.ArraySlicer(axis=0),
direction=ch.ArraySlicer(axis=0),
entry_price=ch.ArraySlicer(axis=0),
worst_price=ch.ArraySlicer(axis=0),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def mae_nb(
size: tp.Array1d,
direction: tp.Array1d,
entry_price: tp.Array1d,
worst_price: tp.Array1d,
use_returns: bool = False,
) -> tp.Array1d:
"""Apply `trade_mae_nb` on each trade."""
out = np.empty(size.shape[0], dtype=float_)
for r in prange(size.shape[0]):
out[r] = trade_mae_nb(
size=size[r],
direction=direction[r],
entry_price=entry_price[r],
worst_price=worst_price[r],
use_returns=use_returns,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="records", axis=0),
arg_take_spec=dict(
records=ch.ArraySlicer(axis=0),
expanding_best_price=ch.ArraySlicer(axis=1),
use_returns=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def expanding_mfe_nb(
records: tp.RecordArray,
expanding_best_price: tp.Array2d,
use_returns: bool = False,
) -> tp.Array2d:
"""Get expanding MFE of each trade."""
out = np.empty_like(expanding_best_price, dtype=float_)
for r in prange(expanding_best_price.shape[1]):
for i in range(expanding_best_price.shape[0]):
out[i, r] = trade_mfe_nb(
size=records["size"][r],
direction=records["direction"][r],
entry_price=records["entry_price"][r],
best_price=expanding_best_price[i, r],
use_returns=use_returns,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="records", axis=0),
arg_take_spec=dict(
records=ch.ArraySlicer(axis=0),
expanding_worst_price=ch.ArraySlicer(axis=1),
use_returns=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def expanding_mae_nb(
records: tp.RecordArray,
expanding_worst_price: tp.Array2d,
use_returns: bool = False,
) -> tp.Array2d:
"""Get expanding MAE of each trade."""
out = np.empty_like(expanding_worst_price, dtype=float_)
for r in prange(expanding_worst_price.shape[1]):
for i in range(expanding_worst_price.shape[0]):
out[i, r] = trade_mae_nb(
size=records["size"][r],
direction=records["direction"][r],
entry_price=records["entry_price"][r],
worst_price=expanding_worst_price[i, r],
use_returns=use_returns,
)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
records=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
open=None,
high=None,
low=None,
close=None,
volatility=None,
entry_price_open=None,
exit_price_close=None,
max_duration=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def edge_ratio_nb(
records: tp.RecordArray,
col_map: tp.GroupMap,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
volatility: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
) -> tp.Array1d:
"""Get edge ratio of each column."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(len(col_lens), np.nan, dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
ridxs = col_idxs[col_start_idx : col_start_idx + col_len]
norm_mfe_sum = 0.0
norm_mfe_cnt = 0
norm_mae_sum = 0.0
norm_mae_cnt = 0
for r in ridxs:
trade = records[r]
best_price, worst_price, _, _ = trade_best_worst_price_nb(
trade=trade,
open=open,
high=high,
low=low,
close=close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
)
mfe = abs(
trade_mfe_nb(
size=trade["size"],
direction=trade["direction"],
entry_price=trade["entry_price"],
best_price=best_price,
use_returns=False,
)
)
mae = abs(
trade_mae_nb(
size=trade["size"],
direction=trade["direction"],
entry_price=trade["entry_price"],
worst_price=worst_price,
use_returns=False,
)
)
_volatility = flex_select_nb(volatility, trade["entry_idx"], trade["col"])
if _volatility == 0:
norm_mfe = np.nan
norm_mae = np.nan
else:
norm_mfe = mfe / _volatility
norm_mae = mae / _volatility
if not np.isnan(norm_mfe):
norm_mfe_sum += norm_mfe
norm_mfe_cnt += 1
if not np.isnan(norm_mae):
norm_mae_sum += norm_mae
norm_mae_cnt += 1
if norm_mfe_cnt == 0:
mean_mfe = np.nan
else:
mean_mfe = norm_mfe_sum / norm_mfe_cnt
if norm_mae_cnt == 0:
mean_mae = np.nan
else:
mean_mae = norm_mae_sum / norm_mae_cnt
if mean_mae == 0:
out[col] = np.nan
else:
out[col] = mean_mfe / mean_mae
return out
@register_jitted(cache=True, tags={"can_parallel"})
def running_edge_ratio_nb(
records: tp.RecordArray,
col_map: tp.GroupMap,
open: tp.Optional[tp.FlexArray2d],
high: tp.Optional[tp.FlexArray2d],
low: tp.Optional[tp.FlexArray2d],
close: tp.FlexArray2d,
volatility: tp.FlexArray2d,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
incl_shorter: bool = False,
) -> tp.Array2d:
"""Get running edge ratio of each column."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
if max_duration is None:
_max_duration = 0
for r in range(len(records)):
trade = records[r]
trade_duration = trade["exit_idx"] - trade["entry_idx"]
if trade_duration > _max_duration:
_max_duration = trade_duration
else:
_max_duration = max_duration
out = np.full((_max_duration, len(col_lens)), np.nan, dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
ridxs = col_idxs[col_start_idx : col_start_idx + col_len]
for k in range(_max_duration):
norm_mfe_sum = 0.0
norm_mfe_cnt = 0
norm_mae_sum = 0.0
norm_mae_cnt = 0
for r in ridxs:
trade = records[r]
if not incl_shorter:
trade_duration = trade["exit_idx"] - trade["entry_idx"]
if trade_duration < k + 1:
continue
best_price, worst_price, _, _ = trade_best_worst_price_nb(
trade=trade,
open=open,
high=high,
low=low,
close=close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=k + 1,
)
mfe = abs(
trade_mfe_nb(
size=trade["size"],
direction=trade["direction"],
entry_price=trade["entry_price"],
best_price=best_price,
use_returns=False,
)
)
mae = abs(
trade_mae_nb(
size=trade["size"],
direction=trade["direction"],
entry_price=trade["entry_price"],
worst_price=worst_price,
use_returns=False,
)
)
_volatility = flex_select_nb(volatility, trade["entry_idx"], trade["col"])
if _volatility == 0:
norm_mfe = np.nan
norm_mae = np.nan
else:
norm_mfe = mfe / _volatility
norm_mae = mae / _volatility
if not np.isnan(norm_mfe):
norm_mfe_sum += norm_mfe
norm_mfe_cnt += 1
if not np.isnan(norm_mae):
norm_mae_sum += norm_mae
norm_mae_cnt += 1
if norm_mfe_cnt == 0:
mean_mfe = np.nan
else:
mean_mfe = norm_mfe_sum / norm_mfe_cnt
if norm_mae_cnt == 0:
mean_mae = np.nan
else:
mean_mae = norm_mae_sum / norm_mae_cnt
if mean_mae == 0:
out[k, col] = np.nan
else:
out[k, col] = mean_mfe / mean_mae
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules with classes and utilities for portfolio optimization."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.portfolio.pfopt.base import *
from vectorbtpro.portfolio.pfopt.nb import *
from vectorbtpro.portfolio.pfopt.records import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base functions and classes for portfolio optimization."""
import inspect
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.indexes import combine_indexes, stack_indexes, select_levels
from vectorbtpro.base.indexing import point_idxr_defaults, range_idxr_defaults
from vectorbtpro.base.merging import row_stack_arrays
from vectorbtpro.base.reshaping import to_pd_array, to_1d_array, to_2d_array, to_dict, broadcast_array_to
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.data.base import Data
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.generic.enums import RangeStatus
from vectorbtpro.generic.splitting.base import Splitter, Takeable
from vectorbtpro.portfolio.enums import Direction
from vectorbtpro.portfolio.pfopt import nb
from vectorbtpro.portfolio.pfopt.records import AllocRanges, AllocPoints
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.returns.accessors import ReturnsAccessor
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.annotations import has_annotatables
from vectorbtpro.utils.config import merge_dicts, Config, HybridConfig
from vectorbtpro.utils.decorators import hybrid_method
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.execution import Task, execute
from vectorbtpro.utils.params import Param, combine_params, Parameterizer
from vectorbtpro.utils.parsing import (
get_func_arg_names,
annotate_args,
flatten_ann_args,
unflatten_ann_args,
ann_args_to_args,
)
from vectorbtpro.utils.pickling import pdict
from vectorbtpro.utils.random_ import set_seed_nb
from vectorbtpro.utils.template import substitute_templates, Rep, RepFunc, CustomTemplate
from vectorbtpro.utils.warnings_ import warn, warn_stdout, WarningsFiltered
if tp.TYPE_CHECKING:
from vectorbtpro.portfolio.base import Portfolio as PortfolioT
else:
PortfolioT = "Portfolio"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from pypfopt.base_optimizer import BaseOptimizer as BaseOptimizerT
except ImportError:
BaseOptimizerT = "BaseOptimizer"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from riskfolio import Portfolio as RPortfolio, HCPortfolio as RHCPortfolio
RPortfolioT = tp.TypeVar("RPortfolioT", bound=tp.Union[RPortfolio, RHCPortfolio])
except ImportError:
RPortfolioT = "RPortfolio"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from universal.algo import Algo
from universal.result import AlgoResult
AlgoT = tp.TypeVar("AlgoT", bound=Algo)
AlgoResultT = tp.TypeVar("AlgoResultT", bound=AlgoResult)
except ImportError:
AlgoT = "Algo"
AlgoResultT = "AlgoResult"
__all__ = [
"pfopt_func_dict",
"pypfopt_optimize",
"riskfolio_optimize",
"PortfolioOptimizer",
"PFO",
]
__pdoc__ = {}
# ############# PyPortfolioOpt ############# #
class pfopt_func_dict(pdict):
"""Dict that contains optimization functions as keys.
Keys can be functions themselves, their names, or `_def` for the default value."""
pass
def select_pfopt_func_kwargs(
pypfopt_func: tp.Callable,
kwargs: tp.Union[None, tp.Kwargs, pfopt_func_dict] = None,
) -> tp.Kwargs:
"""Select keyword arguments belonging to `pypfopt_func`."""
if kwargs is None:
return {}
if isinstance(kwargs, pfopt_func_dict):
if pypfopt_func in kwargs:
_kwargs = kwargs[pypfopt_func]
elif pypfopt_func.__name__ in kwargs:
_kwargs = kwargs[pypfopt_func.__name__]
elif "_def" in kwargs:
_kwargs = kwargs["_def"]
else:
_kwargs = {}
else:
_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, pfopt_func_dict):
if pypfopt_func in v:
_kwargs[k] = v[pypfopt_func]
elif pypfopt_func.__name__ in v:
_kwargs[k] = v[pypfopt_func.__name__]
elif "_def" in v:
_kwargs[k] = v["_def"]
else:
_kwargs[k] = v
return _kwargs
def resolve_pypfopt_func_kwargs(
pypfopt_func: tp.Callable,
cache: tp.KwargsLike = None,
var_kwarg_names: tp.Optional[tp.Iterable[str]] = None,
used_arg_names: tp.Optional[tp.Set[str]] = None,
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments passed to any optimization function with the layout of PyPortfolioOpt.
Parses the signature of `pypfopt_func`, and for each accepted argument, looks for an argument
with the same name in `kwargs`. If not found, tries to resolve that argument using other arguments
or by calling other optimization functions.
Argument `frequency` gets resolved with (global) `freq` and `year_freq` using
`vectorbtpro.returns.accessors.ReturnsAccessor.get_ann_factor`.
Any argument in `kwargs` can be wrapped using `pfopt_func_dict` to define the argument
per function rather than globally.
!!! note
When providing custom functions, make sure that the arguments they accept are visible
in the signature (that is, no variable arguments) and have the same naming as in PyPortfolioOpt.
Functions `market_implied_prior_returns` and `BlackLittermanModel.bl_weights` take `risk_aversion`,
which is different from arguments with the same name in other functions. To set it, pass `delta`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pypfopt")
signature = inspect.signature(pypfopt_func)
kwargs = select_pfopt_func_kwargs(pypfopt_func, kwargs)
if cache is None:
cache = {}
arg_names = get_func_arg_names(pypfopt_func)
if len(arg_names) == 0:
return {}
if used_arg_names is None:
used_arg_names = set()
pass_kwargs = dict()
def _process_arg(arg_name, arg_value):
orig_arg_name = arg_name
if pypfopt_func.__name__ in ("market_implied_prior_returns", "bl_weights"):
if arg_name == "risk_aversion":
# In some methods, risk_aversion is expected as array and means delta
arg_name = "delta"
def _get_kwarg(*args):
used_arg_names.add(args[0])
return kwargs.get(*args)
def _get_prices():
prices = None
if "prices" in cache:
prices = cache["prices"]
elif "prices" in kwargs:
if not _get_kwarg("returns_data", False):
prices = _get_kwarg("prices")
return prices
def _get_returns():
returns = None
if "returns" in cache:
returns = cache["returns"]
elif "returns" in kwargs:
returns = _get_kwarg("returns")
elif "prices" in kwargs and _get_kwarg("returns_data", False):
returns = _get_kwarg("prices")
return returns
def _prices_from_returns():
from pypfopt.expected_returns import prices_from_returns
cache["prices"] = prices_from_returns(_get_returns(), _get_kwarg("log_returns", False))
return cache["prices"]
def _returns_from_prices():
from pypfopt.expected_returns import returns_from_prices
cache["returns"] = returns_from_prices(_get_prices(), _get_kwarg("log_returns", False))
return cache["returns"]
if arg_name == "expected_returns":
if arg_name in kwargs:
used_arg_names.add(arg_name)
if "expected_returns" not in cache:
cache["expected_returns"] = resolve_pypfopt_expected_returns(
cache=cache,
used_arg_names=used_arg_names,
**kwargs,
)
pass_kwargs[orig_arg_name] = cache["expected_returns"]
elif arg_name == "cov_matrix":
if arg_name in kwargs:
used_arg_names.add(arg_name)
if "cov_matrix" not in cache:
cache["cov_matrix"] = resolve_pypfopt_cov_matrix(
cache=cache,
used_arg_names=used_arg_names,
**kwargs,
)
pass_kwargs[orig_arg_name] = cache["cov_matrix"]
elif arg_name == "optimizer":
if arg_name in kwargs:
used_arg_names.add(arg_name)
if "optimizer" not in cache:
cache["optimizer"] = resolve_pypfopt_optimizer(
cache=cache,
used_arg_names=used_arg_names,
**kwargs,
)
pass_kwargs[orig_arg_name] = cache["optimizer"]
if orig_arg_name not in pass_kwargs:
if arg_name in kwargs:
if arg_name == "market_prices":
if pypfopt_func.__name__ != "market_implied_risk_aversion" and checks.is_series(
_get_kwarg(arg_name)
):
pass_kwargs[orig_arg_name] = _get_kwarg(arg_name).to_frame().copy(deep=False)
else:
pass_kwargs[orig_arg_name] = _get_kwarg(arg_name).copy(deep=False)
else:
pass_kwargs[orig_arg_name] = _get_kwarg(arg_name)
else:
if arg_name == "frequency":
ann_factor = ReturnsAccessor.get_ann_factor(
year_freq=_get_kwarg("year_freq", None),
freq=_get_kwarg("freq", None),
)
if ann_factor is not None:
pass_kwargs[orig_arg_name] = ann_factor
elif arg_name == "prices":
if "returns_data" in arg_names:
if "returns_data" in kwargs:
if _get_kwarg("returns_data", False):
if _get_returns() is not None:
pass_kwargs[orig_arg_name] = _get_returns()
elif _get_prices() is not None:
pass_kwargs[orig_arg_name] = _returns_from_prices()
else:
if _get_prices() is not None:
pass_kwargs[orig_arg_name] = _get_prices()
elif _get_returns() is not None:
pass_kwargs[orig_arg_name] = _prices_from_returns()
else:
if _get_prices() is not None:
pass_kwargs[orig_arg_name] = _get_prices()
pass_kwargs["returns_data"] = False
elif _get_returns() is not None:
pass_kwargs[orig_arg_name] = _get_returns()
pass_kwargs["returns_data"] = True
else:
if _get_prices() is not None:
pass_kwargs[orig_arg_name] = _get_prices()
elif _get_returns() is not None:
pass_kwargs[orig_arg_name] = _prices_from_returns()
elif arg_name == "returns":
if _get_returns() is not None:
pass_kwargs[orig_arg_name] = _get_returns()
elif _get_prices() is not None:
pass_kwargs[orig_arg_name] = _returns_from_prices()
elif arg_name == "latest_prices":
from pypfopt.discrete_allocation import get_latest_prices
if _get_prices() is not None:
pass_kwargs[orig_arg_name] = cache["latest_prices"] = get_latest_prices(_get_prices())
elif _get_returns() is not None:
pass_kwargs[orig_arg_name] = cache["latest_prices"] = get_latest_prices(_prices_from_returns())
elif arg_name == "delta":
if "delta" not in cache:
from pypfopt.black_litterman import market_implied_risk_aversion
cache["delta"] = resolve_pypfopt_func_call(
market_implied_risk_aversion,
cache=cache,
used_arg_names=used_arg_names,
**kwargs,
)
pass_kwargs[orig_arg_name] = cache["delta"]
elif arg_name == "pi":
if "pi" not in cache:
from pypfopt.black_litterman import market_implied_prior_returns
cache["pi"] = resolve_pypfopt_func_call(
market_implied_prior_returns,
cache=cache,
used_arg_names=used_arg_names,
**kwargs,
)
pass_kwargs[orig_arg_name] = cache["pi"]
if orig_arg_name not in pass_kwargs:
if arg_value.default != inspect.Parameter.empty:
pass_kwargs[orig_arg_name] = arg_value.default
for arg_name, arg_value in signature.parameters.items():
if arg_value.kind == inspect.Parameter.VAR_POSITIONAL:
raise TypeError(f"Variable positional arguments in {pypfopt_func} cannot be parsed")
elif arg_value.kind == inspect.Parameter.VAR_KEYWORD:
if var_kwarg_names is None:
var_kwarg_names = []
for var_arg_name in var_kwarg_names:
_process_arg(var_arg_name, arg_value)
else:
_process_arg(arg_name, arg_value)
return pass_kwargs
def resolve_pypfopt_func_call(pypfopt_func: tp.Callable, **kwargs) -> tp.Any:
"""Resolve arguments using `resolve_pypfopt_func_kwargs` and call the function with that arguments."""
return pypfopt_func(**resolve_pypfopt_func_kwargs(pypfopt_func, **kwargs))
def resolve_pypfopt_expected_returns(
expected_returns: tp.Union[tp.Callable, tp.AnyArray, str] = "mean_historical_return",
**kwargs,
) -> tp.AnyArray:
"""Resolve the expected returns.
`expected_returns` can be an array, an attribute of `pypfopt.expected_returns`, a function,
or one of the following options:
* 'mean_historical_return': `pypfopt.expected_returns.mean_historical_return`
* 'ema_historical_return': `pypfopt.expected_returns.ema_historical_return`
* 'capm_return': `pypfopt.expected_returns.capm_return`
* 'bl_returns': `pypfopt.black_litterman.BlackLittermanModel.bl_returns`
Any function is resolved using `resolve_pypfopt_func_call`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pypfopt")
if isinstance(expected_returns, str):
if expected_returns.lower() == "mean_historical_return":
from pypfopt.expected_returns import mean_historical_return
return resolve_pypfopt_func_call(mean_historical_return, **kwargs)
if expected_returns.lower() == "ema_historical_return":
from pypfopt.expected_returns import ema_historical_return
return resolve_pypfopt_func_call(ema_historical_return, **kwargs)
if expected_returns.lower() == "capm_return":
from pypfopt.expected_returns import capm_return
return resolve_pypfopt_func_call(capm_return, **kwargs)
if expected_returns.lower() == "bl_returns":
from pypfopt.black_litterman import BlackLittermanModel
return resolve_pypfopt_func_call(
BlackLittermanModel,
var_kwarg_names=["market_caps", "risk_free_rate"],
**kwargs,
).bl_returns()
import pypfopt.expected_returns
if hasattr(pypfopt.expected_returns, expected_returns):
return resolve_pypfopt_func_call(getattr(pypfopt.expected_returns, expected_returns), **kwargs)
raise NotImplementedError("Return model '{}' is not supported".format(expected_returns))
if callable(expected_returns):
return resolve_pypfopt_func_call(expected_returns, **kwargs)
return expected_returns
def resolve_pypfopt_cov_matrix(
cov_matrix: tp.Union[tp.Callable, tp.AnyArray, str] = "ledoit_wolf",
**kwargs,
) -> tp.AnyArray:
"""Resolve the covariance matrix.
`cov_matrix` can be an array, an attribute of `pypfopt.risk_models`, a function,
or one of the following options:
* 'sample_cov': `pypfopt.risk_models.sample_cov`
* 'semicovariance' or 'semivariance': `pypfopt.risk_models.semicovariance`
* 'exp_cov': `pypfopt.risk_models.exp_cov`
* 'ledoit_wolf' or 'ledoit_wolf_constant_variance': `pypfopt.risk_models.CovarianceShrinkage.ledoit_wolf`
with 'constant_variance' as shrinkage factor
* 'ledoit_wolf_single_factor': `pypfopt.risk_models.CovarianceShrinkage.ledoit_wolf`
with 'single_factor' as shrinkage factor
* 'ledoit_wolf_constant_correlation': `pypfopt.risk_models.CovarianceShrinkage.ledoit_wolf`
with 'constant_correlation' as shrinkage factor
* 'oracle_approximating': `pypfopt.risk_models.CovarianceShrinkage.ledoit_wolf`
with 'oracle_approximating' as shrinkage factor
Any function is resolved using `resolve_pypfopt_func_call`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pypfopt")
if isinstance(cov_matrix, str):
if cov_matrix.lower() == "sample_cov":
from pypfopt.risk_models import sample_cov
return resolve_pypfopt_func_call(sample_cov, var_kwarg_names=["fix_method"], **kwargs)
if cov_matrix.lower() == "semicovariance" or cov_matrix.lower() == "semivariance":
from pypfopt.risk_models import semicovariance
return resolve_pypfopt_func_call(semicovariance, var_kwarg_names=["fix_method"], **kwargs)
if cov_matrix.lower() == "exp_cov":
from pypfopt.risk_models import exp_cov
return resolve_pypfopt_func_call(exp_cov, var_kwarg_names=["fix_method"], **kwargs)
if cov_matrix.lower() == "ledoit_wolf" or cov_matrix.lower() == "ledoit_wolf_constant_variance":
from pypfopt.risk_models import CovarianceShrinkage
return resolve_pypfopt_func_call(CovarianceShrinkage, **kwargs).ledoit_wolf()
if cov_matrix.lower() == "ledoit_wolf_single_factor":
from pypfopt.risk_models import CovarianceShrinkage
return resolve_pypfopt_func_call(CovarianceShrinkage, **kwargs).ledoit_wolf(
shrinkage_target="single_factor"
)
if cov_matrix.lower() == "ledoit_wolf_constant_correlation":
from pypfopt.risk_models import CovarianceShrinkage
return resolve_pypfopt_func_call(CovarianceShrinkage, **kwargs).ledoit_wolf(
shrinkage_target="constant_correlation"
)
if cov_matrix.lower() == "oracle_approximating":
from pypfopt.risk_models import CovarianceShrinkage
return resolve_pypfopt_func_call(CovarianceShrinkage, **kwargs).oracle_approximating()
import pypfopt.risk_models
if hasattr(pypfopt.risk_models, cov_matrix):
return resolve_pypfopt_func_call(getattr(pypfopt.risk_models, cov_matrix), **kwargs)
raise NotImplementedError("Risk model '{}' is not supported".format(cov_matrix))
if callable(cov_matrix):
return resolve_pypfopt_func_call(cov_matrix, **kwargs)
return cov_matrix
def resolve_pypfopt_optimizer(
optimizer: tp.Union[tp.Callable, BaseOptimizerT, str] = "efficient_frontier",
**kwargs,
) -> BaseOptimizerT:
"""Resolve the optimizer.
`optimizer` can be an instance of `pypfopt.base_optimizer.BaseOptimizer`, an attribute of `pypfopt`,
a subclass of `pypfopt.base_optimizer.BaseOptimizer`, or one of the following options:
* 'efficient_frontier': `pypfopt.efficient_frontier.EfficientFrontier`
* 'efficient_cdar': `pypfopt.efficient_frontier.EfficientCDaR`
* 'efficient_cvar': `pypfopt.efficient_frontier.EfficientCVaR`
* 'efficient_semivariance': `pypfopt.efficient_frontier.EfficientSemivariance`
* 'black_litterman' or 'bl': `pypfopt.black_litterman.BlackLittermanModel`
* 'hierarchical_portfolio', 'hrpopt', or 'hrp': `pypfopt.hierarchical_portfolio.HRPOpt`
* 'cla': `pypfopt.cla.CLA`
Any function is resolved using `resolve_pypfopt_func_call`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pypfopt")
from pypfopt.base_optimizer import BaseOptimizer
if isinstance(optimizer, str):
if optimizer.lower() == "efficient_frontier":
from pypfopt.efficient_frontier import EfficientFrontier
return resolve_pypfopt_func_call(EfficientFrontier, **kwargs)
if optimizer.lower() == "efficient_cdar":
from pypfopt.efficient_frontier import EfficientCDaR
return resolve_pypfopt_func_call(EfficientCDaR, **kwargs)
if optimizer.lower() == "efficient_cvar":
from pypfopt.efficient_frontier import EfficientCVaR
return resolve_pypfopt_func_call(EfficientCVaR, **kwargs)
if optimizer.lower() == "efficient_semivariance":
from pypfopt.efficient_frontier import EfficientSemivariance
return resolve_pypfopt_func_call(EfficientSemivariance, **kwargs)
if optimizer.lower() == "black_litterman" or optimizer.lower() == "bl":
from pypfopt.black_litterman import BlackLittermanModel
return resolve_pypfopt_func_call(
BlackLittermanModel,
var_kwarg_names=["market_caps", "risk_free_rate"],
**kwargs,
)
if optimizer.lower() == "hierarchical_portfolio" or optimizer.lower() == "hrpopt" or optimizer.lower() == "hrp":
from pypfopt.hierarchical_portfolio import HRPOpt
return resolve_pypfopt_func_call(HRPOpt, **kwargs)
if optimizer.lower() == "cla":
from pypfopt.cla import CLA
return resolve_pypfopt_func_call(CLA, **kwargs)
import pypfopt
if hasattr(pypfopt, optimizer):
return resolve_pypfopt_func_call(getattr(pypfopt, optimizer), **kwargs)
raise NotImplementedError("Optimizer '{}' is not supported".format(optimizer))
if isinstance(optimizer, type) and issubclass(optimizer, BaseOptimizer):
return resolve_pypfopt_func_call(optimizer, **kwargs)
if isinstance(optimizer, BaseOptimizer):
return optimizer
raise NotImplementedError("Optimizer {} is not supported".format(optimizer))
def pypfopt_optimize(
target: tp.Optional[tp.Union[tp.Callable, str]] = None,
target_is_convex: tp.Optional[bool] = None,
weights_sum_to_one: tp.Optional[bool] = None,
target_constraints: tp.Optional[tp.List[tp.Kwargs]] = None,
target_solver: tp.Optional[str] = None,
target_initial_guess: tp.Optional[tp.Array] = None,
objectives: tp.Optional[tp.MaybeIterable[tp.Union[tp.Callable, str]]] = None,
constraints: tp.Optional[tp.MaybeIterable[tp.Callable]] = None,
sector_mapper: tp.Optional[dict] = None,
sector_lower: tp.Optional[dict] = None,
sector_upper: tp.Optional[dict] = None,
discrete_allocation: tp.Optional[bool] = None,
allocation_method: tp.Optional[str] = None,
silence_warnings: tp.Optional[bool] = None,
ignore_opt_errors: tp.Optional[bool] = None,
ignore_errors: tp.Optional[bool] = None,
**kwargs,
) -> tp.Dict[str, float]:
"""Get allocation using PyPortfolioOpt.
First, it resolves the optimizer using `resolve_pypfopt_optimizer`. Depending upon which arguments it takes,
it may further resolve expected returns, covariance matrix, etc. Then, it adds objectives and constraints
to the optimizer instance, calls the target metric, extracts the weights, and finally, converts
the weights to an integer allocation (if requested).
To specify the optimizer, use `optimizer` (see `resolve_pypfopt_optimizer`).
To specify the expected returns, use `expected_returns` (see `resolve_pypfopt_expected_returns`).
To specify the covariance matrix, use `cov_matrix` (see `resolve_pypfopt_cov_matrix`).
All other keyword arguments in `**kwargs` are used by `resolve_pypfopt_func_call`.
Each objective can be a function, an attribute of `pypfopt.objective_functions`, or an iterable of such.
Each constraint can be a function or an interable of such.
The target can be an attribute of the optimizer, or a stand-alone function.
If `target_is_convex` is True, the function is added as a convex function.
Otherwise, the function is added as a non-convex function. The keyword arguments
`weights_sum_to_one` and those starting with `target` are passed
`pypfopt.base_optimizer.BaseConvexOptimizer.convex_objective`
and `pypfopt.base_optimizer.BaseConvexOptimizer.nonconvex_objective` respectively.
Set `ignore_opt_errors` to True to ignore any target optimization errors.
Set `ignore_errors` to True to ignore any errors, even those caused by the user.
If `discrete_allocation` is True, resolves `pypfopt.discrete_allocation.DiscreteAllocation`
and calls `allocation_method` as an attribute of the allocation object.
Any function is resolved using `resolve_pypfopt_func_call`.
For defaults, see `pypfopt` under `vectorbtpro._settings.pfopt`.
Usage:
* Using mean historical returns, Ledoit-Wolf covariance matrix with constant variance,
and efficient frontier:
```pycon
>>> from vectorbtpro import *
>>> data = vbt.YFData.pull(["MSFT", "AMZN", "KO", "MA"])
```
[=100% "100%"]{: .candystripe .candystripe-animate }
```pycon
>>> vbt.pypfopt_optimize(prices=data.get("Close"))
{'MSFT': 0.13324, 'AMZN': 0.10016, 'KO': 0.03229, 'MA': 0.73431}
```
* EMA historical returns and sample covariance:
```pycon
>>> vbt.pypfopt_optimize(
... prices=data.get("Close"),
... expected_returns="ema_historical_return",
... cov_matrix="sample_cov"
... )
{'MSFT': 0.08984, 'AMZN': 0.0, 'KO': 0.91016, 'MA': 0.0}
```
* EMA historical returns, efficient Conditional Value at Risk, and other parameters automatically
passed to their respective functions. Optimized towards lowest CVaR:
```pycon
>>> vbt.pypfopt_optimize(
... prices=data.get("Close"),
... expected_returns="ema_historical_return",
... optimizer="efficient_cvar",
... beta=0.9,
... weight_bounds=(-1, 1),
... target="min_cvar"
... )
{'MSFT': 0.14779, 'AMZN': 0.07224, 'KO': 0.77552, 'MA': 0.00445}
```
* Adding custom objectives:
```pycon
>>> vbt.pypfopt_optimize(
... prices=data.get("Close"),
... objectives=["L2_reg"],
... gamma=0.1,
... target="min_volatility"
... )
{'MSFT': 0.22228, 'AMZN': 0.15685, 'KO': 0.28712, 'MA': 0.33375}
```
* Adding custom constraints:
```pycon
>>> vbt.pypfopt_optimize(
... prices=data.get("Close"),
... constraints=[lambda w: w[data.symbols.index("MSFT")] <= 0.1]
... )
{'MSFT': 0.1, 'AMZN': 0.10676, 'KO': 0.04341, 'MA': 0.74982}
```
* Optimizing towards a custom convex objective (to add a non-convex objective,
set `target_is_convex` to False):
```pycon
>>> import cvxpy as cp
>>> def logarithmic_barrier_objective(w, cov_matrix, k=0.1):
... log_sum = cp.sum(cp.log(w))
... var = cp.quad_form(w, cov_matrix)
... return var - k * log_sum
>>> pypfopt_optimize(
... prices=data.get("Close"),
... target=logarithmic_barrier_objective
... )
{'MSFT': 0.24595, 'AMZN': 0.23047, 'KO': 0.25862, 'MA': 0.26496}
```
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("pypfopt")
from pypfopt.exceptions import OptimizationError
from cvxpy.error import SolverError
from vectorbtpro._settings import settings
pypfopt_cfg = dict(settings["pfopt"]["pypfopt"])
def _resolve_setting(k, v):
setting = pypfopt_cfg.pop(k)
if v is None:
return setting
return v
target = _resolve_setting("target", target)
target_is_convex = _resolve_setting("target_is_convex", target_is_convex)
weights_sum_to_one = _resolve_setting("weights_sum_to_one", weights_sum_to_one)
target_constraints = _resolve_setting("target_constraints", target_constraints)
target_solver = _resolve_setting("target_solver", target_solver)
target_initial_guess = _resolve_setting("target_initial_guess", target_initial_guess)
objectives = _resolve_setting("objectives", objectives)
constraints = _resolve_setting("constraints", constraints)
sector_mapper = _resolve_setting("sector_mapper", sector_mapper)
sector_lower = _resolve_setting("sector_lower", sector_lower)
sector_upper = _resolve_setting("sector_upper", sector_upper)
discrete_allocation = _resolve_setting("discrete_allocation", discrete_allocation)
allocation_method = _resolve_setting("allocation_method", allocation_method)
silence_warnings = _resolve_setting("silence_warnings", silence_warnings)
ignore_opt_errors = _resolve_setting("ignore_opt_errors", ignore_opt_errors)
ignore_errors = _resolve_setting("ignore_errors", ignore_errors)
kwargs = merge_dicts(pypfopt_cfg, kwargs)
if "cache" not in kwargs:
kwargs["cache"] = {}
if "used_arg_names" not in kwargs:
kwargs["used_arg_names"] = set()
try:
with WarningsFiltered(entries="ignore" if silence_warnings else None):
optimizer = kwargs["optimizer"] = resolve_pypfopt_optimizer(**kwargs)
if objectives is not None:
if not checks.is_iterable(objectives) or isinstance(objectives, str):
objectives = [objectives]
for objective in objectives:
if isinstance(objective, str):
import pypfopt.objective_functions
objective = getattr(pypfopt.objective_functions, objective)
objective_kwargs = resolve_pypfopt_func_kwargs(objective, **kwargs)
optimizer.add_objective(objective, **objective_kwargs)
if constraints is not None:
if not checks.is_iterable(constraints):
constraints = [constraints]
for constraint in constraints:
optimizer.add_constraint(constraint)
if sector_mapper is not None:
if sector_lower is None:
sector_lower = {}
if sector_upper is None:
sector_upper = {}
optimizer.add_sector_constraints(sector_mapper, sector_lower, sector_upper)
try:
if isinstance(target, str):
resolve_pypfopt_func_call(getattr(optimizer, target), **kwargs)
else:
if target_is_convex:
optimizer.convex_objective(
target,
weights_sum_to_one=weights_sum_to_one,
**resolve_pypfopt_func_kwargs(target, **kwargs),
)
else:
optimizer.nonconvex_objective(
target,
objective_args=tuple(resolve_pypfopt_func_kwargs(target, **kwargs).values()),
weights_sum_to_one=weights_sum_to_one,
constraints=target_constraints,
solver=target_solver,
initial_guess=target_initial_guess,
)
except (OptimizationError, SolverError, ValueError) as e:
if isinstance(e, ValueError) and "expected return exceeding the risk-free rate" not in str(e):
raise e
if ignore_opt_errors:
warn(str(e))
return {}
raise e
weights = kwargs["weights"] = resolve_pypfopt_func_call(optimizer.clean_weights, **kwargs)
if discrete_allocation:
from pypfopt.discrete_allocation import DiscreteAllocation
allocator = resolve_pypfopt_func_call(DiscreteAllocation, **kwargs)
return resolve_pypfopt_func_call(getattr(allocator, allocation_method), **kwargs)[0]
passed_arg_names = set(kwargs.keys())
passed_arg_names.remove("cache")
passed_arg_names.remove("used_arg_names")
passed_arg_names.remove("optimizer")
passed_arg_names.remove("weights")
unused_arg_names = passed_arg_names.difference(kwargs["used_arg_names"])
if len(unused_arg_names) > 0:
warn(f"Some arguments were not used: {unused_arg_names}")
if not discrete_allocation:
weights = {k: 1 if v >= 1 else v for k, v in weights.items()}
return dict(weights)
except Exception as e:
if ignore_errors:
return {}
raise e
# ############# Riskfolio-Lib ############# #
def prepare_returns(
returns: tp.AnyArray2d,
nan_to_zero: bool = True,
dropna_rows: bool = True,
dropna_cols: bool = True,
dropna_any: bool = True,
) -> tp.Frame:
"""Prepare returns."""
returns = to_pd_array(returns)
if not isinstance(returns, pd.DataFrame):
raise ValueError("Returns must be a two-dimensional array")
if returns.size == 0:
return returns
if nan_to_zero or dropna_rows or dropna_cols or dropna_any:
returns = returns.replace([np.inf, -np.inf], np.nan)
if nan_to_zero:
returns = returns.fillna(0.0)
if dropna_rows or dropna_cols:
if nan_to_zero:
valid_mask = returns != 0
else:
valid_mask = ~returns.isnull()
if dropna_rows:
if nan_to_zero or not dropna_any:
returns = returns.loc[valid_mask.any(axis=1)]
if returns.size == 0:
return returns
if dropna_cols:
returns = returns.loc[:, valid_mask.any(axis=0)]
if returns.size == 0:
return returns
if not nan_to_zero and dropna_any:
returns = returns.dropna()
return returns
def resolve_riskfolio_func_kwargs(
riskfolio_func: tp.Callable,
unused_arg_names: tp.Optional[tp.Set[str]] = None,
func_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Kwargs:
"""Select keyword arguments belonging to `riskfolio_func`."""
func_arg_names = get_func_arg_names(riskfolio_func)
matched_kwargs = dict()
for k, v in kwargs.items():
if k in func_arg_names:
matched_kwargs[k] = v
if unused_arg_names is not None:
if k in unused_arg_names:
unused_arg_names.remove(k)
if func_kwargs is not None:
return merge_dicts(
select_pfopt_func_kwargs(riskfolio_func, matched_kwargs),
select_pfopt_func_kwargs(riskfolio_func, pfopt_func_dict(func_kwargs)),
)
return select_pfopt_func_kwargs(riskfolio_func, matched_kwargs)
def resolve_asset_classes(
asset_classes: tp.Union[None, tp.Frame, tp.Sequence],
columns: tp.Index,
col_indices: tp.Optional[tp.Sequence[int]] = None,
) -> tp.Frame:
"""Resolve asset classes for Riskfolio-Lib.
Supports the following formats:
* None: Takes columns where the bottom-most level is assumed to be assets
* Index: Each level in the index must be a different asset class set
* Nested dict: Each sub-dict must be a different asset class set
* Sequence of strings or ints: Matches them against level names in the columns.
If the columns have a single level, or some level names were not found, uses the sequence
directly as one class asset set named 'Class'.
* Sequence of dicts: Each dict becomes a row in the new DataFrame
* DataFrame where the first column is the asset list and the next columns are the
different asset’s classes sets (this is the target format accepted by Riskfolio-Lib).
See an example [here](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_constraints).
!!! note
If `asset_classes` is neither None nor a DataFrame, the bottom-most level in `columns`
gets renamed to 'Assets' and becomes the first column of the new DataFrame."""
if asset_classes is None:
asset_classes = columns.to_frame().reset_index(drop=True).iloc[:, ::-1]
asset_classes = asset_classes.rename(columns={asset_classes.columns[0]: "Assets"})
if not isinstance(asset_classes, pd.DataFrame):
if isinstance(asset_classes, dict):
asset_classes = pd.DataFrame(asset_classes)
elif isinstance(asset_classes, pd.Index):
asset_classes = asset_classes.to_frame().reset_index(drop=True)
elif checks.is_sequence(asset_classes) and isinstance(asset_classes[0], int):
asset_classes = select_levels(columns, asset_classes).to_frame().reset_index(drop=True)
elif checks.is_sequence(asset_classes) and isinstance(asset_classes[0], str):
if isinstance(columns, pd.MultiIndex) and set(asset_classes) <= set(columns.names):
asset_classes = select_levels(columns, asset_classes).to_frame().reset_index(drop=True)
else:
asset_classes = pd.Index(asset_classes, name="Class").to_frame().reset_index(drop=True)
else:
asset_classes = pd.DataFrame.from_records(asset_classes)
if isinstance(columns, pd.MultiIndex):
assets = columns.get_level_values(-1)
else:
assets = columns
if col_indices is not None and len(col_indices) > 0:
asset_classes = asset_classes.iloc[col_indices]
asset_classes.insert(loc=0, column="Assets", value=assets)
return asset_classes
def resolve_assets_constraints(constraints: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame:
"""Resolve asset constraints for Riskfolio-Lib.
Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_constraints),
also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame.
Dicts don't have to specify all column names, the function will autofill any missing elements/columns."""
if not isinstance(constraints, pd.DataFrame):
if isinstance(constraints, dict):
constraints = pd.DataFrame(constraints)
else:
constraints = pd.DataFrame.from_records(constraints)
constraints.columns = constraints.columns.str.title()
new_constraints = pd.DataFrame(
columns=[
"Disabled",
"Type",
"Set",
"Position",
"Sign",
"Weight",
"Type Relative",
"Relative Set",
"Relative",
"Factor",
],
dtype=object,
)
for c in new_constraints.columns:
if c in constraints.columns:
new_constraints[c] = constraints[c]
new_constraints.fillna("", inplace=True)
new_constraints["Disabled"].replace("", False, inplace=True)
constraints = new_constraints
return constraints
def resolve_factors_constraints(constraints: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame:
"""Resolve factors constraints for Riskfolio-Lib.
Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.factors_constraints),
also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame.
Dicts don't have to specify all column names, the function will autofill any missing elements/columns."""
if not isinstance(constraints, pd.DataFrame):
if isinstance(constraints, dict):
constraints = pd.DataFrame(constraints)
else:
constraints = pd.DataFrame.from_records(constraints)
constraints.columns = constraints.columns.str.title()
new_constraints = pd.DataFrame(
columns=[
"Disabled",
"Factor",
"Sign",
"Value",
"Relative Factor",
],
dtype=object,
)
for c in new_constraints.columns:
if c in constraints.columns:
new_constraints[c] = constraints[c]
new_constraints.fillna("", inplace=True)
new_constraints["Disabled"].replace("", False, inplace=True)
constraints = new_constraints
return constraints
def resolve_assets_views(views: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame:
"""Resolve asset views for Riskfolio-Lib.
Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_views),
also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame.
Dicts don't have to specify all column names, the function will autofill any missing elements/columns."""
if not isinstance(views, pd.DataFrame):
if isinstance(views, dict):
views = pd.DataFrame(views)
else:
views = pd.DataFrame.from_records(views)
views.columns = views.columns.str.title()
new_views = pd.DataFrame(
columns=[
"Disabled",
"Type",
"Set",
"Position",
"Sign",
"Return",
"Type Relative",
"Relative Set",
"Relative",
],
dtype=object,
)
for c in new_views.columns:
if c in views.columns:
new_views[c] = views[c]
new_views.fillna("", inplace=True)
new_views["Disabled"].replace("", False, inplace=True)
views = new_views
return views
def resolve_factors_views(views: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame:
"""Resolve factors views for Riskfolio-Lib.
Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.factors_views),
also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame.
Dicts don't have to specify all column names, the function will autofill any missing elements/columns."""
if not isinstance(views, pd.DataFrame):
if isinstance(views, dict):
views = pd.DataFrame(views)
else:
views = pd.DataFrame.from_records(views)
views.columns = views.columns.str.title()
new_views = pd.DataFrame(
columns=[
"Disabled",
"Factor",
"Sign",
"Value",
"Relative Factor",
],
dtype=object,
)
for c in new_views.columns:
if c in views.columns:
new_views[c] = views[c]
new_views.fillna("", inplace=True)
new_views["Disabled"].replace("", False, inplace=True)
views = new_views
return views
def resolve_hrp_constraints(constraints: tp.Union[tp.Frame, tp.Sequence]) -> tp.Frame:
"""Resolve HRP constraints for Riskfolio-Lib.
Apart from the [target format](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.hrp_constraints),
also accepts a sequence of dicts such that each dict becomes a row in a new DataFrame.
Dicts don't have to specify all column names, the function will autofill any missing elements/columns."""
if not isinstance(constraints, pd.DataFrame):
if isinstance(constraints, dict):
constraints = pd.DataFrame(constraints)
else:
constraints = pd.DataFrame.from_records(constraints)
constraints.columns = constraints.columns.str.title()
new_constraints = pd.DataFrame(
columns=[
"Disabled",
"Type",
"Set",
"Position",
"Sign",
"Weight",
],
dtype=object,
)
for c in new_constraints.columns:
if c in constraints.columns:
new_constraints[c] = constraints[c]
new_constraints.fillna("", inplace=True)
new_constraints["Disabled"].replace("", False, inplace=True)
constraints = new_constraints
return constraints
def riskfolio_optimize(
returns: tp.AnyArray2d,
nan_to_zero: tp.Optional[bool] = None,
dropna_rows: tp.Optional[bool] = None,
dropna_cols: tp.Optional[bool] = None,
dropna_any: tp.Optional[bool] = None,
factors: tp.Optional[tp.AnyArray2d] = None,
port: tp.Optional[RPortfolioT] = None,
port_cls: tp.Union[None, str, tp.Type] = None,
opt_method: tp.Union[None, str, tp.Callable] = None,
stats_methods: tp.Optional[tp.Sequence[str]] = None,
model: tp.Optional[str] = None,
asset_classes: tp.Union[None, tp.Frame, tp.Sequence] = None,
constraints_method: tp.Optional[str] = None,
constraints: tp.Union[None, tp.Frame, tp.Sequence] = None,
views_method: tp.Optional[str] = None,
views: tp.Union[None, tp.Frame, tp.Sequence] = None,
solvers: tp.Optional[tp.Sequence[str]] = None,
sol_params: tp.KwargsLike = None,
freq: tp.Optional[tp.FrequencyLike] = None,
year_freq: tp.Optional[tp.FrequencyLike] = None,
pre_opt: tp.Optional[bool] = None,
pre_opt_kwargs: tp.KwargsLike = None,
pre_opt_as_w: tp.Optional[bool] = None,
func_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
return_port: tp.Optional[bool] = None,
ignore_errors: tp.Optional[bool] = None,
**kwargs,
) -> tp.Union[tp.Dict[str, float], tp.Tuple[tp.Dict[str, float], RPortfolioT]]:
"""Get allocation using Riskfolio-Lib.
Args:
returns (array_like): A dataframe that contains the returns of the assets.
nan_to_zero (bool): Whether to convert NaN values to zero.
dropna_rows (bool): Whether to drop rows with all NaN/zero values.
Gets applied only if `nan_to_zero` is True or `dropna_any` is False.
dropna_cols (bool): Whether to drop columns with all NaN/zero values.
dropna_any (bool): Whether to drop any NaN values.
Gets applied only if `nan_to_zero` is False.
factors (array_like): A dataframe that contains the factors.
port (Portfolio or HCPortfolio): Already initialized portfolio.
port_cls (str or type): Portfolio class.
Supports the following values:
* None: Uses `Portfolio`
* 'hc' or 'hcportfolio' (case-insensitive): Uses `HCPortfolio`
* Other string: Uses attribute of `riskfolio`
* Class: Uses a custom class
opt_method (str or callable): Optimization method.
Supports the following values:
* None or 'optimization': Uses `port.optimization` (where `port` is a portfolio instance)
* 'wc' or 'wc_optimization': Uses `port.wc_optimization`
* 'rp' or 'rp_optimization': Uses `port.rp_optimization`
* 'rrp' or 'rrp_optimization': Uses `port.rrp_optimization`
* 'owa' or 'owa_optimization': Uses `port.owa_optimization`
* String: Uses attribute of `port`
* Callable: Uses a custom optimization function
stats_methods (str or sequence of str): Sequence of stats methods to call before optimization.
If None, tries to automatically populate the sequence using `opt_method` and `model`.
For example, calls `port.assets_stats` if `model="Classic"` is used.
Also, if `func_kwargs` is not empty, adds all functions whose name ends with '_stats'.
model (str): The model used to optimize the portfolio.
asset_classes (any): Asset classes matrix.
See `resolve_asset_classes` for possible formats.
constraints_method (str): Constraints method.
Supports the following values:
* 'assets' or 'assets_constraints': [assets constraints](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_constraints)
* 'factors' or 'factors_constraints': [factors constraints](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.factors_constraints)
* 'hrp' or 'hrp_constraints': [HRP constraints](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.hrp_constraints)
If None and the class `Portfolio` is used, will use factors constraints if `factors_stats` is used,
otherwise assets constraints. If the class `HCPortfolio` is used, will use HRP constraints.
constraints (any): Constraints matrix.
See `resolve_assets_constraints` for possible formats of assets constraints,
`resolve_factors_constraints` for possible formats of factors constraints, and
`resolve_hrp_constraints` for possible formats of HRP constraints.
views_method (str): Views method.
Supports the following values:
* 'assets' or 'assets_views': [assets views](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.assets_views)
* 'factors' or 'factors_views': [factors views](https://riskfolio-lib.readthedocs.io/en/latest/constraints.html#ConstraintsFunctions.factors_views)
If None, will use factors views if `blfactors_stats` is used, otherwise assets views.
views (any): Views matrix.
See `resolve_assets_views` for possible formats of assets views and
`resolve_factors_views` for possible formats of factors views.
solvers (list of str): Solvers.
sol_params (dict): Solver parameters.
freq (frequency_like): Frequency to be used to compute the annualization factor.
Make sure to provide it when using views.
year_freq (frequency_like): Year frequency to be used to compute the annualization factor.
Make sure to provide it when using views.
pre_opt (bool): Whether to pre-optimize the portfolio with `pre_opt_kwargs`.
pre_opt_kwargs (dict): Call `riskfolio_optimize` with these keyword arguments
and use the returned portfolio for further optimization.
pre_opt_as_w (bool): Whether to use the weights as `w` from the pre-optimization step.
func_kwargs (dict): Further keyword arguments by function.
Can be used to override any arguments from `kwargs` matched with the function,
or to add more arguments. Will be wrapped with `pfopt_func_dict` and passed to
`select_pfopt_func_kwargs` when calling each Riskfolio-Lib's function.
silence_warnings (bool): Whether to silence all warnings.
return_port (bool): Whether to also return the portfolio.
ignore_errors (bool): Whether to ignore any errors, even those caused by the user.
**kwargs: Keyword arguments that will be passed to any Riskfolio-Lib's function
that needs them (i.e., lists any of them in its signature).
For defaults, see `riskfolio` under `vectorbtpro._settings.pfopt`.
Usage:
* Classic Mean Risk Optimization:
```pycon
>>> from vectorbtpro import *
>>> data = vbt.YFData.pull(["MSFT", "AMZN", "KO", "MA"])
>>> returns = data.close.vbt.to_returns()
```
[=100% "100%"]{: .candystripe .candystripe-animate }
```pycon
>>> vbt.riskfolio_optimize(
... returns,
... method_mu='hist', method_cov='hist', d=0.94, # assets_stats
... model='Classic', rm='MV', obj='Sharpe', hist=True, rf=0, l=0 # optimization
... )
{'MSFT': 0.26297126323056036,
'AMZN': 0.13984467450137006,
'KO': 0.35870315943426767,
'MA': 0.238480902833802}
```
* The same by splitting arguments:
```pycon
>>> vbt.riskfolio_optimize(
... returns,
... func_kwargs=dict(
... assets_stats=dict(method_mu='hist', method_cov='hist', d=0.94),
... optimization=dict(model='Classic', rm='MV', obj='Sharpe', hist=True, rf=0, l=0)
... )
... )
{'MSFT': 0.26297126323056036,
'AMZN': 0.13984467450137006,
'KO': 0.35870315943426767,
'MA': 0.238480902833802}
```
* Asset constraints:
```pycon
>>> vbt.riskfolio_optimize(
... returns,
... constraints=[
... {
... "Type": "Assets",
... "Position": "MSFT",
... "Sign": "<=",
... "Weight": 0.01
... }
... ]
... )
{'MSFT': 0.009999990814976588,
'AMZN': 0.19788481506569947,
'KO': 0.4553600308839969,
'MA': 0.336755163235327}
```
* Asset class constraints:
```pycon
>>> vbt.riskfolio_optimize(
... returns,
... asset_classes=["C1", "C1", "C2", "C2"],
... constraints=[
... {
... "Type": "Classes",
... "Set": "Class",
... "Position": "C1",
... "Sign": "<=",
... "Weight": 0.1
... }
... ]
... )
{'MSFT': 0.03501297245802569,
'AMZN': 0.06498702655063979,
'KO': 0.4756624658301967,
'MA': 0.4243375351611379}
```
* Hierarchical Risk Parity (HRP) Portfolio Optimization:
```pycon
>>> vbt.riskfolio_optimize(
... returns,
... port_cls="HCPortfolio",
... model='HRP',
... codependence='pearson',
... rm='MV',
... rf=0,
... linkage='single',
... max_k=10,
... leaf_order=True
... )
{'MSFT': 0.19091632057853536,
'AMZN': 0.11069893826556164,
'KO': 0.28589872132122485,
'MA': 0.41248601983467814}
```
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("riskfolio")
import riskfolio as rp
from vectorbtpro._settings import settings
riskfolio_cfg = dict(settings["pfopt"]["riskfolio"])
def _resolve_setting(k, v):
setting = riskfolio_cfg.pop(k)
if v is None:
return setting
return v
nan_to_zero = _resolve_setting("nan_to_zero", nan_to_zero)
dropna_rows = _resolve_setting("dropna_rows", dropna_rows)
dropna_cols = _resolve_setting("dropna_cols", dropna_cols)
dropna_any = _resolve_setting("dropna_any", dropna_any)
factors = _resolve_setting("factors", factors)
port = _resolve_setting("port", port)
port_cls = _resolve_setting("port_cls", port_cls)
opt_method = _resolve_setting("opt_method", opt_method)
stats_methods = _resolve_setting("stats_methods", stats_methods)
model = _resolve_setting("model", model)
asset_classes = _resolve_setting("asset_classes", asset_classes)
constraints_method = _resolve_setting("constraints_method", constraints_method)
constraints = _resolve_setting("constraints", constraints)
views_method = _resolve_setting("views_method", views_method)
views = _resolve_setting("views", views)
solvers = _resolve_setting("solvers", solvers)
sol_params = _resolve_setting("sol_params", sol_params)
freq = _resolve_setting("freq", freq)
year_freq = _resolve_setting("year_freq", year_freq)
pre_opt = _resolve_setting("pre_opt", pre_opt)
pre_opt_kwargs = merge_dicts(riskfolio_cfg.pop("pre_opt_kwargs"), pre_opt_kwargs)
pre_opt_as_w = _resolve_setting("pre_opt_as_w", pre_opt_as_w)
func_kwargs = merge_dicts(riskfolio_cfg.pop("func_kwargs"), func_kwargs)
silence_warnings = _resolve_setting("silence_warnings", silence_warnings)
return_port = _resolve_setting("return_port", return_port)
ignore_errors = _resolve_setting("ignore_errors", ignore_errors)
kwargs = merge_dicts(riskfolio_cfg, kwargs)
if pre_opt_kwargs is None:
pre_opt_kwargs = {}
if func_kwargs is None:
func_kwargs = {}
func_kwargs = pfopt_func_dict(func_kwargs)
unused_arg_names = set(kwargs.keys())
try:
with WarningsFiltered(entries="ignore" if silence_warnings else None):
# Prepare returns
new_returns = prepare_returns(
returns,
nan_to_zero=nan_to_zero,
dropna_rows=dropna_rows,
dropna_cols=dropna_cols,
dropna_any=dropna_any,
)
col_indices = [i for i, c in enumerate(returns.columns) if c in new_returns.columns]
returns = new_returns
if returns.size == 0:
return {}
# Pre-optimize
if pre_opt:
w, port = riskfolio_optimize(returns, port=port, return_port=True, **pre_opt_kwargs)
if pre_opt_as_w:
w = pd.DataFrame.from_records([w]).T.rename(columns={0: "weights"})
kwargs["w"] = w
unused_arg_names.add("w")
# Build portfolio
if port_cls is None:
port_cls = rp.Portfolio
elif isinstance(port_cls, str) and port_cls.lower() in ("hc", "hcportfolio"):
port_cls = rp.HCPortfolio
elif isinstance(port_cls, str):
port_cls = getattr(rp, port_cls)
else:
port_cls = port_cls
matched_kwargs = resolve_riskfolio_func_kwargs(
port_cls,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
if port is None:
port = port_cls(returns, **matched_kwargs)
else:
for k, v in matched_kwargs.items():
setattr(port, k, v)
if solvers is not None:
port.solvers = list(solvers)
if sol_params is not None:
port.sol_params = dict(sol_params)
if factors is not None:
factors = to_pd_array(factors).dropna()
port.factors = factors
# Resolve optimization and stats methods
if opt_method is None:
if len(func_kwargs) > 0:
for name_or_func in func_kwargs:
if isinstance(name_or_func, str):
if name_or_func.endswith("optimization"):
if opt_method is not None:
raise ValueError("Function keyword arguments list multiple optimization methods")
opt_method = name_or_func
if opt_method is None:
opt_method = "optimization"
if stats_methods is None:
if len(func_kwargs) > 0:
for name_or_func in func_kwargs:
if isinstance(name_or_func, str):
if name_or_func.endswith("_stats"):
if stats_methods is None:
stats_methods = []
stats_methods.append(name_or_func)
if isinstance(port, rp.Portfolio):
if isinstance(opt_method, str) and opt_method.lower() == "optimization":
opt_func = port.optimization
if model is None:
opt_func_kwargs = select_pfopt_func_kwargs(opt_func, func_kwargs)
model = opt_func_kwargs.get("model", "Classic")
if model.lower() == "classic":
model = "Classic"
if stats_methods is None:
stats_methods = ["assets_stats"]
elif model.lower() == "fm":
model = "FM"
if stats_methods is None:
stats_methods = ["assets_stats", "factors_stats"]
elif model.lower() == "bl":
model = "BL"
if stats_methods is None:
stats_methods = ["assets_stats", "blacklitterman_stats"]
elif model.lower() in ("bl_fm", "blfm"):
model = "BL_FM"
if stats_methods is None:
stats_methods = ["assets_stats", "factors_stats", "blfactors_stats"]
elif isinstance(opt_method, str) and opt_method.lower() in ("wc", "wc_optimization"):
opt_func = port.wc_optimization
if stats_methods is None:
stats_methods = ["assets_stats", "wc_stats"]
elif isinstance(opt_method, str) and opt_method.lower() in ("rp", "rp_optimization"):
opt_func = port.rp_optimization
if model is None:
opt_func_kwargs = select_pfopt_func_kwargs(opt_func, func_kwargs)
model = opt_func_kwargs.get("model", "Classic")
if model.lower() == "classic":
model = "Classic"
if stats_methods is None:
stats_methods = ["assets_stats"]
elif model.lower() == "fm":
model = "FM"
if stats_methods is None:
stats_methods = ["assets_stats", "factors_stats"]
elif isinstance(opt_method, str) and opt_method.lower() in ("rrp", "rrp_optimization"):
opt_func = port.rrp_optimization
if model is None:
opt_func_kwargs = select_pfopt_func_kwargs(opt_func, func_kwargs)
model = opt_func_kwargs.get("model", "Classic")
if model.lower() == "classic":
model = "Classic"
if stats_methods is None:
stats_methods = ["assets_stats"]
elif model.lower() == "fm":
model = "FM"
if stats_methods is None:
stats_methods = ["assets_stats", "factors_stats"]
elif isinstance(opt_method, str) and opt_method.lower() in ("owa", "owa_optimization"):
opt_func = port.owa_optimization
if stats_methods is None:
stats_methods = ["assets_stats"]
elif isinstance(opt_method, str):
opt_func = getattr(port, opt_method)
else:
opt_func = opt_method
else:
if isinstance(opt_method, str):
opt_func = getattr(port, opt_method)
else:
opt_func = opt_method
if model is not None:
kwargs["model"] = model
unused_arg_names.add("model")
if stats_methods is None:
stats_methods = []
# Apply constraints
if constraints is not None:
if constraints_method is None:
if isinstance(port, rp.Portfolio):
if "factors_stats" in stats_methods:
constraints_method = "factors"
else:
constraints_method = "assets"
elif isinstance(port, rp.HCPortfolio):
constraints_method = "hrp"
else:
raise ValueError("Constraints method is required")
if constraints_method.lower() in ("assets", "assets_constraints"):
asset_classes = resolve_asset_classes(asset_classes, returns.columns, col_indices)
kwargs["asset_classes"] = asset_classes
unused_arg_names.add("asset_classes")
constraints = resolve_assets_constraints(constraints)
kwargs["constraints"] = constraints
unused_arg_names.add("constraints")
matched_kwargs = resolve_riskfolio_func_kwargs(
rp.assets_constraints,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
port.ainequality, port.binequality = warn_stdout(rp.assets_constraints)(**matched_kwargs)
elif constraints_method.lower() in ("factors", "factors_constraints"):
if "loadings" not in kwargs:
matched_kwargs = resolve_riskfolio_func_kwargs(
rp.loadings_matrix,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
if "X" not in matched_kwargs:
matched_kwargs["X"] = port.factors
if "Y" not in matched_kwargs:
matched_kwargs["Y"] = port.returns
loadings = warn_stdout(rp.loadings_matrix)(**matched_kwargs)
kwargs["loadings"] = loadings
unused_arg_names.add("loadings")
constraints = resolve_factors_constraints(constraints)
kwargs["constraints"] = constraints
unused_arg_names.add("constraints")
matched_kwargs = resolve_riskfolio_func_kwargs(
rp.factors_constraints,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
port.ainequality, port.binequality = warn_stdout(rp.factors_constraints)(**matched_kwargs)
elif constraints_method.lower() in ("hrp", "hrp_constraints"):
asset_classes = resolve_asset_classes(asset_classes, returns.columns, col_indices)
kwargs["asset_classes"] = asset_classes
unused_arg_names.add("asset_classes")
constraints = resolve_hrp_constraints(constraints)
kwargs["constraints"] = constraints
unused_arg_names.add("constraints")
matched_kwargs = resolve_riskfolio_func_kwargs(
rp.hrp_constraints,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
port.w_max, port.w_min = warn_stdout(rp.hrp_constraints)(**matched_kwargs)
else:
raise ValueError(f"Constraints method '{constraints_method}' is not supported")
# Resolve views
if views is not None:
if views_method is None:
if "blfactors_stats" in stats_methods:
views_method = "factors"
else:
views_method = "assets"
if views_method.lower() in ("assets", "assets_views"):
asset_classes = resolve_asset_classes(asset_classes, returns.columns, col_indices)
kwargs["asset_classes"] = asset_classes
unused_arg_names.add("asset_classes")
views = resolve_assets_views(views)
kwargs["views"] = views
unused_arg_names.add("views")
matched_kwargs = resolve_riskfolio_func_kwargs(
rp.assets_views,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
P, Q = warn_stdout(rp.assets_views)(**matched_kwargs)
ann_factor = ReturnsAccessor.get_ann_factor(year_freq=year_freq, freq=freq)
if ann_factor is not None:
Q /= ann_factor
else:
warn(f"Set frequency and year frequency to adjust expected returns")
kwargs["P"] = P
unused_arg_names.add("P")
kwargs["Q"] = Q
unused_arg_names.add("Q")
elif views_method.lower() in ("factors", "factors_views"):
if "loadings" not in kwargs:
matched_kwargs = resolve_riskfolio_func_kwargs(
rp.loadings_matrix,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
if "X" not in matched_kwargs:
matched_kwargs["X"] = port.factors
if "Y" not in matched_kwargs:
matched_kwargs["Y"] = port.returns
loadings = warn_stdout(rp.loadings_matrix)(**matched_kwargs)
kwargs["loadings"] = loadings
unused_arg_names.add("loadings")
if "B" not in kwargs:
kwargs["B"] = kwargs["loadings"]
unused_arg_names.add("B")
views = resolve_factors_views(views)
kwargs["views"] = views
unused_arg_names.add("views")
matched_kwargs = resolve_riskfolio_func_kwargs(
rp.factors_views,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
P_f, Q_f = warn_stdout(rp.factors_views)(**matched_kwargs)
ann_factor = ReturnsAccessor.get_ann_factor(year_freq=year_freq, freq=freq)
if ann_factor is not None:
Q_f /= ann_factor
else:
warn(f"Set frequency and year frequency to adjust expected returns")
kwargs["P_f"] = P_f
unused_arg_names.add("P_f")
kwargs["Q_f"] = Q_f
unused_arg_names.add("Q_f")
else:
raise ValueError(f"Views method '{constraints_method}' is not supported")
# Run stats
for stats_method in stats_methods:
stats_func = getattr(port, stats_method)
matched_kwargs = resolve_riskfolio_func_kwargs(
stats_func,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
warn_stdout(stats_func)(**matched_kwargs)
# Run optimization
matched_kwargs = resolve_riskfolio_func_kwargs(
opt_func,
unused_arg_names=unused_arg_names,
func_kwargs=func_kwargs,
**kwargs,
)
weights = warn_stdout(opt_func)(**matched_kwargs)
# Post-process weights
if len(unused_arg_names) > 0:
warn(f"Some arguments were not used: {unused_arg_names}")
if weights is None:
weights = {}
if isinstance(weights, pd.DataFrame):
if "weights" not in weights.columns:
raise ValueError("Weights column wasn't returned")
weights = weights["weights"]
if return_port:
return dict(weights), port
return dict(weights)
except Exception as e:
if ignore_errors:
return {}
raise e
# ############# PortfolioOptimizer ############# #
PortfolioOptimizerT = tp.TypeVar("PortfolioOptimizerT", bound="PortfolioOptimizer")
class PortfolioOptimizer(Analyzable):
"""Class that exposes methods for generating allocations."""
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[PortfolioOptimizerT],
*objs: tp.MaybeTuple[PortfolioOptimizerT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> PortfolioOptimizerT:
"""Stack multiple `PortfolioOptimizer` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, PortfolioOptimizer):
raise TypeError("Each object to be merged must be an instance of PortfolioOptimizer")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.row_stack(
*[obj.wrapper for obj in objs],
stack_columns=False,
**wrapper_kwargs,
)
if "alloc_records" not in kwargs:
alloc_records_type = None
for obj in objs:
if alloc_records_type is None:
alloc_records_type = type(obj.alloc_records)
elif not isinstance(obj.alloc_records, alloc_records_type):
raise TypeError("Objects to be merged must have the same type for alloc_records")
kwargs["alloc_records"] = alloc_records_type.row_stack(
*[obj.alloc_records for obj in objs],
wrapper_kwargs=wrapper_kwargs,
)
if "allocations" not in kwargs:
record_indices = type(kwargs["alloc_records"]).get_row_stack_record_indices(
*[obj.alloc_records for obj in objs],
wrapper=kwargs["alloc_records"].wrapper,
)
kwargs["allocations"] = row_stack_arrays([obj._allocations for obj in objs])[record_indices]
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[PortfolioOptimizerT],
*objs: tp.MaybeTuple[PortfolioOptimizerT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> PortfolioOptimizerT:
"""Stack multiple `PortfolioOptimizer` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, PortfolioOptimizer):
raise TypeError("Each object to be merged must be an instance of PortfolioOptimizer")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in objs],
union_index=False,
**wrapper_kwargs,
)
if "alloc_records" not in kwargs:
alloc_records_type = None
for obj in objs:
if alloc_records_type is None:
alloc_records_type = type(obj.alloc_records)
elif not isinstance(obj.alloc_records, alloc_records_type):
raise TypeError("Objects to be merged must have the same type for alloc_records")
kwargs["alloc_records"] = alloc_records_type.column_stack(
*[obj.alloc_records for obj in objs],
wrapper_kwargs=wrapper_kwargs,
)
if "allocations" not in kwargs:
record_indices = type(kwargs["alloc_records"]).get_column_stack_record_indices(
*[obj.alloc_records for obj in objs],
wrapper=kwargs["alloc_records"].wrapper,
)
kwargs["allocations"] = row_stack_arrays([obj._allocations for obj in objs])[record_indices]
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
def __init__(
self,
wrapper: ArrayWrapper,
alloc_records: tp.Union[AllocRanges, AllocPoints],
allocations: tp.Array2d,
**kwargs,
) -> None:
Analyzable.__init__(
self,
wrapper,
alloc_records=alloc_records,
allocations=allocations,
**kwargs,
)
self._alloc_records = alloc_records
self._allocations = allocations
# Only slices of rows can be selected
self._range_only_select = True
def indexing_func(
self: PortfolioOptimizerT,
*args,
wrapper_meta: tp.DictLike = None,
alloc_wrapper_meta: tp.DictLike = None,
alloc_records_meta: tp.DictLike = None,
**kwargs,
) -> PortfolioOptimizerT:
"""Perform indexing on `PortfolioOptimizer`."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs)
if alloc_records_meta is None:
alloc_records_meta = self.alloc_records.indexing_func_meta(
*args,
wrapper_meta=alloc_wrapper_meta,
**kwargs,
)
new_alloc_records = self.alloc_records.indexing_func(
*args,
records_meta=alloc_records_meta,
**kwargs,
)
new_allocations = to_2d_array(self._allocations)[alloc_records_meta["new_indices"]]
return self.replace(
wrapper=wrapper_meta["new_wrapper"],
alloc_records=new_alloc_records,
allocations=new_allocations,
)
def resample(self: PortfolioOptimizerT, *args, **kwargs) -> PortfolioOptimizerT:
"""Perform resampling on `PortfolioOptimizer`."""
new_wrapper = self.wrapper.resample(*args, **kwargs)
new_alloc_records = self.alloc_records.resample(*args, **kwargs)
return self.replace(
wrapper=new_wrapper,
alloc_records=new_alloc_records,
)
# ############# Class methods ############# #
@classmethod
def run_allocation_group(
cls,
wrapper: ArrayWrapper,
group_configs: tp.List[dict],
group_index: tp.Index,
group_idx: int,
pre_group_func: tp.Optional[tp.Callable] = None,
) -> tp.Tuple[tp.RecordArray, tp.Array2d]:
"""Run an allocation group."""
group_config = dict(group_configs[group_idx])
if pre_group_func is not None:
pre_group_func(group_config)
allocate_func = group_config.pop("allocate_func")
every = group_config.pop("every")
normalize_every = group_config.pop("normalize_every")
at_time = group_config.pop("at_time")
start = group_config.pop("start")
end = group_config.pop("end")
exact_start = group_config.pop("exact_start")
on = group_config.pop("on")
add_delta = group_config.pop("add_delta")
kind = group_config.pop("kind")
indexer_method = group_config.pop("indexer_method")
indexer_tolerance = group_config.pop("indexer_tolerance")
skip_not_found = group_config.pop("skip_not_found")
index_points = group_config.pop("index_points")
rescale_to = group_config.pop("rescale_to")
jitted_loop = group_config.pop("jitted_loop")
jitted = group_config.pop("jitted")
chunked = group_config.pop("chunked")
template_context = group_config.pop("template_context")
execute_kwargs = group_config.pop("execute_kwargs")
args = group_config.pop("args")
kwargs = group_config
template_context = merge_dicts(
dict(
group_configs=group_configs,
group_index=group_index,
group_idx=group_idx,
wrapper=wrapper,
allocate_func=allocate_func,
every=every,
normalize_every=normalize_every,
at_time=at_time,
start=start,
end=end,
exact_start=exact_start,
on=on,
add_delta=add_delta,
kind=kind,
indexer_method=indexer_method,
indexer_tolerance=indexer_tolerance,
skip_not_found=skip_not_found,
index_points=index_points,
rescale_to=rescale_to,
jitted_loop=jitted_loop,
jitted=jitted,
chunked=chunked,
execute_kwargs=execute_kwargs,
args=args,
kwargs=kwargs,
),
template_context,
)
if index_points is None:
get_index_points_kwargs = substitute_templates(
dict(
every=every,
normalize_every=normalize_every,
at_time=at_time,
start=start,
end=end,
exact_start=exact_start,
on=on,
add_delta=add_delta,
kind=kind,
indexer_method=indexer_method,
indexer_tolerance=indexer_tolerance,
skip_not_found=skip_not_found,
),
template_context,
eval_id="get_index_points_defaults",
strict=True,
)
index_points = wrapper.get_index_points(**get_index_points_kwargs)
template_context = merge_dicts(
template_context,
get_index_points_kwargs,
dict(index_points=index_points),
)
else:
index_points = substitute_templates(
index_points,
template_context,
eval_id="index_points",
strict=True,
)
index_points = to_1d_array(index_points)
template_context = merge_dicts(template_context, dict(index_points=index_points))
if jitted_loop:
allocate_func = substitute_templates(
allocate_func,
template_context,
eval_id="allocate_func",
strict=True,
)
args = substitute_templates(args, template_context, eval_id="args")
kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs")
func = jit_reg.resolve_option(nb.allocate_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
_allocations = func(len(wrapper.columns), index_points, allocate_func, *args, **kwargs)
else:
tasks = []
keys = []
for i in range(len(index_points)):
_template_context = merge_dicts(
dict(
i=i,
index_point=index_points[i],
),
template_context,
)
_allocate_func = substitute_templates(
allocate_func,
_template_context,
eval_id="allocate_func",
strict=True,
)
_args = substitute_templates(args, _template_context, eval_id="args")
_kwargs = substitute_templates(kwargs, _template_context, eval_id="kwargs")
tasks.append(Task(_allocate_func, *_args, **_kwargs))
if isinstance(wrapper.index, pd.DatetimeIndex):
keys.append(dt.readable_datetime(wrapper.index[index_points[i]], freq=wrapper.freq))
else:
keys.append(str(wrapper.index[index_points[i]]))
results = execute(tasks, keys=keys, **execute_kwargs)
_allocations = pd.DataFrame(results, columns=wrapper.columns)
if isinstance(_allocations.columns, pd.RangeIndex):
_allocations = _allocations.values
else:
_allocations = _allocations[list(wrapper.columns)].values
if rescale_to is not None:
_allocations = nb.rescale_allocations_nb(_allocations, rescale_to)
return nb.prepare_alloc_points_nb(index_points, _allocations, group_idx)
@classmethod
def from_allocate_func(
cls: tp.Type[PortfolioOptimizerT],
wrapper: ArrayWrapper,
allocate_func: tp.Callable,
*args,
every: tp.Union[None, tp.FrequencyLike, Param] = point_idxr_defaults["every"],
normalize_every: tp.Union[bool, Param] = point_idxr_defaults["normalize_every"],
at_time: tp.Union[None, tp.TimeLike, Param] = point_idxr_defaults["at_time"],
start: tp.Union[None, int, tp.DatetimeLike, Param] = point_idxr_defaults["start"],
end: tp.Union[None, int, tp.DatetimeLike, Param] = point_idxr_defaults["end"],
exact_start: tp.Union[bool, Param] = point_idxr_defaults["exact_start"],
on: tp.Union[None, int, tp.DatetimeLike, tp.IndexLike, Param] = point_idxr_defaults["on"],
add_delta: tp.Union[None, tp.FrequencyLike, Param] = point_idxr_defaults["add_delta"],
kind: tp.Union[None, str, Param] = point_idxr_defaults["kind"],
indexer_method: tp.Union[None, str, Param] = point_idxr_defaults["indexer_method"],
indexer_tolerance: tp.Union[None, str, Param] = point_idxr_defaults["indexer_tolerance"],
skip_not_found: tp.Union[bool, Param] = point_idxr_defaults["skip_not_found"],
index_points: tp.Union[None, tp.MaybeSequence[int], Param] = None,
rescale_to: tp.Union[None, tp.Tuple[float, float], Param] = None,
parameterizer: tp.Optional[tp.MaybeType[Parameterizer]] = None,
param_search_kwargs: tp.KwargsLike = None,
name_tuple_to_str: tp.Union[None, bool, tp.Callable] = None,
group_configs: tp.Union[None, tp.Dict[tp.Hashable, tp.Kwargs], tp.Sequence[tp.Kwargs]] = None,
pre_group_func: tp.Optional[tp.Callable] = None,
jitted_loop: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
template_context: tp.KwargsLike = None,
group_execute_kwargs: tp.KwargsLike = None,
execute_kwargs: tp.KwargsLike = None,
random_subset: tp.Optional[int] = None,
clean_index_kwargs: tp.KwargsLike = None,
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> PortfolioOptimizerT:
"""Generate allocations from an allocation function.
Generates date points and allocates at those points.
Similar to `PortfolioOptimizer.from_optimize_func`, but generates points using
`vectorbtpro.base.wrapping.ArrayWrapper.get_index_points` and makes each point available
as `index_point` in the context.
Templates can use the following variables:
* `i`: Allocation step
* `index_point`: Allocation index
If `jitted_loop` is True, see `vectorbtpro.portfolio.pfopt.nb.allocate_meta_nb`.
Also, in contrast to `PortfolioOptimizer.from_optimize_func`, creates records of type
`vectorbtpro.portfolio.pfopt.records.AllocPoints`.
Usage:
* Allocate uniformly every day:
```pycon
>>> from vectorbtpro import *
>>> data = vbt.YFData.pull(
... ["MSFT", "AMZN", "AAPL"],
... start="2010-01-01",
... end="2020-01-01"
... )
>>> close = data.get("Close")
>>> def uniform_allocate_func(n_cols):
... return np.full(n_cols, 1 / n_cols)
>>> pfo = vbt.PortfolioOptimizer.from_allocate_func(
... close.vbt.wrapper,
... uniform_allocate_func,
... close.shape[1]
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
Date
2010-01-04 00:00:00-05:00 0.333333 0.333333 0.333333
2010-01-05 00:00:00-05:00 0.333333 0.333333 0.333333
2010-01-06 00:00:00-05:00 0.333333 0.333333 0.333333
2010-01-07 00:00:00-05:00 0.333333 0.333333 0.333333
2010-01-08 00:00:00-05:00 0.333333 0.333333 0.333333
... ... ... ...
2019-12-24 00:00:00-05:00 0.333333 0.333333 0.333333
2019-12-26 00:00:00-05:00 0.333333 0.333333 0.333333
2019-12-27 00:00:00-05:00 0.333333 0.333333 0.333333
2019-12-30 00:00:00-05:00 0.333333 0.333333 0.333333
2019-12-31 00:00:00-05:00 0.333333 0.333333 0.333333
[2516 rows x 3 columns]
```
* Allocate randomly every first date of the year:
```pycon
>>> def random_allocate_func(n_cols):
... weights = np.random.uniform(size=n_cols)
... return weights / weights.sum()
>>> pfo = vbt.PortfolioOptimizer.from_allocate_func(
... close.vbt.wrapper,
... random_allocate_func,
... close.shape[1],
... every="AS-JAN"
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
Date
2011-01-03 00:00:00+00:00 0.160335 0.122434 0.717231
2012-01-03 00:00:00+00:00 0.071386 0.469564 0.459051
2013-01-02 00:00:00+00:00 0.125853 0.168480 0.705668
2014-01-02 00:00:00+00:00 0.391565 0.169205 0.439231
2015-01-02 00:00:00+00:00 0.115075 0.602844 0.282081
2016-01-04 00:00:00+00:00 0.244070 0.046547 0.709383
2017-01-03 00:00:00+00:00 0.316065 0.335000 0.348935
2018-01-02 00:00:00+00:00 0.422142 0.252154 0.325704
2019-01-02 00:00:00+00:00 0.368748 0.195147 0.436106
```
* Specify index points manually:
```pycon
>>> pfo = vbt.PortfolioOptimizer.from_allocate_func(
... close.vbt.wrapper,
... random_allocate_func,
... close.shape[1],
... index_points=[0, 30, 60]
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
Date
2010-01-04 00:00:00+00:00 0.257878 0.308287 0.433835
2010-02-17 00:00:00+00:00 0.090927 0.471980 0.437094
2010-03-31 00:00:00+00:00 0.395855 0.148516 0.455629
```
* Specify allocations manually:
```pycon
>>> def manual_allocate_func(weights):
... return weights
>>> pfo = vbt.PortfolioOptimizer.from_allocate_func(
... close.vbt.wrapper,
... manual_allocate_func,
... vbt.RepEval("weights[i]", context=dict(weights=[
... [1, 0, 0],
... [0, 1, 0],
... [0, 0, 1]
... ])),
... index_points=[0, 30, 60]
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
Date
2010-01-04 00:00:00+00:00 1 0 0
2010-02-17 00:00:00+00:00 0 1 0
2010-03-31 00:00:00+00:00 0 0 1
```
* Use Numba-compiled loop:
```pycon
>>> @njit
... def random_allocate_func_nb(i, idx, n_cols):
... weights = np.random.uniform(0, 1, n_cols)
... return weights / weights.sum()
>>> pfo = vbt.PortfolioOptimizer.from_allocate_func(
... close.vbt.wrapper,
... random_allocate_func_nb,
... close.shape[1],
... index_points=[0, 30, 60],
... jitted_loop=True
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
Date
2010-01-04 00:00:00+00:00 0.231925 0.351085 0.416990
2010-02-17 00:00:00+00:00 0.163050 0.070292 0.766658
2010-03-31 00:00:00+00:00 0.497465 0.500215 0.002319
```
!!! hint
There is no big reason of using the Numba-compiled loop, apart from when having
to rebalance many thousands of times. Usually, using a regular Python loop
and a Numba-compiled allocation function should suffice.
"""
from vectorbtpro._settings import settings
params_cfg = settings["params"]
if parameterizer is None:
parameterizer = params_cfg["parameterizer"]
if parameterizer is None:
parameterizer = Parameterizer
param_search_kwargs = merge_dicts(params_cfg["param_search_kwargs"], param_search_kwargs)
if group_execute_kwargs is None:
group_execute_kwargs = {}
if execute_kwargs is None:
execute_kwargs = {}
if clean_index_kwargs is None:
clean_index_kwargs = {}
# Prepare group config names
gc_names = []
gc_names_none = True
n_configs = 0
if group_configs is not None:
if isinstance(group_configs, dict):
new_group_configs = []
for k, v in group_configs.items():
v = dict(v)
v["_name"] = k
new_group_configs.append(v)
group_configs = new_group_configs
else:
group_configs = list(group_configs)
for i, group_config in enumerate(group_configs):
group_config = dict(group_config)
if "args" in group_config:
for k, arg in enumerate(group_config.pop("args")):
group_config[f"args_{k}"] = arg
if "kwargs" in group_config:
for k, v in enumerate(group_config.pop("kwargs")):
group_config[k] = v
if "_name" in group_config and group_config["_name"] is not None:
gc_names.append(group_config.pop("_name"))
gc_names_none = False
else:
gc_names.append(n_configs)
group_configs[i] = group_config
n_configs += 1
else:
group_configs = []
# Combine parameters
paramable_kwargs = {
"every": every,
"normalize_every": normalize_every,
"at_time": at_time,
"start": start,
"end": end,
"exact_start": exact_start,
"on": on,
"add_delta": add_delta,
"kind": kind,
"indexer_method": indexer_method,
"indexer_tolerance": indexer_tolerance,
"skip_not_found": skip_not_found,
"index_points": index_points,
"rescale_to": rescale_to,
**{f"args_{i}": args[i] for i in range(len(args))},
**kwargs,
}
param_dct = parameterizer.find_params_in_obj(paramable_kwargs, **param_search_kwargs)
param_columns = None
if len(param_dct) > 0:
param_product, param_columns = combine_params(
param_dct,
random_subset=random_subset,
clean_index_kwargs=clean_index_kwargs,
name_tuple_to_str=name_tuple_to_str,
)
if param_columns is None:
n_param_configs = len(param_product[list(param_product.keys())[0]])
param_columns = pd.RangeIndex(stop=n_param_configs, name="param_config")
product_group_configs = parameterizer.param_product_to_objs(paramable_kwargs, param_product)
if len(group_configs) == 0:
group_configs = product_group_configs
else:
new_group_configs = []
for i in range(len(product_group_configs)):
for group_config in group_configs:
new_group_config = merge_dicts(product_group_configs[i], group_config)
new_group_configs.append(new_group_config)
group_configs = new_group_configs
# Build group index
n_config_params = len(gc_names)
if param_columns is not None:
if n_config_params == 0 or (n_config_params == 1 and gc_names_none):
group_index = param_columns
else:
group_index = combine_indexes(
(
param_columns,
pd.Index(gc_names, name="group_config"),
),
**clean_index_kwargs,
)
else:
if n_config_params == 0 or (n_config_params == 1 and gc_names_none):
group_index = pd.Index(["group"], name="group")
else:
group_index = pd.Index(gc_names, name="group_config")
# Create group config from arguments if empty
if len(group_configs) == 0:
single_group = True
group_configs.append(dict())
else:
single_group = False
# Resolve each group
groupable_kwargs = {
"allocate_func": allocate_func,
**paramable_kwargs,
"jitted_loop": jitted_loop,
"jitted": jitted,
"chunked": chunked,
"template_context": template_context,
"execute_kwargs": execute_kwargs,
}
new_group_configs = []
for group_config in group_configs:
new_group_config = merge_dicts(groupable_kwargs, group_config)
_args = ()
while True:
if f"args_{len(_args)}" in new_group_config:
_args += (new_group_config.pop(f"args_{len(_args)}"),)
else:
break
new_group_config["args"] = _args
new_group_configs.append(new_group_config)
group_configs = new_group_configs
# Generate allocations
tasks = []
for group_idx, group_config in enumerate(group_configs):
tasks.append(
Task(
cls.run_allocation_group,
wrapper=wrapper,
group_configs=group_configs,
group_index=group_index,
group_idx=group_idx,
pre_group_func=pre_group_func,
)
)
group_execute_kwargs = merge_dicts(dict(show_progress=False if single_group else None), group_execute_kwargs)
results = execute(tasks, keys=group_index, **group_execute_kwargs)
alloc_points, allocations = zip(*results)
# Build column hierarchy
new_columns = combine_indexes((group_index, wrapper.columns), **clean_index_kwargs)
# Create instance
wrapper_kwargs = merge_dicts(
dict(
index=wrapper.index,
columns=new_columns,
ndim=2,
freq=wrapper.freq,
column_only_select=False,
range_only_select=True,
group_select=True,
grouped_ndim=1 if single_group else 2,
group_by=group_index.names if group_index.nlevels > 1 else group_index.name,
allow_enable=False,
allow_disable=True,
allow_modify=False,
),
wrapper_kwargs,
)
new_wrapper = ArrayWrapper(**wrapper_kwargs)
alloc_points = AllocPoints(
ArrayWrapper(
index=wrapper.index,
columns=new_wrapper.get_columns(),
ndim=new_wrapper.get_ndim(),
freq=wrapper.freq,
column_only_select=False,
range_only_select=True,
),
np.concatenate(alloc_points),
)
allocations = row_stack_arrays(allocations)
return cls(new_wrapper, alloc_points, allocations)
@classmethod
def from_allocations(
cls: tp.Type[PortfolioOptimizerT],
wrapper: ArrayWrapper,
allocations: tp.ArrayLike,
**kwargs,
) -> PortfolioOptimizerT:
"""Pick allocations from a (flexible) array.
Uses `PortfolioOptimizer.from_allocate_func`.
If `allocations` is a DataFrame, uses its index as labels. If it's a Series or dict, uses it
as a single allocation without index, which by default gets assigned to each index. If it's neither
one of the above nor a NumPy array, tries to convert it into a NumPy array.
If `allocations` is a NumPy array, uses `vectorbtpro.portfolio.pfopt.nb.pick_idx_allocate_func_nb`
and a Numba-compiled loop. Otherwise, uses a regular Python function to pick each allocation
(which can be a dict, Series, etc.). Selection of elements is done in a flexible manner,
meaning a single element will be applied to all rows, while one-dimensional arrays will be
also applied to all rows but also broadcast across columns (as opposed to rows)."""
if isinstance(allocations, pd.Series):
allocations = allocations.to_dict()
if isinstance(allocations, dict):
allocations = pd.DataFrame([allocations], columns=wrapper.columns)
allocations = allocations.values
if isinstance(allocations, pd.DataFrame):
kwargs = merge_dicts(dict(on=allocations.index, kind="labels"), kwargs)
allocations = allocations.values
if not isinstance(allocations, np.ndarray):
with WarningsFiltered():
try:
new_allocations = np.asarray(allocations)
if new_allocations.dtype != object:
allocations = new_allocations
except Exception as e:
pass
if isinstance(allocations, np.ndarray):
def _resolve_allocations(index_points):
target_shape = (len(index_points), len(wrapper.columns))
return broadcast_array_to(allocations, target_shape, expand_axis=0)
return cls.from_allocate_func(
wrapper,
nb.pick_idx_allocate_func_nb,
RepFunc(_resolve_allocations),
jitted_loop=True,
**kwargs,
)
def _pick_allocate_func(index_points, i):
if not checks.is_sequence(allocations):
return allocations
if len(allocations) == 1:
return allocations[0]
if len(index_points) != len(allocations):
raise ValueError(f"Allocation array must have {len(index_points)} rows")
return allocations[i]
return cls.from_allocate_func(wrapper, _pick_allocate_func, Rep("index_points"), Rep("i"), **kwargs)
@classmethod
def from_initial(
cls: tp.Type[PortfolioOptimizerT],
wrapper: ArrayWrapper,
allocations: tp.ArrayLike,
**kwargs,
) -> PortfolioOptimizerT:
"""Allocate once at the first index.
Uses `PortfolioOptimizer.from_allocations` with `on=0`."""
return cls.from_allocations(wrapper, allocations, on=0, **kwargs)
@classmethod
def from_filled_allocations(
cls: tp.Type[PortfolioOptimizerT],
allocations: tp.AnyArray2d,
valid_only: bool = True,
nonzero_only: bool = True,
unique_only: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
**kwargs,
) -> PortfolioOptimizerT:
"""Pick allocations from an already filled array.
Uses `PortfolioOptimizer.from_allocate_func`.
Uses `vectorbtpro.portfolio.pfopt.nb.pick_point_allocate_func_nb` and a Numba-compiled loop.
Extracts allocation points using `vectorbtpro.portfolio.pfopt.nb.get_alloc_points_nb`."""
if wrapper is None:
if checks.is_frame(allocations):
wrapper = ArrayWrapper.from_obj(allocations)
else:
raise TypeError("Wrapper is required if allocations is not a DataFrame")
allocations = to_2d_array(allocations, expand_axis=0)
if allocations.shape != wrapper.shape_2d:
raise ValueError("Allocation array must have the same shape as wrapper")
on = nb.get_alloc_points_nb(
allocations,
valid_only=valid_only,
nonzero_only=nonzero_only,
unique_only=unique_only,
)
kwargs = merge_dicts(dict(on=on), kwargs)
return cls.from_allocate_func(
wrapper,
nb.pick_point_allocate_func_nb,
allocations,
jitted_loop=True,
**kwargs,
)
@classmethod
def from_uniform(cls: tp.Type[PortfolioOptimizerT], wrapper: ArrayWrapper, **kwargs) -> PortfolioOptimizerT:
"""Generate uniform allocations.
Uses `PortfolioOptimizer.from_allocate_func`."""
def _uniform_allocate_func():
return np.full(wrapper.shape_2d[1], 1 / wrapper.shape_2d[1])
return cls.from_allocate_func(wrapper, _uniform_allocate_func, **kwargs)
@classmethod
def from_random(
cls: tp.Type[PortfolioOptimizerT],
wrapper: ArrayWrapper,
direction: tp.Union[str, int] = "longonly",
n: tp.Optional[int] = None,
seed: tp.Optional[int] = None,
**kwargs,
) -> PortfolioOptimizerT:
"""Generate random allocations.
Uses `PortfolioOptimizer.from_allocate_func`.
Uses `vectorbtpro.portfolio.pfopt.nb.random_allocate_func_nb` and a Numba-compiled loop."""
if isinstance(direction, str):
direction = map_enum_fields(direction, Direction)
if seed is not None:
set_seed_nb(seed)
return cls.from_allocate_func(
wrapper,
nb.random_allocate_func_nb,
wrapper.shape_2d[1],
direction,
n,
jitted_loop=True,
**kwargs,
)
@classmethod
def from_universal_algo(
cls: tp.Type[PortfolioOptimizerT],
algo: tp.Union[str, tp.Type[AlgoT], AlgoT, AlgoResultT],
S: tp.Optional[tp.AnyArray2d] = None,
n_jobs: int = 1,
log_progress: bool = False,
valid_only: bool = True,
nonzero_only: bool = True,
unique_only: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
**kwargs,
) -> PortfolioOptimizerT:
"""Generate allocations using [Universal Portfolios](https://github.com/Marigold/universal-portfolios).
`S` can be any price, while `algo` must be either an attribute of the package, subclass of
`universal.algo.Algo`, instance of `universal.algo.Algo`, or instance of `universal.result.AlgoResult`.
Extracts allocation points using `vectorbtpro.portfolio.pfopt.nb.get_alloc_points_nb`."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("universal")
from universal.algo import Algo
from universal.result import AlgoResult
if wrapper is None:
if S is None or not checks.is_frame(S):
raise TypeError("Wrapper is required if allocations is not a DataFrame")
else:
wrapper = ArrayWrapper.from_obj(S)
def _pre_group_func(group_config, _algo=algo):
_ = group_config.pop("args", ())
if isinstance(_algo, str):
import universal.algos
_algo = getattr(universal.algos, _algo)
if isinstance(_algo, type) and issubclass(_algo, Algo):
reserved_arg_names = get_func_arg_names(cls.from_allocate_func)
algo_keys = set(group_config.keys()).difference(reserved_arg_names)
algo_kwargs = {}
for k in algo_keys:
algo_kwargs[k] = group_config.pop(k)
_algo = _algo(**algo_kwargs)
if isinstance(_algo, Algo):
if S is None:
raise ValueError("S is required")
_algo = _algo.run(S, n_jobs=n_jobs, log_progress=log_progress)
if isinstance(_algo, AlgoResult):
weights = _algo.weights[wrapper.columns].values
else:
raise TypeError(f"Algo {_algo} not supported")
if "on" not in kwargs:
group_config["on"] = nb.get_alloc_points_nb(
weights,
valid_only=valid_only,
nonzero_only=nonzero_only,
unique_only=unique_only,
)
group_config["args"] = (weights,)
return cls.from_allocate_func(
wrapper,
nb.pick_point_allocate_func_nb,
jitted_loop=True,
pre_group_func=_pre_group_func,
**kwargs,
)
@classmethod
def run_optimization_group(
cls,
wrapper: ArrayWrapper,
group_configs: tp.List[dict],
group_index: tp.Index,
group_idx: int,
pre_group_func: tp.Optional[tp.Callable] = None,
silence_warnings: bool = False,
) -> tp.Tuple[tp.RecordArray, tp.Array2d]:
"""Run an optimization group."""
group_config = dict(group_configs[group_idx])
if pre_group_func is not None:
pre_group_func(group_config)
optimize_func = group_config.pop("optimize_func")
every = group_config.pop("every")
normalize_every = group_config.pop("normalize_every")
split_every = group_config.pop("split_every")
start_time = group_config.pop("start_time")
end_time = group_config.pop("end_time")
lookback_period = group_config.pop("lookback_period")
start = group_config.pop("start")
end = group_config.pop("end")
exact_start = group_config.pop("exact_start")
fixed_start = group_config.pop("fixed_start")
closed_start = group_config.pop("closed_start")
closed_end = group_config.pop("closed_end")
add_start_delta = group_config.pop("add_start_delta")
add_end_delta = group_config.pop("add_end_delta")
kind = group_config.pop("kind")
skip_not_found = group_config.pop("skip_not_found")
index_ranges = group_config.pop("index_ranges")
index_loc = group_config.pop("index_loc")
rescale_to = group_config.pop("rescale_to")
alloc_wait = group_config.pop("alloc_wait")
splitter_cls = group_config.pop("splitter_cls")
eval_id = group_config.pop("eval_id")
jitted_loop = group_config.pop("jitted_loop")
jitted = group_config.pop("jitted")
chunked = group_config.pop("chunked")
template_context = group_config.pop("template_context")
execute_kwargs = group_config.pop("execute_kwargs")
args = group_config.pop("args")
kwargs = group_config
if splitter_cls is None:
splitter_cls = Splitter
template_context = merge_dicts(
dict(
group_configs=group_configs,
group_index=group_index,
group_idx=group_idx,
wrapper=wrapper,
optimize_func=optimize_func,
every=every,
normalize_every=normalize_every,
split_every=split_every,
start_time=start_time,
end_time=end_time,
lookback_period=lookback_period,
start=start,
end=end,
exact_start=exact_start,
fixed_start=fixed_start,
closed_start=closed_start,
closed_end=closed_end,
add_start_delta=add_start_delta,
add_end_delta=add_end_delta,
kind=kind,
skip_not_found=skip_not_found,
index_ranges=index_ranges,
index_loc=index_loc,
rescale_to=rescale_to,
alloc_wait=alloc_wait,
splitter_cls=splitter_cls,
eval_id=eval_id,
jitted_loop=jitted_loop,
jitted=jitted,
chunked=chunked,
args=args,
kwargs=kwargs,
execute_kwargs=execute_kwargs,
),
template_context,
)
if index_ranges is None:
get_index_ranges_defaults = substitute_templates(
dict(
every=every,
normalize_every=normalize_every,
split_every=split_every,
start_time=start_time,
end_time=end_time,
lookback_period=lookback_period,
start=start,
end=end,
exact_start=exact_start,
fixed_start=fixed_start,
closed_start=closed_start,
closed_end=closed_end,
add_start_delta=add_start_delta,
add_end_delta=add_end_delta,
kind=kind,
skip_not_found=skip_not_found,
jitted=jitted,
),
template_context,
eval_id="get_index_ranges_defaults",
strict=True,
)
index_ranges = wrapper.get_index_ranges(**get_index_ranges_defaults)
template_context = merge_dicts(
template_context,
get_index_ranges_defaults,
dict(index_ranges=index_ranges),
)
else:
index_ranges = substitute_templates(
index_ranges,
template_context,
eval_id="index_ranges",
strict=True,
)
if isinstance(index_ranges, np.ndarray):
index_ranges = (index_ranges[:, 0], index_ranges[:, 1])
elif not isinstance(index_ranges[0], np.ndarray) and not isinstance(index_ranges[1], np.ndarray):
index_ranges = to_2d_array(index_ranges, expand_axis=0)
index_ranges = (index_ranges[:, 0], index_ranges[:, 1])
template_context = merge_dicts(template_context, dict(index_ranges=index_ranges))
if index_loc is not None:
index_loc = substitute_templates(
index_loc,
template_context,
eval_id="index_loc",
strict=True,
)
index_loc = to_1d_array(index_loc)
template_context = merge_dicts(template_context, dict(index_loc=index_loc))
if jitted_loop:
optimize_func = substitute_templates(
optimize_func,
template_context,
eval_id="optimize_func",
strict=True,
)
args = substitute_templates(args, template_context, eval_id="args")
kwargs = substitute_templates(kwargs, template_context, eval_id="kwargs")
func = jit_reg.resolve_option(nb.optimize_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
_allocations = func(
len(wrapper.columns),
index_ranges[0],
index_ranges[1],
optimize_func,
*args,
**kwargs,
)
else:
tasks = []
keys = []
for i in range(len(index_ranges[0])):
index_slice = slice(max(0, index_ranges[0][i]), index_ranges[1][i])
_template_context = merge_dicts(
dict(
i=i,
index_slice=index_slice,
index_start=index_ranges[0][i],
index_end=index_ranges[1][i],
),
template_context,
)
_optimize_func = substitute_templates(
optimize_func,
_template_context,
eval_id="optimize_func",
strict=True,
)
_args = substitute_templates(args, _template_context, eval_id="args")
_kwargs = substitute_templates(kwargs, _template_context, eval_id="kwargs")
if has_annotatables(_optimize_func):
ann_args = annotate_args(
_optimize_func,
_args,
_kwargs,
attach_annotations=True,
)
flat_ann_args = flatten_ann_args(ann_args)
flat_ann_args = splitter_cls.parse_and_inject_takeables(flat_ann_args, eval_id=eval_id)
ann_args = unflatten_ann_args(flat_ann_args)
_args, _kwargs = ann_args_to_args(ann_args)
__args = ()
for v in _args:
if isinstance(v, Takeable) and v.meets_eval_id(eval_id):
v = splitter_cls.take_range_from_takeable(
v,
index_slice,
template_context=_template_context,
silence_warnings=silence_warnings,
index=wrapper.index,
freq=wrapper.freq,
)
__args += (v,)
__kwargs = {}
for k, v in _kwargs.items():
if isinstance(v, Takeable) and v.meets_eval_id(eval_id):
v = splitter_cls.take_range_from_takeable(
v,
index_slice,
template_context=_template_context,
silence_warnings=silence_warnings,
index=wrapper.index,
freq=wrapper.freq,
)
__kwargs[k] = v
tasks.append(Task(_optimize_func, *__args, **__kwargs))
if isinstance(wrapper.index, pd.DatetimeIndex):
keys.append(
"{} → {}".format(
dt.readable_datetime(wrapper.index[index_ranges[0][i]], freq=wrapper.freq),
dt.readable_datetime(wrapper.index[index_ranges[1][i] - 1], freq=wrapper.freq),
)
)
else:
keys.append(
"{} → {}".format(
str(wrapper.index[index_ranges[0][i]]),
str(wrapper.index[index_ranges[1][i] - 1]),
)
)
results = execute(tasks, keys=keys, **execute_kwargs)
_allocations = pd.DataFrame(results, columns=wrapper.columns)
if isinstance(_allocations.columns, pd.RangeIndex):
_allocations = _allocations.values
else:
_allocations = _allocations[list(wrapper.columns)].values
if rescale_to is not None:
_allocations = nb.rescale_allocations_nb(_allocations, rescale_to)
if index_loc is None:
alloc_wait = substitute_templates(
alloc_wait,
template_context,
eval_id="alloc_wait",
strict=True,
)
alloc_idx = index_ranges[1] - 1 + alloc_wait
else:
alloc_idx = index_loc
status = np.where(
alloc_idx >= len(wrapper.index),
RangeStatus.Open,
RangeStatus.Closed,
)
return nb.prepare_alloc_ranges_nb(
index_ranges[0],
index_ranges[1],
alloc_idx,
status,
_allocations,
group_idx,
)
@classmethod
def from_optimize_func(
cls: tp.Type[PortfolioOptimizerT],
wrapper: ArrayWrapper,
optimize_func: tp.Callable,
*args,
every: tp.Union[None, tp.FrequencyLike, Param] = range_idxr_defaults["every"],
normalize_every: tp.Union[bool, Param] = range_idxr_defaults["normalize_every"],
split_every: tp.Union[bool, Param] = range_idxr_defaults["split_every"],
start_time: tp.Union[None, tp.TimeLike, Param] = range_idxr_defaults["start_time"],
end_time: tp.Union[None, tp.TimeLike, Param] = range_idxr_defaults["end_time"],
lookback_period: tp.Union[None, tp.FrequencyLike, Param] = range_idxr_defaults["lookback_period"],
start: tp.Union[None, int, tp.DatetimeLike, tp.IndexLike, Param] = range_idxr_defaults["start"],
end: tp.Union[None, int, tp.DatetimeLike, tp.IndexLike, Param] = range_idxr_defaults["end"],
exact_start: tp.Union[bool, Param] = range_idxr_defaults["exact_start"],
fixed_start: tp.Union[bool, Param] = range_idxr_defaults["fixed_start"],
closed_start: tp.Union[bool, Param] = range_idxr_defaults["closed_start"],
closed_end: tp.Union[bool, Param] = range_idxr_defaults["closed_end"],
add_start_delta: tp.Union[None, tp.FrequencyLike, Param] = range_idxr_defaults["add_start_delta"],
add_end_delta: tp.Union[None, tp.FrequencyLike, Param] = range_idxr_defaults["add_end_delta"],
kind: tp.Union[None, str, Param] = range_idxr_defaults["kind"],
skip_not_found: tp.Union[bool, Param] = range_idxr_defaults["skip_not_found"],
index_ranges: tp.Union[None, tp.MaybeSequence[tp.MaybeSequence[int]], Param] = None,
index_loc: tp.Union[None, tp.MaybeSequence[int], Param] = None,
rescale_to: tp.Union[None, tp.Tuple[float, float], Param] = None,
alloc_wait: tp.Union[int, Param] = 1,
parameterizer: tp.Optional[tp.MaybeType[Parameterizer]] = None,
param_search_kwargs: tp.KwargsLike = None,
name_tuple_to_str: tp.Union[None, bool, tp.Callable] = None,
group_configs: tp.Union[None, tp.Dict[tp.Hashable, tp.Kwargs], tp.Sequence[tp.Kwargs]] = None,
pre_group_func: tp.Optional[tp.Callable] = None,
splitter_cls: tp.Optional[tp.Type[Splitter]] = None,
eval_id: tp.Optional[tp.Hashable] = None,
jitted_loop: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
template_context: tp.KwargsLike = None,
group_execute_kwargs: tp.KwargsLike = None,
execute_kwargs: tp.KwargsLike = None,
random_subset: tp.Optional[int] = None,
clean_index_kwargs: tp.KwargsLike = None,
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> PortfolioOptimizerT:
"""Generate allocations from an optimization function.
Generates date ranges, performs optimization on the subset of data that belongs to each date range,
and allocates at the end of each range.
This is a parameterized method that allows testing multiple combinations on most arguments.
First, it checks whether any of the arguments is wrapped with `vectorbtpro.utils.params.Param`
and combines their values. It then combines them over `group_configs`, if provided.
Before execution, it additionally processes the group config using `pre_group_func`.
It then resolves the date ranges, either using the ready-to-use `index_ranges` or
by passing all the arguments ranging from `every` to `jitted` to
`vectorbtpro.base.wrapping.ArrayWrapper.get_index_ranges`. The optimization
function `optimize_func` is then called on each date range by first substituting
any templates found in `*args` and `**kwargs`. To forward any reserved arguments
such as `jitted` to the optimization function, specify their names in `forward_args`
and `forward_kwargs`.
!!! note
Make sure to use vectorbt's own templates to select the current date range
(available as `index_slice` in the context mapping) from each array.
If `jitted_loop` is True, see `vectorbtpro.portfolio.pfopt.nb.optimize_meta_nb`.
Otherwise, must take template-substituted `*args` and `**kwargs`, and return an array or
dictionary with asset allocations (also empty).
Templates can use the following variables:
* `i`: Optimization step
* `index_start`: Optimization start index (including)
* `index_end`: Optimization end index (excluding)
* `index_slice`: `slice(index_start, index_end)`
!!! note
When `jitted_loop` is True and in case of multiple groups, use templates
to substitute by the current group index (available as `group_idx` in the context mapping).
All allocations of all groups are stacked into one big 2-dim array where columns are assets
and rows are allocations. Furthermore, date ranges are used to fill a record array of type
`vectorbtpro.portfolio.pfopt.records.AllocRanges` that acts as an indexer for allocations.
For example, the field `col` stores the group index corresponding to each allocation. Since
this record array does not hold any information on assets themselves, it has its own wrapper
that holds groups instead of columns, while the wrapper of the `PortfolioOptimizer` instance
contains regular columns grouped by groups.
Usage:
* Allocate once:
```pycon
>>> from vectorbtpro import *
>>> data = vbt.YFData.pull(
... ["MSFT", "AMZN", "AAPL"],
... start="2010-01-01",
... end="2020-01-01"
... )
>>> close = data.get("Close")
>>> def optimize_func(df):
... sharpe = df.mean() / df.std()
... return sharpe / sharpe.sum()
>>> df_arg = vbt.RepEval("close.iloc[index_slice]", context=dict(close=close))
>>> pfo = vbt.PortfolioOptimizer.from_optimize_func(
... close.vbt.wrapper,
... optimize_func,
... df_arg,
... end="2015-01-01"
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
alloc_group Date
group 2015-01-02 00:00:00+00:00 0.402459 0.309351 0.288191
```
* Allocate every first date of the year:
```pycon
>>> pfo = vbt.PortfolioOptimizer.from_optimize_func(
... close.vbt.wrapper,
... optimize_func,
... df_arg,
... every="AS-JAN"
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
alloc_group Date
group 2011-01-03 00:00:00+00:00 0.480693 0.257317 0.261990
2012-01-03 00:00:00+00:00 0.489893 0.215381 0.294727
2013-01-02 00:00:00+00:00 0.540165 0.228755 0.231080
2014-01-02 00:00:00+00:00 0.339649 0.273996 0.386354
2015-01-02 00:00:00+00:00 0.350406 0.418638 0.230956
2016-01-04 00:00:00+00:00 0.332212 0.141090 0.526698
2017-01-03 00:00:00+00:00 0.390852 0.225379 0.383769
2018-01-02 00:00:00+00:00 0.337711 0.317683 0.344606
2019-01-02 00:00:00+00:00 0.411852 0.282680 0.305468
```
* Specify index ranges manually:
```pycon
>>> pfo = vbt.PortfolioOptimizer.from_optimize_func(
... close.vbt.wrapper,
... optimize_func,
... df_arg,
... index_ranges=[
... (0, 30),
... (30, 60),
... (60, 90)
... ]
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
alloc_group Date
group 2010-02-16 00:00:00+00:00 0.340641 0.285897 0.373462
2010-03-30 00:00:00+00:00 0.596392 0.206317 0.197291
2010-05-12 00:00:00+00:00 0.437481 0.283160 0.279358
```
* Test multiple combinations of one argument:
```pycon
>>> pfo = vbt.PortfolioOptimizer.from_optimize_func(
... close.vbt.wrapper,
... optimize_func,
... df_arg,
... every="AS-JAN",
... start="2015-01-01",
... lookback_period=vbt.Param(["3MS", "6MS"])
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
lookback_period Date
3MS 2016-01-04 00:00:00+00:00 0.282725 0.234970 0.482305
2017-01-03 00:00:00+00:00 0.318100 0.269355 0.412545
2018-01-02 00:00:00+00:00 0.387499 0.236432 0.376068
2019-01-02 00:00:00+00:00 0.575464 0.254808 0.169728
6MS 2016-01-04 00:00:00+00:00 0.265035 0.198619 0.536346
2017-01-03 00:00:00+00:00 0.314144 0.409020 0.276836
2018-01-02 00:00:00+00:00 0.322741 0.282639 0.394621
2019-01-02 00:00:00+00:00 0.565691 0.234760 0.199549
```
* Test multiple cross-argument combinations:
```pycon
>>> pfo = vbt.PortfolioOptimizer.from_optimize_func(
... close.vbt.wrapper,
... optimize_func,
... df_arg,
... every="AS-JAN",
... group_configs=[
... dict(start="2015-01-01"),
... dict(start="2019-06-01", every="MS"),
... dict(end="2014-01-01")
... ]
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
group_config Date
0 2016-01-04 00:00:00+00:00 0.332212 0.141090 0.526698
2017-01-03 00:00:00+00:00 0.390852 0.225379 0.383769
2018-01-02 00:00:00+00:00 0.337711 0.317683 0.344606
2019-01-02 00:00:00+00:00 0.411852 0.282680 0.305468
1 2019-07-01 00:00:00+00:00 0.351461 0.327334 0.321205
2019-08-01 00:00:00+00:00 0.418411 0.249799 0.331790
2019-09-03 00:00:00+00:00 0.400439 0.374044 0.225517
2019-10-01 00:00:00+00:00 0.509387 0.250497 0.240117
2019-11-01 00:00:00+00:00 0.349983 0.469181 0.180835
2019-12-02 00:00:00+00:00 0.260437 0.380563 0.359000
2 2012-01-03 00:00:00+00:00 0.489892 0.215381 0.294727
2013-01-02 00:00:00+00:00 0.540165 0.228755 0.231080
2014-01-02 00:00:00+00:00 0.339649 0.273997 0.386354
```
* Use Numba-compiled loop:
```pycon
>>> @njit
... def optimize_func_nb(i, from_idx, to_idx, close):
... mean = vbt.nb.nanmean_nb(close[from_idx:to_idx])
... std = vbt.nb.nanstd_nb(close[from_idx:to_idx])
... sharpe = mean / std
... return sharpe / np.sum(sharpe)
>>> pfo = vbt.PortfolioOptimizer.from_optimize_func(
... close.vbt.wrapper,
... optimize_func_nb,
... np.asarray(close),
... index_ranges=[
... (0, 30),
... (30, 60),
... (60, 90)
... ],
... jitted_loop=True
... )
>>> pfo.allocations
symbol MSFT AMZN AAPL
Date
2010-02-17 00:00:00+00:00 0.336384 0.289598 0.374017
2010-03-31 00:00:00+00:00 0.599417 0.207158 0.193425
2010-05-13 00:00:00+00:00 0.434084 0.281246 0.284670
```
!!! hint
There is no big reason of using the Numba-compiled loop, apart from when having
to rebalance many thousands of times. Usually, using a regular Python loop
and a Numba-compiled optimization function suffice.
"""
from vectorbtpro._settings import settings
params_cfg = settings["params"]
if parameterizer is None:
parameterizer = params_cfg["parameterizer"]
if parameterizer is None:
parameterizer = Parameterizer
param_search_kwargs = merge_dicts(params_cfg["param_search_kwargs"], param_search_kwargs)
if group_execute_kwargs is None:
group_execute_kwargs = {}
if execute_kwargs is None:
execute_kwargs = {}
if clean_index_kwargs is None:
clean_index_kwargs = {}
# Prepare group config names
gc_names = []
gc_names_none = True
n_configs = 0
if group_configs is not None:
group_configs = list(group_configs)
for i, group_config in enumerate(group_configs):
if isinstance(group_configs, dict):
new_group_configs = []
for k, v in group_configs.items():
v = dict(v)
v["_name"] = k
new_group_configs.append(v)
group_configs = new_group_configs
else:
group_configs = list(group_configs)
if "args" in group_config:
for k, arg in enumerate(group_config.pop("args")):
group_config[f"args_{k}"] = arg
if "kwargs" in group_config:
for k, v in enumerate(group_config.pop("kwargs")):
group_config[k] = v
if "_name" in group_config and group_config["_name"] is not None:
gc_names.append(group_config.pop("_name"))
gc_names_none = False
else:
gc_names.append(n_configs)
group_configs[i] = group_config
n_configs += 1
else:
group_configs = []
# Combine parameters
paramable_kwargs = {
"every": every,
"normalize_every": normalize_every,
"split_every": split_every,
"start_time": start_time,
"end_time": end_time,
"lookback_period": lookback_period,
"start": start,
"end": end,
"exact_start": exact_start,
"fixed_start": fixed_start,
"closed_start": closed_start,
"closed_end": closed_end,
"add_start_delta": add_start_delta,
"add_end_delta": add_end_delta,
"kind": kind,
"skip_not_found": skip_not_found,
"index_ranges": index_ranges,
"index_loc": index_loc,
"rescale_to": rescale_to,
"alloc_wait": alloc_wait,
**{f"args_{i}": args[i] for i in range(len(args))},
**kwargs,
}
param_dct = parameterizer.find_params_in_obj(paramable_kwargs, **param_search_kwargs)
param_columns = None
if len(param_dct) > 0:
param_product, param_columns = combine_params(
param_dct,
random_subset=random_subset,
clean_index_kwargs=clean_index_kwargs,
name_tuple_to_str=name_tuple_to_str,
)
if param_columns is None:
n_param_configs = len(param_product[list(param_product.keys())[0]])
param_columns = pd.RangeIndex(stop=n_param_configs, name="param_config")
product_group_configs = parameterizer.param_product_to_objs(paramable_kwargs, param_product)
if len(group_configs) == 0:
group_configs = product_group_configs
else:
new_group_configs = []
for i in range(len(product_group_configs)):
for group_config in group_configs:
new_group_config = merge_dicts(product_group_configs[i], group_config)
new_group_configs.append(new_group_config)
group_configs = new_group_configs
# Build group index
n_config_params = len(gc_names)
if param_columns is not None:
if n_config_params == 0 or (n_config_params == 1 and gc_names_none):
group_index = param_columns
else:
group_index = combine_indexes(
(
param_columns,
pd.Index(gc_names, name="group_config"),
),
**clean_index_kwargs,
)
else:
if n_config_params == 0 or (n_config_params == 1 and gc_names_none):
group_index = pd.Index(["group"], name="group")
else:
group_index = pd.Index(gc_names, name="group_config")
# Create group config from arguments if empty
if len(group_configs) == 0:
single_group = True
group_configs.append(dict())
else:
single_group = False
# Resolve each group
groupable_kwargs = {
"optimize_func": optimize_func,
**paramable_kwargs,
"splitter_cls": splitter_cls,
"eval_id": eval_id,
"jitted_loop": jitted_loop,
"jitted": jitted,
"chunked": chunked,
"template_context": template_context,
"execute_kwargs": execute_kwargs,
}
new_group_configs = []
for group_config in group_configs:
new_group_config = merge_dicts(groupable_kwargs, group_config)
_args = ()
while True:
if f"args_{len(_args)}" in new_group_config:
_args += (new_group_config.pop(f"args_{len(_args)}"),)
else:
break
new_group_config["args"] = _args
new_group_configs.append(new_group_config)
group_configs = new_group_configs
# Generate allocations
tasks = []
for group_idx, group_config in enumerate(group_configs):
tasks.append(
Task(
cls.run_optimization_group,
wrapper=wrapper,
group_configs=group_configs,
group_index=group_index,
group_idx=group_idx,
pre_group_func=pre_group_func,
)
)
group_execute_kwargs = merge_dicts(dict(show_progress=False if single_group else None), group_execute_kwargs)
results = execute(tasks, keys=group_index, **group_execute_kwargs)
alloc_ranges, allocations = zip(*results)
# Build column hierarchy
new_columns = combine_indexes((group_index, wrapper.columns), **clean_index_kwargs)
# Create instance
wrapper_kwargs = merge_dicts(
dict(
index=wrapper.index,
columns=new_columns,
ndim=2,
freq=wrapper.freq,
column_only_select=False,
range_only_select=True,
group_select=True,
grouped_ndim=1 if single_group else 2,
group_by=group_index.names if group_index.nlevels > 1 else group_index.name,
allow_enable=False,
allow_disable=True,
allow_modify=False,
),
wrapper_kwargs,
)
new_wrapper = ArrayWrapper(**wrapper_kwargs)
alloc_ranges = AllocRanges(
ArrayWrapper(
index=wrapper.index,
columns=new_wrapper.get_columns(),
ndim=new_wrapper.get_ndim(),
freq=wrapper.freq,
column_only_select=False,
range_only_select=True,
),
np.concatenate(alloc_ranges),
)
allocations = row_stack_arrays(allocations)
return cls(new_wrapper, alloc_ranges, allocations)
@classmethod
def from_pypfopt(
cls: tp.Type[PortfolioOptimizerT],
wrapper: tp.Optional[ArrayWrapper] = None,
**kwargs,
) -> PortfolioOptimizerT:
"""`PortfolioOptimizer.from_optimize_func` applied on `pypfopt_optimize`.
If a wrapper is not provided, parses the wrapper from `prices` or `returns`, if provided."""
if wrapper is None:
if "prices" in kwargs:
wrapper = ArrayWrapper.from_obj(kwargs["prices"])
elif "returns" in kwargs:
wrapper = ArrayWrapper.from_obj(kwargs["returns"])
else:
raise TypeError("Must provide a wrapper if price and returns are not set")
else:
checks.assert_instance_of(wrapper, ArrayWrapper, arg_name="wrapper")
if "prices" in kwargs and not isinstance(kwargs["prices"], CustomTemplate):
kwargs["prices"] = RepFunc(lambda index_slice, _prices=kwargs["prices"]: _prices.iloc[index_slice])
if "returns" in kwargs and not isinstance(kwargs["returns"], CustomTemplate):
kwargs["returns"] = RepFunc(lambda index_slice, _returns=kwargs["returns"]: _returns.iloc[index_slice])
return cls.from_optimize_func(wrapper, pypfopt_optimize, **kwargs)
@classmethod
def from_riskfolio(
cls: tp.Type[PortfolioOptimizerT],
returns: tp.AnyArray2d,
wrapper: tp.Optional[ArrayWrapper] = None,
**kwargs,
) -> PortfolioOptimizerT:
"""`PortfolioOptimizer.from_optimize_func` applied on Riskfolio-Lib."""
if wrapper is None:
if not isinstance(returns, CustomTemplate):
wrapper = ArrayWrapper.from_obj(returns)
else:
raise TypeError("Must provide a wrapper if returns are a template")
else:
checks.assert_instance_of(wrapper, ArrayWrapper, arg_name="wrapper")
if not isinstance(returns, CustomTemplate):
returns = RepFunc(lambda index_slice, _returns=returns: _returns.iloc[index_slice])
return cls.from_optimize_func(wrapper, riskfolio_optimize, returns, **kwargs)
# ############# Properties ############# #
@property
def alloc_records(self) -> tp.Union[AllocRanges, AllocPoints]:
"""Allocation ranges of type `vectorbtpro.portfolio.pfopt.records.AllocRanges`
or points of type `vectorbtpro.portfolio.pfopt.records.AllocPoints`."""
return self._alloc_records
def get_allocations(self, squeeze_groups: bool = True) -> tp.Frame:
"""Get a DataFrame with allocation groups concatenated along the index axis."""
idx_arr = self.alloc_records.get_field_arr("idx")
group_arr = self.alloc_records.col_arr
allocations = self._allocations
if isinstance(self.alloc_records, AllocRanges):
closed_mask = self.alloc_records.get_field_arr("status") == RangeStatus.Closed
idx_arr = idx_arr[closed_mask]
group_arr = group_arr[closed_mask]
allocations = allocations[closed_mask]
if squeeze_groups and self.wrapper.grouped_ndim == 1:
index = self.wrapper.index[idx_arr]
else:
index = stack_indexes((self.alloc_records.wrapper.columns[group_arr], self.wrapper.index[idx_arr]))
n_group_levels = self.wrapper.grouper.get_index().nlevels
columns = self.wrapper.columns.droplevel(tuple(range(n_group_levels))).unique()
return pd.DataFrame(allocations, index=index, columns=columns)
@property
def allocations(self) -> tp.Frame:
"""Calls `PortfolioOptimizer.get_allocations` with default arguments."""
return self.get_allocations()
@property
def mean_allocation(self) -> tp.Series:
"""Get the mean allocation per column."""
group_level_names = self.wrapper.grouper.get_index().names
return self.get_allocations(squeeze_groups=False).groupby(group_level_names).mean().transpose()
def fill_allocations(
self,
dropna: tp.Optional[str] = None,
fill_value: tp.Scalar = np.nan,
wrap_kwargs: tp.KwargsLike = None,
squeeze_groups: bool = True,
) -> tp.Frame:
"""Fill an empty DataFrame with allocations.
Set `dropna` to 'all' to remove all NaN rows, or to 'head' to remove any rows coming before
the first allocation."""
if wrap_kwargs is None:
wrap_kwargs = {}
out = self.wrapper.fill(fill_value, group_by=False, **wrap_kwargs)
idx_arr = self.alloc_records.get_field_arr("idx")
group_arr = self.alloc_records.col_arr
allocations = self._allocations
if isinstance(self.alloc_records, AllocRanges):
status_arr = self.alloc_records.get_field_arr("status")
closed_mask = status_arr == RangeStatus.Closed
idx_arr = idx_arr[closed_mask]
group_arr = group_arr[closed_mask]
allocations = allocations[closed_mask]
for g in range(len(self.alloc_records.wrapper.columns)):
group_mask = group_arr == g
index_mask = np.full(len(self.wrapper.index), False)
index_mask[idx_arr[group_mask]] = True
column_mask = self.wrapper.grouper.get_groups() == g
out.loc[index_mask, column_mask] = allocations[group_mask]
if dropna is not None:
if dropna.lower() == "all":
out = out.dropna(how="all")
elif dropna.lower() == "head":
out = out.iloc[idx_arr.min() :]
else:
raise ValueError(f"Invalid dropna: '{dropna}'")
if squeeze_groups and self.wrapper.grouped_ndim == 1:
n_group_levels = self.wrapper.grouper.get_index().nlevels
out = out.droplevel(tuple(range(n_group_levels)), axis=1)
return out
@property
def filled_allocations(self) -> tp.Frame:
"""Calls `PortfolioOptimizer.fill_allocations` with default arguments."""
return self.fill_allocations()
# ############# Simulation ############# #
def simulate(self, close: tp.Union[tp.ArrayLike, Data], **kwargs) -> PortfolioT:
"""Run `vectorbtpro.portfolio.base.Portfolio.from_optimizer` on this instance."""
from vectorbtpro.portfolio.base import Portfolio
return Portfolio.from_optimizer(close, self, **kwargs)
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `PortfolioOptimizer.stats`.
Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
`stats` from `vectorbtpro._settings.pfopt`."""
from vectorbtpro._settings import settings
pfopt_stats_cfg = settings["pfopt"]["stats"]
return merge_dicts(Analyzable.stats_defaults.__get__(self), pfopt_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
total_records=dict(title="Total Records", calc_func="alloc_records.count", tags="alloc_records"),
coverage=dict(
title="Coverage",
calc_func="alloc_records.get_coverage",
overlapping=False,
check_alloc_ranges=True,
tags=["alloc_ranges", "coverage"],
),
overlap_coverage=dict(
title="Overlap Coverage",
calc_func="alloc_records.get_coverage",
overlapping=True,
check_alloc_ranges=True,
tags=["alloc_ranges", "coverage"],
),
mean_allocation=dict(
title="Mean Allocation",
calc_func="mean_allocation",
post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"),
tags="allocations",
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot(
self,
column: tp.Optional[tp.Label] = None,
dropna: tp.Optional[str] = "head",
line_shape: str = "hv",
plot_rb_dates: tp.Optional[bool] = None,
trace_kwargs: tp.KwargsLikeSequence = None,
add_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot allocations.
Args:
column (str): Name of the allocation group to plot.
dropna (int): See `PortfolioOptimizer.fill_allocations`.
line_shape (str): Line shape.
plot_rb_dates (bool): Whether to plot rebalancing dates.
Defaults to True if there are no more than 20 rebalancing dates.
trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter`.
add_shape_kwargs (dict): Keyword arguments passed to `fig.add_shape` for rebalancing dates.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
* Continuing with the examples under `PortfolioOptimizer.from_optimize_func`:
```pycon
>>> from vectorbtpro import *
>>> pfo = vbt.PortfolioOptimizer.from_random(
... vbt.ArrayWrapper(
... index=pd.date_range("2020-01-01", "2021-01-01"),
... columns=["MSFT", "AMZN", "AAPL"],
... ndim=2
... ),
... every="MS",
... seed=40
... )
>>> pfo.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from vectorbtpro.utils.figure import make_figure
self_group = self.select_col(column=column)
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
if self_group.alloc_records.count() > 0:
filled_allocations = self_group.fill_allocations(dropna=dropna).ffill()
fig = filled_allocations.vbt.areaplot(
line_shape=line_shape,
trace_kwargs=trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if plot_rb_dates is None or (isinstance(plot_rb_dates, bool) and plot_rb_dates):
rb_dates = self_group.allocations.index
if plot_rb_dates is None:
plot_rb_dates = len(rb_dates) <= 20
if plot_rb_dates:
add_shape_kwargs = merge_dicts(
dict(
type="line",
line=dict(
color=fig.layout.template.layout.plot_bgcolor,
dash="dot",
width=1,
),
xref="x",
yref="paper",
y0=0,
y1=1,
),
add_shape_kwargs,
)
for rb_date in rb_dates:
fig.add_shape(x0=rb_date, x1=rb_date, **add_shape_kwargs)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `PortfolioOptimizer.plots`.
Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
`plots` from `vectorbtpro._settings.pfopt`."""
from vectorbtpro._settings import settings
pfopt_plots_cfg = settings["pfopt"]["plots"]
return merge_dicts(Analyzable.plots_defaults.__get__(self), pfopt_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
alloc_ranges=dict(
title="Allocation Ranges",
plot_func="alloc_records.plot",
check_alloc_ranges=True,
tags="alloc_ranges",
),
plot=dict(
title="Allocations",
plot_func="plot",
tags="allocations",
),
)
)
@property
def subplots(self) -> Config:
return self._subplots
PortfolioOptimizer.override_metrics_doc(__pdoc__)
PortfolioOptimizer.override_subplots_doc(__pdoc__)
PFO = PortfolioOptimizer
"""Shortcut for `PortfolioOptimizer`."""
__pdoc__["PFO"] = False
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for portfolio optimization."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.portfolio.enums import Direction, alloc_point_dt, alloc_range_dt
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.array_ import rescale_nb
__all__ = []
@register_jitted(cache=True)
def get_alloc_points_nb(
filled_allocations: tp.Array2d,
valid_only: bool = True,
nonzero_only: bool = True,
unique_only: bool = True,
) -> tp.Array1d:
"""Get allocation points from filled allocations.
If `valid_only` is True, does not register a new allocation when all points are NaN.v
If `nonzero_only` is True, does not register a new allocation when all points are zero.
If `unique_only` is True, does not register a new allocation when it's the same as the last one."""
out = np.empty(len(filled_allocations), dtype=int_)
k = 0
for i in range(filled_allocations.shape[0]):
all_nan = True
all_zeros = True
all_same = True
for col in range(filled_allocations.shape[1]):
if not np.isnan(filled_allocations[i, col]):
all_nan = False
if abs(filled_allocations[i, col]) > 0:
all_zeros = False
if k == 0 or (k > 0 and filled_allocations[i, col] != filled_allocations[out[k - 1], col]):
all_same = False
if valid_only and all_nan:
continue
if nonzero_only and all_zeros:
continue
if unique_only and all_same:
continue
out[k] = i
k += 1
return out[:k]
@register_chunkable(
size=ch.ArraySizer(arg_query="range_starts", axis=0),
arg_take_spec=dict(
n_cols=None,
range_starts=ch.ArraySlicer(axis=0),
range_ends=ch.ArraySlicer(axis=0),
optimize_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="row_stack",
)
@register_jitted(tags={"can_parallel"})
def optimize_meta_nb(
n_cols: int,
range_starts: tp.Array1d,
range_ends: tp.Array1d,
optimize_func_nb: tp.Callable,
*args,
) -> tp.Array2d:
"""Optimize by reducing each index range.
`reduce_func_nb` must take the range index, the range start, the range end, and `*args`.
Must return a 1-dim array with the same size as `n_cols`."""
out = np.empty((range_starts.shape[0], n_cols), dtype=float_)
for i in prange(len(range_starts)):
out[i] = optimize_func_nb(i, range_starts[i], range_ends[i], *args)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="index_points", axis=0),
arg_take_spec=dict(
n_cols=None,
index_points=ch.ArraySlicer(axis=0),
allocate_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="row_stack",
)
@register_jitted(tags={"can_parallel"})
def allocate_meta_nb(
n_cols: int,
index_points: tp.Array1d,
allocate_func_nb: tp.Callable,
*args,
) -> tp.Array2d:
"""Allocate by mapping each index point.
`map_func_nb` must take the point index, the index point, and `*args`.
Must return a 1-dim array with the same size as `n_cols`."""
out = np.empty((index_points.shape[0], n_cols), dtype=float_)
for i in prange(len(index_points)):
out[i] = allocate_func_nb(i, index_points[i], *args)
return out
@register_jitted(cache=True)
def pick_idx_allocate_func_nb(i: int, index_point: int, allocations: tp.Array2d) -> tp.Array1d:
"""Pick the allocation at an absolute position in an array."""
return allocations[i]
@register_jitted(cache=True)
def pick_point_allocate_func_nb(i: int, index_point: int, allocations: tp.Array2d) -> tp.Array1d:
"""Pick the allocation at an index point in an array."""
return allocations[index_point]
@register_jitted(cache=True)
def random_allocate_func_nb(
i: int,
index_point: int,
n_cols: int,
direction: int = Direction.LongOnly,
n: tp.Optional[int] = None,
) -> tp.Array1d:
"""Generate a random allocation."""
weights = np.full(n_cols, np.nan, dtype=float_)
pos_sum = 0
neg_sum = 0
if n is None:
for c in range(n_cols):
w = np.random.uniform(0, 1)
if direction == Direction.ShortOnly:
w = -w
elif direction == Direction.Both:
if np.random.randint(0, 2) == 0:
w = -w
if w >= 0:
pos_sum += w
else:
neg_sum += abs(w)
weights[c] = w
else:
rand_indices = np.random.choice(n_cols, size=n, replace=False)
for k in range(len(rand_indices)):
w = np.random.uniform(0, 1)
if direction == Direction.ShortOnly:
w = -w
elif direction == Direction.Both:
if np.random.randint(0, 2) == 0:
w = -w
if w >= 0:
pos_sum += w
else:
neg_sum += abs(w)
weights[rand_indices[k]] = w
for c in range(n_cols):
if not np.isnan(weights[c]):
if weights[c] >= 0:
if pos_sum > 0:
weights[c] = weights[c] / pos_sum
else:
if neg_sum > 0:
weights[c] = weights[c] / neg_sum
else:
weights[c] = 0.0
return weights
@register_jitted(cache=True)
def prepare_alloc_points_nb(
index_points: tp.Array1d,
allocations: tp.Array2d,
group: int,
) -> tp.Tuple[tp.RecordArray, tp.Array2d]:
"""Prepare allocation points."""
alloc_points = np.empty_like(index_points, dtype=alloc_point_dt)
new_allocations = np.empty_like(allocations)
k = 0
for i in range(allocations.shape[0]):
all_nan = True
for col in range(allocations.shape[1]):
if not np.isnan(allocations[i, col]):
all_nan = False
break
if all_nan:
continue
if k > 0 and alloc_points["alloc_idx"][k - 1] == index_points[i]:
new_allocations[k - 1] = allocations[i]
else:
alloc_points["id"][k] = k
alloc_points["col"][k] = group
alloc_points["alloc_idx"][k] = index_points[i]
new_allocations[k] = allocations[i]
k += 1
return alloc_points[:k], new_allocations[:k]
@register_jitted(cache=True)
def prepare_alloc_ranges_nb(
start_idx: tp.Array1d,
end_idx: tp.Array1d,
alloc_idx: tp.Array1d,
status: tp.Array1d,
allocations: tp.Array2d,
group: int,
) -> tp.Tuple[tp.RecordArray, tp.Array2d]:
"""Prepare allocation ranges."""
alloc_ranges = np.empty_like(alloc_idx, dtype=alloc_range_dt)
new_allocations = np.empty_like(allocations)
k = 0
for i in range(allocations.shape[0]):
all_nan = True
for col in range(allocations.shape[1]):
if not np.isnan(allocations[i, col]):
all_nan = False
break
if all_nan:
continue
if k > 0 and alloc_ranges["alloc_idx"][k - 1] == alloc_idx[i]:
new_allocations[k - 1] = allocations[i]
else:
alloc_ranges["id"][k] = k
alloc_ranges["col"][k] = group
alloc_ranges["start_idx"][k] = start_idx[i]
alloc_ranges["end_idx"][k] = end_idx[i]
alloc_ranges["alloc_idx"][k] = alloc_idx[i]
alloc_ranges["status"][k] = status[i]
new_allocations[k] = allocations[i]
k += 1
return alloc_ranges[:k], new_allocations[:k]
@register_jitted(cache=True)
def rescale_allocations_nb(allocations: tp.Array2d, to_range: tp.Tuple[float, float]) -> tp.Array2d:
"""Rescale allocations to a new scale.
Positive and negative weights are rescaled separately from each other."""
new_min, new_max = to_range
if np.isnan(new_min) or np.isinf(new_min):
raise ValueError("Minimum of the new scale must be finite")
if np.isnan(new_max) or np.isinf(new_max):
raise ValueError("Maximum of the new scale must be finite")
if new_min >= new_max:
raise ValueError("Minimum cannot be equal to or higher than maximum")
out = np.empty_like(allocations, dtype=float_)
for i in range(allocations.shape[0]):
all_nan = True
all_zero = True
pos_sum = 0.0
neg_sum = 0.0
for col in range(allocations.shape[1]):
if np.isnan(allocations[i, col]):
continue
all_nan = False
if allocations[i, col] > 0:
all_zero = False
pos_sum += allocations[i, col]
elif allocations[i, col] < 0:
all_zero = False
neg_sum += abs(allocations[i, col])
if all_nan:
out[i] = np.nan
continue
if all_zero:
out[i] = 0.0
continue
if new_max <= 0 and pos_sum > 0:
raise ValueError("Cannot rescale positive weights to a negative scale")
if new_min >= 0 and neg_sum > 0:
raise ValueError("Cannot rescale negative weights to a positive scale")
for col in range(allocations.shape[1]):
if np.isnan(allocations[i, col]):
out[i, col] = np.nan
continue
if allocations[i, col] > 0:
out[i, col] = rescale_nb(allocations[i, col] / pos_sum, (0.0, 1.0), (max(0.0, new_min), new_max))
elif allocations[i, col] < 0:
out[i, col] = rescale_nb(abs(allocations[i, col]) / neg_sum, (0.0, 1.0), (min(new_max, 0.0), new_min))
else:
out[i, col] = 0.0
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for working with allocation records."""
from vectorbtpro import _typing as tp
from vectorbtpro.generic.ranges import Ranges
from vectorbtpro.portfolio.enums import alloc_range_dt, alloc_point_dt
from vectorbtpro.records.base import Records
from vectorbtpro.records.decorators import override_field_config
from vectorbtpro.utils.config import ReadonlyConfig, Config
__all__ = [
"AllocRanges",
"AllocPoints",
]
__pdoc__ = {}
# ############# AllocRanges ############# #
alloc_ranges_field_config = ReadonlyConfig(
dict(
dtype=alloc_range_dt,
settings={
"idx": dict(name="alloc_idx"), # remap field of Records
"col": dict(title="Group", mapping="groups", group_indexing=True), # remap field of Records
"alloc_idx": dict(title="Allocation Index", mapping="index"),
},
)
)
"""_"""
__pdoc__[
"alloc_ranges_field_config"
] = f"""Field config for `AllocRanges`.
```python
{alloc_ranges_field_config.prettify()}
```
"""
AllocRangesT = tp.TypeVar("AllocRangesT", bound="AllocRanges")
@override_field_config(alloc_ranges_field_config)
class AllocRanges(Ranges):
"""Extends `vectorbtpro.records.base.Records` for working with allocation point records."""
@property
def field_config(self) -> Config:
return self._field_config
AllocRanges.override_field_config_doc(__pdoc__)
# ############# AllocPoints ############# #
alloc_points_field_config = ReadonlyConfig(
dict(
dtype=alloc_point_dt,
settings={
"idx": dict(name="alloc_idx"), # remap field of Records
"col": dict(title="Group", mapping="groups", group_indexing=True), # remap field of Records
"alloc_idx": dict(title="Allocation Index", mapping="index"),
},
)
)
"""_"""
__pdoc__[
"alloc_points_field_config"
] = f"""Field config for `AllocRanges`.
```python
{alloc_points_field_config.prettify()}
```
"""
AllocPointsT = tp.TypeVar("AllocPointsT", bound="AllocPoints")
@override_field_config(alloc_points_field_config)
class AllocPoints(Records):
"""Extends `vectorbtpro.generic.ranges.Ranges` for working with allocation range records."""
@property
def field_config(self) -> Config:
return self._field_config
AllocPoints.override_field_config_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for working with portfolio."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.portfolio.nb import *
from vectorbtpro.portfolio.pfopt import *
from vectorbtpro.portfolio.base import *
from vectorbtpro.portfolio.call_seq import *
from vectorbtpro.portfolio.chunking import *
from vectorbtpro.portfolio.decorators import *
from vectorbtpro.portfolio.logs import *
from vectorbtpro.portfolio.orders import *
from vectorbtpro.portfolio.preparing import *
from vectorbtpro.portfolio.trades import *
__exclude_from__all__ = [
"enums",
]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for simulating a portfolio and measuring its performance."""
import inspect
import string
from functools import partial
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.indexes import ExceptLevel
from vectorbtpro.base.merging import row_stack_arrays
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.reshaping import (
to_1d_array,
to_2d_array,
broadcast_array_to,
to_pd_array,
to_2d_shape,
)
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.data.base import OHLCDataMixin
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.generic.drawdowns import Drawdowns
from vectorbtpro.generic.sim_range import SimRangeMixin
from vectorbtpro.portfolio import nb, enums
from vectorbtpro.portfolio.decorators import attach_shortcut_properties, attach_returns_acc_methods
from vectorbtpro.portfolio.logs import Logs
from vectorbtpro.portfolio.orders import Orders, FSOrders
from vectorbtpro.portfolio.pfopt.base import PortfolioOptimizer
from vectorbtpro.portfolio.preparing import (
PFPrepResult,
BasePFPreparer,
FOPreparer,
FSPreparer,
FOFPreparer,
FDOFPreparer,
)
from vectorbtpro.portfolio.trades import Trades, EntryTrades, ExitTrades, Positions
from vectorbtpro.records.base import Records
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.returns.accessors import ReturnsAccessor
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import get_dict_attr
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.colors import adjust_opacity
from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, ReadonlyConfig, HybridConfig, atomic_dict
from vectorbtpro.utils.decorators import custom_property, cached_property, hybrid_method
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.parsing import get_func_kwargs
from vectorbtpro.utils.template import Rep, RepEval, RepFunc
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from vectorbtpro.returns.qs_adapter import QSAdapter as QSAdapterT
except ImportError:
QSAdapterT = "QSAdapter"
__all__ = [
"Portfolio",
"PF",
]
__pdoc__ = {}
def fix_wrapper_for_records(pf: "Portfolio") -> ArrayWrapper:
"""Allow flags for records that were restricted for portfolio."""
if pf.cash_sharing:
return pf.wrapper.replace(allow_enable=True, allow_modify=True)
return pf.wrapper
def records_indexing_func(
self: "Portfolio",
obj: tp.RecordArray,
wrapper_meta: dict,
cls: tp.Union[type, str],
groups_only: bool = False,
**kwargs,
) -> tp.RecordArray:
"""Apply indexing function on records."""
wrapper = fix_wrapper_for_records(self)
if groups_only:
wrapper = wrapper.resolve()
wrapper_meta = dict(wrapper_meta)
wrapper_meta["col_idxs"] = wrapper_meta["group_idxs"]
if isinstance(cls, str):
cls = getattr(self, cls)
records = cls(wrapper, obj)
records_meta = records.indexing_func_meta(wrapper_meta=wrapper_meta)
return records.indexing_func(records_meta=records_meta).values
def records_resample_func(
self: "Portfolio",
obj: tp.ArrayLike,
resampler: tp.Union[Resampler, tp.PandasResampler],
wrapper: ArrayWrapper,
cls: tp.Union[type, str],
**kwargs,
) -> tp.RecordArray:
"""Apply resampling function on records."""
if isinstance(cls, str):
cls = getattr(self, cls)
return cls(wrapper, obj).resample(resampler).values
def returns_resample_func(
self: "Portfolio",
obj: tp.ArrayLike,
resampler: tp.Union[Resampler, tp.PandasResampler],
wrapper: ArrayWrapper,
fill_with_zero: bool = True,
log_returns: bool = False,
**kwargs,
):
"""Apply resampling function on returns."""
return (
pd.DataFrame(obj, index=wrapper.index)
.vbt.returns(log_returns=log_returns)
.resample(
resampler,
fill_with_zero=fill_with_zero,
)
.obj.values
)
returns_acc_config = ReadonlyConfig(
{
"daily_returns": dict(source_name="daily"),
"annual_returns": dict(source_name="annual"),
"cumulative_returns": dict(source_name="cumulative"),
"annualized_return": dict(source_name="annualized"),
"annualized_volatility": dict(),
"calmar_ratio": dict(),
"omega_ratio": dict(),
"sharpe_ratio": dict(),
"sharpe_ratio_std": dict(),
"prob_sharpe_ratio": dict(),
"deflated_sharpe_ratio": dict(),
"downside_risk": dict(),
"sortino_ratio": dict(),
"information_ratio": dict(),
"beta": dict(),
"alpha": dict(),
"tail_ratio": dict(),
"value_at_risk": dict(),
"cond_value_at_risk": dict(),
"capture_ratio": dict(),
"up_capture_ratio": dict(),
"down_capture_ratio": dict(),
"drawdown": dict(),
"max_drawdown": dict(),
}
)
"""_"""
__pdoc__[
"returns_acc_config"
] = f"""Config of returns accessor methods to be attached to `Portfolio`.
```python
{returns_acc_config.prettify()}
```
"""
shortcut_config = ReadonlyConfig(
{
"filled_close": dict(group_by_aware=False, decorator=cached_property),
"filled_bm_close": dict(group_by_aware=False, decorator=cached_property),
"weights": dict(group_by_aware=False, decorator=cached_property, obj_type="red_array"),
"long_view": dict(obj_type="portfolio"),
"short_view": dict(obj_type="portfolio"),
"orders": dict(
obj_type="records",
field_aliases=("order_records",),
wrap_func=lambda pf, obj, **kwargs: pf.orders_cls.from_records(
fix_wrapper_for_records(pf),
obj,
open=pf.open_flex,
high=pf.high_flex,
low=pf.low_flex,
close=pf.close_flex,
),
indexing_func=partial(records_indexing_func, cls="orders_cls"),
resample_func=partial(records_resample_func, cls="orders_cls"),
),
"logs": dict(
obj_type="records",
field_aliases=("log_records",),
wrap_func=lambda pf, obj, **kwargs: pf.logs_cls.from_records(
fix_wrapper_for_records(pf),
obj,
open=pf.open_flex,
high=pf.high_flex,
low=pf.low_flex,
close=pf.close_flex,
),
indexing_func=partial(records_indexing_func, cls="logs_cls"),
resample_func=partial(records_resample_func, cls="logs_cls"),
),
"entry_trades": dict(
obj_type="records",
field_aliases=("entry_trade_records",),
wrap_func=lambda pf, obj, **kwargs: pf.entry_trades_cls.from_records(
fix_wrapper_for_records(pf),
obj,
open=pf.open_flex,
high=pf.high_flex,
low=pf.low_flex,
close=pf.close_flex,
),
indexing_func=partial(records_indexing_func, cls="entry_trades_cls"),
resample_func=partial(records_resample_func, cls="entry_trades_cls"),
),
"exit_trades": dict(
obj_type="records",
field_aliases=("exit_trade_records",),
wrap_func=lambda pf, obj, **kwargs: pf.exit_trades_cls.from_records(
fix_wrapper_for_records(pf),
obj,
open=pf.open_flex,
high=pf.high_flex,
low=pf.low_flex,
close=pf.close_flex,
),
indexing_func=partial(records_indexing_func, cls="exit_trades_cls"),
resample_func=partial(records_resample_func, cls="exit_trades_cls"),
),
"trades": dict(
obj_type="records",
field_aliases=("trade_records",),
wrap_func=lambda pf, obj, **kwargs: pf.trades_cls.from_records(
fix_wrapper_for_records(pf),
obj,
open=pf.open_flex,
high=pf.high_flex,
low=pf.low_flex,
close=pf.close_flex,
),
indexing_func=partial(records_indexing_func, cls="trades_cls"),
resample_func=partial(records_resample_func, cls="trades_cls"),
),
"trade_history": dict(),
"signals": dict(),
"positions": dict(
obj_type="records",
field_aliases=("position_records",),
wrap_func=lambda pf, obj, **kwargs: pf.positions_cls.from_records(
fix_wrapper_for_records(pf),
obj,
open=pf.open_flex,
high=pf.high_flex,
low=pf.low_flex,
close=pf.close_flex,
),
indexing_func=partial(records_indexing_func, cls="positions_cls"),
resample_func=partial(records_resample_func, cls="positions_cls"),
),
"drawdowns": dict(
obj_type="records",
field_aliases=("drawdown_records",),
wrap_func=lambda pf, obj, **kwargs: pf.drawdowns_cls.from_records(
fix_wrapper_for_records(pf).resolve(),
obj,
),
indexing_func=partial(records_indexing_func, cls="drawdowns_cls", groups_only=True),
resample_func=partial(records_resample_func, cls="drawdowns_cls"),
),
"init_position": dict(obj_type="red_array", group_by_aware=False),
"asset_flow": dict(
group_by_aware=False,
resample_func="sum",
resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0)),
),
"long_asset_flow": dict(
method_name="get_asset_flow",
group_by_aware=False,
method_kwargs=dict(direction="longonly"),
resample_func="sum",
resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0)),
),
"short_asset_flow": dict(
method_name="get_asset_flow",
group_by_aware=False,
method_kwargs=dict(direction="shortonly"),
resample_func="sum",
resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0)),
),
"assets": dict(group_by_aware=False),
"long_assets": dict(
method_name="get_assets",
group_by_aware=False,
method_kwargs=dict(direction="longonly"),
),
"short_assets": dict(
method_name="get_assets",
group_by_aware=False,
method_kwargs=dict(direction="shortonly"),
),
"position_mask": dict(),
"long_position_mask": dict(method_name="get_position_mask", method_kwargs=dict(direction="longonly")),
"short_position_mask": dict(method_name="get_position_mask", method_kwargs=dict(direction="shortonly")),
"position_coverage": dict(obj_type="red_array"),
"long_position_coverage": dict(
method_name="get_position_coverage",
obj_type="red_array",
method_kwargs=dict(direction="longonly"),
),
"short_position_coverage": dict(
method_name="get_position_coverage",
obj_type="red_array",
method_kwargs=dict(direction="shortonly"),
),
"position_entry_price": dict(group_by_aware=False),
"position_exit_price": dict(group_by_aware=False),
"init_cash": dict(obj_type="red_array"),
"cash_deposits": dict(resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0))),
"total_cash_deposits": dict(obj_type="red_array"),
"cash_earnings": dict(resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0))),
"total_cash_earnings": dict(obj_type="red_array"),
"cash_flow": dict(resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0))),
"free_cash_flow": dict(
method_name="get_cash_flow",
method_kwargs=dict(free=True),
resample_func="sum",
resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0)),
),
"cash": dict(),
"position": dict(method_name="get_assets", group_by_aware=False),
"debt": dict(method_name=None, group_by_aware=False),
"locked_cash": dict(method_name=None, group_by_aware=False),
"free_cash": dict(method_name="get_cash", method_kwargs=dict(free=True)),
"init_price": dict(obj_type="red_array", group_by_aware=False),
"init_position_value": dict(obj_type="red_array"),
"init_value": dict(obj_type="red_array"),
"input_value": dict(obj_type="red_array"),
"asset_value": dict(),
"long_asset_value": dict(method_name="get_asset_value", method_kwargs=dict(direction="longonly")),
"short_asset_value": dict(method_name="get_asset_value", method_kwargs=dict(direction="shortonly")),
"gross_exposure": dict(),
"long_gross_exposure": dict(method_name="get_gross_exposure", method_kwargs=dict(direction="longonly")),
"short_gross_exposure": dict(method_name="get_gross_exposure", method_kwargs=dict(direction="shortonly")),
"net_exposure": dict(),
"value": dict(),
"allocations": dict(group_by_aware=False),
"long_allocations": dict(
method_name="get_allocations",
method_kwargs=dict(direction="longonly"),
group_by_aware=False,
),
"short_allocations": dict(
method_name="get_allocations",
method_kwargs=dict(direction="shortonly"),
group_by_aware=False,
),
"total_profit": dict(obj_type="red_array"),
"final_value": dict(obj_type="red_array"),
"total_return": dict(obj_type="red_array"),
"returns": dict(resample_func=returns_resample_func),
"log_returns": dict(
method_name="get_returns",
method_kwargs=dict(log_returns=True),
resample_func=partial(returns_resample_func, log_returns=True),
),
"daily_log_returns": dict(
method_name="get_returns",
method_kwargs=dict(daily_returns=True, log_returns=True),
resample_func=partial(returns_resample_func, log_returns=True),
),
"asset_pnl": dict(resample_func="sum", resample_kwargs=dict(wrap_kwargs=dict(fillna=0.0))),
"asset_returns": dict(resample_func=returns_resample_func),
"market_value": dict(),
"market_returns": dict(resample_func=returns_resample_func),
"bm_value": dict(),
"bm_returns": dict(resample_func=returns_resample_func),
"total_market_return": dict(obj_type="red_array"),
"daily_returns": dict(resample_func=returns_resample_func),
"annual_returns": dict(resample_func=returns_resample_func),
"cumulative_returns": dict(),
"annualized_return": dict(obj_type="red_array"),
"annualized_volatility": dict(obj_type="red_array"),
"calmar_ratio": dict(obj_type="red_array"),
"omega_ratio": dict(obj_type="red_array"),
"sharpe_ratio": dict(obj_type="red_array"),
"sharpe_ratio_std": dict(obj_type="red_array"),
"prob_sharpe_ratio": dict(obj_type="red_array"),
"deflated_sharpe_ratio": dict(obj_type="red_array"),
"downside_risk": dict(obj_type="red_array"),
"sortino_ratio": dict(obj_type="red_array"),
"information_ratio": dict(obj_type="red_array"),
"beta": dict(obj_type="red_array"),
"alpha": dict(obj_type="red_array"),
"tail_ratio": dict(obj_type="red_array"),
"value_at_risk": dict(obj_type="red_array"),
"cond_value_at_risk": dict(obj_type="red_array"),
"capture_ratio": dict(obj_type="red_array"),
"up_capture_ratio": dict(obj_type="red_array"),
"down_capture_ratio": dict(obj_type="red_array"),
"drawdown": dict(),
"max_drawdown": dict(obj_type="red_array"),
}
)
"""_"""
__pdoc__[
"shortcut_config"
] = f"""Config of shortcut properties to be attached to `Portfolio`.
```python
{shortcut_config.prettify()}
```
"""
PortfolioT = tp.TypeVar("PortfolioT", bound="Portfolio")
PortfolioResultT = tp.Union[PortfolioT, BasePFPreparer, PFPrepResult, enums.SimulationOutput]
class MetaPortfolio(type(Analyzable)):
"""Metaclass for `Portfolio`."""
@property
def in_output_config(cls) -> Config:
"""In-output config."""
return cls._in_output_config
@attach_shortcut_properties(shortcut_config)
@attach_returns_acc_methods(returns_acc_config)
class Portfolio(Analyzable, SimRangeMixin, metaclass=MetaPortfolio):
"""Class for simulating a portfolio and measuring its performance.
Args:
wrapper (ArrayWrapper): Array wrapper.
See `vectorbtpro.base.wrapping.ArrayWrapper`.
close (array_like): Last asset price at each time step.
order_records (array_like): A structured NumPy array of order records.
open (array_like): Open price of each bar.
high (array_like): High price of each bar.
low (array_like): Low price of each bar.
log_records (array_like): A structured NumPy array of log records.
cash_sharing (bool): Whether to share cash within the same group.
init_cash (InitCashMode or array_like of float): Initial capital.
Can be provided in a format suitable for flexible indexing.
init_position (array_like of float): Initial position.
Can be provided in a format suitable for flexible indexing.
init_price (array_like of float): Initial position price.
Can be provided in a format suitable for flexible indexing.
cash_deposits (array_like of float): Cash deposited/withdrawn at each timestamp.
Can be provided in a format suitable for flexible indexing.
cash_earnings (array_like of float): Earnings added at each timestamp.
Can be provided in a format suitable for flexible indexing.
sim_start (int, datetime_like, or array_like): Simulation start per column. Defaults to None.
sim_end (int, datetime_like, or array_like): Simulation end per column. Defaults to None.
call_seq (array_like of int): Sequence of calls per row and group. Defaults to None.
in_outputs (namedtuple): Named tuple with in-output objects.
To substitute `Portfolio` attributes, provide already broadcasted and grouped objects.
Also see `Portfolio.in_outputs_indexing_func` on how in-output objects are indexed.
use_in_outputs (bool): Whether to return in-output objects when calling properties.
bm_close (array_like): Last benchmark asset price at each time step.
fillna_close (bool): Whether to forward and backward fill NaN values in `close`.
Applied after the simulation to avoid NaNs in asset value.
See `Portfolio.get_filled_close`.
weights (array_like): Asset weights.
Applied to the initial position, initial cash, cash deposits, cash earnings, and orders.
trades_type (str or int): Default `vectorbtpro.portfolio.trades.Trades` to use across `Portfolio`.
See `vectorbtpro.portfolio.enums.TradesType`.
orders_cls (type): Class for wrapping order records.
logs_cls (type): Class for wrapping log records.
trades_cls (type): Class for wrapping trade records.
entry_trades_cls (type): Class for wrapping entry trade records.
exit_trades_cls (type): Class for wrapping exit trade records.
positions_cls (type): Class for wrapping position records.
drawdowns_cls (type): Class for wrapping drawdown records.
For defaults, see `vectorbtpro._settings.portfolio`.
!!! note
Use class methods with `from_` prefix to build a portfolio.
The `__init__` method is reserved for indexing purposes.
!!! note
This class is meant to be immutable. To change any attribute, use `Portfolio.replace`."""
_writeable_attrs: tp.WriteableAttrs = {"_in_output_config"}
@classmethod
def row_stack_objs(
cls: tp.Type[PortfolioT],
objs: tp.Sequence[tp.Any],
wrappers: tp.Sequence[ArrayWrapper],
grouping: str = "columns_or_groups",
obj_name: tp.Optional[str] = None,
obj_type: tp.Optional[str] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
cash_sharing: bool = False,
row_stack_func: tp.Optional[tp.Callable] = None,
**kwargs,
) -> tp.Any:
"""Stack (two-dimensional) objects along rows.
`row_stack_func` must take the portfolio class, and all the arguments passed to this method.
If you don't need any of the arguments, make `row_stack_func` accept them as `**kwargs`.
If all the objects are None, boolean, or empty, returns the first one."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
all_none = True
for obj in objs:
if obj is None or isinstance(obj, bool) or (checks.is_np_array(obj) and obj.size == 0):
if not checks.is_deep_equal(obj, objs[0]):
raise ValueError(f"Cannot unify scalar in-outputs with the name '{obj_name}'")
else:
all_none = False
break
if all_none:
return objs[0]
if row_stack_func is not None:
return row_stack_func(
cls,
objs,
wrappers,
grouping=grouping,
obj_name=obj_name,
obj_type=obj_type,
wrapper=wrapper,
**kwargs,
)
if grouping == "columns_or_groups":
obj_group_by = None
elif grouping == "columns":
obj_group_by = False
elif grouping == "groups":
obj_group_by = None
elif grouping == "cash_sharing":
obj_group_by = None if cash_sharing else False
else:
raise ValueError(f"Grouping '{grouping}' is not supported")
if obj_type is None and checks.is_np_array(objs[0]):
n_cols = wrapper.get_shape_2d(group_by=obj_group_by)[1]
can_stack = (objs[0].ndim == 1 and n_cols == 1) or (objs[0].ndim == 2 and objs[0].shape[1] == n_cols)
elif obj_type is not None and obj_type == "array":
can_stack = True
else:
can_stack = False
if can_stack:
wrapped_objs = []
for i, obj in enumerate(objs):
wrapped_objs.append(wrappers[i].wrap(obj, group_by=obj_group_by))
return wrapper.row_stack_arrs(*wrapped_objs, group_by=obj_group_by, wrap=False)
raise ValueError(f"Cannot figure out how to stack in-outputs with the name '{obj_name}' along rows")
@classmethod
def row_stack_in_outputs(
cls: tp.Type[PortfolioT],
*objs: tp.MaybeTuple[PortfolioT],
**kwargs,
) -> tp.Optional[tp.NamedTuple]:
"""Stack `Portfolio.in_outputs` along rows.
All in-output tuples must be either None or have the same fields.
If the field can be found in the attributes of this `Portfolio` instance, reads the
attribute's options to get requirements for the type and layout of the in-output object.
For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with
`Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`.
Performs stacking on the in-output objects of the same field using `Portfolio.row_stack_objs`."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
all_none = True
for obj in objs:
if obj.in_outputs is not None:
all_none = False
break
if all_none:
return None
all_keys = set()
for obj in objs:
all_keys |= set(obj.in_outputs._asdict().keys())
for obj in objs:
if obj.in_outputs is None or len(all_keys.difference(set(obj.in_outputs._asdict().keys()))) > 0:
raise ValueError("Objects to be merged must have the same in-output fields")
cls_dir = set(dir(cls))
new_in_outputs = {}
for field in objs[0].in_outputs._asdict().keys():
field_options = merge_dicts(
cls.parse_field_options(field),
cls.in_output_config.get(field, None),
)
if field_options.get("field", field) in cls_dir:
prop = getattr(cls, field_options["field"])
prop_options = getattr(prop, "options", {})
obj_type = prop_options.get("obj_type", "array")
group_by_aware = prop_options.get("group_by_aware", True)
row_stack_func = prop_options.get("row_stack_func", None)
else:
obj_type = None
group_by_aware = True
row_stack_func = None
_kwargs = merge_dicts(
dict(
grouping=field_options.get("grouping", "columns_or_groups" if group_by_aware else "columns"),
obj_name=field_options.get("field", field),
obj_type=field_options.get("obj_type", obj_type),
row_stack_func=field_options.get("row_stack_func", row_stack_func),
),
kwargs,
)
new_field_obj = cls.row_stack_objs(
[getattr(obj.in_outputs, field) for obj in objs],
[obj.wrapper for obj in objs],
**_kwargs,
)
new_in_outputs[field] = new_field_obj
return type(objs[0].in_outputs)(**new_in_outputs)
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[PortfolioT],
*objs: tp.MaybeTuple[PortfolioT],
wrapper_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
combine_init_cash: bool = False,
combine_init_position: bool = False,
combine_init_price: bool = False,
**kwargs,
) -> PortfolioT:
"""Stack multiple `Portfolio` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.
Cash sharing must be the same among all objects.
Close, benchmark close, cash deposits, cash earnings, call sequence, and other two-dimensional arrays
are stacked using `vectorbtpro.base.wrapping.ArrayWrapper.row_stack_arrs`. In-outputs
are stacked using `Portfolio.row_stack_in_outputs`. Records are stacked using
`vectorbtpro.records.base.Records.row_stack_records_arrs`.
If the initial cash of each object is one of the options in `vectorbtpro.portfolio.enums.InitCashMode`,
it will be retained for the resulting object. Once any of the objects has the initial cash listed
as an absolute amount or an array, the initial cash of the first object will be copied over to
the final object, while the initial cash of all other objects will be resolved and used
as cash deposits, unless they all are zero. Set `combine_init_cash` to True to simply sum all
initial cash arrays.
If only the first object has an initial position greater than zero, it will be copied over to
the final object. Otherwise, an error will be thrown, unless `combine_init_position` is enabled
to sum all initial position arrays. The same goes for the initial price, which becomes
a candidate for stacking only if any of the arrays are not NaN.
!!! note
When possible, avoid using initial position and price in objects to be stacked:
there is currently no way of injecting them in the correct order, while simply taking
the sum or weighted average may distort the reality since they weren't available
prior to the actual simulation."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Portfolio):
raise TypeError("Each object to be merged must be an instance of Portfolio")
_objs = list(map(lambda x: x.disable_weights(), objs))
if "wrapper" not in kwargs:
wrapper_kwargs = merge_dicts(dict(group_by=group_by), wrapper_kwargs)
kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in _objs], **wrapper_kwargs)
for i in range(1, len(_objs)):
if _objs[i].cash_sharing != _objs[0].cash_sharing:
raise ValueError("Objects to be merged must have the same 'cash_sharing'")
kwargs["cash_sharing"] = _objs[0].cash_sharing
cs_group_by = None if kwargs["cash_sharing"] else False
cs_n_cols = kwargs["wrapper"].get_shape_2d(group_by=cs_group_by)[1]
n_cols = kwargs["wrapper"].shape_2d[1]
if "close" not in kwargs:
kwargs["close"] = kwargs["wrapper"].row_stack_arrs(
*[obj.close for obj in _objs],
group_by=False,
wrap=False,
)
if "open" not in kwargs:
stack_open_objs = True
for obj in _objs:
if obj._open is None:
stack_open_objs = False
break
if stack_open_objs:
kwargs["open"] = kwargs["wrapper"].row_stack_arrs(
*[obj.open for obj in _objs],
group_by=False,
wrap=False,
)
if "high" not in kwargs:
stack_high_objs = True
for obj in _objs:
if obj._high is None:
stack_high_objs = False
break
if stack_high_objs:
kwargs["high"] = kwargs["wrapper"].row_stack_arrs(
*[obj.high for obj in _objs],
group_by=False,
wrap=False,
)
if "low" not in kwargs:
stack_low_objs = True
for obj in _objs:
if obj._low is None:
stack_low_objs = False
break
if stack_low_objs:
kwargs["low"] = kwargs["wrapper"].row_stack_arrs(
*[obj.low for obj in _objs],
group_by=False,
wrap=False,
)
if "order_records" not in kwargs:
kwargs["order_records"] = Orders.row_stack_records_arrs(*[obj.orders for obj in _objs], **kwargs)
if "log_records" not in kwargs:
kwargs["log_records"] = Logs.row_stack_records_arrs(*[obj.logs for obj in _objs], **kwargs)
if "init_cash" not in kwargs:
stack_init_cash_objs = False
for obj in _objs:
if not checks.is_int(obj._init_cash) or obj._init_cash not in enums.InitCashMode:
stack_init_cash_objs = True
break
if stack_init_cash_objs:
stack_init_cash_objs = False
init_cash_objs = []
for i, obj in enumerate(_objs):
init_cash_obj = obj.get_init_cash(group_by=cs_group_by)
init_cash_obj = to_1d_array(init_cash_obj)
init_cash_obj = broadcast_array_to(init_cash_obj, cs_n_cols)
if i > 0 and (init_cash_obj != 0).any():
stack_init_cash_objs = True
init_cash_objs.append(init_cash_obj)
if stack_init_cash_objs:
if not combine_init_cash:
cash_deposits_objs = []
for i, obj in enumerate(_objs):
cash_deposits_obj = obj.get_cash_deposits(group_by=cs_group_by)
cash_deposits_obj = to_2d_array(cash_deposits_obj)
cash_deposits_obj = broadcast_array_to(
cash_deposits_obj,
(cash_deposits_obj.shape[0], cs_n_cols),
)
cash_deposits_obj = cash_deposits_obj.copy()
if i > 0:
cash_deposits_obj[0] = init_cash_objs[i]
cash_deposits_objs.append(cash_deposits_obj)
kwargs["cash_deposits"] = row_stack_arrays(cash_deposits_objs)
kwargs["init_cash"] = init_cash_objs[0]
else:
kwargs["init_cash"] = np.asarray(init_cash_objs).sum(axis=0)
else:
kwargs["init_cash"] = init_cash_objs[0]
if "init_position" not in kwargs:
stack_init_position_objs = False
init_position_objs = []
for i, obj in enumerate(_objs):
init_position_obj = obj.get_init_position()
init_position_obj = to_1d_array(init_position_obj)
init_position_obj = broadcast_array_to(init_position_obj, n_cols)
if i > 0 and (init_position_obj != 0).any():
stack_init_position_objs = True
init_position_objs.append(init_position_obj)
if stack_init_position_objs:
if not combine_init_position:
raise ValueError("Initial position cannot be stacked along rows")
kwargs["init_position"] = np.asarray(init_position_objs).sum(axis=0)
else:
kwargs["init_position"] = init_position_objs[0]
if "init_price" not in kwargs:
stack_init_price_objs = False
init_position_objs = []
init_price_objs = []
for i, obj in enumerate(_objs):
init_position_obj = obj.get_init_position()
init_position_obj = to_1d_array(init_position_obj)
init_position_obj = broadcast_array_to(init_position_obj, n_cols)
init_price_obj = obj.get_init_price()
init_price_obj = to_1d_array(init_price_obj)
init_price_obj = broadcast_array_to(init_price_obj, n_cols)
if i > 0 and (init_position_obj != 0).any() and not np.isnan(init_price_obj).all():
stack_init_price_objs = True
init_position_objs.append(init_position_obj)
init_price_objs.append(init_price_obj)
if stack_init_price_objs:
if not combine_init_price:
raise ValueError("Initial price cannot be stacked along rows")
init_position_objs = np.asarray(init_position_objs)
init_price_objs = np.asarray(init_price_objs)
mask1 = (init_position_objs != 0).any(axis=1)
mask2 = (~np.isnan(init_price_objs)).any(axis=1)
mask = mask1 & mask2
init_position_objs = init_position_objs[mask]
init_price_objs = init_price_objs[mask]
nom = (init_position_objs * init_price_objs).sum(axis=0)
denum = init_position_objs.sum(axis=0)
kwargs["init_price"] = nom / denum
else:
kwargs["init_price"] = init_price_objs[0]
if "cash_deposits" not in kwargs:
stack_cash_deposits_objs = False
for obj in _objs:
if obj._cash_deposits.size > 1 or obj._cash_deposits.item() != 0:
stack_cash_deposits_objs = True
break
if stack_cash_deposits_objs:
kwargs["cash_deposits"] = kwargs["wrapper"].row_stack_arrs(
*[obj.get_cash_deposits(group_by=cs_group_by) for obj in _objs],
group_by=cs_group_by,
wrap=False,
)
else:
kwargs["cash_deposits"] = np.array([[0.0]])
if "cash_earnings" not in kwargs:
stack_cash_earnings_objs = False
for obj in _objs:
if obj._cash_earnings.size > 1 or obj._cash_earnings.item() != 0:
stack_cash_earnings_objs = True
break
if stack_cash_earnings_objs:
kwargs["cash_earnings"] = kwargs["wrapper"].row_stack_arrs(
*[obj.get_cash_earnings(group_by=False) for obj in _objs],
group_by=False,
wrap=False,
)
else:
kwargs["cash_earnings"] = np.array([[0.0]])
if "call_seq" not in kwargs:
stack_call_seq_objs = True
for obj in _objs:
if obj.config["call_seq"] is None:
stack_call_seq_objs = False
break
if stack_call_seq_objs:
kwargs["call_seq"] = kwargs["wrapper"].row_stack_arrs(
*[obj.call_seq for obj in _objs],
group_by=False,
wrap=False,
)
if "bm_close" not in kwargs:
stack_bm_close_objs = True
for obj in _objs:
if obj._bm_close is None or isinstance(obj._bm_close, bool):
stack_bm_close_objs = False
break
if stack_bm_close_objs:
kwargs["bm_close"] = kwargs["wrapper"].row_stack_arrs(
*[obj.bm_close for obj in _objs],
group_by=False,
wrap=False,
)
if "in_outputs" not in kwargs:
kwargs["in_outputs"] = cls.row_stack_in_outputs(*_objs, **kwargs)
if "sim_start" not in kwargs:
kwargs["sim_start"] = cls.row_stack_sim_start(kwargs["wrapper"], *_objs)
if "sim_end" not in kwargs:
kwargs["sim_end"] = cls.row_stack_sim_end(kwargs["wrapper"], *_objs)
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
@classmethod
def column_stack_objs(
cls: tp.Type[PortfolioT],
objs: tp.Sequence[tp.Any],
wrappers: tp.Sequence[ArrayWrapper],
grouping: str = "columns_or_groups",
obj_name: tp.Optional[str] = None,
obj_type: tp.Optional[str] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
cash_sharing: bool = False,
column_stack_func: tp.Optional[tp.Callable] = None,
**kwargs,
) -> tp.Any:
"""Stack (one and two-dimensional) objects along column.
`column_stack_func` must take the portfolio class, and all the arguments passed to this method.
If you don't need any of the arguments, make `column_stack_func` accept them as `**kwargs`.
If all the objects are None, boolean, or empty, returns the first one."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
all_none = True
for obj in objs:
if obj is None or isinstance(obj, bool) or (checks.is_np_array(obj) and obj.size == 0):
if not checks.is_deep_equal(obj, objs[0]):
raise ValueError(f"Cannot unify scalar in-outputs with the name '{obj_name}'")
else:
all_none = False
break
if all_none:
return objs[0]
if column_stack_func is not None:
return column_stack_func(
cls,
objs,
wrappers,
grouping=grouping,
obj_name=obj_name,
obj_type=obj_type,
wrapper=wrapper,
**kwargs,
)
if grouping == "columns_or_groups":
obj_group_by = None
elif grouping == "columns":
obj_group_by = False
elif grouping == "groups":
obj_group_by = None
elif grouping == "cash_sharing":
obj_group_by = None if cash_sharing else False
else:
raise ValueError(f"Grouping '{grouping}' is not supported")
if obj_type is None and checks.is_np_array(obj):
if to_2d_shape(objs[0].shape) == wrappers[0].get_shape_2d(group_by=obj_group_by):
can_stack = True
reduced = False
elif objs[0].shape == (wrappers[0].get_shape_2d(group_by=obj_group_by)[1],):
can_stack = True
reduced = True
else:
can_stack = False
elif obj_type is not None and obj_type == "array":
can_stack = True
reduced = False
elif obj_type is not None and obj_type == "red_array":
can_stack = True
reduced = True
else:
can_stack = False
if can_stack:
if reduced:
wrapped_objs = []
for i, obj in enumerate(objs):
wrapped_objs.append(wrappers[i].wrap_reduced(obj, group_by=obj_group_by))
return wrapper.concat_arrs(*wrapped_objs, group_by=obj_group_by).values
wrapped_objs = []
for i, obj in enumerate(objs):
wrapped_objs.append(wrappers[i].wrap(obj, group_by=obj_group_by))
return wrapper.column_stack_arrs(*wrapped_objs, group_by=obj_group_by, wrap=False)
raise ValueError(f"Cannot figure out how to stack in-outputs with the name '{obj_name}' along columns")
@classmethod
def column_stack_in_outputs(
cls: tp.Type[PortfolioT],
*objs: tp.MaybeTuple[PortfolioT],
**kwargs,
) -> tp.Optional[tp.NamedTuple]:
"""Stack `Portfolio.in_outputs` along columns.
All in-output tuples must be either None or have the same fields.
If the field can be found in the attributes of this `Portfolio` instance, reads the
attribute's options to get requirements for the type and layout of the in-output object.
For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with
`Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`.
Performs stacking on the in-output objects of the same field using `Portfolio.column_stack_objs`."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
all_none = True
for obj in objs:
if obj.in_outputs is not None:
all_none = False
break
if all_none:
return None
all_keys = set()
for obj in objs:
all_keys |= set(obj.in_outputs._asdict().keys())
for obj in objs:
if obj.in_outputs is None or len(all_keys.difference(set(obj.in_outputs._asdict().keys()))) > 0:
raise ValueError("Objects to be merged must have the same in-output fields")
cls_dir = set(dir(cls))
new_in_outputs = {}
for field in objs[0].in_outputs._asdict().keys():
field_options = merge_dicts(
cls.parse_field_options(field),
cls.in_output_config.get(field, None),
)
if field_options.get("field", field) in cls_dir:
prop = getattr(cls, field_options["field"])
prop_options = getattr(prop, "options", {})
obj_type = prop_options.get("obj_type", "array")
group_by_aware = prop_options.get("group_by_aware", True)
column_stack_func = prop_options.get("column_stack_func", None)
else:
obj_type = None
group_by_aware = True
column_stack_func = None
_kwargs = merge_dicts(
dict(
grouping=field_options.get("grouping", "columns_or_groups" if group_by_aware else "columns"),
obj_name=field_options.get("field", field),
obj_type=field_options.get("obj_type", obj_type),
column_stack_func=field_options.get("column_stack_func", column_stack_func),
),
kwargs,
)
new_field_obj = cls.column_stack_objs(
[getattr(obj.in_outputs, field) for obj in objs],
[obj.wrapper for obj in objs],
**_kwargs,
)
new_in_outputs[field] = new_field_obj
return type(objs[0].in_outputs)(**new_in_outputs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[PortfolioT],
*objs: tp.MaybeTuple[PortfolioT],
wrapper_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
ffill_close: bool = False,
fbfill_close: bool = False,
**kwargs,
) -> PortfolioT:
"""Stack multiple `Portfolio` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.
Cash sharing must be the same among all objects.
Two-dimensional arrays are stacked using
`vectorbtpro.base.wrapping.ArrayWrapper.column_stack_arrs`
while one-dimensional arrays are stacked using
`vectorbtpro.base.wrapping.ArrayWrapper.concat_arrs`.
In-outputs are stacked using `Portfolio.column_stack_in_outputs`. Records are stacked using
`vectorbtpro.records.base.Records.column_stack_records_arrs`."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Portfolio):
raise TypeError("Each object to be merged must be an instance of Portfolio")
_objs = list(map(lambda x: x.disable_weights(), objs))
if "wrapper" not in kwargs:
wrapper_kwargs = merge_dicts(dict(group_by=group_by), wrapper_kwargs)
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in _objs],
**wrapper_kwargs,
)
for i in range(1, len(_objs)):
if _objs[i].cash_sharing != _objs[0].cash_sharing:
raise ValueError("Objects to be merged must have the same 'cash_sharing'")
if "cash_sharing" not in kwargs:
kwargs["cash_sharing"] = _objs[0].cash_sharing
cs_group_by = None if kwargs["cash_sharing"] else False
if "close" not in kwargs:
new_close = kwargs["wrapper"].column_stack_arrs(
*[obj.close for obj in _objs],
group_by=False,
)
if fbfill_close:
new_close = new_close.vbt.fbfill()
elif ffill_close:
new_close = new_close.vbt.ffill()
kwargs["close"] = new_close
if "open" not in kwargs:
stack_open_objs = True
for obj in _objs:
if obj._open is None:
stack_open_objs = False
break
if stack_open_objs:
kwargs["open"] = kwargs["wrapper"].column_stack_arrs(
*[obj.open for obj in _objs],
group_by=False,
wrap=False,
)
if "high" not in kwargs:
stack_high_objs = True
for obj in _objs:
if obj._high is None:
stack_high_objs = False
break
if stack_high_objs:
kwargs["high"] = kwargs["wrapper"].column_stack_arrs(
*[obj.high for obj in _objs],
group_by=False,
wrap=False,
)
if "low" not in kwargs:
stack_low_objs = True
for obj in _objs:
if obj._low is None:
stack_low_objs = False
break
if stack_low_objs:
kwargs["low"] = kwargs["wrapper"].column_stack_arrs(
*[obj.low for obj in _objs],
group_by=False,
wrap=False,
)
if "order_records" not in kwargs:
kwargs["order_records"] = Orders.column_stack_records_arrs(*[obj.orders for obj in _objs], **kwargs)
if "log_records" not in kwargs:
kwargs["log_records"] = Logs.column_stack_records_arrs(*[obj.logs for obj in _objs], **kwargs)
if "init_cash" not in kwargs:
stack_init_cash_objs = False
for obj in _objs:
if not checks.is_int(obj._init_cash) or obj._init_cash not in enums.InitCashMode:
stack_init_cash_objs = True
break
if stack_init_cash_objs:
kwargs["init_cash"] = to_1d_array(
kwargs["wrapper"].concat_arrs(
*[obj.get_init_cash(group_by=cs_group_by) for obj in _objs],
group_by=cs_group_by,
)
)
if "init_position" not in kwargs:
stack_init_position_objs = False
for obj in _objs:
if (to_1d_array(obj.init_position) != 0).any():
stack_init_position_objs = True
break
if stack_init_position_objs:
kwargs["init_position"] = to_1d_array(
kwargs["wrapper"].concat_arrs(
*[obj.init_position for obj in _objs],
group_by=False,
),
)
else:
kwargs["init_position"] = np.array([0.0])
if "init_price" not in kwargs:
stack_init_price_objs = False
for obj in _objs:
if not np.isnan(to_1d_array(obj.init_price)).all():
stack_init_price_objs = True
break
if stack_init_price_objs:
kwargs["init_price"] = to_1d_array(
kwargs["wrapper"].concat_arrs(
*[obj.init_price for obj in _objs],
group_by=False,
),
)
else:
kwargs["init_price"] = np.array([np.nan])
if "cash_deposits" not in kwargs:
stack_cash_deposits_objs = False
for obj in _objs:
if obj._cash_deposits.size > 1 or obj._cash_deposits.item() != 0:
stack_cash_deposits_objs = True
break
if stack_cash_deposits_objs:
kwargs["cash_deposits"] = kwargs["wrapper"].column_stack_arrs(
*[obj.get_cash_deposits(group_by=cs_group_by) for obj in _objs],
group_by=cs_group_by,
reindex_kwargs=dict(fill_value=0),
wrap=False,
)
else:
kwargs["cash_deposits"] = np.array([[0.0]])
if "cash_earnings" not in kwargs:
stack_cash_earnings_objs = False
for obj in _objs:
if obj._cash_earnings.size > 1 or obj._cash_earnings.item() != 0:
stack_cash_earnings_objs = True
break
if stack_cash_earnings_objs:
kwargs["cash_earnings"] = kwargs["wrapper"].column_stack_arrs(
*[obj.get_cash_earnings(group_by=False) for obj in _objs],
group_by=False,
reindex_kwargs=dict(fill_value=0),
wrap=False,
)
else:
kwargs["cash_earnings"] = np.array([[0.0]])
if "call_seq" not in kwargs:
stack_call_seq_objs = True
for obj in _objs:
if obj.config["call_seq"] is None:
stack_call_seq_objs = False
break
if stack_call_seq_objs:
kwargs["call_seq"] = kwargs["wrapper"].column_stack_arrs(
*[obj.call_seq for obj in _objs],
group_by=False,
reindex_kwargs=dict(fill_value=0),
wrap=False,
)
if "bm_close" not in kwargs:
stack_bm_close_objs = True
for obj in _objs:
if obj._bm_close is None or isinstance(obj._bm_close, bool):
stack_bm_close_objs = False
break
if stack_bm_close_objs:
new_bm_close = kwargs["wrapper"].column_stack_arrs(
*[obj.bm_close for obj in _objs],
group_by=False,
wrap=False,
)
if fbfill_close:
new_bm_close = new_bm_close.vbt.fbfill()
elif ffill_close:
new_bm_close = new_bm_close.vbt.ffill()
kwargs["bm_close"] = new_bm_close
if "in_outputs" not in kwargs:
kwargs["in_outputs"] = cls.column_stack_in_outputs(*_objs, **kwargs)
if "sim_start" not in kwargs:
kwargs["sim_start"] = cls.column_stack_sim_start(kwargs["wrapper"], *_objs)
if "sim_end" not in kwargs:
kwargs["sim_end"] = cls.column_stack_sim_end(kwargs["wrapper"], *_objs)
if "weights" not in kwargs:
stack_weights_objs = False
obj_weights = []
for obj in objs:
if obj.weights is not None:
stack_weights_objs = True
obj_weights.append(obj.weights)
else:
obj_weights.append([np.nan] * obj.wrapper.shape_2d[1])
if stack_weights_objs:
kwargs["weights"] = to_1d_array(
kwargs["wrapper"].concat_arrs(
*obj_weights,
group_by=False,
),
)
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
def __init__(
self,
wrapper: ArrayWrapper,
order_records: tp.Union[tp.RecordArray, enums.SimulationOutput],
*,
close: tp.ArrayLike,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
log_records: tp.Optional[tp.RecordArray] = None,
cash_sharing: bool = False,
init_cash: tp.Union[str, tp.ArrayLike] = "auto",
init_position: tp.ArrayLike = 0.0,
init_price: tp.ArrayLike = np.nan,
cash_deposits: tp.ArrayLike = 0.0,
cash_deposits_as_input: tp.Optional[bool] = None,
cash_earnings: tp.ArrayLike = 0.0,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
call_seq: tp.Optional[tp.Array2d] = None,
in_outputs: tp.Optional[tp.NamedTuple] = None,
use_in_outputs: tp.Optional[bool] = None,
bm_close: tp.Optional[tp.ArrayLike] = None,
fillna_close: tp.Optional[bool] = None,
year_freq: tp.Optional[tp.FrequencyLike] = None,
returns_acc_defaults: tp.KwargsLike = None,
trades_type: tp.Optional[tp.Union[str, int]] = None,
orders_cls: tp.Optional[type] = None,
logs_cls: tp.Optional[type] = None,
trades_cls: tp.Optional[type] = None,
entry_trades_cls: tp.Optional[type] = None,
exit_trades_cls: tp.Optional[type] = None,
positions_cls: tp.Optional[type] = None,
drawdowns_cls: tp.Optional[type] = None,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
**kwargs,
) -> None:
from vectorbtpro._settings import settings
portfolio_cfg = settings["portfolio"]
if cash_sharing:
if wrapper.grouper.allow_enable or wrapper.grouper.allow_modify:
wrapper = wrapper.replace(allow_enable=False, allow_modify=False)
if isinstance(order_records, enums.SimulationOutput):
sim_out = order_records
order_records = sim_out.order_records
log_records = sim_out.log_records
cash_deposits = sim_out.cash_deposits
cash_earnings = sim_out.cash_earnings
sim_start = sim_out.sim_start
sim_end = sim_out.sim_end
call_seq = sim_out.call_seq
in_outputs = sim_out.in_outputs
close = to_2d_array(close)
if open is not None:
open = to_2d_array(open)
if high is not None:
high = to_2d_array(high)
if low is not None:
low = to_2d_array(low)
if isinstance(init_cash, str):
init_cash = map_enum_fields(init_cash, enums.InitCashMode)
if not checks.is_int(init_cash) or init_cash not in enums.InitCashMode:
init_cash = to_1d_array(init_cash)
init_position = to_1d_array(init_position)
init_price = to_1d_array(init_price)
cash_deposits = to_2d_array(cash_deposits)
cash_earnings = to_2d_array(cash_earnings)
if cash_deposits_as_input is None:
cash_deposits_as_input = portfolio_cfg["cash_deposits_as_input"]
if bm_close is not None and not isinstance(bm_close, bool):
bm_close = to_2d_array(bm_close)
if log_records is None:
log_records = np.array([], dtype=enums.log_dt)
if use_in_outputs is None:
use_in_outputs = portfolio_cfg["use_in_outputs"]
if fillna_close is None:
fillna_close = portfolio_cfg["fillna_close"]
if weights is None:
weights = portfolio_cfg["weights"]
if trades_type is None:
trades_type = portfolio_cfg["trades_type"]
if isinstance(trades_type, str):
trades_type = map_enum_fields(trades_type, enums.TradesType)
Analyzable.__init__(
self,
wrapper,
order_records=order_records,
open=open,
high=high,
low=low,
close=close,
log_records=log_records,
cash_sharing=cash_sharing,
init_cash=init_cash,
init_position=init_position,
init_price=init_price,
cash_deposits=cash_deposits,
cash_deposits_as_input=cash_deposits_as_input,
cash_earnings=cash_earnings,
sim_start=sim_start,
sim_end=sim_end,
call_seq=call_seq,
in_outputs=in_outputs,
use_in_outputs=use_in_outputs,
bm_close=bm_close,
fillna_close=fillna_close,
year_freq=year_freq,
returns_acc_defaults=returns_acc_defaults,
trades_type=trades_type,
orders_cls=orders_cls,
logs_cls=logs_cls,
trades_cls=trades_cls,
entry_trades_cls=entry_trades_cls,
exit_trades_cls=exit_trades_cls,
positions_cls=positions_cls,
drawdowns_cls=drawdowns_cls,
weights=weights,
**kwargs,
)
SimRangeMixin.__init__(self, sim_start=sim_start, sim_end=sim_end)
self._open = open
self._high = high
self._low = low
self._close = close
self._order_records = order_records
self._log_records = log_records
self._cash_sharing = cash_sharing
self._init_cash = init_cash
self._init_position = init_position
self._init_price = init_price
self._cash_deposits = cash_deposits
self._cash_deposits_as_input = cash_deposits_as_input
self._cash_earnings = cash_earnings
self._call_seq = call_seq
self._in_outputs = in_outputs
self._use_in_outputs = use_in_outputs
self._bm_close = bm_close
self._fillna_close = fillna_close
self._year_freq = year_freq
self._returns_acc_defaults = returns_acc_defaults
self._trades_type = trades_type
self._orders_cls = orders_cls
self._logs_cls = logs_cls
self._trades_cls = trades_cls
self._entry_trades_cls = entry_trades_cls
self._exit_trades_cls = exit_trades_cls
self._positions_cls = positions_cls
self._drawdowns_cls = drawdowns_cls
self._weights = weights
# Only slices of rows can be selected
self._range_only_select = True
# Copy writeable attrs
self._in_output_config = type(self)._in_output_config.copy()
# ############# In-outputs ############# #
_in_output_config: tp.ClassVar[Config] = HybridConfig(
dict(
cash=dict(grouping="cash_sharing"),
position=dict(grouping="columns"),
debt=dict(grouping="columns"),
locked_cash=dict(grouping="columns"),
free_cash=dict(grouping="cash_sharing"),
returns=dict(grouping="cash_sharing"),
)
)
@property
def in_output_config(self) -> Config:
"""In-output config of `${cls_name}`.
```python
${in_output_config}
```
Returns `${cls_name}._in_output_config`, which gets (hybrid-) copied upon creation of each instance.
Thus, changing this config won't affect the class.
To change in_outputs, you can either change the config in-place, override this property,
or overwrite the instance variable `${cls_name}._in_output_config`.
"""
return self._in_output_config
@classmethod
def parse_field_options(cls, field: str) -> tp.Kwargs:
"""Parse options based on the name of a field.
Returns a dictionary with the parsed grouping, object type, and cleaned field name.
Grouping is parsed by looking for the following suffixes:
* '_cs': per group if grouped with cash sharing, otherwise per column
* '_pcg': per group if grouped, otherwise per column
* '_pg': per group
* '_pc': per column
* '_records': records
Object type is parsed by looking for the following suffixes:
* '_2d': element per timestamp and column or group (time series)
* '_1d': element per column or group (reduced time series)
Those substrings are then removed to produce a clean field name."""
options = dict()
new_parts = []
for part in field.split("_"):
if part == "1d":
options["obj_type"] = "red_array"
elif part == "2d":
options["obj_type"] = "array"
elif part == "records":
options["obj_type"] = "records"
elif part == "pc":
options["grouping"] = "columns"
elif part == "pg":
options["grouping"] = "groups"
elif part == "pcg":
options["grouping"] = "columns_or_groups"
elif part == "cs":
options["grouping"] = "cash_sharing"
else:
new_parts.append(part)
field = "_".join(new_parts)
options["field"] = field
return options
def matches_field_options(
self,
options: tp.Kwargs,
obj_type: tp.Optional[str] = None,
group_by_aware: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
) -> bool:
"""Return whether options of a field match the requirements.
Requirements include the type of the object (array, reduced array, records),
the grouping of the object (1/2 dimensions, group/column-wise layout). The current
grouping and cash sharing of this portfolio object are also taken into account.
When an option is not in `options`, it's automatically marked as matching."""
field_obj_type = options.get("obj_type", None)
field_grouping = options.get("grouping", None)
if field_obj_type is not None and obj_type is not None:
if field_obj_type != obj_type:
return False
if field_grouping is not None:
if wrapper is None:
wrapper = self.wrapper
is_grouped = wrapper.grouper.is_grouped(group_by=group_by)
if is_grouped:
if group_by_aware:
if field_grouping == "groups":
return True
if field_grouping == "columns_or_groups":
return True
if self.cash_sharing:
if field_grouping == "cash_sharing":
return True
else:
if field_grouping == "columns":
return True
if not self.cash_sharing:
if field_grouping == "cash_sharing":
return True
else:
if field_grouping == "columns":
return True
if field_grouping == "columns_or_groups":
return True
if field_grouping == "cash_sharing":
return True
return False
return True
def wrap_obj(
self,
obj: tp.Any,
obj_name: tp.Optional[str] = None,
grouping: str = "columns_or_groups",
obj_type: tp.Optional[str] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_func: tp.Optional[tp.Callable] = None,
wrap_kwargs: tp.KwargsLike = None,
force_wrapping: bool = False,
silence_warnings: bool = False,
**kwargs,
) -> tp.Any:
"""Wrap an object.
`wrap_func` must take the portfolio, `obj`, all the arguments passed to this method, and `**kwargs`.
If you don't need any of the arguments, make `indexing_func` accept them as `**kwargs`.
If the object is None or boolean, returns as-is."""
if obj is None or isinstance(obj, bool):
return obj
if wrapper is None:
wrapper = self.wrapper
is_grouped = wrapper.grouper.is_grouped(group_by=group_by)
if wrap_func is not None:
return wrap_func(
self,
obj,
obj_name=obj_name,
grouping=grouping,
obj_type=obj_type,
wrapper=wrapper,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
force_wrapping=force_wrapping,
silence_warnings=silence_warnings,
**kwargs,
)
def _wrap_reduced_grouped(obj):
_wrap_kwargs = merge_dicts(dict(name_or_index=obj_name), wrap_kwargs)
return wrapper.wrap_reduced(obj, group_by=group_by, **_wrap_kwargs)
def _wrap_reduced(obj):
_wrap_kwargs = merge_dicts(dict(name_or_index=obj_name), wrap_kwargs)
return wrapper.wrap_reduced(obj, group_by=False, **_wrap_kwargs)
def _wrap_grouped(obj):
return wrapper.wrap(obj, group_by=group_by, **resolve_dict(wrap_kwargs))
def _wrap(obj):
return wrapper.wrap(obj, group_by=False, **resolve_dict(wrap_kwargs))
if obj_type is not None and obj_type not in {"records"}:
if grouping == "cash_sharing":
if obj_type == "array":
if is_grouped and self.cash_sharing:
return _wrap_grouped(obj)
return _wrap(obj)
if obj_type == "red_array":
if is_grouped and self.cash_sharing:
return _wrap_reduced_grouped(obj)
return _wrap_reduced(obj)
if obj.ndim == 2:
if is_grouped and self.cash_sharing:
return _wrap_grouped(obj)
return _wrap(obj)
if obj.ndim == 1:
if is_grouped and self.cash_sharing:
return _wrap_reduced_grouped(obj)
return _wrap_reduced(obj)
if grouping == "columns_or_groups":
if obj_type == "array":
if is_grouped:
return _wrap_grouped(obj)
return _wrap(obj)
if obj_type == "red_array":
if is_grouped:
return _wrap_reduced_grouped(obj)
return _wrap_reduced(obj)
if obj.ndim == 2:
if is_grouped:
return _wrap_grouped(obj)
return _wrap(obj)
if obj.ndim == 1:
if is_grouped:
return _wrap_reduced_grouped(obj)
return _wrap_reduced(obj)
if grouping == "groups":
if obj_type == "array":
return _wrap_grouped(obj)
if obj_type == "red_array":
return _wrap_reduced_grouped(obj)
if obj.ndim == 2:
return _wrap_grouped(obj)
if obj.ndim == 1:
return _wrap_reduced_grouped(obj)
if grouping == "columns":
if obj_type == "array":
return _wrap(obj)
if obj_type == "red_array":
return _wrap_reduced(obj)
if obj.ndim == 2:
return _wrap(obj)
if obj.ndim == 1:
return _wrap_reduced(obj)
if obj_type not in {"records"}:
if checks.is_np_array(obj) and not checks.is_record_array(obj):
if is_grouped:
if obj_type is not None and obj_type == "array":
return _wrap_grouped(obj)
if obj_type is not None and obj_type == "red_array":
return _wrap_reduced_grouped(obj)
if to_2d_shape(obj.shape) == wrapper.get_shape_2d():
return _wrap_grouped(obj)
if obj.shape == (wrapper.get_shape_2d()[1],):
return _wrap_reduced_grouped(obj)
if obj_type is not None and obj_type == "array":
return _wrap(obj)
if obj_type is not None and obj_type == "red_array":
return _wrap_reduced(obj)
if to_2d_shape(obj.shape) == wrapper.shape_2d:
return _wrap(obj)
if obj.shape == (wrapper.shape_2d[1],):
return _wrap_reduced(obj)
if force_wrapping:
raise NotImplementedError(f"Cannot wrap object '{obj_name}'")
if not silence_warnings:
warn(f"Cannot figure out how to wrap object '{obj_name}'")
return obj
def get_in_output(
self,
field: str,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> tp.Union[None, bool, tp.AnyArray]:
"""Find and wrap an in-output object matching the field.
If the field can be found in the attributes of this `Portfolio` instance, reads the
attribute's options to get requirements for the type and layout of the in-output object.
For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with
`Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`.
If `field` is not in `Portfolio.in_outputs`, searches for the field in aliases and options.
In such case, to narrow down the number of candidates, options are additionally matched against
the requirements using `Portfolio.matches_field_options`. Finally, the matched in-output object is
wrapped using `Portfolio.wrap_obj`."""
if self.in_outputs is None:
raise ValueError("No in-outputs attached")
if field in self.cls_dir:
prop = getattr(type(self), field)
prop_options = getattr(prop, "options", {})
obj_type = prop_options.get("obj_type", "array")
group_by_aware = prop_options.get("group_by_aware", True)
wrap_func = prop_options.get("wrap_func", None)
wrap_kwargs = prop_options.get("wrap_kwargs", None)
force_wrapping = prop_options.get("force_wrapping", False)
silence_warnings = prop_options.get("silence_warnings", False)
field_aliases = prop_options.get("field_aliases", None)
if field_aliases is None:
field_aliases = []
field_aliases = {field, *field_aliases}
found_attr = True
else:
obj_type = None
group_by_aware = True
wrap_func = None
wrap_kwargs = None
force_wrapping = False
silence_warnings = False
field_aliases = {field}
found_attr = False
found_field = None
found_field_options = None
for _field in set(self.in_outputs._fields):
_field_options = merge_dicts(
self.parse_field_options(_field),
self.in_output_config.get(_field, None),
)
if (not found_attr and field == _field) or (
(_field in field_aliases or _field_options.get("field", _field) in field_aliases)
and self.matches_field_options(
_field_options,
obj_type=obj_type,
group_by_aware=group_by_aware,
wrapper=wrapper,
group_by=group_by,
)
):
if found_field is not None:
raise ValueError(f"Multiple fields for '{field}' found in in_outputs")
found_field = _field
found_field_options = _field_options
if found_field is None:
raise AttributeError(f"No compatible field for '{field}' found in in_outputs")
obj = getattr(self.in_outputs, found_field)
if found_attr and checks.is_np_array(obj) and obj.shape == (0, 0): # for returns
return None
kwargs = merge_dicts(
dict(
grouping=found_field_options.get("grouping", "columns_or_groups" if group_by_aware else "columns"),
obj_type=found_field_options.get("obj_type", obj_type),
wrap_func=found_field_options.get("wrap_func", wrap_func),
wrap_kwargs=found_field_options.get("wrap_kwargs", wrap_kwargs),
force_wrapping=found_field_options.get("force_wrapping", force_wrapping),
silence_warnings=found_field_options.get("silence_warnings", silence_warnings),
),
kwargs,
)
return self.wrap_obj(
obj,
found_field_options.get("field", found_field),
wrapper=wrapper,
group_by=group_by,
**kwargs,
)
# ############# Indexing ############# #
def index_obj(
self,
obj: tp.Any,
wrapper_meta: dict,
obj_name: tp.Optional[str] = None,
grouping: str = "columns_or_groups",
obj_type: tp.Optional[str] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
indexing_func: tp.Optional[tp.Callable] = None,
force_indexing: bool = False,
silence_warnings: bool = False,
**kwargs,
) -> tp.Any:
"""Perform indexing on an object.
`indexing_func` must take the portfolio, all the arguments passed to this method, and `**kwargs`.
If you don't need any of the arguments, make `indexing_func` accept them as `**kwargs`.
If the object is None, boolean, or empty, returns as-is."""
if obj is None or isinstance(obj, bool) or (checks.is_np_array(obj) and obj.size == 0):
return obj
if wrapper is None:
wrapper = self.wrapper
if indexing_func is not None:
return indexing_func(
self,
obj,
wrapper_meta,
obj_name=obj_name,
grouping=grouping,
obj_type=obj_type,
wrapper=wrapper,
group_by=group_by,
force_indexing=force_indexing,
silence_warnings=silence_warnings,
**kwargs,
)
def _index_1d_by_group(obj: tp.ArrayLike) -> tp.ArrayLike:
return to_1d_array(obj)[wrapper_meta["group_idxs"]]
def _index_1d_by_col(obj: tp.ArrayLike) -> tp.ArrayLike:
return to_1d_array(obj)[wrapper_meta["col_idxs"]]
def _index_2d_by_group(obj: tp.ArrayLike) -> tp.ArrayLike:
return to_2d_array(obj)[wrapper_meta["row_idxs"], :][:, wrapper_meta["group_idxs"]]
def _index_2d_by_col(obj: tp.ArrayLike) -> tp.ArrayLike:
return to_2d_array(obj)[wrapper_meta["row_idxs"], :][:, wrapper_meta["col_idxs"]]
def _index_records(obj: tp.RecordArray) -> tp.RecordArray:
records = Records(wrapper, obj)
records_meta = records.indexing_func_meta(wrapper_meta=wrapper_meta)
return records.indexing_func(records_meta=records_meta).values
is_grouped = wrapper.grouper.is_grouped(group_by=group_by)
if obj_type is not None and obj_type == "records":
return _index_records(obj)
if grouping == "cash_sharing":
if obj_type is not None and obj_type == "array":
if is_grouped and self.cash_sharing:
return _index_2d_by_group(obj)
return _index_2d_by_col(obj)
if obj_type is not None and obj_type == "red_array":
if is_grouped and self.cash_sharing:
return _index_1d_by_group(obj)
return _index_1d_by_col(obj)
if obj.ndim == 2:
if is_grouped and self.cash_sharing:
return _index_2d_by_group(obj)
return _index_2d_by_col(obj)
if obj.ndim == 1:
if is_grouped and self.cash_sharing:
return _index_1d_by_group(obj)
return _index_1d_by_col(obj)
if grouping == "columns_or_groups":
if obj_type is not None and obj_type == "array":
if is_grouped:
return _index_2d_by_group(obj)
return _index_2d_by_col(obj)
if obj_type is not None and obj_type == "red_array":
if is_grouped:
return _index_1d_by_group(obj)
return _index_1d_by_col(obj)
if obj.ndim == 2:
if is_grouped:
return _index_2d_by_group(obj)
return _index_2d_by_col(obj)
if obj.ndim == 1:
if is_grouped:
return _index_1d_by_group(obj)
return _index_1d_by_col(obj)
if grouping == "groups":
if obj_type is not None and obj_type == "array":
return _index_2d_by_group(obj)
if obj_type is not None and obj_type == "red_array":
return _index_1d_by_group(obj)
if obj.ndim == 2:
return _index_2d_by_group(obj)
if obj.ndim == 1:
return _index_1d_by_group(obj)
if grouping == "columns":
if obj_type is not None and obj_type == "array":
return _index_2d_by_col(obj)
if obj_type is not None and obj_type == "red_array":
return _index_1d_by_col(obj)
if obj.ndim == 2:
return _index_2d_by_col(obj)
if obj.ndim == 1:
return _index_1d_by_col(obj)
if checks.is_np_array(obj):
if is_grouped:
if obj_type is not None and obj_type == "array":
return _index_2d_by_group(obj)
if obj_type is not None and obj_type == "red_array":
return _index_1d_by_group(obj)
if to_2d_shape(obj.shape) == wrapper.get_shape_2d():
return _index_2d_by_group(obj)
if obj.shape == (wrapper.get_shape_2d()[1],):
return _index_1d_by_group(obj)
if obj_type is not None and obj_type == "array":
return _index_2d_by_col(obj)
if obj_type is not None and obj_type == "red_array":
return _index_1d_by_col(obj)
if to_2d_shape(obj.shape) == wrapper.shape_2d:
return _index_2d_by_col(obj)
if obj.shape == (wrapper.shape_2d[1],):
return _index_1d_by_col(obj)
if force_indexing:
raise NotImplementedError(f"Cannot index object '{obj_name}'")
if not silence_warnings:
warn(f"Cannot figure out how to index object '{obj_name}'")
return obj
def in_outputs_indexing_func(self, wrapper_meta: dict, **kwargs) -> tp.Optional[tp.NamedTuple]:
"""Perform indexing on `Portfolio.in_outputs`.
If the field can be found in the attributes of this `Portfolio` instance, reads the
attribute's options to get requirements for the type and layout of the in-output object.
For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with
`Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`.
Performs indexing on the in-output object using `Portfolio.index_obj`."""
if self.in_outputs is None:
return None
new_in_outputs = {}
for field, obj in self.in_outputs._asdict().items():
field_options = merge_dicts(
self.parse_field_options(field),
self.in_output_config.get(field, None),
)
if field_options.get("field", field) in self.cls_dir:
prop = getattr(type(self), field_options["field"])
prop_options = getattr(prop, "options", {})
obj_type = prop_options.get("obj_type", "array")
group_by_aware = prop_options.get("group_by_aware", True)
indexing_func = prop_options.get("indexing_func", None)
force_indexing = prop_options.get("force_indexing", False)
silence_warnings = prop_options.get("silence_warnings", False)
else:
obj_type = None
group_by_aware = True
indexing_func = None
force_indexing = False
silence_warnings = False
_kwargs = merge_dicts(
dict(
grouping=field_options.get("grouping", "columns_or_groups" if group_by_aware else "columns"),
obj_name=field_options.get("field", field),
obj_type=field_options.get("obj_type", obj_type),
indexing_func=field_options.get("indexing_func", indexing_func),
force_indexing=field_options.get("force_indexing", force_indexing),
silence_warnings=field_options.get("silence_warnings", silence_warnings),
),
kwargs,
)
new_obj = self.index_obj(obj, wrapper_meta, **_kwargs)
new_in_outputs[field] = new_obj
return type(self.in_outputs)(**new_in_outputs)
def indexing_func(
self: PortfolioT,
*args,
in_output_kwargs: tp.KwargsLike = None,
wrapper_meta: tp.DictLike = None,
**kwargs,
) -> PortfolioT:
"""Perform indexing on `Portfolio`.
In-outputs are indexed using `Portfolio.in_outputs_indexing_func`."""
_self = self.disable_weights()
if wrapper_meta is None:
wrapper_meta = _self.wrapper.indexing_func_meta(
*args,
column_only_select=_self.column_only_select,
range_only_select=_self.range_only_select,
group_select=_self.group_select,
**kwargs,
)
new_wrapper = wrapper_meta["new_wrapper"]
row_idxs = wrapper_meta["row_idxs"]
rows_changed = wrapper_meta["rows_changed"]
col_idxs = wrapper_meta["col_idxs"]
columns_changed = wrapper_meta["columns_changed"]
group_idxs = wrapper_meta["group_idxs"]
new_close = ArrayWrapper.select_from_flex_array(
_self._close,
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
if _self._open is not None:
new_open = ArrayWrapper.select_from_flex_array(
_self._open,
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
else:
new_open = _self._open
if _self._high is not None:
new_high = ArrayWrapper.select_from_flex_array(
_self._high,
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
else:
new_high = _self._high
if _self._low is not None:
new_low = ArrayWrapper.select_from_flex_array(
_self._low,
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
else:
new_low = _self._low
new_order_records = _self.orders.indexing_func_meta(wrapper_meta=wrapper_meta)["new_records_arr"]
new_log_records = _self.logs.indexing_func_meta(wrapper_meta=wrapper_meta)["new_records_arr"]
new_init_cash = _self._init_cash
if not checks.is_int(new_init_cash):
new_init_cash = to_1d_array(new_init_cash)
if rows_changed and row_idxs.start > 0:
if _self.wrapper.grouper.is_grouped() and not _self.cash_sharing:
cash = _self.get_cash(group_by=False)
else:
cash = _self.cash
new_init_cash = to_1d_array(cash.iloc[row_idxs.start - 1])
if columns_changed and new_init_cash.shape[0] > 1:
if _self.cash_sharing:
new_init_cash = new_init_cash[group_idxs]
else:
new_init_cash = new_init_cash[col_idxs]
new_init_position = to_1d_array(_self._init_position)
if rows_changed and row_idxs.start > 0:
new_init_position = to_1d_array(_self.assets.iloc[row_idxs.start - 1])
if columns_changed and new_init_position.shape[0] > 1:
new_init_position = new_init_position[col_idxs]
new_init_price = to_1d_array(_self._init_price)
if rows_changed and row_idxs.start > 0:
new_init_price = to_1d_array(_self.close.iloc[: row_idxs.start].ffill().iloc[-1])
if columns_changed and new_init_price.shape[0] > 1:
new_init_price = new_init_price[col_idxs]
new_cash_deposits = ArrayWrapper.select_from_flex_array(
_self._cash_deposits,
row_idxs=row_idxs,
col_idxs=group_idxs if _self.cash_sharing else col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
new_cash_earnings = ArrayWrapper.select_from_flex_array(
_self._cash_earnings,
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
if _self._call_seq is not None:
new_call_seq = ArrayWrapper.select_from_flex_array(
_self._call_seq,
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
else:
new_call_seq = None
if _self._bm_close is not None and not isinstance(_self._bm_close, bool):
new_bm_close = ArrayWrapper.select_from_flex_array(
_self._bm_close,
row_idxs=row_idxs,
col_idxs=col_idxs,
rows_changed=rows_changed,
columns_changed=columns_changed,
)
else:
new_bm_close = _self._bm_close
new_in_outputs = _self.in_outputs_indexing_func(wrapper_meta, **resolve_dict(in_output_kwargs))
new_sim_start = _self.sim_start_indexing_func(wrapper_meta)
new_sim_end = _self.sim_end_indexing_func(wrapper_meta)
if self.weights is not None:
new_weights = to_1d_array(self.weights)
if columns_changed and new_weights.shape[0] > 1:
new_weights = new_weights[col_idxs]
else:
new_weights = self._weights
return self.replace(
wrapper=new_wrapper,
order_records=new_order_records,
open=new_open,
high=new_high,
low=new_low,
close=new_close,
log_records=new_log_records,
init_cash=new_init_cash,
init_position=new_init_position,
init_price=new_init_price,
cash_deposits=new_cash_deposits,
cash_earnings=new_cash_earnings,
call_seq=new_call_seq,
in_outputs=new_in_outputs,
bm_close=new_bm_close,
sim_start=new_sim_start,
sim_end=new_sim_end,
weights=new_weights,
)
# ############# Resampling ############# #
def resample_obj(
self,
obj: tp.Any,
resampler: tp.Union[Resampler, tp.PandasResampler],
obj_name: tp.Optional[str] = None,
obj_type: tp.Optional[str] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
resample_func: tp.Union[None, str, tp.Callable] = None,
resample_kwargs: tp.KwargsLike = None,
force_resampling: bool = False,
silence_warnings: bool = False,
**kwargs,
) -> tp.Any:
"""Resample an object.
`resample_func` must take the portfolio, `obj`, `resampler`, all the arguments passed to this method,
and `**kwargs`. If you don't need any of the arguments, make `resample_func` accept them as `**kwargs`.
If `resample_func` is a string, will use it as `reduce_func_nb` in
`vectorbtpro.generic.accessors.GenericAccessor.resample_apply`. Default is 'last'.
If the object is None, boolean, or empty, returns as-is."""
if obj is None or isinstance(obj, bool) or (checks.is_np_array(obj) and obj.size == 0):
return obj
if wrapper is None:
wrapper = self.wrapper
if resample_func is None:
resample_func = "last"
if not isinstance(resample_func, str):
return resample_func(
self,
obj,
resampler,
obj_name=obj_name,
obj_type=obj_type,
wrapper=wrapper,
group_by=group_by,
resample_kwargs=resample_kwargs,
force_resampling=force_resampling,
silence_warnings=silence_warnings,
**kwargs,
)
def _resample(obj: tp.Array) -> tp.SeriesFrame:
wrapped_obj = ArrayWrapper.from_obj(obj, index=wrapper.index).wrap(obj)
return wrapped_obj.vbt.resample_apply(resampler, resample_func, **resolve_dict(resample_kwargs)).values
if obj_type is not None and obj_type == "red_array":
return obj
if obj_type is None or obj_type == "array":
is_grouped = wrapper.grouper.is_grouped(group_by=group_by)
if checks.is_np_array(obj):
if is_grouped:
if to_2d_shape(obj.shape) == wrapper.get_shape_2d():
return _resample(obj)
if obj.shape == (wrapper.get_shape_2d()[1],):
return obj
if to_2d_shape(obj.shape) == wrapper.shape_2d:
return _resample(obj)
if obj.shape == (wrapper.shape_2d[1],):
return obj
if force_resampling:
raise NotImplementedError(f"Cannot resample object '{obj_name}'")
if not silence_warnings:
warn(f"Cannot figure out how to resample object '{obj_name}'")
return obj
def resample_in_outputs(
self,
resampler: tp.Union[Resampler, tp.PandasResampler],
**kwargs,
) -> tp.Optional[tp.NamedTuple]:
"""Resample `Portfolio.in_outputs`.
If the field can be found in the attributes of this `Portfolio` instance, reads the
attribute's options to get requirements for the type and layout of the in-output object.
For each field in `Portfolio.in_outputs`, resolves the field's options by parsing its name with
`Portfolio.parse_field_options` and also looks for options in `Portfolio.in_output_config`.
Performs indexing on the in-output object using `Portfolio.resample_obj`."""
if self.in_outputs is None:
return None
new_in_outputs = {}
for field, obj in self.in_outputs._asdict().items():
field_options = merge_dicts(
self.parse_field_options(field),
self.in_output_config.get(field, None),
)
if field_options.get("field", field) in self.cls_dir:
prop = getattr(type(self), field_options["field"])
prop_options = getattr(prop, "options", {})
obj_type = prop_options.get("obj_type", "array")
resample_func = prop_options.get("resample_func", None)
resample_kwargs = prop_options.get("resample_kwargs", None)
force_resampling = prop_options.get("force_resampling", False)
silence_warnings = prop_options.get("silence_warnings", False)
else:
obj_type = None
resample_func = None
resample_kwargs = None
force_resampling = False
silence_warnings = False
_kwargs = merge_dicts(
dict(
obj_name=field_options.get("field", field),
obj_type=field_options.get("obj_type", obj_type),
resample_func=field_options.get("resample_func", resample_func),
resample_kwargs=field_options.get("resample_kwargs", resample_kwargs),
force_resampling=field_options.get("force_resampling", force_resampling),
silence_warnings=field_options.get("silence_warnings", silence_warnings),
),
kwargs,
)
new_obj = self.resample_obj(obj, resampler, **_kwargs)
new_in_outputs[field] = new_obj
return type(self.in_outputs)(**new_in_outputs)
def resample(
self: PortfolioT,
*args,
ffill_close: bool = False,
fbfill_close: bool = False,
in_output_kwargs: tp.KwargsLike = None,
wrapper_meta: tp.DictLike = None,
**kwargs,
) -> PortfolioT:
"""Resample the `Portfolio` instance.
!!! warning
Downsampling is associated with information loss:
* Cash deposits and earnings are assumed to be added/removed at the beginning of each time step.
Imagine depositing $100 and using them up in the same bar, and then depositing another $100
and using them up. Downsampling both bars into a single bar will aggregate cash deposits
and earnings, and put both of them at the beginning of the new bar, even though the second
deposit was added later in time.
* Market/benchmark returns are computed by applying the initial value on the close price
of the first bar and by tracking the price change to simulate holding. Moving the close
price of the first bar further into the future will affect this computation and almost
certainly produce a different market value and returns. To mitigate this, make sure
to downsample to an index with the first bar containing only the first bar from the
origin timeframe."""
_self = self.disable_weights()
if _self._call_seq is not None:
raise ValueError("Cannot resample call_seq")
if wrapper_meta is None:
wrapper_meta = _self.wrapper.resample_meta(*args, **kwargs)
resampler = wrapper_meta["resampler"]
new_wrapper = wrapper_meta["new_wrapper"]
new_close = _self.close.vbt.resample_apply(resampler, "last")
if fbfill_close:
new_close = new_close.vbt.fbfill()
elif ffill_close:
new_close = new_close.vbt.ffill()
new_close = new_close.values
if _self._open is not None:
new_open = _self.open.vbt.resample_apply(resampler, "first").values
else:
new_open = _self._open
if _self._high is not None:
new_high = _self.high.vbt.resample_apply(resampler, "max").values
else:
new_high = _self._high
if _self._low is not None:
new_low = _self.low.vbt.resample_apply(resampler, "min").values
else:
new_low = _self._low
new_order_records = _self.orders.resample_records_arr(resampler)
new_log_records = _self.logs.resample_records_arr(resampler)
if _self._cash_deposits.size > 1 or _self._cash_deposits.item() != 0:
new_cash_deposits = _self.get_cash_deposits(group_by=None if _self.cash_sharing else False)
new_cash_deposits = new_cash_deposits.vbt.resample_apply(resampler, generic_nb.sum_reduce_nb)
new_cash_deposits = new_cash_deposits.fillna(0.0)
new_cash_deposits = new_cash_deposits.values
else:
new_cash_deposits = _self._cash_deposits
if _self._cash_earnings.size > 1 or _self._cash_earnings.item() != 0:
new_cash_earnings = _self.get_cash_earnings(group_by=False)
new_cash_earnings = new_cash_earnings.vbt.resample_apply(resampler, generic_nb.sum_reduce_nb)
new_cash_earnings = new_cash_earnings.fillna(0.0)
new_cash_earnings = new_cash_earnings.values
else:
new_cash_earnings = _self._cash_earnings
if _self._bm_close is not None and not isinstance(_self._bm_close, bool):
new_bm_close = _self.bm_close.vbt.resample_apply(resampler, "last")
if fbfill_close:
new_bm_close = new_bm_close.vbt.fbfill()
elif ffill_close:
new_bm_close = new_bm_close.vbt.ffill()
new_bm_close = new_bm_close.values
else:
new_bm_close = _self._bm_close
if _self._in_outputs is not None:
new_in_outputs = _self.resample_in_outputs(resampler, **resolve_dict(in_output_kwargs))
else:
new_in_outputs = None
new_sim_start = _self.resample_sim_start(new_wrapper)
new_sim_end = _self.resample_sim_end(new_wrapper)
return self.replace(
wrapper=new_wrapper,
order_records=new_order_records,
open=new_open,
high=new_high,
low=new_low,
close=new_close,
log_records=new_log_records,
cash_deposits=new_cash_deposits,
cash_earnings=new_cash_earnings,
in_outputs=new_in_outputs,
bm_close=new_bm_close,
sim_start=new_sim_start,
sim_end=new_sim_end,
)
# ############# Class methods ############# #
@classmethod
def from_orders(
cls: tp.Type[PortfolioT],
close: tp.Union[tp.ArrayLike, OHLCDataMixin, FOPreparer, PFPrepResult],
size: tp.Optional[tp.ArrayLike] = None,
size_type: tp.Optional[tp.ArrayLike] = None,
direction: tp.Optional[tp.ArrayLike] = None,
price: tp.Optional[tp.ArrayLike] = None,
fees: tp.Optional[tp.ArrayLike] = None,
fixed_fees: tp.Optional[tp.ArrayLike] = None,
slippage: tp.Optional[tp.ArrayLike] = None,
min_size: tp.Optional[tp.ArrayLike] = None,
max_size: tp.Optional[tp.ArrayLike] = None,
size_granularity: tp.Optional[tp.ArrayLike] = None,
leverage: tp.Optional[tp.ArrayLike] = None,
leverage_mode: tp.Optional[tp.ArrayLike] = None,
reject_prob: tp.Optional[tp.ArrayLike] = None,
price_area_vio_mode: tp.Optional[tp.ArrayLike] = None,
allow_partial: tp.Optional[tp.ArrayLike] = None,
raise_reject: tp.Optional[tp.ArrayLike] = None,
log: tp.Optional[tp.ArrayLike] = None,
val_price: tp.Optional[tp.ArrayLike] = None,
from_ago: tp.Optional[tp.ArrayLike] = None,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
init_cash: tp.Optional[tp.ArrayLike] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
cash_earnings: tp.Optional[tp.ArrayLike] = None,
cash_dividends: tp.Optional[tp.ArrayLike] = None,
cash_sharing: tp.Optional[bool] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
call_seq: tp.Optional[tp.ArrayLike] = None,
attach_call_seq: tp.Optional[bool] = None,
ffill_val_price: tp.Optional[bool] = None,
update_value: tp.Optional[bool] = None,
save_state: tp.Optional[bool] = None,
save_value: tp.Optional[bool] = None,
save_returns: tp.Optional[bool] = None,
skip_empty: tp.Optional[bool] = None,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = None,
seed: tp.Optional[int] = None,
group_by: tp.GroupByLike = None,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
bm_close: tp.Optional[tp.ArrayLike] = None,
records: tp.Optional[tp.RecordsLike] = None,
return_preparer: bool = False,
return_prep_result: bool = False,
return_sim_out: bool = False,
**kwargs,
) -> PortfolioResultT:
"""Simulate portfolio from orders - size, price, fees, and other information.
See `vectorbtpro.portfolio.nb.from_orders.from_orders_nb`.
Prepared by `vectorbtpro.portfolio.preparing.FOPreparer`.
Args:
close (array_like, OHLCDataMixin, FOPreparer, or PFPrepResult): Latest asset price at each time step.
Will broadcast.
Used for calculating unrealized PnL and portfolio value.
If an instance of `vectorbtpro.data.base.OHLCDataMixin`, will extract the open, high, low, and close price.
If an instance of `vectorbtpro.portfolio.preparing.FOPreparer`, will use it as a preparer.
If an instance of `vectorbtpro.portfolio.preparing.PFPrepResult`, will use it as a preparer result.
size (float or array_like): Size to order.
See `vectorbtpro.portfolio.enums.Order.size`. Will broadcast.
size_type (SizeType or array_like): See `vectorbtpro.portfolio.enums.SizeType` and
`vectorbtpro.portfolio.enums.Order.size_type`. Will broadcast.
direction (Direction or array_like): See `vectorbtpro.portfolio.enums.Direction` and
`vectorbtpro.portfolio.enums.Order.direction`. Will broadcast.
price (array_like of float): Order price.
Will broadcast.
See `vectorbtpro.portfolio.enums.Order.price`. Can be also provided as
`vectorbtpro.portfolio.enums.PriceType`. Options `PriceType.NextOpen` and `PriceType.NextClose`
are only applicable per column, that is, they cannot be used inside full arrays.
In addition, they require the argument `from_ago` to be None.
fees (float or array_like): Fees in percentage of the order value.
See `vectorbtpro.portfolio.enums.Order.fees`. Will broadcast.
fixed_fees (float or array_like): Fixed amount of fees to pay per order.
See `vectorbtpro.portfolio.enums.Order.fixed_fees`. Will broadcast.
slippage (float or array_like): Slippage in percentage of price.
See `vectorbtpro.portfolio.enums.Order.slippage`. Will broadcast.
min_size (float or array_like): Minimum size for an order to be accepted.
See `vectorbtpro.portfolio.enums.Order.min_size`. Will broadcast.
max_size (float or array_like): Maximum size for an order.
See `vectorbtpro.portfolio.enums.Order.max_size`. Will broadcast.
Will be partially filled if exceeded.
size_granularity (float or array_like): Granularity of the size.
See `vectorbtpro.portfolio.enums.Order.size_granularity`. Will broadcast.
leverage (float or array_like): Leverage.
See `vectorbtpro.portfolio.enums.Order.leverage`. Will broadcast.
leverage_mode (LeverageMode or array_like): Leverage mode.
See `vectorbtpro.portfolio.enums.Order.leverage_mode`. Will broadcast.
reject_prob (float or array_like): Order rejection probability.
See `vectorbtpro.portfolio.enums.Order.reject_prob`. Will broadcast.
price_area_vio_mode (PriceAreaVioMode or array_like): See `vectorbtpro.portfolio.enums.PriceAreaVioMode`.
Will broadcast.
allow_partial (bool or array_like): Whether to allow partial fills.
See `vectorbtpro.portfolio.enums.Order.allow_partial`. Will broadcast.
Does not apply when size is `np.inf`.
raise_reject (bool or array_like): Whether to raise an exception if order gets rejected.
See `vectorbtpro.portfolio.enums.Order.raise_reject`. Will broadcast.
log (bool or array_like): Whether to log orders.
See `vectorbtpro.portfolio.enums.Order.log`. Will broadcast.
val_price (array_like of float): Asset valuation price.
Will broadcast.
Can be also provided as `vectorbtpro.portfolio.enums.ValPriceType`.
* Any `-np.inf` element is replaced by the latest valuation price
(`open` or the latest known valuation price if `ffill_val_price`).
* Any `np.inf` element is replaced by the current order price.
Used at the time of decision making to calculate value of each asset in the group,
for example, to convert target value into target amount.
!!! note
In contrast to `Portfolio.from_order_func`, order price is known beforehand (kind of),
thus `val_price` is set to the current order price (using `np.inf`) by default.
To valuate using previous close, set it in the settings to `-np.inf`.
!!! note
Make sure to use timestamp for `val_price` that comes before timestamps of
all orders in the group with cash sharing (previous `close` for example),
otherwise you're cheating yourself.
open (array_like of float): First asset price at each time step.
Defaults to `np.nan`. Will broadcast.
Used as a price boundary (see `vectorbtpro.portfolio.enums.PriceArea`).
high (array_like of float): Highest asset price at each time step.
Defaults to `np.nan`. Will broadcast.
Used as a price boundary (see `vectorbtpro.portfolio.enums.PriceArea`).
low (array_like of float): Lowest asset price at each time step.
Defaults to `np.nan`. Will broadcast.
Used as a price boundary (see `vectorbtpro.portfolio.enums.PriceArea`).
init_cash (InitCashMode, float or array_like): Initial capital.
By default, will broadcast to the final number of columns.
But if cash sharing is enabled, will broadcast to the number of groups.
See `vectorbtpro.portfolio.enums.InitCashMode` to find optimal initial cash.
!!! note
Mode `InitCashMode.AutoAlign` is applied after the portfolio is initialized
to set the same initial cash for all columns/groups. Changing grouping
will change the initial cash, so be aware when indexing.
init_position (float or array_like): Initial position.
By default, will broadcast to the final number of columns.
init_price (float or array_like): Initial position price.
By default, will broadcast to the final number of columns.
cash_deposits (float or array_like): Cash to be deposited/withdrawn at each timestamp.
Will broadcast to the final shape. Must have the same number of columns as `init_cash`.
Applied at the beginning of each timestamp.
cash_earnings (float or array_like): Earnings in cash to be added at each timestamp.
Will broadcast to the final shape.
Applied at the end of each timestamp.
cash_dividends (float or array_like): Dividends in cash to be added at each timestamp.
Will broadcast to the final shape.
Gets multiplied by the position and saved into `cash_earnings`.
Applied at the end of each timestamp.
cash_sharing (bool): Whether to share cash within the same group.
If `group_by` is None and `cash_sharing` is True, `group_by` becomes True to form a single
group with cash sharing.
!!! warning
Introduces cross-asset dependencies.
This method presumes that in a group of assets that share the same capital all
orders will be executed within the same tick and retain their price regardless
of their position in the queue, even though they depend upon each other and thus
cannot be executed in parallel.
from_ago (int or array_like): Take order information from a number of bars ago.
Will broadcast.
Negative numbers will be cast to positive to avoid the look-ahead bias. Defaults to 0.
Remember to account of it if you're using a custom signal function!
sim_start (int, datetime_like, or array_like): Simulation start row or index (inclusive).
Can be "auto", which will be substituted by the index of the first non-NA size value.
sim_end (int, datetime_like, or array_like): Simulation end row or index (exclusive).
Can be "auto", which will be substituted by the index of the first non-NA size value.
call_seq (CallSeqType or array_like): Default sequence of calls per row and group.
Each value in this sequence must indicate the position of column in the group to
call next. Processing of `call_seq` goes always from left to right.
For example, `[2, 0, 1]` would first call column 'c', then 'a', and finally 'b'.
Supported are multiple options:
* Set to None to generate the default call sequence on the fly. Will create a
full array only if `attach_call_seq` is True.
* Use `vectorbtpro.portfolio.enums.CallSeqType` to create a full array of a specific type.
* Set to array to specify a custom call sequence.
If `CallSeqType.Auto` selected, rearranges calls dynamically based on order value.
Calculates value of all orders per row and group, and sorts them by this value.
Sell orders will be executed first to release funds for buy orders.
!!! warning
`CallSeqType.Auto` should be used with caution:
* It not only presumes that order prices are known beforehand, but also that
orders can be executed in arbitrary order and still retain their price.
In reality, this is hardly the case: after processing one asset, some time
has passed and the price for other assets might have already changed.
* Even if you're able to specify a slippage large enough to compensate for
this behavior, slippage itself should depend upon execution order.
This method doesn't let you do that.
* Orders in the same queue are executed regardless of whether previous orders
have been filled, which can leave them without required funds.
For more control, use `Portfolio.from_order_func`.
attach_call_seq (bool): Whether to attach `call_seq` to the instance.
Makes sense if you want to analyze the simulation order. Otherwise, just takes memory.
ffill_val_price (bool): Whether to track valuation price only if it's known.
Otherwise, unknown `close` will lead to NaN in valuation price at the next timestamp.
update_value (bool): Whether to update group value after each filled order.
save_state (bool): Whether to save the state.
The arrays will be available as `cash`, `position`, `debt`, `locked_cash`, and `free_cash` in in-outputs.
save_value (bool): Whether to save the value.
The array will be available as `value` in in-outputs.
save_returns (bool): Whether to save the returns.
The array will be available as `returns` in in-outputs.
skip_empty (bool): Whether to skip rows with no order.
max_order_records (int): The max number of order records expected to be filled at each column.
Defaults to the maximum number of non-NaN values across all columns of the size array.
Set to a lower number if you run out of memory, and to 0 to not fill.
max_log_records (int): The max number of log records expected to be filled at each column.
Defaults to the maximum number of True values across all columns of the log array.
Set to a lower number if you run out of memory, and to 0 to not fill.
seed (int): Seed to be set for both `call_seq` and at the beginning of the simulation.
group_by (any): Group columns. See `vectorbtpro.base.grouping.base.Grouper`.
broadcast_kwargs (dict): Keyword arguments passed to `vectorbtpro.base.reshaping.broadcast`.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
chunked (any): See `vectorbtpro.utils.chunking.resolve_chunked_option`.
bm_close (array_like): Latest benchmark price at each time step.
Will broadcast.
If not provided, will use `close`. If False, will not use any benchmark.
records (array_like): Records to construct arrays from.
See `vectorbtpro.base.indexing.IdxRecords`.
return_preparer (bool): Whether to return the preparer of the type
`vectorbtpro.portfolio.preparing.FOPreparer`.
!!! note
Seed won't be set in this case, you need to explicitly call `preparer.set_seed()`.
return_prep_result (bool): Whether to return the preparer result of the type
`vectorbtpro.portfolio.preparing.PFPrepResult`.
return_sim_out (bool): Whether to return the simulation output of the type
`vectorbtpro.portfolio.enums.SimulationOutput`.
**kwargs: Keyword arguments passed to the `Portfolio` constructor.
All broadcastable arguments will broadcast using `vectorbtpro.base.reshaping.broadcast`
but keep original shape to utilize flexible indexing and to save memory.
For defaults, see `vectorbtpro._settings.portfolio`. Those defaults are not used to fill
NaN values after reindexing: vectorbt uses its own sensible defaults, which are usually NaN
for floating arrays and default flags for integer arrays. Use `vectorbtpro.base.reshaping.BCO`
with `fill_value` to override.
!!! note
When `call_seq` is not `CallSeqType.Auto`, at each timestamp, processing of the assets in
a group goes strictly in order defined in `call_seq`. This order can't be changed dynamically.
This has one big implication for this particular method: the last asset in the call stack
cannot be processed until other assets are processed. This is the reason why rebalancing
cannot work properly in this setting: one has to specify percentages for all assets beforehand
and then tweak the processing order to sell to-be-sold assets first in order to release funds
for to-be-bought assets. This can be automatically done by using `CallSeqType.Auto`.
!!! hint
All broadcastable arguments can be set per frame, series, row, column, or element.
Usage:
* Buy 10 units each tick:
```pycon
>>> close = pd.Series([1, 2, 3, 4, 5])
>>> pf = vbt.Portfolio.from_orders(close, 10)
>>> pf.assets
0 10.0
1 20.0
2 30.0
3 40.0
4 40.0
dtype: float64
>>> pf.cash
0 90.0
1 70.0
2 40.0
3 0.0
4 0.0
dtype: float64
```
* Reverse each position by first closing it:
```pycon
>>> size = [1, 0, -1, 0, 1]
>>> pf = vbt.Portfolio.from_orders(close, size, size_type='targetpercent')
>>> pf.assets
0 100.000000
1 0.000000
2 -66.666667
3 0.000000
4 26.666667
dtype: float64
>>> pf.cash
0 0.000000
1 200.000000
2 400.000000
3 133.333333
4 0.000000
dtype: float64
```
* Regularly deposit cash at open and invest it within the same bar at close:
```pycon
>>> close = pd.Series([1, 2, 3, 4, 5])
>>> cash_deposits = pd.Series([10., 0., 10., 0., 10.])
>>> pf = vbt.Portfolio.from_orders(
... close,
... size=cash_deposits, # invest the amount deposited
... size_type='value',
... cash_deposits=cash_deposits
... )
>>> pf.cash
0 100.0
1 100.0
2 100.0
3 100.0
4 100.0
dtype: float64
>>> pf.asset_flow
0 10.000000
1 0.000000
2 3.333333
3 0.000000
4 2.000000
dtype: float64
```
* Equal-weighted portfolio as in `vectorbtpro.portfolio.nb.from_order_func.from_order_func_nb` example
(it's more compact but has less control over execution):
```pycon
>>> np.random.seed(42)
>>> close = pd.DataFrame(np.random.uniform(1, 10, size=(5, 3)))
>>> size = pd.Series(np.full(5, 1/3)) # each column 33.3%
>>> size[1::2] = np.nan # skip every second tick
>>> pf = vbt.Portfolio.from_orders(
... close, # acts both as reference and order price here
... size,
... size_type='targetpercent',
... direction='longonly',
... call_seq='auto', # first sell then buy
... group_by=True, # one group
... cash_sharing=True, # assets share the same cash
... fees=0.001, fixed_fees=1., slippage=0.001 # costs
... )
>>> pf.get_asset_value(group_by=False).vbt.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
* Test 10 random weight combinations:
```pycon
>>> np.random.seed(42)
>>> close = pd.DataFrame(
... np.random.uniform(1, 10, size=(5, 3)),
... columns=pd.Index(['a', 'b', 'c'], name='asset'))
>>> # Generate random weight combinations
>>> rand_weights = []
>>> for i in range(10):
... rand_weights.append(np.random.dirichlet(np.ones(close.shape[1]), size=1)[0])
>>> rand_weights
[array([0.15474873, 0.27706078, 0.5681905 ]),
array([0.30468598, 0.18545189, 0.50986213]),
array([0.15780486, 0.36292607, 0.47926907]),
array([0.25697713, 0.64902589, 0.09399698]),
array([0.43310548, 0.53836359, 0.02853093]),
array([0.78628605, 0.15716865, 0.0565453 ]),
array([0.37186671, 0.42150531, 0.20662798]),
array([0.22441579, 0.06348919, 0.71209502]),
array([0.41619664, 0.09338007, 0.49042329]),
array([0.01279537, 0.87770864, 0.10949599])]
>>> # Bring close and rand_weights to the same shape
>>> rand_weights = np.concatenate(rand_weights)
>>> close = close.vbt.tile(10, keys=pd.Index(np.arange(10), name='weights_vector'))
>>> size = vbt.broadcast_to(weights, close).copy()
>>> size[1::2] = np.nan
>>> size
weights_vector 0 ... 9
asset a b c ... a b c
0 0.154749 0.277061 0.56819 ... 0.012795 0.877709 0.109496
1 NaN NaN NaN ... NaN NaN NaN
2 0.154749 0.277061 0.56819 ... 0.012795 0.877709 0.109496
3 NaN NaN NaN ... NaN NaN NaN
4 0.154749 0.277061 0.56819 ... 0.012795 0.877709 0.109496
[5 rows x 30 columns]
>>> pf = vbt.Portfolio.from_orders(
... close,
... size,
... size_type='targetpercent',
... direction='longonly',
... call_seq='auto',
... group_by='weights_vector', # group by column level
... cash_sharing=True,
... fees=0.001, fixed_fees=1., slippage=0.001
... )
>>> pf.total_return
weights_vector
0 -0.294372
1 0.139207
2 -0.281739
3 0.041242
4 0.467566
5 0.829925
6 0.320672
7 -0.087452
8 0.376681
9 -0.702773
Name: total_return, dtype: float64
```
"""
if isinstance(close, FOPreparer):
preparer = close
prep_result = None
elif isinstance(close, PFPrepResult):
preparer = None
prep_result = close
else:
local_kwargs = locals()
local_kwargs = {**local_kwargs, **local_kwargs["kwargs"]}
del local_kwargs["kwargs"]
del local_kwargs["cls"]
del local_kwargs["return_preparer"]
del local_kwargs["return_prep_result"]
del local_kwargs["return_sim_out"]
parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True)
if parsed_data is not None:
local_kwargs["data"] = parsed_data
local_kwargs["close"] = None
preparer = FOPreparer(**local_kwargs)
if not return_preparer:
preparer.set_seed()
prep_result = None
if return_preparer:
return preparer
if prep_result is None:
prep_result = preparer.result
if return_prep_result:
return prep_result
sim_out = prep_result.target_func(**prep_result.target_args)
if return_sim_out:
return sim_out
return cls(order_records=sim_out, **prep_result.pf_args)
@classmethod
def from_signals(
cls: tp.Type[PortfolioT],
close: tp.Union[tp.ArrayLike, OHLCDataMixin, FSPreparer, PFPrepResult],
entries: tp.Optional[tp.ArrayLike] = None,
exits: tp.Optional[tp.ArrayLike] = None,
*,
direction: tp.Optional[tp.ArrayLike] = None,
long_entries: tp.Optional[tp.ArrayLike] = None,
long_exits: tp.Optional[tp.ArrayLike] = None,
short_entries: tp.Optional[tp.ArrayLike] = None,
short_exits: tp.Optional[tp.ArrayLike] = None,
adjust_func_nb: tp.Union[None, tp.PathLike, nb.AdjustFuncT] = None,
adjust_args: tp.Args = (),
signal_func_nb: tp.Union[None, tp.PathLike, nb.SignalFuncT] = None,
signal_args: tp.ArgsLike = (),
post_segment_func_nb: tp.Union[None, tp.PathLike, nb.PostSegmentFuncT] = None,
post_segment_args: tp.ArgsLike = (),
order_mode: bool = False,
size: tp.Optional[tp.ArrayLike] = None,
size_type: tp.Optional[tp.ArrayLike] = None,
price: tp.Optional[tp.ArrayLike] = None,
fees: tp.Optional[tp.ArrayLike] = None,
fixed_fees: tp.Optional[tp.ArrayLike] = None,
slippage: tp.Optional[tp.ArrayLike] = None,
min_size: tp.Optional[tp.ArrayLike] = None,
max_size: tp.Optional[tp.ArrayLike] = None,
size_granularity: tp.Optional[tp.ArrayLike] = None,
leverage: tp.Optional[tp.ArrayLike] = None,
leverage_mode: tp.Optional[tp.ArrayLike] = None,
reject_prob: tp.Optional[tp.ArrayLike] = None,
price_area_vio_mode: tp.Optional[tp.ArrayLike] = None,
allow_partial: tp.Optional[tp.ArrayLike] = None,
raise_reject: tp.Optional[tp.ArrayLike] = None,
log: tp.Optional[tp.ArrayLike] = None,
val_price: tp.Optional[tp.ArrayLike] = None,
accumulate: tp.Optional[tp.ArrayLike] = None,
upon_long_conflict: tp.Optional[tp.ArrayLike] = None,
upon_short_conflict: tp.Optional[tp.ArrayLike] = None,
upon_dir_conflict: tp.Optional[tp.ArrayLike] = None,
upon_opposite_entry: tp.Optional[tp.ArrayLike] = None,
order_type: tp.Optional[tp.ArrayLike] = None,
limit_delta: tp.Optional[tp.ArrayLike] = None,
limit_tif: tp.Optional[tp.ArrayLike] = None,
limit_expiry: tp.Optional[tp.ArrayLike] = None,
limit_reverse: tp.Optional[tp.ArrayLike] = None,
limit_order_price: tp.Optional[tp.ArrayLike] = None,
upon_adj_limit_conflict: tp.Optional[tp.ArrayLike] = None,
upon_opp_limit_conflict: tp.Optional[tp.ArrayLike] = None,
use_stops: tp.Optional[bool] = None,
stop_ladder: tp.Optional[bool] = None,
sl_stop: tp.Optional[tp.ArrayLike] = None,
tsl_stop: tp.Optional[tp.ArrayLike] = None,
tsl_th: tp.Optional[tp.ArrayLike] = None,
tp_stop: tp.Optional[tp.ArrayLike] = None,
td_stop: tp.Optional[tp.ArrayLike] = None,
dt_stop: tp.Optional[tp.ArrayLike] = None,
stop_entry_price: tp.Optional[tp.ArrayLike] = None,
stop_exit_price: tp.Optional[tp.ArrayLike] = None,
stop_exit_type: tp.Optional[tp.ArrayLike] = None,
stop_order_type: tp.Optional[tp.ArrayLike] = None,
stop_limit_delta: tp.Optional[tp.ArrayLike] = None,
upon_stop_update: tp.Optional[tp.ArrayLike] = None,
upon_adj_stop_conflict: tp.Optional[tp.ArrayLike] = None,
upon_opp_stop_conflict: tp.Optional[tp.ArrayLike] = None,
delta_format: tp.Optional[tp.ArrayLike] = None,
time_delta_format: tp.Optional[tp.ArrayLike] = None,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
init_cash: tp.Optional[tp.ArrayLike] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
cash_earnings: tp.Optional[tp.ArrayLike] = None,
cash_dividends: tp.Optional[tp.ArrayLike] = None,
cash_sharing: tp.Optional[bool] = None,
from_ago: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
call_seq: tp.Optional[tp.ArrayLike] = None,
attach_call_seq: tp.Optional[bool] = None,
ffill_val_price: tp.Optional[bool] = None,
update_value: tp.Optional[bool] = None,
fill_pos_info: tp.Optional[bool] = None,
save_state: tp.Optional[bool] = None,
save_value: tp.Optional[bool] = None,
save_returns: tp.Optional[bool] = None,
skip_empty: tp.Optional[bool] = None,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = None,
in_outputs: tp.Optional[tp.MappingLike] = None,
seed: tp.Optional[int] = None,
group_by: tp.GroupByLike = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
staticized: tp.StaticizedOption = None,
bm_close: tp.Optional[tp.ArrayLike] = None,
records: tp.Optional[tp.RecordsLike] = None,
return_preparer: bool = False,
return_prep_result: bool = False,
return_sim_out: bool = False,
**kwargs,
) -> PortfolioResultT:
"""Simulate portfolio from entry and exit signals.
Supports the following modes:
1. `entries` and `exits`:
Uses `vectorbtpro.portfolio.nb.from_signals.dir_signal_func_nb` as `signal_func_nb`
if an adjustment function is provided (not cacheable), otherwise translates signals using
`vectorbtpro.portfolio.nb.from_signals.dir_to_ls_signals_nb` then simulates statically (cacheable)
2. `entries` (acting as long), `exits` (acting as long), `short_entries`, and `short_exits`:
Uses `vectorbtpro.portfolio.nb.from_signals.ls_signal_func_nb` as `signal_func_nb`
if an adjustment function is provided (not cacheable), otherwise simulates statically (cacheable)
3. `order_mode=True` without signals:
Uses `vectorbtpro.portfolio.nb.from_signals.order_signal_func_nb` as `signal_func_nb` (not cacheable)
4. `signal_func_nb` and `signal_args`: Custom signal function (not cacheable)
Prepared by `vectorbtpro.portfolio.preparing.FSPreparer`.
Args:
close (array_like, OHLCDataMixin, FSPreparer, or PFPrepResult): See `Portfolio.from_orders`.
entries (array_like of bool): Boolean array of entry signals.
Defaults to True if all other signal arrays are not set, otherwise False. Will broadcast.
* If `short_entries` and `short_exits` are not set: Acts as a long signal if `direction`
is 'all' or 'longonly', otherwise short.
* If `short_entries` or `short_exits` are set: Acts as `long_entries`.
exits (array_like of bool): Boolean array of exit signals.
Defaults to False. Will broadcast.
* If `short_entries` and `short_exits` are not set: Acts as a short signal if `direction`
is 'all' or 'longonly', otherwise long.
* If `short_entries` or `short_exits` are set: Acts as `long_exits`.
direction (Direction or array_like): See `Portfolio.from_orders`.
Takes only effect if `short_entries` and `short_exits` are not set.
long_entries (array_like of bool): Boolean array of long entry signals.
Defaults to False. Will broadcast.
long_exits (array_like of bool): Boolean array of long exit signals.
Defaults to False. Will broadcast.
short_entries (array_like of bool): Boolean array of short entry signals.
Defaults to False. Will broadcast.
short_exits (array_like of bool): Boolean array of short exit signals.
Defaults to False. Will broadcast.
adjust_func_nb (path_like or callable): User-defined function to adjust the current simulation state.
Defaults to `vectorbtpro.portfolio.nb.from_signals.no_adjust_func_nb`.
Passed as argument to `vectorbtpro.portfolio.nb.from_signals.dir_signal_func_nb`,
`vectorbtpro.portfolio.nb.from_signals.ls_signal_func_nb`, and
`vectorbtpro.portfolio.nb.from_signals.order_signal_func_nb`. Has no effect
when using other signal functions.
Can be a path to a module when using staticizing.
adjust_args (tuple): Packed arguments passed to `adjust_func_nb`.
signal_func_nb (path_like or callable): Function called to generate signals.
See `vectorbtpro.portfolio.nb.from_signals.from_signal_func_nb`.
Can be a path to a module when using staticizing.
signal_args (tuple): Packed arguments passed to `signal_func_nb`.
post_segment_func_nb (path_like or callable): Post-segment function.
See `vectorbtpro.portfolio.nb.from_signals.from_signal_func_nb`.
Can be a path to a module when using staticizing.
post_segment_args (tuple): Packed arguments passed to `post_segment_func_nb`.
order_mode (bool): Whether to simulate as orders without signals.
size (float or array_like): See `Portfolio.from_orders`.
!!! note
Negative size is not allowed. You must express direction using signals.
size_type (SizeType or array_like): See `Portfolio.from_orders`.
Only `SizeType.Amount`, `SizeType.Value`, `SizeType.Percent(100)`, and
`SizeType.ValuePercent(100)` are supported. Other modes such as target percentage
are not compatible with signals since their logic may contradict the direction of the signal.
!!! note
`SizeType.Percent(100)` does not support position reversal. Switch to a single
direction or use `OppositeEntryMode.Close` to close the position first.
See warning in `Portfolio.from_orders`.
price (array_like of float): See `Portfolio.from_orders`.
fees (float or array_like): See `Portfolio.from_orders`.
fixed_fees (float or array_like): See `Portfolio.from_orders`.
slippage (float or array_like): See `Portfolio.from_orders`.
min_size (float or array_like): See `Portfolio.from_orders`.
max_size (float or array_like): See `Portfolio.from_orders`.
Will be partially filled if exceeded. You might not be able to properly close
the position if accumulation is enabled and `max_size` is too low.
size_granularity (float or array_like): See `Portfolio.from_orders`.
leverage (float or array_like): See `Portfolio.from_orders`.
leverage_mode (LeverageMode or array_like): See `Portfolio.from_orders`.
reject_prob (float or array_like): See `Portfolio.from_orders`.
price_area_vio_mode (PriceAreaVioMode or array_like): See `Portfolio.from_orders`.
allow_partial (bool or array_like): See `Portfolio.from_orders`.
raise_reject (bool or array_like): See `Portfolio.from_orders`.
log (bool or array_like): See `Portfolio.from_orders`.
val_price (array_like of float): See `Portfolio.from_orders`.
accumulate (bool, AccumulationMode or array_like): See `vectorbtpro.portfolio.enums.AccumulationMode`.
If True, becomes 'both'. If False, becomes 'disabled'. Will broadcast.
When enabled, `Portfolio.from_signals` behaves similarly to `Portfolio.from_orders`.
upon_long_conflict (ConflictMode or array_like): Conflict mode for long signals.
See `vectorbtpro.portfolio.enums.ConflictMode`. Will broadcast.
upon_short_conflict (ConflictMode or array_like): Conflict mode for short signals.
See `vectorbtpro.portfolio.enums.ConflictMode`. Will broadcast.
upon_dir_conflict (DirectionConflictMode or array_like): See `vectorbtpro.portfolio.enums.DirectionConflictMode`.
Will broadcast.
upon_opposite_entry (OppositeEntryMode or array_like): See `vectorbtpro.portfolio.enums.OppositeEntryMode`.
Will broadcast.
order_type (OrderType or array_like): See `vectorbtpro.portfolio.enums.OrderType`.
Only one active limit order is allowed at a time.
limit_delta (float or array_like): Delta from `price` to build the limit price.
Will broadcast.
If NaN, `price` becomes the limit price. Otherwise, applied on top of `price` depending
on the current direction: if the direction-aware size is positive (= buying), a positive delta
will decrease the limit price; if the direction-aware size is negative (= selling), a positive delta
will increase the limit price. Delta can be negative.
Set an element to `np.nan` to disable. Use `delta_format` to specify the format.
limit_tif (frequency_like or array_like): Time in force for limit signals.
Will broadcast.
Any frequency-like object is converted using `vectorbtpro.utils.datetime_.to_timedelta64`.
Any array must either contain timedeltas or integers, and will be cast into integer format
after broadcasting. If the object provided is of data type `object`, will be converted
to timedelta automatically.
Measured in the distance after the open time of the signal bar. If the expiration time happens
in the middle of the current bar, we pessimistically assume that the order has been expired.
The check is performed at the beginning of the bar, and the first check is performed at the
next bar after the signal. For example, if the format is `TimeDeltaFormat.Rows`, 0 or 1 means
the order must execute at the same bar or not at all; 2 means the order must execute at the
same or next bar or not at all.
Set an element to `-1` to disable. Use `time_delta_format` to specify the format.
limit_expiry (frequency_like, datetime_like, or array_like): Expiration time.
Will broadcast.
Any frequency-like object is used to build a period index, such that each timestamp in the original
index is pointing to the timestamp where the period ends. For example, providing "d" will
make any limit order expire on the next day. Any array must either contain timestamps or integers
(not timedeltas!), and will be cast into integer format after broadcasting. If the object
provided is of data type `object`, will be converted to datetime and its timezone will
be removed automatically (as done on the index).
Behaves in a similar way as `limit_tif`.
Set an element to `-1` or `pd.Timestamp.max` to disable. Use `time_delta_format` to specify the format.
limit_reverse (bool or array_like): Whether to reverse the price hit detection.
Will broadcast.
If True, a buy/sell limit price will be checked against high/low (not low/high).
Also, the limit delta will be applied above/below (not below/above) the initial price.
limit_order_price (LimitOrderPrice or array_like): See `vectorbtpro.portfolio.enums.LimitOrderPrice`.
Will broadcast.
If provided on per-element basis, gets applied upon order creation. If a positive value is provided,
used directly as a price, otherwise used as an enumerated value.
upon_adj_limit_conflict (PendingConflictMode or array_like): Conflict mode for limit and user-defined
signals of adjacent sign. See `vectorbtpro.portfolio.enums.PendingConflictMode`. Will broadcast.
upon_opp_limit_conflict (PendingConflictMode or array_like): Conflict mode for limit and user-defined
signals of opposite sign. See `vectorbtpro.portfolio.enums.PendingConflictMode`. Will broadcast.
use_stops (bool): Whether to use stops.
Defaults to None, which becomes True if any of the stops are not NaN or
the adjustment function is not the default one.
Disable this to make simulation a bit faster for simple use cases.
stop_ladder (bool or StopLadderMode): Whether and which kind of stop laddering to use.
See `vectorbtpro.portfolio.enums.StopLadderMode`.
If so, rows in the supplied arrays will become ladder steps. Make sure that
they are increasing. If one column should have less steps, pad it with NaN
for price-based stops and -1 for time-based stops.
Rows in each array can be of an arbitrary length but columns must broadcast against
the number of columns in the data. Applied on all stop types.
sl_stop (array_like of float): Stop loss.
Will broadcast.
Set an element to `np.nan` to disable. Use `delta_format` to specify the format.
tsl_stop (array_like of float): Trailing stop loss.
Will broadcast.
Set an element to `np.nan` to disable. Use `delta_format` to specify the format.
tsl_th (array_like of float): Take profit threshold for the trailing stop loss.
Will broadcast.
Set an element to `np.nan` to disable. Use `delta_format` to specify the format.
tp_stop (array_like of float): Take profit.
Will broadcast.
Set an element to `np.nan` to disable. Use `delta_format` to specify the format.
td_stop (frequency_like or array_like): Timedelta-stop.
Will broadcast.
Set an element to `-1` to disable. Use `time_delta_format` to specify the format.
dt_stop (frequency_like, datetime_like, or array_like): Datetime-stop.
Will broadcast.
Set an element to `-1` to disable. Use `time_delta_format` to specify the format.
stop_entry_price (StopEntryPrice or array_like): See `vectorbtpro.portfolio.enums.StopEntryPrice`.
Will broadcast.
If provided on per-element basis, gets applied upon entry. If a positive value is provided,
used directly as a price, otherwise used as an enumerated value.
stop_exit_price (StopExitPrice or array_like): See `vectorbtpro.portfolio.enums.StopExitPrice`.
Will broadcast.
If provided on per-element basis, gets applied upon entry. If a positive value is provided,
used directly as a price, otherwise used as an enumerated value.
stop_exit_type (StopExitType or array_like): See `vectorbtpro.portfolio.enums.StopExitType`.
Will broadcast.
If provided on per-element basis, gets applied upon entry.
stop_order_type (OrderType or array_like): Similar to `order_type` but for stop orders.
Will broadcast.
If provided on per-element basis, gets applied upon entry.
stop_limit_delta (float or array_like): Similar to `limit_delta` but for stop orders.
Will broadcast.
upon_stop_update (StopUpdateMode or array_like): See `vectorbtpro.portfolio.enums.StopUpdateMode`.
Will broadcast.
Only has effect if accumulation is enabled.
If provided on per-element basis, gets applied upon repeated entry.
upon_adj_stop_conflict (PendingConflictMode or array_like): Conflict mode for stop and user-defined
signals of adjacent sign. See `vectorbtpro.portfolio.enums.PendingConflictMode`. Will broadcast.
upon_opp_stop_conflict (PendingConflictMode or array_like): Conflict mode for stop and user-defined
signals of opposite sign. See `vectorbtpro.portfolio.enums.PendingConflictMode`. Will broadcast.
delta_format (DeltaFormat or array_like): See `vectorbtpro.portfolio.enums.DeltaFormat`.
Will broadcast.
time_delta_format (TimeDeltaFormat or array_like): See `vectorbtpro.portfolio.enums.TimeDeltaFormat`.
Will broadcast.
open (array_like of float): See `Portfolio.from_orders`.
For stop signals, `np.nan` gets replaced by `close`.
high (array_like of float): See `Portfolio.from_orders`.
For stop signals, `np.nan` replaced by the maximum out of `open` and `close`.
low (array_like of float): See `Portfolio.from_orders`.
For stop signals, `np.nan` replaced by the minimum out of `open` and `close`.
init_cash (InitCashMode, float or array_like): See `Portfolio.from_orders`.
init_position (float or array_like): See `Portfolio.from_orders`.
init_price (float or array_like): See `Portfolio.from_orders`.
cash_deposits (float or array_like): See `Portfolio.from_orders`.
cash_earnings (float or array_like): See `Portfolio.from_orders`.
cash_dividends (float or array_like): See `Portfolio.from_orders`.
cash_sharing (bool): See `Portfolio.from_orders`.
from_ago (int or array_like): See `Portfolio.from_orders`.
Take effect only for user-defined signals, not for stop signals.
sim_start (int, datetime_like, or array_like): Simulation start row or index (inclusive).
Can be "auto", which will be substituted by the index of the first signal across
long and short entries and long and short exits.
sim_end (int, datetime_like, or array_like): Simulation end row or index (exclusive).
Can be "auto", which will be substituted by the index of the last signal across
long and short entries and long and short exits.
call_seq (CallSeqType or array_like): See `Portfolio.from_orders`.
attach_call_seq (bool): See `Portfolio.from_orders`.
ffill_val_price (bool): See `Portfolio.from_orders`.
update_value (bool): See `Portfolio.from_orders`.
fill_pos_info (bool): fill_pos_info (bool): Whether to fill position record.
Disable this to make simulation faster for simple use cases.
save_state (bool): See `Portfolio.from_orders`.
save_value (bool): See `Portfolio.from_orders`.
save_returns (bool): See `Portfolio.from_orders`.
skip_empty (bool): See `Portfolio.from_orders`
max_order_records (int): See `Portfolio.from_orders`.
max_log_records (int): See `Portfolio.from_orders`.
in_outputs (mapping_like): Mapping with in-output objects. Only for flexible mode.
Will be available via `Portfolio.in_outputs` as a named tuple.
To substitute `Portfolio` attributes, provide already broadcasted and grouped objects,
for example, by using `broadcast_named_args` and templates. Also see
`Portfolio.in_outputs_indexing_func` on how in-output objects are indexed.
When chunking, make sure to provide the chunk taking specification and the merging function.
See `vectorbtpro.portfolio.chunking.merge_sim_outs`.
!!! note
When using Numba below 0.54, `in_outputs` cannot be a mapping, but must be a named tuple
defined globally so Numba can introspect its attributes for pickling.
seed (int): See `Portfolio.from_orders`.
group_by (any): See `Portfolio.from_orders`.
broadcast_named_args (dict): Dictionary with named arguments to broadcast.
You can then pass argument names wrapped with `vectorbtpro.utils.template.Rep`
and this method will substitute them by their corresponding broadcasted objects.
broadcast_kwargs (dict): See `Portfolio.from_orders`.
template_context (mapping): Context used to substitute templates in arguments.
jitted (any): See `Portfolio.from_orders`.
chunked (any): See `Portfolio.from_orders`.
staticized (bool, dict, hashable, or callable): Keyword arguments or task id for staticizing.
If True or dictionary, will be passed as keyword arguments to
`vectorbtpro.utils.cutting.cut_and_save_func` to save a cacheable version of the
simulator to a file. If a hashable or callable, will be used as a task id of an
already registered jittable and chunkable simulator. Dictionary allows additional options
`override` and `reload` to override and reload an already existing module respectively.
bm_close (array_like): See `Portfolio.from_orders`.
records (array_like): See `Portfolio.from_orders`.
return_preparer (bool): See `Portfolio.from_orders`.
return_prep_result (bool): See `Portfolio.from_orders`.
return_sim_out (bool): See `Portfolio.from_orders`.
**kwargs: Keyword arguments passed to the `Portfolio` constructor.
All broadcastable arguments will broadcast using `vectorbtpro.base.reshaping.broadcast`
but keep original shape to utilize flexible indexing and to save memory.
For defaults, see `vectorbtpro._settings.portfolio`. Those defaults are not used to fill
NaN values after reindexing: vectorbt uses its own sensible defaults, which are usually NaN
for floating arrays and default flags for integer arrays. Use `vectorbtpro.base.reshaping.BCO`
with `fill_value` to override.
Also see notes and hints for `Portfolio.from_orders`.
Usage:
* By default, if all signal arrays are None, `entries` becomes True,
which opens a position at the very first tick and does nothing else:
```pycon
>>> close = pd.Series([1, 2, 3, 4, 5])
>>> pf = vbt.Portfolio.from_signals(close, size=1)
>>> pf.asset_flow
0 1.0
1 0.0
2 0.0
3 0.0
4 0.0
dtype: float64
```
* Entry opens long, exit closes long:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... size=1,
... direction='longonly'
... )
>>> pf.asset_flow
0 1.0
1 0.0
2 0.0
3 -1.0
4 0.0
dtype: float64
>>> # Using direction-aware arrays instead of `direction`
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]), # long_entries
... exits=pd.Series([False, False, True, True, True]), # long_exits
... short_entries=False,
... short_exits=False,
... size=1
... )
>>> pf.asset_flow
0 1.0
1 0.0
2 0.0
3 -1.0
4 0.0
dtype: float64
```
Notice how both `short_entries` and `short_exits` are provided as constants - as any other
broadcastable argument, they are treated as arrays where each element is False.
* Entry opens short, exit closes short:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... size=1,
... direction='shortonly'
... )
>>> pf.asset_flow
0 -1.0
1 0.0
2 0.0
3 1.0
4 0.0
dtype: float64
>>> # Using direction-aware arrays instead of `direction`
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=False, # long_entries
... exits=False, # long_exits
... short_entries=pd.Series([True, True, True, False, False]),
... short_exits=pd.Series([False, False, True, True, True]),
... size=1
... )
>>> pf.asset_flow
0 -1.0
1 0.0
2 0.0
3 1.0
4 0.0
dtype: float64
```
* Entry opens long and closes short, exit closes long and opens short:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... size=1,
... direction='both'
... )
>>> pf.asset_flow
0 1.0
1 0.0
2 0.0
3 -2.0
4 0.0
dtype: float64
>>> # Using direction-aware arrays instead of `direction`
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]), # long_entries
... exits=False, # long_exits
... short_entries=pd.Series([False, False, True, True, True]),
... short_exits=False,
... size=1
... )
>>> pf.asset_flow
0 1.0
1 0.0
2 0.0
3 -2.0
4 0.0
dtype: float64
```
* More complex signal combinations are best expressed using direction-aware arrays.
For example, ignore opposite signals as long as the current position is open:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries =pd.Series([True, False, False, False, False]), # long_entries
... exits =pd.Series([False, False, True, False, False]), # long_exits
... short_entries=pd.Series([False, True, False, True, False]),
... short_exits =pd.Series([False, False, False, False, True]),
... size=1,
... upon_opposite_entry='ignore'
... )
>>> pf.asset_flow
0 1.0
1 0.0
2 -1.0
3 -1.0
4 1.0
dtype: float64
```
* First opposite signal closes the position, second one opens a new position:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... size=1,
... direction='both',
... upon_opposite_entry='close'
... )
>>> pf.asset_flow
0 1.0
1 0.0
2 0.0
3 -1.0
4 -1.0
dtype: float64
```
* If both long entry and exit signals are True (a signal conflict), choose exit:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... size=1.,
... direction='longonly',
... upon_long_conflict='exit')
>>> pf.asset_flow
0 1.0
1 0.0
2 -1.0
3 0.0
4 0.0
dtype: float64
```
* If both long entry and short entry signal are True (a direction conflict), choose short:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... size=1.,
... direction='both',
... upon_dir_conflict='short')
>>> pf.asset_flow
0 1.0
1 0.0
2 -2.0
3 0.0
4 0.0
dtype: float64
```
!!! note
Remember that when direction is set to 'both', entries become `long_entries` and exits become
`short_entries`, so this becomes a conflict of directions rather than signals.
* If there are both signal and direction conflicts:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=True, # long_entries
... exits=True, # long_exits
... short_entries=True,
... short_exits=True,
... size=1,
... upon_long_conflict='entry',
... upon_short_conflict='entry',
... upon_dir_conflict='short'
... )
>>> pf.asset_flow
0 -1.0
1 0.0
2 0.0
3 0.0
4 0.0
dtype: float64
```
* Turn on accumulation of signals. Entry means long order, exit means short order
(acts similar to `from_orders`):
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... size=1.,
... direction='both',
... accumulate=True)
>>> pf.asset_flow
0 1.0
1 1.0
2 0.0
3 -1.0
4 -1.0
dtype: float64
```
* Allow increasing a position (of any direction), deny decreasing a position:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... size=1.,
... direction='both',
... accumulate='addonly')
>>> pf.asset_flow
0 1.0 << open a long position
1 1.0 << add to the position
2 0.0
3 -3.0 << close and open a short position
4 -1.0 << add to the position
dtype: float64
```
* Test multiple parameters via regular broadcasting:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... direction=[list(Direction)],
... broadcast_kwargs=dict(columns_from=pd.Index(vbt.pf_enums.Direction._fields, name='direction')))
>>> pf.asset_flow
direction LongOnly ShortOnly Both
0 100.0 -100.0 100.0
1 0.0 0.0 0.0
2 0.0 0.0 0.0
3 -100.0 50.0 -200.0
4 0.0 0.0 0.0
```
* Test multiple parameters via `vectorbtpro.base.reshaping.BCO`:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close,
... entries=pd.Series([True, True, True, False, False]),
... exits=pd.Series([False, False, True, True, True]),
... direction=vbt.Param(Direction))
>>> pf.asset_flow
direction LongOnly ShortOnly Both
0 100.0 -100.0 100.0
1 0.0 0.0 0.0
2 0.0 0.0 0.0
3 -100.0 50.0 -200.0
4 0.0 0.0 0.0
```
* Set risk/reward ratio by passing trailing stop loss and take profit thresholds:
```pycon
>>> close = pd.Series([10, 11, 12, 11, 10, 9])
>>> entries = pd.Series([True, False, False, False, False, False])
>>> exits = pd.Series([False, False, False, False, False, True])
>>> pf = vbt.Portfolio.from_signals(
... close, entries, exits,
... tsl_stop=0.1, tp_stop=0.2) # take profit hit
>>> pf.asset_flow
0 10.0
1 0.0
2 -10.0
3 0.0
4 0.0
5 0.0
dtype: float64
>>> pf = vbt.Portfolio.from_signals(
... close, entries, exits,
... tsl_stop=0.1, tp_stop=0.3) # trailing stop loss hit
>>> pf.asset_flow
0 10.0
1 0.0
2 0.0
3 0.0
4 -10.0
5 0.0
dtype: float64
>>> pf = vbt.Portfolio.from_signals(
... close, entries, exits,
... tsl_stop=np.inf, tp_stop=np.inf) # nothing hit, exit as usual
>>> pf.asset_flow
0 10.0
1 0.0
2 0.0
3 0.0
4 0.0
5 -10.0
dtype: float64
```
* Test different stop combinations:
```pycon
>>> pf = vbt.Portfolio.from_signals(
... close, entries, exits,
... tsl_stop=vbt.Param([0.1, 0.2]),
... tp_stop=vbt.Param([0.2, 0.3])
... )
>>> pf.asset_flow
tsl_stop 0.1 0.2
tp_stop 0.2 0.3 0.2 0.3
0 10.0 10.0 10.0 10.0
1 0.0 0.0 0.0 0.0
2 -10.0 0.0 -10.0 0.0
3 0.0 0.0 0.0 0.0
4 0.0 -10.0 0.0 0.0
5 0.0 0.0 0.0 -10.0
```
This works because `pd.Index` automatically translates into `vectorbtpro.base.reshaping.BCO`
with `product` set to True.
* We can implement our own stop loss or take profit, or adjust the existing one at each time step.
Let's implement [stepped stop-loss](https://www.freqtrade.io/en/stable/strategy-advanced/#stepped-stoploss):
```pycon
>>> @njit
... def adjust_func_nb(c):
... val_price_now = c.last_val_price[c.col]
... tsl_init_price = c.last_tsl_info["init_price"][c.col]
... current_profit = (val_price_now - tsl_init_price) / tsl_init_price
... if current_profit >= 0.40:
... c.last_tsl_info["stop"][c.col] = 0.25
... elif current_profit >= 0.25:
... c.last_tsl_info["stop"][c.col] = 0.15
... elif current_profit >= 0.20:
... c.last_tsl_info["stop"][c.col] = 0.07
>>> close = pd.Series([10, 11, 12, 11, 10])
>>> pf = vbt.Portfolio.from_signals(close, adjust_func_nb=adjust_func_nb)
>>> pf.asset_flow
0 10.0
1 0.0
2 0.0
3 -10.0 # 7% from 12 hit
4 11.16
dtype: float64
```
* Sometimes there is a need to provide or transform signals dynamically. For this, we can implement
a custom signal function `signal_func_nb`. For example, let's implement a signal function that
takes two numerical arrays - long and short one - and transforms them into 4 direction-aware boolean
arrays that vectorbt understands:
```pycon
>>> @njit
... def signal_func_nb(c, long_num_arr, short_num_arr):
... long_num = vbt.pf_nb.select_nb(c, long_num_arr)
... short_num = vbt.pf_nb.select_nb(c, short_num_arr)
... is_long_entry = long_num > 0
... is_long_exit = long_num < 0
... is_short_entry = short_num > 0
... is_short_exit = short_num < 0
... return is_long_entry, is_long_exit, is_short_entry, is_short_exit
>>> pf = vbt.Portfolio.from_signals(
... pd.Series([1, 2, 3, 4, 5]),
... signal_func_nb=signal_func_nb,
... signal_args=(vbt.Rep('long_num_arr'), vbt.Rep('short_num_arr')),
... broadcast_named_args=dict(
... long_num_arr=pd.Series([1, 0, -1, 0, 0]),
... short_num_arr=pd.Series([0, 1, 0, 1, -1])
... ),
... size=1,
... upon_opposite_entry='ignore'
... )
>>> pf.asset_flow
0 1.0
1 0.0
2 -1.0
3 -1.0
4 1.0
dtype: float64
```
Passing both arrays as `broadcast_named_args` broadcasts them internally as any other array,
so we don't have to worry about their dimensions every time we change our data.
"""
if isinstance(close, FSPreparer):
preparer = close
prep_result = None
elif isinstance(close, PFPrepResult):
preparer = None
prep_result = close
else:
local_kwargs = locals()
local_kwargs = {**local_kwargs, **local_kwargs["kwargs"]}
del local_kwargs["kwargs"]
del local_kwargs["cls"]
del local_kwargs["return_preparer"]
del local_kwargs["return_prep_result"]
del local_kwargs["return_sim_out"]
parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True)
if parsed_data is not None:
local_kwargs["data"] = parsed_data
local_kwargs["close"] = None
preparer = FSPreparer(**local_kwargs)
if not return_preparer:
preparer.set_seed()
prep_result = None
if return_preparer:
return preparer
if prep_result is None:
prep_result = preparer.result
if return_prep_result:
return prep_result
sim_out = prep_result.target_func(**prep_result.target_args)
if return_sim_out:
return sim_out
return cls(order_records=sim_out, **prep_result.pf_args)
@classmethod
def from_holding(
cls: tp.Type[PortfolioT],
close: tp.Union[tp.ArrayLike, OHLCDataMixin],
direction: tp.Optional[int] = None,
at_first_valid_in: tp.Optional[str] = "close",
close_at_end: tp.Optional[bool] = None,
dynamic_mode: bool = False,
**kwargs,
) -> PortfolioResultT:
"""Simulate portfolio from plain holding using signals.
If `close_at_end` is True, will place an opposite signal at the very end.
`**kwargs` are passed to the class method `Portfolio.from_signals`."""
from vectorbtpro._settings import settings
portfolio_cfg = settings["portfolio"]
if direction is None:
direction = portfolio_cfg["hold_direction"]
direction = map_enum_fields(direction, enums.Direction)
if not checks.is_int(direction):
raise TypeError("Direction must be a scalar")
if close_at_end is None:
close_at_end = portfolio_cfg["close_at_end"]
if dynamic_mode:
def _substitute_signal_args(preparer):
return (
direction,
close_at_end,
*((preparer.adjust_func_nb,) if preparer.staticized is None else ()),
preparer.adjust_args,
)
return cls.from_signals(
close,
signal_func_nb=nb.holding_enex_signal_func_nb,
signal_args=RepFunc(_substitute_signal_args),
accumulate=False,
**kwargs,
)
def _entries(wrapper, new_objs):
if at_first_valid_in is None:
entries = np.full((wrapper.shape_2d[0], 1), False)
entries[0] = True
return entries
ts = new_objs[at_first_valid_in]
valid_index = generic_nb.first_valid_index_nb(ts)
if (valid_index == -1).all():
return np.array([[False]])
if (valid_index == 0).all():
entries = np.full((wrapper.shape_2d[0], 1), False)
entries[0] = True
return entries
entries = np.full(wrapper.shape_2d, False)
entries[valid_index, np.arange(wrapper.shape_2d[1])] = True
return entries
def _exits(wrapper):
if close_at_end:
exits = np.full((wrapper.shape_2d[0], 1), False)
exits[-1] = True
else:
exits = np.array([[False]])
return exits
return cls.from_signals(
close,
entries=RepFunc(_entries),
exits=RepFunc(_exits),
direction=direction,
accumulate=False,
**kwargs,
)
@classmethod
def from_random_signals(
cls: tp.Type[PortfolioT],
close: tp.Union[tp.ArrayLike, OHLCDataMixin],
n: tp.Optional[tp.ArrayLike] = None,
prob: tp.Optional[tp.ArrayLike] = None,
entry_prob: tp.Optional[tp.ArrayLike] = None,
exit_prob: tp.Optional[tp.ArrayLike] = None,
param_product: bool = False,
seed: tp.Optional[int] = None,
run_kwargs: tp.KwargsLike = None,
**kwargs,
) -> PortfolioResultT:
"""Simulate portfolio from random entry and exit signals.
Generates signals based either on the number of signals `n` or the probability
of encountering a signal `prob`.
* If `n` is set, see `vectorbtpro.signals.generators.randnx.RANDNX`.
* If `prob` is set, see `vectorbtpro.signals.generators.rprobnx.RPROBNX`.
Based on `Portfolio.from_signals`.
!!! note
To generate random signals, the shape of `close` is used. Broadcasting with other
arrays happens after the generation.
Usage:
* Test multiple combinations of random entries and exits:
```pycon
>>> close = pd.Series([1, 2, 3, 4, 5])
>>> pf = vbt.Portfolio.from_random_signals(close, n=[2, 1, 0], seed=42)
>>> pf.orders.count()
randnx_n
2 4
1 2
0 0
Name: count, dtype: int64
```
* Test the Cartesian product of entry and exit encounter probabilities:
```pycon
>>> pf = vbt.Portfolio.from_random_signals(
... close,
... entry_prob=[0, 0.5, 1],
... exit_prob=[0, 0.5, 1],
... param_product=True,
... seed=42)
>>> pf.orders.count()
rprobnx_entry_prob rprobnx_exit_prob
0.0 0.0 0
0.5 0
1.0 0
0.5 0.0 1
0.5 4
1.0 3
1.0 0.0 1
0.5 4
1.0 5
Name: count, dtype: int64
```
"""
from vectorbtpro._settings import settings
portfolio_cfg = settings["portfolio"]
parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True)
if parsed_data is not None:
data = parsed_data
close = data.close
if close is None:
raise ValueError("Column for close couldn't be found in data")
close_wrapper = data.symbol_wrapper
else:
close = to_pd_array(close)
close_wrapper = ArrayWrapper.from_obj(close)
data = close
if entry_prob is None:
entry_prob = prob
if exit_prob is None:
exit_prob = prob
if seed is None:
seed = portfolio_cfg["seed"]
if run_kwargs is None:
run_kwargs = {}
if n is not None and (entry_prob is not None or exit_prob is not None):
raise ValueError("Must provide either n or entry_prob and exit_prob")
if n is not None:
from vectorbtpro.signals.generators.randnx import RANDNX
rand = RANDNX.run(
n=n,
input_shape=close_wrapper.shape,
input_index=close_wrapper.index,
input_columns=close_wrapper.columns,
seed=seed,
**run_kwargs,
)
entries = rand.entries
exits = rand.exits
elif entry_prob is not None and exit_prob is not None:
from vectorbtpro.signals.generators.rprobnx import RPROBNX
rprobnx = RPROBNX.run(
entry_prob=entry_prob,
exit_prob=exit_prob,
param_product=param_product,
input_shape=close_wrapper.shape,
input_index=close_wrapper.index,
input_columns=close_wrapper.columns,
seed=seed,
**run_kwargs,
)
entries = rprobnx.entries
exits = rprobnx.exits
else:
raise ValueError("Must provide at least n or entry_prob and exit_prob")
return cls.from_signals(data, entries, exits, seed=seed, **kwargs)
@classmethod
def from_optimizer(
cls: tp.Type[PortfolioT],
close: tp.Union[tp.ArrayLike, OHLCDataMixin],
optimizer: PortfolioOptimizer,
pf_method: str = "from_orders",
squeeze_groups: bool = True,
dropna: tp.Optional[str] = None,
fill_value: tp.Scalar = np.nan,
size_type: tp.ArrayLike = "targetpercent",
direction: tp.Optional[tp.ArrayLike] = None,
cash_sharing: tp.Optional[bool] = True,
call_seq: tp.Optional[tp.ArrayLike] = "auto",
group_by: tp.GroupByLike = None,
**kwargs,
) -> PortfolioResultT:
"""Build portfolio from an optimizer of type `vectorbtpro.portfolio.pfopt.base.PortfolioOptimizer`.
Uses `Portfolio.from_orders` as the base simulation method.
The size type is 'targetpercent'. If there are positive and negative values, the direction
is automatically set to 'both', otherwise to 'longonly' for positive-only and `shortonly`
for negative-only values. Also, the cash sharing is set to True, the call sequence is set
to 'auto', and the grouper is set to the grouper of the optimizer by default.
Usage:
```pycon
>>> close = pd.DataFrame({
... "MSFT": [1, 2, 3, 4, 5],
... "GOOG": [5, 4, 3, 2, 1],
... "AAPL": [1, 2, 3, 2, 1]
... }, index=pd.date_range(start="2020-01-01", periods=5))
>>> pfo = vbt.PortfolioOptimizer.from_random(
... close.vbt.wrapper,
... every="2D",
... seed=42
... )
>>> pfo.fill_allocations()
MSFT GOOG AAPL
2020-01-01 0.182059 0.462129 0.355812
2020-01-02 NaN NaN NaN
2020-01-03 0.657381 0.171323 0.171296
2020-01-04 NaN NaN NaN
2020-01-05 0.038078 0.567845 0.394077
>>> pf = vbt.Portfolio.from_optimizer(close, pfo)
>>> pf.get_asset_value(group_by=False).vbt / pf.value
alloc_group group
MSFT GOOG AAPL
2020-01-01 0.182059 0.462129 0.355812 << rebalanced
2020-01-02 0.251907 0.255771 0.492322
2020-01-03 0.657381 0.171323 0.171296 << rebalanced
2020-01-04 0.793277 0.103369 0.103353
2020-01-05 0.038078 0.567845 0.394077 << rebalanced
```
"""
size = optimizer.fill_allocations(squeeze_groups=squeeze_groups, dropna=dropna, fill_value=fill_value)
if direction is None:
pos_size_any = (size.values > 0).any()
neg_size_any = (size.values < 0).any()
if pos_size_any and neg_size_any:
direction = "both"
elif pos_size_any:
direction = "longonly"
else:
direction = "shortonly"
size = size.abs()
if group_by is None:
def _substitute_group_by(index):
columns = optimizer.wrapper.columns
if squeeze_groups and optimizer.wrapper.grouped_ndim == 1:
columns = columns.droplevel(level=0)
if not index.equals(columns):
if "symbol" in index.names:
return ExceptLevel("symbol")
raise ValueError("Column hierarchy has changed. Disable squeeze_groups and provide group_by.")
return optimizer.wrapper.grouper.group_by
group_by = RepFunc(_substitute_group_by)
if pf_method.lower() == "from_orders":
return cls.from_orders(
close,
size=size,
size_type=size_type,
direction=direction,
cash_sharing=cash_sharing,
call_seq=call_seq,
group_by=group_by,
**kwargs,
)
elif pf_method.lower() == "from_signals":
return cls.from_signals(
close,
order_mode=True,
size=size,
size_type=size_type,
direction=direction,
accumulate=True,
cash_sharing=cash_sharing,
call_seq=call_seq,
group_by=group_by,
**kwargs,
)
else:
raise ValueError(f"Invalid pf_method: '{pf_method}'")
@classmethod
def from_order_func(
cls: tp.Type[PortfolioT],
close: tp.Union[tp.ArrayLike, OHLCDataMixin, FOFPreparer, PFPrepResult],
*,
init_cash: tp.Optional[tp.ArrayLike] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
cash_earnings: tp.Optional[tp.ArrayLike] = None,
cash_sharing: tp.Optional[bool] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
call_seq: tp.Optional[tp.ArrayLike] = None,
attach_call_seq: tp.Optional[bool] = None,
segment_mask: tp.Optional[tp.ArrayLike] = None,
call_pre_segment: tp.Optional[bool] = None,
call_post_segment: tp.Optional[bool] = None,
pre_sim_func_nb: tp.Optional[nb.PreSimFuncT] = None,
pre_sim_args: tp.Args = (),
post_sim_func_nb: tp.Optional[nb.PostSimFuncT] = None,
post_sim_args: tp.Args = (),
pre_group_func_nb: tp.Optional[nb.PreGroupFuncT] = None,
pre_group_args: tp.Args = (),
post_group_func_nb: tp.Optional[nb.PostGroupFuncT] = None,
post_group_args: tp.Args = (),
pre_row_func_nb: tp.Optional[nb.PreRowFuncT] = None,
pre_row_args: tp.Args = (),
post_row_func_nb: tp.Optional[nb.PostRowFuncT] = None,
post_row_args: tp.Args = (),
pre_segment_func_nb: tp.Optional[nb.PreSegmentFuncT] = None,
pre_segment_args: tp.Args = (),
post_segment_func_nb: tp.Optional[nb.PostSegmentFuncT] = None,
post_segment_args: tp.Args = (),
order_func_nb: tp.Optional[nb.OrderFuncT] = None,
order_args: tp.Args = (),
flex_order_func_nb: tp.Optional[nb.FlexOrderFuncT] = None,
flex_order_args: tp.Args = (),
post_order_func_nb: tp.Optional[nb.PostOrderFuncT] = None,
post_order_args: tp.Args = (),
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
ffill_val_price: tp.Optional[bool] = None,
update_value: tp.Optional[bool] = None,
fill_pos_info: tp.Optional[bool] = None,
track_value: tp.Optional[bool] = None,
row_wise: tp.Optional[bool] = None,
max_order_records: tp.Optional[int] = None,
max_log_records: tp.Optional[int] = None,
in_outputs: tp.Optional[tp.MappingLike] = None,
seed: tp.Optional[int] = None,
group_by: tp.GroupByLike = None,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
keep_inout_flex: tp.Optional[bool] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
staticized: tp.StaticizedOption = None,
bm_close: tp.Optional[tp.ArrayLike] = None,
records: tp.Optional[tp.RecordsLike] = None,
return_preparer: bool = False,
return_prep_result: bool = False,
return_sim_out: bool = False,
**kwargs,
) -> PortfolioResultT:
"""Build portfolio from a custom order function.
!!! hint
See `vectorbtpro.portfolio.nb.from_order_func.from_order_func_nb` for illustrations and argument definitions.
For more details on individual simulation functions:
* `order_func_nb`: See `vectorbtpro.portfolio.nb.from_order_func.from_order_func_nb`
* `order_func_nb` and `row_wise`: See `vectorbtpro.portfolio.nb.from_order_func.from_order_func_rw_nb`
* `flex_order_func_nb`: See `vectorbtpro.portfolio.nb.from_order_func.from_flex_order_func_nb`
* `flex_order_func_nb` and `row_wise`: See `vectorbtpro.portfolio.nb.from_order_func.from_flex_order_func_rw_nb`
Prepared by `vectorbtpro.portfolio.preparing.FOFPreparer`.
Args:
close (array_like, OHLCDataMixin, FOFPreparer, or PFPrepResult): Latest asset price at each time step.
Will broadcast.
If an instance of `vectorbtpro.data.base.OHLCDataMixin`, will extract the open, high,
low, and close price.
Used for calculating unrealized PnL and portfolio value.
init_cash (InitCashMode, float or array_like): See `Portfolio.from_orders`.
init_position (float or array_like): See `Portfolio.from_orders`.
init_price (float or array_like): See `Portfolio.from_orders`.
cash_deposits (float or array_like): See `Portfolio.from_orders`.
cash_earnings (float or array_like): See `Portfolio.from_orders`.
cash_sharing (bool): Whether to share cash within the same group.
If `group_by` is None, `group_by` becomes True to form a single group with cash sharing.
sim_start (int, datetime_like, or array_like): Simulation start row or index (inclusive).
sim_end (int, datetime_like, or array_like): Simulation end row or index (exclusive).
call_seq (CallSeqType or array_like): Default sequence of calls per row and group.
* Use `vectorbtpro.portfolio.enums.CallSeqType` to select a sequence type.
* Set to array to specify custom sequence. Will not broadcast.
!!! note
CallSeqType.Auto must be implemented manually.
Use `vectorbtpro.portfolio.nb.from_order_func.sort_call_seq_1d_nb`
or `vectorbtpro.portfolio.nb.from_order_func.sort_call_seq_out_1d_nb` in `pre_segment_func_nb`.
attach_call_seq (bool): See `Portfolio.from_orders`.
segment_mask (int or array_like of bool): Mask of whether a particular segment should be executed.
Supplying an integer will activate every n-th row.
Supplying a boolean or an array of boolean will broadcast to the number of rows and groups.
Does not broadcast together with `close` and `broadcast_named_args`, only against the final shape.
call_pre_segment (bool): Whether to call `pre_segment_func_nb` regardless of `segment_mask`.
call_post_segment (bool): Whether to call `post_segment_func_nb` regardless of `segment_mask`.
pre_sim_func_nb (callable): Function called before simulation.
Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_pre_func_nb`.
pre_sim_args (tuple): Packed arguments passed to `pre_sim_func_nb`.
post_sim_func_nb (callable): Function called after simulation.
Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_post_func_nb`.
post_sim_args (tuple): Packed arguments passed to `post_sim_func_nb`.
pre_group_func_nb (callable): Function called before each group.
Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_pre_func_nb`.
Called only if `row_wise` is False.
pre_group_args (tuple): Packed arguments passed to `pre_group_func_nb`.
post_group_func_nb (callable): Function called after each group.
Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_post_func_nb`.
Called only if `row_wise` is False.
post_group_args (tuple): Packed arguments passed to `post_group_func_nb`.
pre_row_func_nb (callable): Function called before each row.
Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_pre_func_nb`.
Called only if `row_wise` is True.
pre_row_args (tuple): Packed arguments passed to `pre_row_func_nb`.
post_row_func_nb (callable): Function called after each row.
Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_post_func_nb`.
Called only if `row_wise` is True.
post_row_args (tuple): Packed arguments passed to `post_row_func_nb`.
pre_segment_func_nb (callable): Function called before each segment.
Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_pre_func_nb`.
pre_segment_args (tuple): Packed arguments passed to `pre_segment_func_nb`.
post_segment_func_nb (callable): Function called after each segment.
Defaults to `vectorbtpro.portfolio.nb.from_order_func.no_post_func_nb`.
post_segment_args (tuple): Packed arguments passed to `post_segment_func_nb`.
order_func_nb (callable): Order generation function.
order_args: Packed arguments passed to `order_func_nb`.
flex_order_func_nb (callable): Flexible order generation function.
flex_order_args: Packed arguments passed to `flex_order_func_nb`.
post_order_func_nb (callable): Callback that is called after the order has been processed.
post_order_args (tuple): Packed arguments passed to `post_order_func_nb`.
open (array_like of float): See `Portfolio.from_orders`.
high (array_like of float): See `Portfolio.from_orders`.
low (array_like of float): See `Portfolio.from_orders`.
ffill_val_price (bool): Whether to track valuation price only if it's known.
Otherwise, unknown `close` will lead to NaN in valuation price at the next timestamp.
update_value (bool): Whether to update group value after each filled order.
fill_pos_info (bool): Whether to fill position record.
Disable this to make simulation faster for simple use cases.
track_value (bool): Whether to track value metrics such as
the current valuation price, value, and return.
Disable this to make simulation faster for simple use cases.
row_wise (bool): Whether to iterate over rows rather than columns/groups.
max_order_records (int): The max number of order records expected to be filled at each column.
Defaults to the number of rows in the broadcasted shape.
Set to a lower number if you run out of memory, to 0 to not fill, and to a higher number
if there are more than one order expected at each timestamp.
max_log_records (int): The max number of log records expected to be filled at each column.
Defaults to the number of rows in the broadcasted shape.
Set to a lower number if you run out of memory, to 0 to not fill, and to a higher number
if there are more than one order expected at each timestamp.
in_outputs (mapping_like): Mapping with in-output objects.
Will be available via `Portfolio.in_outputs` as a named tuple.
To substitute `Portfolio` attributes, provide already broadcasted and grouped objects,
for example, by using `broadcast_named_args` and templates. Also see
`Portfolio.in_outputs_indexing_func` on how in-output objects are indexed.
When chunking, make sure to provide the chunk taking specification and the merging function.
See `vectorbtpro.portfolio.chunking.merge_sim_outs`.
!!! note
When using Numba below 0.54, `in_outputs` cannot be a mapping, but must be a named tuple
defined globally so Numba can introspect its attributes for pickling.
seed (int): See `Portfolio.from_orders`.
group_by (any): See `Portfolio.from_orders`.
broadcast_named_args (dict): See `Portfolio.from_signals`.
broadcast_kwargs (dict): See `Portfolio.from_orders`.
template_context (mapping): See `Portfolio.from_signals`.
keep_inout_flex (bool): Whether to keep arrays that can be edited in-place raw when broadcasting.
Disable this to be able to edit `segment_mask`, `cash_deposits`, and
`cash_earnings` during the simulation.
jitted (any): See `Portfolio.from_orders`.
!!! note
Disabling jitting will not disable jitter (such as Numba) on other functions,
only on the main (simulation) function. If neccessary, you should ensure that every other
function is not compiled as well. For example, when working with Numba, you can do this
by using the `py_func` attribute of that function. Or, you can disable Numba
entirely by running `os.environ['NUMBA_DISABLE_JIT'] = '1'` before importing vectorbtpro.
!!! warning
Parallelization assumes that groups are independent and there is no data flowing between them.
chunked (any): See `vectorbtpro.utils.chunking.resolve_chunked_option`.
staticized (bool, dict, hashable, or callable): Keyword arguments or task id for staticizing.
If True or dictionary, will be passed as keyword arguments to
`vectorbtpro.utils.cutting.cut_and_save_func` to save a cacheable version of the
simulator to a file. If a hashable or callable, will be used as a task id of an
already registered jittable and chunkable simulator. Dictionary allows additional options
`override` and `reload` to override and reload an already existing module respectively.
bm_close (array_like): See `Portfolio.from_orders`.
records (array_like): See `Portfolio.from_orders`.
return_preparer (bool): See `Portfolio.from_orders`.
return_prep_result (bool): See `Portfolio.from_orders`.
return_sim_out (bool): See `Portfolio.from_orders`.
**kwargs: Keyword arguments passed to the `Portfolio` constructor.
For defaults, see `vectorbtpro._settings.portfolio`. Those defaults are not used to fill
NaN values after reindexing: vectorbt uses its own sensible defaults, which are usually NaN
for floating arrays and default flags for integer arrays. Use `vectorbtpro.base.reshaping.BCO`
with `fill_value` to override.
!!! note
All passed functions must be Numba-compiled if Numba is enabled.
Also see notes on `Portfolio.from_orders`.
!!! note
In contrast to other methods, the valuation price is previous `close` instead of the order price
since the price of an order is unknown before the call (which is more realistic by the way).
You can still override the valuation price in `pre_segment_func_nb`.
Usage:
* Buy 10 units each tick using closing price:
```pycon
>>> @njit
... def order_func_nb(c, size):
... return vbt.pf_nb.order_nb(size=size)
>>> close = pd.Series([1, 2, 3, 4, 5])
>>> pf = vbt.Portfolio.from_order_func(
... close,
... order_func_nb=order_func_nb,
... order_args=(10,)
... )
>>> pf.assets
0 10.0
1 20.0
2 30.0
3 40.0
4 40.0
dtype: float64
>>> pf.cash
0 90.0
1 70.0
2 40.0
3 0.0
4 0.0
dtype: float64
```
* Reverse each position by first closing it. Keep state of last position to determine
which position to open next (just as an example, there are easier ways to do this):
```pycon
>>> @njit
... def pre_group_func_nb(c):
... last_pos_state = np.array([-1])
... return (last_pos_state,)
>>> @njit
... def order_func_nb(c, last_pos_state):
... if c.position_now != 0:
... return vbt.pf_nb.close_position_nb()
...
... if last_pos_state[0] == 1:
... size = -np.inf # open short
... last_pos_state[0] = -1
... else:
... size = np.inf # open long
... last_pos_state[0] = 1
... return vbt.pf_nb.order_nb(size=size)
>>> pf = vbt.Portfolio.from_order_func(
... close,
... order_func_nb=order_func_nb,
... pre_group_func_nb=pre_group_func_nb
... )
>>> pf.assets
0 100.000000
1 0.000000
2 -66.666667
3 0.000000
4 26.666667
dtype: float64
>>> pf.cash
0 0.000000
1 200.000000
2 400.000000
3 133.333333
4 0.000000
dtype: float64
```
* Equal-weighted portfolio as in the example under `vectorbtpro.portfolio.nb.from_order_func.from_order_func_nb`:
```pycon
>>> @njit
... def pre_group_func_nb(c):
... order_value_out = np.empty(c.group_len, dtype=float_)
... return (order_value_out,)
>>> @njit
... def pre_segment_func_nb(c, order_value_out, size, price, size_type, direction):
... for col in range(c.from_col, c.to_col):
... c.last_val_price[col] = vbt.pf_nb.select_from_col_nb(c, col, price)
... vbt.pf_nb.sort_call_seq_nb(c, size, size_type, direction, order_value_out)
... return ()
>>> @njit
... def order_func_nb(c, size, price, size_type, direction, fees, fixed_fees, slippage):
... return vbt.pf_nb.order_nb(
... size=vbt.pf_nb.select_nb(c, size),
... price=vbt.pf_nb.select_nb(c, price),
... size_type=vbt.pf_nb.select_nb(c, size_type),
... direction=vbt.pf_nb.select_nb(c, direction),
... fees=vbt.pf_nb.select_nb(c, fees),
... fixed_fees=vbt.pf_nb.select_nb(c, fixed_fees),
... slippage=vbt.pf_nb.select_nb(c, slippage)
... )
>>> np.random.seed(42)
>>> close = np.random.uniform(1, 10, size=(5, 3))
>>> size_template = vbt.RepEval('np.array([[1 / group_lens[0]]])')
>>> pf = vbt.Portfolio.from_order_func(
... close,
... order_func_nb=order_func_nb,
... order_args=(
... size_template,
... vbt.Rep('price'),
... vbt.Rep('size_type'),
... vbt.Rep('direction'),
... vbt.Rep('fees'),
... vbt.Rep('fixed_fees'),
... vbt.Rep('slippage'),
... ),
... segment_mask=2, # rebalance every second tick
... pre_group_func_nb=pre_group_func_nb,
... pre_segment_func_nb=pre_segment_func_nb,
... pre_segment_args=(
... size_template,
... vbt.Rep('price'),
... vbt.Rep('size_type'),
... vbt.Rep('direction')
... ),
... broadcast_named_args=dict( # broadcast against each other
... price=close,
... size_type=vbt.pf_enums.SizeType.TargetPercent,
... direction=vbt.pf_enums.Direction.LongOnly,
... fees=0.001,
... fixed_fees=1.,
... slippage=0.001
... ),
... template_context=dict(np=np), # required by size_template
... cash_sharing=True, group_by=True, # one group with cash sharing
... )
>>> pf.get_asset_value(group_by=False).vbt.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
Templates are a very powerful tool to prepare any custom arguments after they are broadcast and
before they are passed to the simulation function. In the example above, we use `broadcast_named_args`
to broadcast some arguments against each other and templates to pass those objects to callbacks.
Additionally, we used an evaluation template to compute the size based on the number of assets in each group.
You may ask: why should we bother using broadcasting and templates if we could just pass `size=1/3`?
Because of flexibility those features provide: we can now pass whatever parameter combinations we want
and it will work flawlessly. For example, to create two groups of equally-allocated positions,
we need to change only two parameters:
```pycon
>>> close = np.random.uniform(1, 10, size=(5, 6)) # 6 columns instead of 3
>>> group_by = ['g1', 'g1', 'g1', 'g2', 'g2', 'g2'] # 2 groups instead of 1
>>> # Replace close and group_by in the example above
>>> pf['g1'].get_asset_value(group_by=False).vbt.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
```pycon
>>> pf['g2'].get_asset_value(group_by=False).vbt.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
* Combine multiple exit conditions. Exit early if the price hits some threshold before an actual exit:
```pycon
>>> @njit
... def pre_sim_func_nb(c):
... # We need to define stop price per column once
... stop_price = np.full(c.target_shape[1], np.nan, dtype=float_)
... return (stop_price,)
>>> @njit
... def order_func_nb(c, stop_price, entries, exits, size):
... # Select info related to this order
... entry_now = vbt.pf_nb.select_nb(c, entries)
... exit_now = vbt.pf_nb.select_nb(c, exits)
... size_now = vbt.pf_nb.select_nb(c, size)
... price_now = vbt.pf_nb.select_nb(c, c.close)
... stop_price_now = stop_price[c.col]
...
... # Our logic
... if entry_now:
... if c.position_now == 0:
... return vbt.pf_nb.order_nb(
... size=size_now,
... price=price_now,
... direction=vbt.pf_enums.Direction.LongOnly)
... elif exit_now or price_now >= stop_price_now:
... if c.position_now > 0:
... return vbt.pf_nb.order_nb(
... size=-size_now,
... price=price_now,
... direction=vbt.pf_enums.Direction.LongOnly)
... return vbt.pf_enums.NoOrder
>>> @njit
... def post_order_func_nb(c, stop_price, stop):
... # Same broadcasting as for size
... stop_now = vbt.pf_nb.select_nb(c, stop)
...
... if c.order_result.status == vbt.pf_enums.OrderStatus.Filled:
... if c.order_result.side == vbt.pf_enums.OrderSide.Buy:
... # Position entered: Set stop condition
... stop_price[c.col] = (1 + stop_now) * c.order_result.price
... else:
... # Position exited: Remove stop condition
... stop_price[c.col] = np.nan
>>> def simulate(close, entries, exits, size, stop):
... return vbt.Portfolio.from_order_func(
... close,
... order_func_nb=order_func_nb,
... order_args=(vbt.Rep('entries'), vbt.Rep('exits'), vbt.Rep('size')),
... pre_sim_func_nb=pre_sim_func_nb,
... post_order_func_nb=post_order_func_nb,
... post_order_args=(vbt.Rep('stop'),),
... broadcast_named_args=dict( # broadcast against each other
... entries=entries,
... exits=exits,
... size=size,
... stop=stop
... )
... )
>>> close = pd.Series([10, 11, 12, 13, 14])
>>> entries = pd.Series([True, True, False, False, False])
>>> exits = pd.Series([False, False, False, True, True])
>>> simulate(close, entries, exits, np.inf, 0.1).asset_flow
0 10.0
1 0.0
2 -10.0
3 0.0
4 0.0
dtype: float64
>>> simulate(close, entries, exits, np.inf, 0.2).asset_flow
0 10.0
1 0.0
2 -10.0
3 0.0
4 0.0
dtype: float64
>>> simulate(close, entries, exits, np.inf, np.nan).asset_flow
0 10.0
1 0.0
2 0.0
3 -10.0
4 0.0
dtype: float64
```
The reason why stop of 10% does not result in an order at the second time step is because
it comes at the same time as entry, so it must wait until no entry is present.
This can be changed by replacing the statement "elif" with "if", which would execute
an exit regardless if an entry is present (similar to using `ConflictMode.Opposite` in
`Portfolio.from_signals`).
We can also test the parameter combinations above all at once (thanks to broadcasting
using `vectorbtpro.base.reshaping.broadcast`):
```pycon
>>> stop = pd.DataFrame([[0.1, 0.2, np.nan]])
>>> simulate(close, entries, exits, np.inf, stop).asset_flow
0 1 2
0 10.0 10.0 10.0
1 0.0 0.0 0.0
2 -10.0 -10.0 0.0
3 0.0 0.0 -10.0
4 0.0 0.0 0.0
```
Or much simpler using Cartesian product:
```pycon
>>> stop = vbt.Param([0.1, 0.2, np.nan])
>>> simulate(close, entries, exits, np.inf, stop).asset_flow
threshold 0.1 0.2 NaN
0 10.0 10.0 10.0
1 0.0 0.0 0.0
2 -10.0 -10.0 0.0
3 0.0 0.0 -10.0
4 0.0 0.0 0.0
```
This works because `pd.Index` automatically translates into `vectorbtpro.base.reshaping.BCO`
with `product` set to True.
* Let's illustrate how to generate multiple orders per symbol and bar.
For each bar, buy at open and sell at close:
```pycon
>>> @njit
... def flex_order_func_nb(c, size):
... if c.call_idx == 0:
... return c.from_col, vbt.pf_nb.order_nb(size=size, price=c.open[c.i, c.from_col])
... if c.call_idx == 1:
... return c.from_col, vbt.pf_nb.close_position_nb(price=c.close[c.i, c.from_col])
... return -1, vbt.pf_enums.NoOrder
>>> open = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
>>> close = pd.DataFrame({'a': [2, 3, 4], 'b': [3, 4, 5]})
>>> size = 1
>>> pf = vbt.Portfolio.from_order_func(
... close,
... flex_order_func_nb=flex_order_func_nb,
... flex_order_args=(size,),
... open=open,
... max_order_records=close.shape[0] * 2
... )
>>> pf.orders.readable
Order Id Column Timestamp Size Price Fees Side
0 0 a 0 1.0 1.0 0.0 Buy
1 1 a 0 1.0 2.0 0.0 Sell
2 2 a 1 1.0 2.0 0.0 Buy
3 3 a 1 1.0 3.0 0.0 Sell
4 4 a 2 1.0 3.0 0.0 Buy
5 5 a 2 1.0 4.0 0.0 Sell
6 0 b 0 1.0 4.0 0.0 Buy
7 1 b 0 1.0 3.0 0.0 Sell
8 2 b 1 1.0 5.0 0.0 Buy
9 3 b 1 1.0 4.0 0.0 Sell
10 4 b 2 1.0 6.0 0.0 Buy
11 5 b 2 1.0 5.0 0.0 Sell
```
!!! warning
Each bar is effectively a black box - we don't know how the price moves in-between.
Since trades should come in an order that closely replicates that of the real world, the only
pieces of information that always remain in the correct order are the opening and closing price.
"""
if isinstance(close, FOFPreparer):
preparer = close
prep_result = None
elif isinstance(close, PFPrepResult):
preparer = None
prep_result = close
else:
local_kwargs = locals()
local_kwargs = {**local_kwargs, **local_kwargs["kwargs"]}
del local_kwargs["kwargs"]
del local_kwargs["cls"]
del local_kwargs["return_preparer"]
del local_kwargs["return_prep_result"]
del local_kwargs["return_sim_out"]
parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True)
if parsed_data is not None:
local_kwargs["data"] = parsed_data
local_kwargs["close"] = None
preparer = FOFPreparer(**local_kwargs)
if not return_preparer:
preparer.set_seed()
prep_result = None
if return_preparer:
return preparer
if prep_result is None:
prep_result = preparer.result
if return_prep_result:
return prep_result
sim_out = prep_result.target_func(**prep_result.target_args)
if return_sim_out:
return sim_out
return cls(order_records=sim_out, **prep_result.pf_args)
@classmethod
def from_def_order_func(
cls: tp.Type[PortfolioT],
close: tp.Union[tp.ArrayLike, OHLCDataMixin, FDOFPreparer, PFPrepResult],
size: tp.Optional[tp.ArrayLike] = None,
size_type: tp.Optional[tp.ArrayLike] = None,
direction: tp.Optional[tp.ArrayLike] = None,
price: tp.Optional[tp.ArrayLike] = None,
fees: tp.Optional[tp.ArrayLike] = None,
fixed_fees: tp.Optional[tp.ArrayLike] = None,
slippage: tp.Optional[tp.ArrayLike] = None,
min_size: tp.Optional[tp.ArrayLike] = None,
max_size: tp.Optional[tp.ArrayLike] = None,
size_granularity: tp.Optional[tp.ArrayLike] = None,
leverage: tp.Optional[tp.ArrayLike] = None,
leverage_mode: tp.Optional[tp.ArrayLike] = None,
reject_prob: tp.Optional[tp.ArrayLike] = None,
price_area_vio_mode: tp.Optional[tp.ArrayLike] = None,
allow_partial: tp.Optional[tp.ArrayLike] = None,
raise_reject: tp.Optional[tp.ArrayLike] = None,
log: tp.Optional[tp.ArrayLike] = None,
pre_segment_func_nb: tp.Optional[nb.PreSegmentFuncT] = None,
order_func_nb: tp.Optional[nb.OrderFuncT] = None,
flex_order_func_nb: tp.Optional[nb.FlexOrderFuncT] = None,
val_price: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
call_seq: tp.Optional[tp.ArrayLike] = None,
flexible: bool = False,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
chunked: tp.ChunkedOption = None,
return_preparer: bool = False,
return_prep_result: bool = False,
return_sim_out: bool = False,
**kwargs,
) -> PortfolioResultT:
"""Build portfolio from the default order function.
Default order function takes size, price, fees, and other available information, and issues
an order at each column and time step. Additionally, it uses a segment preprocessing function
that overrides the valuation price and sorts the call sequence. This way, it behaves similarly to
`Portfolio.from_orders`, but allows injecting pre- and postprocessing functions to have more
control over the execution. It also knows how to chunk each argument. The only disadvantage is
that `Portfolio.from_orders` is more optimized towards performance (up to 5x).
If `flexible` is True:
* `pre_segment_func_nb` is `vectorbtpro.portfolio.nb.from_order_func.def_flex_pre_segment_func_nb`
* `flex_order_func_nb` is `vectorbtpro.portfolio.nb.from_order_func.def_flex_order_func_nb`
If `flexible` is False:
* `pre_segment_func_nb` is `vectorbtpro.portfolio.nb.from_order_func.def_pre_segment_func_nb`
* `order_func_nb` is `vectorbtpro.portfolio.nb.from_order_func.def_order_func_nb`
Prepared by `vectorbtpro.portfolio.preparing.FDOFPreparer`.
For details on other arguments, see `Portfolio.from_orders` and `Portfolio.from_order_func`.
Usage:
* Working with `Portfolio.from_def_order_func` is a similar experience as working
with `Portfolio.from_orders`:
```pycon
>>> close = pd.Series([1, 2, 3, 4, 5])
>>> pf = vbt.Portfolio.from_def_order_func(close, 10)
>>> pf.assets
0 10.0
1 20.0
2 30.0
3 40.0
4 40.0
dtype: float64
>>> pf.cash
0 90.0
1 70.0
2 40.0
3 0.0
4 0.0
dtype: float64
```
* Equal-weighted portfolio as in the example under `Portfolio.from_order_func`
but much less verbose and with asset value pre-computed during the simulation (= faster):
```pycon
>>> np.random.seed(42)
>>> close = np.random.uniform(1, 10, size=(5, 3))
>>> @njit
... def post_segment_func_nb(c):
... for col in range(c.from_col, c.to_col):
... c.in_outputs.asset_value_pc[c.i, col] = c.last_position[col] * c.last_val_price[col]
>>> pf = vbt.Portfolio.from_def_order_func(
... close,
... size=1/3,
... size_type='targetpercent',
... direction='longonly',
... fees=0.001,
... fixed_fees=1.,
... slippage=0.001,
... segment_mask=2,
... cash_sharing=True,
... group_by=True,
... call_seq='auto',
... post_segment_func_nb=post_segment_func_nb,
... call_post_segment=True,
... in_outputs=dict(asset_value_pc=vbt.RepEval('np.empty_like(close)'))
... )
>>> asset_value = pf.wrapper.wrap(pf.in_outputs.asset_value_pc, group_by=False)
>>> asset_value.vbt.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
if isinstance(close, FDOFPreparer):
preparer = close
prep_result = None
elif isinstance(close, PFPrepResult):
preparer = None
prep_result = close
else:
local_kwargs = locals()
local_kwargs = {**local_kwargs, **local_kwargs["kwargs"]}
del local_kwargs["kwargs"]
del local_kwargs["cls"]
del local_kwargs["return_preparer"]
del local_kwargs["return_prep_result"]
del local_kwargs["return_sim_out"]
parsed_data = BasePFPreparer.parse_data(close, all_ohlc=True)
if parsed_data is not None:
local_kwargs["data"] = parsed_data
local_kwargs["close"] = None
preparer = FDOFPreparer(**local_kwargs)
if not return_preparer:
preparer.set_seed()
prep_result = None
if return_preparer:
return preparer
if prep_result is None:
prep_result = preparer.result
if return_prep_result:
return prep_result
sim_out = prep_result.target_func(**prep_result.target_args)
if return_sim_out:
return sim_out
return cls(order_records=sim_out, **prep_result.pf_args)
# ############# Grouping ############# #
def regroup(self: PortfolioT, group_by: tp.GroupByLike, **kwargs) -> PortfolioT:
"""Regroup this object.
See `vectorbtpro.base.wrapping.Wrapping.regroup`.
!!! note
All cached objects will be lost."""
if self.cash_sharing:
if self.wrapper.grouper.is_grouping_modified(group_by=group_by):
raise ValueError("Cannot modify grouping globally when cash_sharing=True")
return Wrapping.regroup(self, group_by, **kwargs)
# ############# Properties ############# #
@property
def cash_sharing(self) -> bool:
"""Whether to share cash within the same group."""
return self._cash_sharing
@property
def in_outputs(self) -> tp.Optional[tp.NamedTuple]:
"""Named tuple with in-output objects."""
return self._in_outputs
@property
def use_in_outputs(self) -> bool:
"""Whether to return in-output objects when calling properties."""
return self._use_in_outputs
@property
def fillna_close(self) -> bool:
"""Whether to forward-backward fill NaN values in `Portfolio.close`."""
return self._fillna_close
@property
def year_freq(self) -> tp.Optional[tp.PandasFrequency]:
"""Year frequency."""
return ReturnsAccessor.get_year_freq(
year_freq=self._year_freq,
index=self.wrapper.index,
freq=self.wrapper.freq,
)
@property
def returns_acc_defaults(self) -> tp.KwargsLike:
"""Defaults for `vectorbtpro.returns.accessors.ReturnsAccessor`."""
return self._returns_acc_defaults
@property
def trades_type(self) -> int:
"""Default `vectorbtpro.portfolio.trades.Trades` to use across `Portfolio`."""
return self._trades_type
@property
def orders_cls(self) -> type:
"""Class for wrapping order records."""
if self._orders_cls is None:
return Orders
return self._orders_cls
@property
def logs_cls(self) -> type:
"""Class for wrapping log records."""
if self._logs_cls is None:
return Logs
return self._logs_cls
@property
def trades_cls(self) -> type:
"""Class for wrapping trade records."""
if self._trades_cls is None:
return Trades
return self._trades_cls
@property
def entry_trades_cls(self) -> type:
"""Class for wrapping entry trade records."""
if self._entry_trades_cls is None:
return EntryTrades
return self._entry_trades_cls
@property
def exit_trades_cls(self) -> type:
"""Class for wrapping exit trade records."""
if self._exit_trades_cls is None:
return ExitTrades
return self._exit_trades_cls
@property
def positions_cls(self) -> type:
"""Class for wrapping position records."""
if self._positions_cls is None:
return Positions
return self._positions_cls
@property
def drawdowns_cls(self) -> type:
"""Class for wrapping drawdown records."""
if self._drawdowns_cls is None:
return Drawdowns
return self._drawdowns_cls
@custom_property(group_by_aware=False)
def call_seq(self) -> tp.Optional[tp.SeriesFrame]:
"""Sequence of calls per row and group."""
if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "call_seq"):
call_seq = self.in_outputs.call_seq
else:
call_seq = self._call_seq
if call_seq is None:
return None
return self.wrapper.wrap(call_seq, group_by=False)
@property
def cash_deposits_as_input(self) -> bool:
"""Whether to add cash deposits to the input value when calculating returns.
Otherwise, will subtract them from the output value."""
return self._cash_deposits_as_input
# ############# Price ############# #
@property
def open_flex(self) -> tp.Optional[tp.ArrayLike]:
"""`Portfolio.open` in a format suitable for flexible indexing."""
if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "open"):
open = self.in_outputs.open
else:
open = self._open
return open
@property
def high_flex(self) -> tp.Optional[tp.ArrayLike]:
"""`Portfolio.high` in a format suitable for flexible indexing."""
if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "high"):
high = self.in_outputs.high
else:
high = self._high
return high
@property
def low_flex(self) -> tp.Optional[tp.ArrayLike]:
"""`Portfolio.low` in a format suitable for flexible indexing."""
if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "low"):
low = self.in_outputs.low
else:
low = self._low
return low
@property
def close_flex(self) -> tp.ArrayLike:
"""`Portfolio.close` in a format suitable for flexible indexing."""
if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "close"):
close = self.in_outputs.close
else:
close = self._close
return close
@custom_property(group_by_aware=False, resample_func="first")
def open(self) -> tp.Optional[tp.SeriesFrame]:
"""Open price of each bar."""
if self.open_flex is None:
return None
return self.wrapper.wrap(self.open_flex, group_by=False)
@custom_property(group_by_aware=False, resample_func="max")
def high(self) -> tp.Optional[tp.SeriesFrame]:
"""High price of each bar."""
if self.high_flex is None:
return None
return self.wrapper.wrap(self.high_flex, group_by=False)
@custom_property(group_by_aware=False, resample_func="min")
def low(self) -> tp.Optional[tp.SeriesFrame]:
"""Low price of each bar."""
if self.low_flex is None:
return None
return self.wrapper.wrap(self.low_flex, group_by=False)
@custom_property(group_by_aware=False, resample_func="last")
def close(self) -> tp.SeriesFrame:
"""Last asset price at each time step."""
return self.wrapper.wrap(self.close_flex, group_by=False)
@hybrid_method
def get_filled_close(
cls_or_self,
close: tp.Optional[tp.SeriesFrame] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get forward and backward filled closing price.
See `vectorbtpro.generic.nb.base.fbfill_nb`."""
if not isinstance(cls_or_self, type):
if close is None:
close = cls_or_self.close
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(close, arg_name="close")
checks.assert_not_none(wrapper, arg_name="wrapper")
func = jit_reg.resolve_option(generic_nb.fbfill_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
filled_close = func(to_2d_array(close))
return wrapper.wrap(filled_close, group_by=False, **resolve_dict(wrap_kwargs))
@custom_property(group_by_aware=False, resample_func="last")
def bm_close(self) -> tp.Union[None, bool, tp.SeriesFrame]:
"""Benchmark price per unit series."""
if self.use_in_outputs and self.in_outputs is not None and hasattr(self.in_outputs, "bm_close"):
bm_close = self.in_outputs.bm_close
else:
bm_close = self._bm_close
if bm_close is None or isinstance(bm_close, bool):
return bm_close
return self.wrapper.wrap(bm_close, group_by=False)
@hybrid_method
def get_filled_bm_close(
cls_or_self,
bm_close: tp.Optional[tp.SeriesFrame] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Union[None, bool, tp.SeriesFrame]:
"""Get forward and backward filled benchmark closing price.
See `vectorbtpro.generic.nb.base.fbfill_nb`."""
if not isinstance(cls_or_self, type):
if bm_close is None:
bm_close = cls_or_self.bm_close
if bm_close is None or isinstance(bm_close, bool):
return bm_close
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(bm_close, arg_name="bm_close")
checks.assert_not_none(wrapper, arg_name="wrapper")
func = jit_reg.resolve_option(generic_nb.fbfill_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
filled_bm_close = func(to_2d_array(bm_close))
return wrapper.wrap(filled_bm_close, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_weights(
cls_or_self,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
) -> tp.Union[None, tp.ArrayLike, tp.Series]:
"""Get asset weights."""
if not isinstance(cls_or_self, type):
if weights is None:
weights = cls_or_self._weights
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
if weights is None or weights is False:
return None
return wrapper.wrap_reduced(weights, group_by=False)
# ############# Views ############# #
def apply_weights(
self: PortfolioT,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
rescale: bool = False,
group_by: tp.GroupByLike = None,
apply_group_by: bool = False,
**kwargs,
) -> PortfolioT:
"""Get view of portfolio with asset weights applied and optionally rescaled.
If `rescale` is True, weights are rescaled in respect to other weights in the same group.
For example, weights 0.5 and 0.5 are rescaled to 1.0 and 1.0 respectively, while
weights 0.7 and 0.3 are rescaled to 1.4 (1.4 * 0.5 = 0.7) and 0.6 (0.6 * 0.5 = 0.3) respectively."""
if weights is not None and weights is not False:
weights = to_1d_array(self.get_weights(weights=weights))
if rescale:
if self.wrapper.grouper.is_grouped(group_by=group_by):
new_weights = np.empty(len(weights), dtype=float_)
for group_idxs in self.wrapper.grouper.iter_group_idxs(group_by=group_by):
group_weights = weights[group_idxs]
new_weights[group_idxs] = group_weights * len(group_weights) / group_weights.sum()
weights = new_weights
else:
weights = weights * len(weights) / weights.sum()
if group_by is not None and apply_group_by:
_self = self.regroup(group_by=group_by)
else:
_self = self
return _self.replace(weights=weights, **kwargs)
def disable_weights(self: PortfolioT, **kwargs) -> PortfolioT:
"""Get view of portfolio with asset weights disabled."""
return self.replace(weights=False, **kwargs)
def get_long_view(
self: PortfolioT,
orders: tp.Optional[Orders] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> PortfolioT:
"""Get view of portfolio with long positions only."""
if orders is None:
orders = self.resolve_shortcut_attr(
"orders",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
weights=False,
)
if init_position is None:
init_position = self._init_position
if init_price is None:
init_price = self._init_price
new_order_records = orders.get_long_view(
init_position=init_position,
init_price=init_price,
jitted=jitted,
chunked=chunked,
).values
init_position = broadcast_array_to(init_position, self.wrapper.shape_2d[1])
init_price = broadcast_array_to(init_price, self.wrapper.shape_2d[1])
new_init_position = np.where(init_position > 0, init_position, 0)
new_init_price = np.where(init_position > 0, init_price, np.nan)
return self.replace(
order_records=new_order_records,
init_position=new_init_position,
init_price=new_init_price,
**kwargs,
)
def get_short_view(
self: PortfolioT,
orders: tp.Optional[Orders] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> PortfolioT:
"""Get view of portfolio with short positions only."""
if orders is None:
orders = self.resolve_shortcut_attr(
"orders",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
weights=False,
)
if init_position is None:
init_position = self._init_position
if init_price is None:
init_price = self._init_price
new_order_records = orders.get_short_view(
init_position=init_position,
init_price=init_price,
jitted=jitted,
chunked=chunked,
).values
init_position = broadcast_array_to(init_position, self.wrapper.shape_2d[1])
init_price = broadcast_array_to(init_price, self.wrapper.shape_2d[1])
new_init_position = np.where(init_position < 0, init_position, 0)
new_init_price = np.where(init_position < 0, init_price, np.nan)
return self.replace(
order_records=new_order_records,
init_position=new_init_position,
init_price=new_init_price,
**kwargs,
)
# ############# Records ############# #
@property
def order_records(self) -> tp.RecordArray:
"""A structured NumPy array of order records."""
return self._order_records
@hybrid_method
def get_orders(
cls_or_self,
order_records: tp.Optional[tp.RecordArray] = None,
open: tp.Optional[tp.SeriesFrame] = None,
high: tp.Optional[tp.SeriesFrame] = None,
low: tp.Optional[tp.SeriesFrame] = None,
close: tp.Optional[tp.SeriesFrame] = None,
orders_cls: tp.Optional[type] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> Orders:
"""Get order records.
See `vectorbtpro.portfolio.orders.Orders`."""
if not isinstance(cls_or_self, type):
if order_records is None:
order_records = cls_or_self.order_records
if open is None:
open = cls_or_self.open_flex
if high is None:
high = cls_or_self.high_flex
if low is None:
low = cls_or_self.low_flex
if close is None:
close = cls_or_self.close_flex
if orders_cls is None:
orders_cls = cls_or_self.orders_cls
if weights is None:
weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper)
elif weights is False:
weights = None
if wrapper is None:
wrapper = fix_wrapper_for_records(cls_or_self)
else:
checks.assert_not_none(order_records, arg_name="order_records")
if orders_cls is None:
orders_cls = Orders
checks.assert_not_none(wrapper, arg_name="wrapper")
weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper)
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
if sim_start is not None or sim_end is not None:
func = jit_reg.resolve_option(nb.records_within_sim_range_nb, jitted)
order_records = func(
wrapper.shape_2d,
order_records,
order_records["col"],
order_records["idx"],
sim_start=sim_start,
sim_end=sim_end,
)
if weights is not None:
func = jit_reg.resolve_option(nb.apply_weights_to_orders_nb, jitted)
order_records = func(
order_records,
order_records["col"],
to_1d_array(weights),
)
return orders_cls(
wrapper,
order_records,
open=open,
high=high,
low=low,
close=close,
**kwargs,
).regroup(group_by)
@property
def log_records(self) -> tp.RecordArray:
"""A structured NumPy array of log records."""
return self._log_records
@hybrid_method
def get_logs(
cls_or_self,
log_records: tp.Optional[tp.RecordArray] = None,
open: tp.Optional[tp.SeriesFrame] = None,
high: tp.Optional[tp.SeriesFrame] = None,
low: tp.Optional[tp.SeriesFrame] = None,
close: tp.Optional[tp.SeriesFrame] = None,
logs_cls: tp.Optional[type] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> Logs:
"""Get log records.
See `vectorbtpro.portfolio.logs.Logs`."""
if not isinstance(cls_or_self, type):
if log_records is None:
log_records = cls_or_self.log_records
if open is None:
open = cls_or_self.open_flex
if high is None:
high = cls_or_self.high_flex
if low is None:
low = cls_or_self.low_flex
if close is None:
close = cls_or_self.close_flex
if logs_cls is None:
logs_cls = cls_or_self.logs_cls
if wrapper is None:
wrapper = fix_wrapper_for_records(cls_or_self)
else:
checks.assert_not_none(log_records, arg_name="log_records")
if logs_cls is None:
logs_cls = Logs
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
if sim_start is not None or sim_end is not None:
func = jit_reg.resolve_option(nb.records_within_sim_range_nb, jitted)
log_records = func(
wrapper.shape_2d,
log_records,
log_records["col"],
log_records["idx"],
sim_start=sim_start,
sim_end=sim_end,
)
return logs_cls(
wrapper,
log_records,
open=open,
high=high,
low=low,
close=close,
**kwargs,
).regroup(group_by)
@hybrid_method
def get_entry_trades(
cls_or_self,
orders: tp.Optional[Orders] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
entry_trades_cls: tp.Optional[type] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> EntryTrades:
"""Get entry trade records.
See `vectorbtpro.portfolio.trades.EntryTrades`."""
if not isinstance(cls_or_self, type):
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if init_position is None:
init_position = cls_or_self.resolve_shortcut_attr(
"init_position",
wrapper=wrapper,
keep_flex=True,
)
if init_price is None:
init_price = cls_or_self.resolve_shortcut_attr(
"init_price",
wrapper=wrapper,
keep_flex=True,
)
if entry_trades_cls is None:
entry_trades_cls = cls_or_self.entry_trades_cls
else:
checks.assert_not_none(orders, arg_name="orders")
if init_position is None:
init_position = 0.0
if entry_trades_cls is None:
entry_trades_cls = EntryTrades
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, group_by=False)
return entry_trades_cls.from_orders(
orders,
init_position=init_position,
init_price=init_price,
sim_start=sim_start,
sim_end=sim_end,
**kwargs,
)
@hybrid_method
def get_exit_trades(
cls_or_self,
orders: tp.Optional[Orders] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
exit_trades_cls: tp.Optional[type] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> ExitTrades:
"""Get exit trade records.
See `vectorbtpro.portfolio.trades.ExitTrades`."""
if not isinstance(cls_or_self, type):
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if init_position is None:
init_position = cls_or_self.resolve_shortcut_attr(
"init_position",
wrapper=wrapper,
keep_flex=True,
)
if init_price is None:
init_price = cls_or_self.resolve_shortcut_attr(
"init_price",
wrapper=wrapper,
keep_flex=True,
)
if exit_trades_cls is None:
exit_trades_cls = cls_or_self.exit_trades_cls
else:
checks.assert_not_none(orders, arg_name="orders")
if init_position is None:
init_position = 0.0
if exit_trades_cls is None:
exit_trades_cls = ExitTrades
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, group_by=False)
return exit_trades_cls.from_orders(
orders,
init_position=init_position,
init_price=init_price,
sim_start=sim_start,
sim_end=sim_end,
**kwargs,
)
@hybrid_method
def get_positions(
cls_or_self,
trades: tp.Optional[Trades] = None,
positions_cls: tp.Optional[type] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> Positions:
"""Get position records.
See `vectorbtpro.portfolio.trades.Positions`."""
if not isinstance(cls_or_self, type):
if trades is None:
trades = cls_or_self.resolve_shortcut_attr(
"exit_trades",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if positions_cls is None:
positions_cls = cls_or_self.positions_cls
else:
checks.assert_not_none(trades, arg_name="trades")
if positions_cls is None:
positions_cls = Positions
return positions_cls.from_trades(trades, **kwargs)
def get_trades(
self,
trades_type: tp.Optional[tp.Union[str, int]] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> Trades:
"""Get trade/position records depending upon `Portfolio.trades_type`."""
if trades_type is None:
trades_type = self.trades_type
else:
if isinstance(trades_type, str):
trades_type = map_enum_fields(trades_type, enums.TradesType)
if trades_type == enums.TradesType.EntryTrades:
return self.resolve_shortcut_attr(
"entry_trades",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
**kwargs,
)
if trades_type == enums.TradesType.ExitTrades:
return self.resolve_shortcut_attr(
"exit_trades",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
**kwargs,
)
if trades_type == enums.TradesType.Positions:
return self.resolve_shortcut_attr(
"positions",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
**kwargs,
)
raise NotImplementedError
@hybrid_method
def get_trade_history(
cls_or_self,
orders: tp.Optional[Orders] = None,
entry_trades: tp.Optional[EntryTrades] = None,
exit_trades: tp.Optional[ExitTrades] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
) -> tp.Frame:
"""Get order history merged with entry and exit trades as a readable DataFrame.
!!! note
The P&L and return aggregated across the DataFrame may not match the actual total P&L and
return, as this DataFrame annotates entry and exit orders with the performance relative to
their respective trade types. To obtain accurate total statistics, aggregate only the statistics
of either trade type. Additionally, entry orders include open statistics, whereas exit orders do not.
"""
if not isinstance(cls_or_self, type):
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if entry_trades is None:
entry_trades = cls_or_self.resolve_shortcut_attr(
"entry_trades",
orders=orders,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if exit_trades is None:
exit_trades = cls_or_self.resolve_shortcut_attr(
"exit_trades",
orders=orders,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
else:
checks.assert_not_none(orders, arg_name="orders")
checks.assert_not_none(entry_trades, arg_name="entry_trades")
checks.assert_not_none(exit_trades, arg_name="exit_trades")
order_history = orders.records_readable
del order_history["Size"]
del order_history["Price"]
del order_history["Fees"]
entry_trade_history = entry_trades.records_readable
del entry_trade_history["Entry Index"]
del entry_trade_history["Exit Order Id"]
del entry_trade_history["Exit Index"]
del entry_trade_history["Avg Exit Price"]
del entry_trade_history["Exit Fees"]
entry_trade_history.rename(columns={"Entry Order Id": "Order Id"}, inplace=True)
entry_trade_history.rename(columns={"Avg Entry Price": "Price"}, inplace=True)
entry_trade_history.rename(columns={"Entry Fees": "Fees"}, inplace=True)
exit_trade_history = exit_trades.records_readable
del exit_trade_history["Exit Index"]
del exit_trade_history["Entry Order Id"]
del exit_trade_history["Entry Index"]
del exit_trade_history["Avg Entry Price"]
del exit_trade_history["Entry Fees"]
exit_trade_history.rename(columns={"Exit Order Id": "Order Id"}, inplace=True)
exit_trade_history.rename(columns={"Avg Exit Price": "Price"}, inplace=True)
exit_trade_history.rename(columns={"Exit Fees": "Fees"}, inplace=True)
trade_history = pd.concat((entry_trade_history, exit_trade_history), axis=0)
trade_history = pd.merge(order_history, trade_history, on=["Column", "Order Id"])
trade_history = trade_history.sort_values(by=["Column", "Order Id", "Position Id"])
trade_history = trade_history.reset_index(drop=True)
trade_history["Entry Trade Id"] = trade_history["Entry Trade Id"].fillna(-1).astype(int)
trade_history["Exit Trade Id"] = trade_history["Exit Trade Id"].fillna(-1).astype(int)
trade_history["Entry Trade Id"] = trade_history.pop("Entry Trade Id")
trade_history["Exit Trade Id"] = trade_history.pop("Exit Trade Id")
trade_history["Position Id"] = trade_history.pop("Position Id")
return trade_history
@hybrid_method
def get_signals(
cls_or_self,
orders: tp.Optional[Orders] = None,
entry_trades: tp.Optional[EntryTrades] = None,
exit_trades: tp.Optional[ExitTrades] = None,
idx_arr: tp.Union[None, str, tp.Array1d] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame, tp.SeriesFrame, tp.SeriesFrame]:
"""Get long entries, long exits, short entries, and short exits.
Returns per group is grouping is enabled. Pass `group_by=False` to disable."""
if not isinstance(cls_or_self, type):
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if entry_trades is None:
entry_trades = cls_or_self.resolve_shortcut_attr(
"entry_trades",
orders=orders,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if exit_trades is None:
exit_trades = cls_or_self.resolve_shortcut_attr(
"exit_trades",
orders=orders,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
else:
checks.assert_not_none(orders, arg_name="orders")
checks.assert_not_none(entry_trades, arg_name="entry_trades")
checks.assert_not_none(exit_trades, arg_name="exit_trades")
if isinstance(orders, FSOrders) and idx_arr is None:
idx_arr = "signal_idx"
if idx_arr is not None:
if isinstance(idx_arr, str):
idx_ma = orders.map_field(idx_arr, idx_arr=idx_arr)
else:
idx_ma = orders.map_array(idx_arr, idx_arr=idx_arr)
else:
idx_ma = orders.idx
order_index = pd.MultiIndex.from_arrays(
(orders.col_mapper.get_col_arr(group_by=group_by), orders.id_arr), names=["col", "id"]
)
order_idx_sr = pd.Series(idx_ma.values, index=order_index, name="idx")
def _get_type_signals(type_order_ids):
type_order_ids = type_order_ids.apply_mask(type_order_ids.values != -1)
type_order_index = pd.MultiIndex.from_arrays(
(type_order_ids.col_mapper.get_col_arr(group_by=group_by), type_order_ids.values), names=["col", "id"]
)
type_idx_df = order_idx_sr.loc[type_order_index].reset_index()
type_signals = orders.wrapper.fill(False, group_by=group_by)
if isinstance(type_signals, pd.Series):
type_signals.values[type_idx_df["idx"].values] = True
else:
type_signals.values[type_idx_df["idx"].values, type_idx_df["col"].values] = True
return type_signals
return (
_get_type_signals(entry_trades.long_view.entry_order_id),
_get_type_signals(exit_trades.long_view.exit_order_id),
_get_type_signals(entry_trades.short_view.entry_order_id),
_get_type_signals(exit_trades.short_view.exit_order_id),
)
@hybrid_method
def get_drawdowns(
cls_or_self,
value: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
drawdowns_cls: tp.Optional[type] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> Drawdowns:
"""Get drawdown records from `Portfolio.get_value`.
See `vectorbtpro.generic.drawdowns.Drawdowns`."""
if not isinstance(cls_or_self, type):
if value is None:
value = cls_or_self.resolve_shortcut_attr(
"value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if drawdowns_cls is None:
drawdowns_cls = cls_or_self.drawdowns_cls
if wrapper is None:
wrapper = fix_wrapper_for_records(cls_or_self)
else:
checks.assert_not_none(value, arg_name="value")
if drawdowns_cls is None:
drawdowns_cls = Drawdowns
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, group_by=False)
if wrapper is not None:
wrapper = wrapper.resolve(group_by=group_by)
return drawdowns_cls.from_price(
value,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
**kwargs,
)
# ############# Assets ############# #
@hybrid_method
def get_init_position(
cls_or_self,
init_position_raw: tp.Optional[tp.ArrayLike] = None,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
keep_flex: bool = False,
) -> tp.Union[tp.ArrayLike, tp.MaybeSeries]:
"""Get initial position per column."""
if not isinstance(cls_or_self, type):
if init_position_raw is None:
init_position_raw = cls_or_self._init_position
if weights is None:
weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper)
elif weights is False:
weights = None
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_position_raw, arg_name="init_position_raw")
checks.assert_not_none(wrapper, arg_name="wrapper")
weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper)
if keep_flex and weights is None:
return init_position_raw
init_position = broadcast_array_to(init_position_raw, wrapper.shape_2d[1])
if weights is not None:
weights = to_1d_array(weights)
init_position = np.where(np.isnan(weights), init_position, weights * init_position)
if keep_flex:
return init_position
wrap_kwargs = merge_dicts(dict(name_or_index="init_position"), wrap_kwargs)
return wrapper.wrap_reduced(init_position, group_by=False, **wrap_kwargs)
@hybrid_method
def get_asset_flow(
cls_or_self,
direction: tp.Union[str, int] = "both",
orders: tp.Optional[Orders] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get asset flow series per column.
Returns the total transacted amount of assets at each time step."""
if not isinstance(cls_or_self, type):
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=None,
)
if init_position is None:
init_position = cls_or_self.resolve_shortcut_attr(
"init_position",
wrapper=wrapper,
keep_flex=True,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(orders, arg_name="orders")
if init_position is None:
init_position = 0.0
if wrapper is None:
wrapper = orders.wrapper
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
direction = map_enum_fields(direction, enums.Direction)
func = jit_reg.resolve_option(nb.asset_flow_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
asset_flow = func(
wrapper.shape_2d,
orders.values,
orders.col_mapper.col_map,
direction=direction,
init_position=to_1d_array(init_position),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(asset_flow, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_assets(
cls_or_self,
direction: tp.Union[str, int] = "both",
asset_flow: tp.Optional[tp.SeriesFrame] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get asset series per column.
Returns the position at each time step."""
if not isinstance(cls_or_self, type):
if asset_flow is None:
asset_flow = cls_or_self.resolve_shortcut_attr(
"asset_flow",
direction=enums.Direction.Both,
init_position=init_position,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
)
if init_position is None:
init_position = cls_or_self.resolve_shortcut_attr(
"init_position",
wrapper=wrapper,
keep_flex=True,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(asset_flow, arg_name="asset_flow")
if init_position is None:
init_position = 0.0
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
direction = map_enum_fields(direction, enums.Direction)
func = jit_reg.resolve_option(nb.assets_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
assets = func(
to_2d_array(asset_flow),
direction=direction,
init_position=to_1d_array(init_position),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(assets, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_position_mask(
cls_or_self,
direction: tp.Union[str, int] = "both",
assets: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get position mask per column or group.
An element is True if there is a position at the given time step."""
if not isinstance(cls_or_self, type):
if assets is None:
assets = cls_or_self.resolve_shortcut_attr(
"assets",
direction=direction,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(assets, arg_name="assets")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.position_mask_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
position_mask = func(
to_2d_array(assets),
group_lens=group_lens,
sim_start=sim_start,
sim_end=sim_end,
)
else:
func = jit_reg.resolve_option(nb.position_mask_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
position_mask = func(
to_2d_array(assets),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(position_mask, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_position_coverage(
cls_or_self,
direction: tp.Union[str, int] = "both",
assets: tp.Optional[tp.SeriesFrame] = None,
granular_groups: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get position coverage per column or group.
Position coverage is the number of time steps in the market divided by the total number of time steps."""
if not isinstance(cls_or_self, type):
if assets is None:
assets = cls_or_self.resolve_shortcut_attr(
"assets",
direction=direction,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(assets, arg_name="assets")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.position_coverage_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
position_coverage = func(
to_2d_array(assets),
group_lens=group_lens,
granular_groups=granular_groups,
sim_start=sim_start,
sim_end=sim_end,
)
else:
func = jit_reg.resolve_option(nb.position_coverage_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
position_coverage = func(
to_2d_array(assets),
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="position_coverage"), wrap_kwargs)
return wrapper.wrap_reduced(position_coverage, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_position_entry_price(
cls_or_self,
orders: tp.Optional[Orders] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
fill_closed_position: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get the position's entry price at each time step."""
if not isinstance(cls_or_self, type):
if orders is None:
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=None,
)
if init_position is None:
init_position = cls_or_self.resolve_shortcut_attr(
"init_position",
wrapper=wrapper,
keep_flex=True,
)
if init_price is None:
init_price = cls_or_self.resolve_shortcut_attr(
"init_price",
wrapper=wrapper,
keep_flex=True,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(orders, arg_name="orders")
if init_position is None:
init_position = 0.0
if init_price is None:
init_price = np.nan
if wrapper is None:
wrapper = orders.wrapper
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
func = jit_reg.resolve_option(nb.get_position_feature_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
entry_price = func(
orders.values,
to_2d_array(orders.close),
orders.col_mapper.col_map,
feature=enums.PositionFeature.EntryPrice,
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
fill_closed_position=fill_closed_position,
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(entry_price, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_position_exit_price(
cls_or_self,
orders: tp.Optional[Orders] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
fill_closed_position: bool = False,
fill_exit_price: bool = True,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get the position's exit price at each time step."""
if not isinstance(cls_or_self, type):
if orders is None:
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=None,
)
if init_position is None:
init_position = cls_or_self.resolve_shortcut_attr(
"init_position",
wrapper=wrapper,
keep_flex=True,
)
if init_price is None:
init_price = cls_or_self.resolve_shortcut_attr(
"init_price",
wrapper=wrapper,
keep_flex=True,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(orders, arg_name="orders")
if init_position is None:
init_position = 0.0
if init_price is None:
init_price = np.nan
if wrapper is None:
wrapper = orders.wrapper
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
func = jit_reg.resolve_option(nb.get_position_feature_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
exit_price = func(
orders.values,
to_2d_array(orders.close),
orders.col_mapper.col_map,
feature=enums.PositionFeature.ExitPrice,
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
fill_closed_position=fill_closed_position,
fill_exit_price=fill_exit_price,
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(exit_price, group_by=False, **resolve_dict(wrap_kwargs))
# ############# Cash ############# #
@hybrid_method
def get_cash_deposits(
cls_or_self,
cash_deposits_raw: tp.Optional[tp.ArrayLike] = None,
cash_sharing: tp.Optional[bool] = None,
split_shared: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
keep_flex: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Union[tp.ArrayLike, tp.MaybeSeries]:
"""Get cash deposit series per column or group.
Set `keep_flex` to True to keep format suitable for flexible indexing.
This consumes less memory."""
if not isinstance(cls_or_self, type):
if cash_deposits_raw is None:
cash_deposits_raw = cls_or_self._cash_deposits
if cash_sharing is None:
cash_sharing = cls_or_self.cash_sharing
if weights is None:
weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper)
elif weights is False:
weights = None
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
if cash_deposits_raw is None:
cash_deposits_raw = 0.0
checks.assert_not_none(cash_sharing, arg_name="cash_sharing")
checks.assert_not_none(wrapper, arg_name="wrapper")
weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper)
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
cash_deposits_arr = to_2d_array(cash_deposits_raw)
if keep_flex and not cash_deposits_arr.any():
return cash_deposits_raw
if wrapper.grouper.is_grouped(group_by=group_by):
if keep_flex and cash_sharing and weights is None and sim_start is None and sim_end is None:
return cash_deposits_raw
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.cash_deposits_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
cash_deposits = func(
wrapper.shape_2d,
group_lens,
cash_sharing,
cash_deposits_raw=cash_deposits_arr,
weights=to_1d_array(weights) if weights is not None else None,
sim_start=sim_start,
sim_end=sim_end,
)
else:
if keep_flex and not cash_sharing and weights is None and sim_start is None and sim_end is None:
return cash_deposits_raw
group_lens = wrapper.grouper.get_group_lens()
func = jit_reg.resolve_option(nb.cash_deposits_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
cash_deposits = func(
wrapper.shape_2d,
group_lens,
cash_sharing,
cash_deposits_raw=cash_deposits_arr,
split_shared=split_shared,
weights=to_1d_array(weights) if weights is not None else None,
sim_start=sim_start,
sim_end=sim_end,
)
if keep_flex:
return cash_deposits
return wrapper.wrap(cash_deposits, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_total_cash_deposits(
cls_or_self,
cash_deposits_raw: tp.Optional[tp.ArrayLike] = None,
cash_sharing: tp.Optional[bool] = None,
split_shared: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.ArrayLike:
"""Get total cash deposit series per column or group."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
cash_deposits = cls_or_self.get_cash_deposits(
cash_deposits_raw=cash_deposits_raw,
cash_sharing=cash_sharing,
split_shared=split_shared,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
weights=weights,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
total_cash_deposits = np.nansum(cash_deposits, axis=0)
return wrapper.wrap_reduced(total_cash_deposits, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_cash_earnings(
cls_or_self,
cash_earnings_raw: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
keep_flex: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Union[tp.ArrayLike, tp.MaybeSeries]:
"""Get earnings in cash series per column or group.
Set `keep_flex` to True to keep format suitable for flexible indexing.
This consumes less memory."""
if not isinstance(cls_or_self, type):
if cash_earnings_raw is None:
cash_earnings_raw = cls_or_self._cash_earnings
if weights is None:
weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper)
elif weights is False:
weights = None
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
if cash_earnings_raw is None:
cash_earnings_raw = 0.0
checks.assert_not_none(wrapper, arg_name="wrapper")
weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper)
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
cash_earnings_arr = to_2d_array(cash_earnings_raw)
if keep_flex and not cash_earnings_arr.any():
return cash_earnings_raw
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.cash_earnings_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
cash_earnings = func(
wrapper.shape_2d,
group_lens,
cash_earnings_raw=cash_earnings_arr,
weights=to_1d_array(weights) if weights is not None else None,
sim_start=sim_start,
sim_end=sim_end,
)
else:
if keep_flex and weights is None and sim_start is None and sim_end is None:
return cash_earnings_raw
func = jit_reg.resolve_option(nb.cash_earnings_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
cash_earnings = func(
wrapper.shape_2d,
cash_earnings_raw=cash_earnings_arr,
weights=to_1d_array(weights) if weights is not None else None,
sim_start=sim_start,
sim_end=sim_end,
)
if keep_flex:
return cash_earnings
return wrapper.wrap(cash_earnings, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_total_cash_earnings(
cls_or_self,
cash_earnings_raw: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.ArrayLike:
"""Get total cash earning series per column or group."""
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
cash_earnings = cls_or_self.get_cash_earnings(
cash_earnings_raw=cash_earnings_raw,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
weights=weights,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
total_cash_earnings = np.nansum(cash_earnings, axis=0)
return wrapper.wrap_reduced(total_cash_earnings, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_cash_flow(
cls_or_self,
free: bool = False,
orders: tp.Optional[Orders] = None,
cash_earnings: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get cash flow series per column or group.
Use `free` to return the flow of the free cash, which never goes above the initial level,
because an operation always costs money.
!!! note
Does not include cash deposits, but includes earnings.
Using `free` yields the same result as during the simulation only when `leverage=1`.
For anything else, prefill the state instead of reconstructing it."""
if not isinstance(cls_or_self, type):
if orders is None:
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
weights=weights,
wrapper=wrapper,
group_by=None,
)
if cash_earnings is None:
cash_earnings = cls_or_self.resolve_shortcut_attr(
"cash_earnings",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
weights=weights,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=False,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(orders, arg_name="orders")
if cash_earnings is None:
cash_earnings = 0.0
if wrapper is None:
wrapper = orders.wrapper
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
func = jit_reg.resolve_option(nb.cash_flow_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
cash_flow = func(
wrapper.shape_2d,
orders.values,
orders.col_mapper.col_map,
free=free,
cash_earnings=to_2d_array(cash_earnings),
sim_start=sim_start,
sim_end=sim_end,
)
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.cash_flow_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
cash_flow = func(
cash_flow,
group_lens,
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(cash_flow, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_init_cash(
cls_or_self,
init_cash_raw: tp.Optional[tp.ArrayLike] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
free_cash_flow: tp.Optional[tp.SeriesFrame] = None,
cash_sharing: tp.Optional[bool] = None,
split_shared: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
weights: tp.Union[None, bool, tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get initial amount of cash per column or group."""
if not isinstance(cls_or_self, type):
if init_cash_raw is None:
init_cash_raw = cls_or_self._init_cash
if checks.is_int(init_cash_raw) and init_cash_raw in enums.InitCashMode:
if cash_deposits is None:
cash_deposits = cls_or_self.resolve_shortcut_attr(
"cash_deposits",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
weights=weights,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if free_cash_flow is None:
free_cash_flow = cls_or_self.resolve_shortcut_attr(
"cash_flow",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
weights=weights,
free=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_sharing is None:
cash_sharing = cls_or_self.cash_sharing
if weights is None:
weights = cls_or_self.resolve_shortcut_attr("weights", wrapper=wrapper)
elif weights is False:
weights = None
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_cash_raw, arg_name="init_cash_raw")
if checks.is_int(init_cash_raw) and init_cash_raw in enums.InitCashMode:
checks.assert_not_none(free_cash_flow, arg_name="free_cash_flow")
if cash_deposits is None:
cash_deposits = 0.0
checks.assert_not_none(cash_sharing, arg_name="cash_sharing")
checks.assert_not_none(wrapper, arg_name="wrapper")
weights = cls_or_self.get_weights(weights=weights, wrapper=wrapper)
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
if checks.is_int(init_cash_raw) and init_cash_raw in enums.InitCashMode:
func = jit_reg.resolve_option(nb.align_init_cash_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
init_cash = func(
init_cash_raw,
to_2d_array(free_cash_flow),
cash_deposits=to_2d_array(cash_deposits),
sim_start=sim_start,
sim_end=sim_end,
)
else:
init_cash_raw = to_1d_array(init_cash_raw)
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.init_cash_grouped_nb, jitted)
init_cash = func(
init_cash_raw,
group_lens,
cash_sharing,
weights=to_1d_array(weights) if weights is not None else None,
)
else:
group_lens = wrapper.grouper.get_group_lens()
func = jit_reg.resolve_option(nb.init_cash_nb, jitted)
init_cash = func(
init_cash_raw,
group_lens,
cash_sharing,
split_shared=split_shared,
weights=to_1d_array(weights) if weights is not None else None,
)
wrap_kwargs = merge_dicts(dict(name_or_index="init_cash"), wrap_kwargs)
return wrapper.wrap_reduced(init_cash, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_cash(
cls_or_self,
free: bool = False,
init_cash: tp.Optional[tp.ArrayLike] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
cash_flow: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get cash balance series per column or group.
For `free`, see `Portfolio.get_cash_flow`."""
if not isinstance(cls_or_self, type):
if init_cash is None:
init_cash = cls_or_self.resolve_shortcut_attr(
"init_cash",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_deposits is None:
cash_deposits = cls_or_self.resolve_shortcut_attr(
"cash_deposits",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_flow is None:
cash_flow = cls_or_self.resolve_shortcut_attr(
"cash_flow",
free=free,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_cash, arg_name="init_cash")
if cash_deposits is None:
cash_deposits = 0.0
checks.assert_not_none(cash_flow, arg_name="cash_flow")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.cash_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
cash = func(
to_2d_array(cash_flow),
to_1d_array(init_cash),
cash_deposits=to_2d_array(cash_deposits),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(cash, group_by=group_by, **resolve_dict(wrap_kwargs))
# ############# Value ############# #
@hybrid_method
def get_init_price(
cls_or_self,
init_price_raw: tp.Optional[tp.ArrayLike] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
keep_flex: bool = False,
) -> tp.Union[tp.ArrayLike, tp.MaybeSeries]:
"""Get initial price per column."""
if not isinstance(cls_or_self, type):
if init_price_raw is None:
init_price_raw = cls_or_self._init_price
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_price_raw, arg_name="init_price_raw")
checks.assert_not_none(wrapper, arg_name="wrapper")
if keep_flex:
return init_price_raw
init_price = broadcast_array_to(init_price_raw, wrapper.shape_2d[1])
if keep_flex:
return init_price
wrap_kwargs = merge_dicts(dict(name_or_index="init_price"), wrap_kwargs)
return wrapper.wrap_reduced(init_price, group_by=False, **wrap_kwargs)
@hybrid_method
def get_init_position_value(
cls_or_self,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get initial position value per column."""
if not isinstance(cls_or_self, type):
if init_position is None:
init_position = cls_or_self.resolve_shortcut_attr(
"init_position",
wrapper=wrapper,
keep_flex=True,
)
if init_price is None:
init_price = cls_or_self.resolve_shortcut_attr(
"init_price",
wrapper=wrapper,
keep_flex=True,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
if init_position is None:
init_position = 0.0
checks.assert_not_none(init_price, arg_name="init_price")
checks.assert_not_none(wrapper, arg_name="wrapper")
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.init_position_value_grouped_nb, jitted)
init_position_value = func(
group_lens,
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
)
else:
func = jit_reg.resolve_option(nb.init_position_value_nb, jitted)
init_position_value = func(
wrapper.shape_2d[1],
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
)
wrap_kwargs = merge_dicts(dict(name_or_index="init_position_value"), wrap_kwargs)
return wrapper.wrap_reduced(init_position_value, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_init_value(
cls_or_self,
init_position_value: tp.Optional[tp.MaybeSeries] = None,
init_cash: tp.Optional[tp.MaybeSeries] = None,
split_shared: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get initial value per column or group.
Includes initial cash and the value of initial position."""
if not isinstance(cls_or_self, type):
if init_position_value is None:
init_position_value = cls_or_self.resolve_shortcut_attr(
"init_position_value",
jitted=jitted,
wrapper=wrapper,
group_by=group_by,
)
if init_cash is None:
init_cash = cls_or_self.resolve_shortcut_attr(
"init_cash",
split_shared=split_shared,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_position_value, arg_name="init_position_value")
checks.assert_not_none(init_cash, arg_name="init_cash")
checks.assert_not_none(wrapper, arg_name="wrapper")
func = jit_reg.resolve_option(nb.init_value_nb, jitted)
init_value = func(to_1d_array(init_position_value), to_1d_array(init_cash))
wrap_kwargs = merge_dicts(dict(name_or_index="init_value"), wrap_kwargs)
return wrapper.wrap_reduced(init_value, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_input_value(
cls_or_self,
total_cash_deposits: tp.Optional[tp.ArrayLike] = None,
init_value: tp.Optional[tp.MaybeSeries] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get total input value per column or group.
Includes initial value and any cash deposited at any point in time."""
if not isinstance(cls_or_self, type):
if total_cash_deposits is None:
total_cash_deposits = cls_or_self.resolve_shortcut_attr(
"total_cash_deposits",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if init_value is None:
init_value = cls_or_self.resolve_shortcut_attr(
"init_value",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
if total_cash_deposits is None:
total_cash_deposits = 0.0
checks.assert_not_none(init_value, arg_name="init_value")
checks.assert_not_none(wrapper, arg_name="wrapper")
input_value = to_1d_array(total_cash_deposits) + to_1d_array(init_value)
wrap_kwargs = merge_dicts(dict(name_or_index="input_value"), wrap_kwargs)
return wrapper.wrap_reduced(input_value, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_asset_value(
cls_or_self,
direction: tp.Union[str, int] = "both",
close: tp.Optional[tp.SeriesFrame] = None,
assets: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get asset value series per column or group."""
if not isinstance(cls_or_self, type):
if close is None:
if cls_or_self.fillna_close:
close = cls_or_self.filled_close
else:
close = cls_or_self.close
if assets is None:
assets = cls_or_self.resolve_shortcut_attr(
"assets",
direction=direction,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(close, arg_name="close")
checks.assert_not_none(assets, arg_name="assets")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
func = jit_reg.resolve_option(nb.asset_value_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
asset_value = func(
to_2d_array(close),
to_2d_array(assets),
sim_start=sim_start,
sim_end=sim_end,
)
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.asset_value_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
asset_value = func(
asset_value,
group_lens,
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(asset_value, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_value(
cls_or_self,
cash: tp.Optional[tp.SeriesFrame] = None,
asset_value: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get portfolio value series per column or group.
By default, will generate portfolio value for each asset based on cash flows and thus
independent of other assets, with the initial cash balance and position being that of the
entire group. Useful for generating returns and comparing assets within the same group."""
if not isinstance(cls_or_self, type):
if cash is None:
cash = cls_or_self.resolve_shortcut_attr(
"cash",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if asset_value is None:
asset_value = cls_or_self.resolve_shortcut_attr(
"asset_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(cash, arg_name="cash")
checks.assert_not_none(asset_value, arg_name="asset_value")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.value_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
value = func(
to_2d_array(cash),
to_2d_array(asset_value),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(value, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_gross_exposure(
cls_or_self,
direction: tp.Union[str, int] = "both",
asset_value: tp.Optional[tp.SeriesFrame] = None,
value: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get gross exposure.
!!! note
When both directions, `asset_value` must include the addition of the absolute long-only
and short-only asset values."""
direction = map_enum_fields(direction, enums.Direction)
if not isinstance(cls_or_self, type):
if asset_value is None:
if direction == enums.Direction.Both and cls_or_self.wrapper.grouper.is_grouped(group_by=group_by):
long_asset_value = cls_or_self.resolve_shortcut_attr(
"asset_value",
direction="longonly",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
short_asset_value = cls_or_self.resolve_shortcut_attr(
"asset_value",
direction="shortonly",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
asset_value = long_asset_value + short_asset_value
else:
asset_value = cls_or_self.resolve_shortcut_attr(
"asset_value",
direction=direction,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if value is None:
value = cls_or_self.resolve_shortcut_attr(
"value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(asset_value, arg_name="asset_value")
checks.assert_not_none(value, arg_name="value")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.gross_exposure_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
gross_exposure = func(
to_2d_array(asset_value),
to_2d_array(value),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(gross_exposure, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_net_exposure(
cls_or_self,
long_exposure: tp.Optional[tp.SeriesFrame] = None,
short_exposure: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get net exposure."""
if not isinstance(cls_or_self, type):
if long_exposure is None:
long_exposure = cls_or_self.resolve_shortcut_attr(
"gross_exposure",
direction=enums.Direction.LongOnly,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if short_exposure is None:
short_exposure = cls_or_self.resolve_shortcut_attr(
"gross_exposure",
direction=enums.Direction.ShortOnly,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(long_exposure, arg_name="long_exposure")
checks.assert_not_none(short_exposure, arg_name="short_exposure")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.net_exposure_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
net_exposure = func(
to_2d_array(long_exposure),
to_2d_array(short_exposure),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(net_exposure, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_allocations(
cls_or_self,
direction: tp.Union[str, int] = "both",
asset_value: tp.Optional[tp.SeriesFrame] = None,
value: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get portfolio allocation series per column."""
if not isinstance(cls_or_self, type):
if asset_value is None:
asset_value = cls_or_self.resolve_shortcut_attr(
"asset_value",
direction=direction,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=False,
)
if value is None:
value = cls_or_self.resolve_shortcut_attr(
"value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(asset_value, arg_name="asset_value")
checks.assert_not_none(value, arg_name="value")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.allocations_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
allocations = func(
to_2d_array(asset_value),
to_2d_array(value),
group_lens,
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(allocations, group_by=False, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_total_profit(
cls_or_self,
close: tp.Optional[tp.SeriesFrame] = None,
orders: tp.Optional[Orders] = None,
init_position: tp.Optional[tp.ArrayLike] = None,
init_price: tp.Optional[tp.ArrayLike] = None,
cash_earnings: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get total profit per column or group.
Calculated directly from order records (fast)."""
if not isinstance(cls_or_self, type):
if close is None:
if cls_or_self.fillna_close:
close = cls_or_self.filled_close
else:
close = cls_or_self.close
if orders is None:
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=None,
)
if init_position is None:
init_position = cls_or_self.resolve_shortcut_attr(
"init_position",
wrapper=wrapper,
keep_flex=True,
)
if init_price is None:
init_price = cls_or_self.resolve_shortcut_attr(
"init_price",
wrapper=wrapper,
keep_flex=True,
)
if cash_earnings is None:
cash_earnings = cls_or_self.resolve_shortcut_attr(
"cash_earnings",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=False,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(orders, arg_name="orders")
if close is None:
close = orders.close
checks.assert_not_none(close, arg_name="close")
checks.assert_not_none(init_price, arg_name="init_price")
if init_position is None:
init_position = 0.0
if cash_earnings is None:
cash_earnings = 0.0
if wrapper is None:
wrapper = orders.wrapper
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
func = jit_reg.resolve_option(nb.total_profit_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
total_profit = func(
wrapper.shape_2d,
to_2d_array(close),
orders.values,
orders.col_mapper.col_map,
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
cash_earnings=to_2d_array(cash_earnings),
sim_start=sim_start,
sim_end=sim_end,
)
if wrapper.grouper.is_grouped(group_by=group_by):
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.total_profit_grouped_nb, jitted)
total_profit = func(total_profit, group_lens)
wrap_kwargs = merge_dicts(dict(name_or_index="total_profit"), wrap_kwargs)
return wrapper.wrap_reduced(total_profit, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_final_value(
cls_or_self,
input_value: tp.Optional[tp.MaybeSeries] = None,
total_profit: tp.Optional[tp.MaybeSeries] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get total profit per column or group."""
if not isinstance(cls_or_self, type):
if input_value is None:
input_value = cls_or_self.resolve_shortcut_attr(
"input_value",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if total_profit is None:
total_profit = cls_or_self.resolve_shortcut_attr(
"total_profit",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(input_value, arg_name="input_value")
checks.assert_not_none(total_profit, arg_name="total_profit")
checks.assert_not_none(wrapper, arg_name="wrapper")
final_value = to_1d_array(input_value) + to_1d_array(total_profit)
wrap_kwargs = merge_dicts(dict(name_or_index="final_value"), wrap_kwargs)
return wrapper.wrap_reduced(final_value, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_total_return(
cls_or_self,
input_value: tp.Optional[tp.MaybeSeries] = None,
total_profit: tp.Optional[tp.MaybeSeries] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get total return per column or group."""
if not isinstance(cls_or_self, type):
if input_value is None:
input_value = cls_or_self.resolve_shortcut_attr(
"input_value",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if total_profit is None:
total_profit = cls_or_self.resolve_shortcut_attr(
"total_profit",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(input_value, arg_name="input_value")
checks.assert_not_none(total_profit, arg_name="total_profit")
checks.assert_not_none(wrapper, arg_name="wrapper")
total_return = to_1d_array(total_profit) / to_1d_array(input_value)
wrap_kwargs = merge_dicts(dict(name_or_index="total_return"), wrap_kwargs)
return wrapper.wrap_reduced(total_return, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_returns(
cls_or_self,
init_value: tp.Optional[tp.MaybeSeries] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
cash_deposits_as_input: tp.Optional[bool] = None,
value: tp.Optional[tp.SeriesFrame] = None,
log_returns: bool = False,
daily_returns: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get return series per column or group based on portfolio value."""
if not isinstance(cls_or_self, type):
if init_value is None:
init_value = cls_or_self.resolve_shortcut_attr(
"init_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_deposits is None:
cash_deposits = cls_or_self.resolve_shortcut_attr(
"cash_deposits",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_deposits_as_input is None:
cash_deposits_as_input = cls_or_self.cash_deposits_as_input
if value is None:
value = cls_or_self.resolve_shortcut_attr(
"value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_value, arg_name="init_value")
if cash_deposits is None:
cash_deposits = 0.0
if cash_deposits_as_input is None:
cash_deposits_as_input = False
checks.assert_not_none(value, arg_name="value")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.returns_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
returns = func(
to_2d_array(value),
to_1d_array(init_value),
cash_deposits=to_2d_array(cash_deposits),
cash_deposits_as_input=cash_deposits_as_input,
log_returns=log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
returns = wrapper.wrap(returns, group_by=group_by, **resolve_dict(wrap_kwargs))
if daily_returns:
returns = returns.vbt.returns(log_returns=log_returns).daily(jitted=jitted)
return returns
@hybrid_method
def get_asset_pnl(
cls_or_self,
init_position_value: tp.Optional[tp.MaybeSeries] = None,
asset_value: tp.Optional[tp.SeriesFrame] = None,
cash_flow: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get asset (realized and unrealized) PnL series per column or group."""
if not isinstance(cls_or_self, type):
if init_position_value is None:
init_position_value = cls_or_self.resolve_shortcut_attr(
"init_position_value",
jitted=jitted,
wrapper=wrapper,
group_by=group_by,
)
if asset_value is None:
asset_value = cls_or_self.resolve_shortcut_attr(
"asset_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_flow is None:
cash_flow = cls_or_self.resolve_shortcut_attr(
"cash_flow",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_position_value, arg_name="init_position_value")
checks.assert_not_none(asset_value, arg_name="asset_value")
checks.assert_not_none(cash_flow, arg_name="cash_flow")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.asset_pnl_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
asset_pnl = func(
to_2d_array(asset_value),
to_2d_array(cash_flow),
init_position_value=to_1d_array(init_position_value),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(asset_pnl, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_asset_returns(
cls_or_self,
init_position_value: tp.Optional[tp.MaybeSeries] = None,
asset_value: tp.Optional[tp.SeriesFrame] = None,
cash_flow: tp.Optional[tp.SeriesFrame] = None,
log_returns: bool = False,
daily_returns: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get asset return series per column or group.
This type of returns is based solely on cash flows and asset value rather than portfolio
value. It ignores passive cash and thus it will return the same numbers irrespective of the amount of
cash currently available, even `np.inf`. The scale of returns is comparable to that of going
all in and keeping available cash at zero."""
if not isinstance(cls_or_self, type):
if init_position_value is None:
init_position_value = cls_or_self.resolve_shortcut_attr(
"init_position_value",
jitted=jitted,
wrapper=wrapper,
group_by=group_by,
)
if asset_value is None:
asset_value = cls_or_self.resolve_shortcut_attr(
"asset_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_flow is None:
cash_flow = cls_or_self.resolve_shortcut_attr(
"cash_flow",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_position_value, arg_name="init_position_value")
checks.assert_not_none(asset_value, arg_name="asset_value")
checks.assert_not_none(cash_flow, arg_name="cash_flow")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.asset_returns_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
asset_returns = func(
to_2d_array(asset_value),
to_2d_array(cash_flow),
init_position_value=to_1d_array(init_position_value),
log_returns=log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
asset_returns = wrapper.wrap(asset_returns, group_by=group_by, **resolve_dict(wrap_kwargs))
if daily_returns:
asset_returns = asset_returns.vbt.returns(log_returns=log_returns).daily(jitted=jitted)
return asset_returns
@hybrid_method
def get_market_value(
cls_or_self,
close: tp.Optional[tp.SeriesFrame] = None,
init_value: tp.Optional[tp.MaybeSeries] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get market value series per column or group.
If grouped, evenly distributes the initial cash among assets in the group.
!!! note
Does not take into account fees and slippage. For this, create a separate portfolio."""
if not isinstance(cls_or_self, type):
if close is None:
if cls_or_self.fillna_close:
close = cls_or_self.filled_close
else:
close = cls_or_self.close
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(close, arg_name="close")
checks.assert_not_none(init_value, arg_name="init_value")
if cash_deposits is None:
cash_deposits = 0.0
checks.assert_not_none(wrapper, arg_name="wrapper")
if wrapper.grouper.is_grouped(group_by=group_by):
if not isinstance(cls_or_self, type):
if init_value is None:
init_value = cls_or_self.resolve_shortcut_attr(
"init_value",
split_shared=True,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=False,
)
if cash_deposits is None:
cash_deposits = cls_or_self.resolve_shortcut_attr(
"cash_deposits",
split_shared=True,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=False,
)
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
group_lens = wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.market_value_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
market_value = func(
to_2d_array(close),
group_lens,
to_1d_array(init_value),
cash_deposits=to_2d_array(cash_deposits),
sim_start=sim_start,
sim_end=sim_end,
)
else:
if not isinstance(cls_or_self, type):
if init_value is None:
init_value = cls_or_self.resolve_shortcut_attr(
"init_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=False,
)
if cash_deposits is None:
cash_deposits = cls_or_self.resolve_shortcut_attr(
"cash_deposits",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=False,
)
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
func = jit_reg.resolve_option(nb.market_value_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
market_value = func(
to_2d_array(close),
to_1d_array(init_value),
cash_deposits=to_2d_array(cash_deposits),
sim_start=sim_start,
sim_end=sim_end,
)
return wrapper.wrap(market_value, group_by=group_by, **resolve_dict(wrap_kwargs))
@hybrid_method
def get_market_returns(
cls_or_self,
init_value: tp.Optional[tp.MaybeSeries] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
cash_deposits_as_input: tp.Optional[bool] = None,
market_value: tp.Optional[tp.SeriesFrame] = None,
log_returns: bool = False,
daily_returns: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get market return series per column or group."""
if not isinstance(cls_or_self, type):
if init_value is None:
init_value = cls_or_self.resolve_shortcut_attr(
"init_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_deposits is None:
cash_deposits = cls_or_self.resolve_shortcut_attr(
"cash_deposits",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
keep_flex=True,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if cash_deposits_as_input is None:
cash_deposits_as_input = cls_or_self.cash_deposits_as_input
if market_value is None:
market_value = cls_or_self.resolve_shortcut_attr(
"market_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_value, arg_name="init_value")
if cash_deposits is None:
cash_deposits = 0.0
if cash_deposits_as_input is None:
cash_deposits_as_input = False
checks.assert_not_none(market_value, arg_name="market_value")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.returns_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
market_returns = func(
to_2d_array(market_value),
to_1d_array(init_value),
cash_deposits=to_2d_array(cash_deposits),
cash_deposits_as_input=cash_deposits_as_input,
log_returns=log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
market_returns = wrapper.wrap(market_returns, group_by=group_by, **resolve_dict(wrap_kwargs))
if daily_returns:
market_returns = market_returns.vbt.returns(log_returns=log_returns).daily(jitted=jitted)
return market_returns
@hybrid_method
def get_total_market_return(
cls_or_self,
input_value: tp.Optional[tp.MaybeSeries] = None,
market_value: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Get total market return."""
if not isinstance(cls_or_self, type):
if input_value is None:
input_value = cls_or_self.resolve_shortcut_attr(
"input_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if market_value is None:
market_value = cls_or_self.resolve_shortcut_attr(
"market_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(input_value, arg_name="input_value")
checks.assert_not_none(market_value, arg_name="market_value")
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
func = jit_reg.resolve_option(nb.total_market_return_nb, jitted)
total_market_return = func(
to_2d_array(market_value),
to_1d_array(input_value),
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="total_market_return"), wrap_kwargs)
return wrapper.wrap_reduced(total_market_return, group_by=group_by, **wrap_kwargs)
@hybrid_method
def get_bm_value(
cls_or_self,
bm_close: tp.Optional[tp.ArrayLike] = None,
init_value: tp.Optional[tp.MaybeSeries] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Optional[tp.SeriesFrame]:
"""Get benchmark value series per column or group.
Based on `Portfolio.bm_close` and `Portfolio.get_market_value`."""
if not isinstance(cls_or_self, type):
if bm_close is None:
bm_close = cls_or_self.bm_close
if isinstance(bm_close, bool):
if not bm_close:
return None
bm_close = None
if bm_close is not None:
if cls_or_self.fillna_close:
bm_close = cls_or_self.filled_bm_close
return cls_or_self.get_market_value(
close=bm_close,
init_value=init_value,
cash_deposits=cash_deposits,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
@hybrid_method
def get_bm_returns(
cls_or_self,
init_value: tp.Optional[tp.MaybeSeries] = None,
cash_deposits: tp.Optional[tp.ArrayLike] = None,
bm_value: tp.Optional[tp.SeriesFrame] = None,
log_returns: bool = False,
daily_returns: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Optional[tp.SeriesFrame]:
"""Get benchmark return series per column or group.
Based on `Portfolio.bm_close` and `Portfolio.get_market_returns`."""
if not isinstance(cls_or_self, type):
bm_value = cls_or_self.resolve_shortcut_attr(
"bm_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if bm_value is None:
return None
return cls_or_self.get_market_returns(
init_value=init_value,
cash_deposits=cash_deposits,
market_value=bm_value,
log_returns=log_returns,
daily_returns=daily_returns,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
wrap_kwargs=wrap_kwargs,
)
@hybrid_method
def get_returns_acc(
cls_or_self,
returns: tp.Optional[tp.SeriesFrame] = None,
use_asset_returns: bool = False,
bm_returns: tp.Union[None, bool, tp.ArrayLike] = None,
log_returns: bool = False,
daily_returns: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
freq: tp.Optional[tp.FrequencyLike] = None,
year_freq: tp.Optional[tp.FrequencyLike] = None,
defaults: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> ReturnsAccessor:
"""Get returns accessor of type `vectorbtpro.returns.accessors.ReturnsAccessor`.
!!! hint
You can find most methods of this accessor as (cacheable) attributes of this portfolio."""
if not isinstance(cls_or_self, type):
if returns is None:
if use_asset_returns:
returns = cls_or_self.resolve_shortcut_attr(
"asset_returns",
log_returns=log_returns,
daily_returns=daily_returns,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
else:
returns = cls_or_self.resolve_shortcut_attr(
"returns",
log_returns=log_returns,
daily_returns=daily_returns,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
if bm_returns is None or (isinstance(bm_returns, bool) and bm_returns):
bm_returns = cls_or_self.resolve_shortcut_attr(
"bm_returns",
log_returns=log_returns,
daily_returns=daily_returns,
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
elif isinstance(bm_returns, bool) and not bm_returns:
bm_returns = None
if freq is None:
freq = cls_or_self.wrapper.freq
if year_freq is None:
year_freq = cls_or_self.year_freq
defaults = merge_dicts(cls_or_self.returns_acc_defaults, defaults)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(returns, arg_name="returns")
sim_start = cls_or_self.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=group_by)
sim_end = cls_or_self.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=group_by)
if daily_returns:
freq = "D"
if wrapper is not None:
wrapper = wrapper.resolve(group_by=group_by)
return returns.vbt.returns(
wrapper=wrapper,
bm_returns=bm_returns,
log_returns=log_returns,
freq=freq,
year_freq=year_freq,
defaults=defaults,
sim_start=sim_start,
sim_end=sim_end,
**kwargs,
)
@property
def returns_acc(self) -> ReturnsAccessor:
"""`Portfolio.get_returns_acc` with default arguments."""
return self.get_returns_acc()
@hybrid_method
def get_qs(
cls_or_self,
returns: tp.Optional[tp.SeriesFrame] = None,
use_asset_returns: bool = False,
bm_returns: tp.Union[None, bool, tp.ArrayLike] = None,
log_returns: bool = False,
daily_returns: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
freq: tp.Optional[tp.FrequencyLike] = None,
year_freq: tp.Optional[tp.FrequencyLike] = None,
defaults: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> QSAdapterT:
"""Get quantstats adapter of type `vectorbtpro.returns.qs_adapter.QSAdapter`.
`**kwargs` are passed to the adapter constructor."""
from vectorbtpro.returns.qs_adapter import QSAdapter
returns_acc = cls_or_self.get_returns_acc(
returns=returns,
use_asset_returns=use_asset_returns,
bm_returns=bm_returns,
log_returns=log_returns,
daily_returns=daily_returns,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
freq=freq,
year_freq=year_freq,
defaults=defaults,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
return QSAdapter(returns_acc, **kwargs)
@property
def qs(self) -> QSAdapterT:
"""`Portfolio.get_qs` with default arguments."""
return self.get_qs()
# ############# Resolution ############# #
@property
def self_aliases(self) -> tp.Set[str]:
"""Names to associate with this object."""
return {"self", "portfolio", "pf"}
def pre_resolve_attr(self, attr: str, final_kwargs: tp.KwargsLike = None) -> str:
"""Pre-process an attribute before resolution.
Uses the following keys:
* `use_asset_returns`: Whether to use `Portfolio.get_asset_returns` when resolving `returns` argument.
* `trades_type`: Which trade type to use when resolving `trades` argument."""
if "use_asset_returns" in final_kwargs:
if attr == "returns" and final_kwargs["use_asset_returns"]:
attr = "asset_returns"
if "trades_type" in final_kwargs:
trades_type = final_kwargs["trades_type"]
if isinstance(final_kwargs["trades_type"], str):
trades_type = map_enum_fields(trades_type, enums.TradesType)
if attr == "trades" and trades_type != self.trades_type:
if trades_type == enums.TradesType.EntryTrades:
attr = "entry_trades"
elif trades_type == enums.TradesType.ExitTrades:
attr = "exit_trades"
else:
attr = "positions"
return attr
def post_resolve_attr(self, attr: str, out: tp.Any, final_kwargs: tp.KwargsLike = None) -> str:
"""Post-process an object after resolution.
Uses the following keys:
* `incl_open`: Whether to include open trades/positions when resolving an argument
that is an instance of `vectorbtpro.portfolio.trades.Trades`."""
if "incl_open" in final_kwargs:
if isinstance(out, Trades) and not final_kwargs["incl_open"]:
out = out.status_closed
return out
def resolve_shortcut_attr(self, attr_name: str, *args, **kwargs) -> tp.Any:
"""Resolve an attribute that may have shortcut properties.
If `attr_name` has a prefix `get_`, checks whether the respective shortcut property can be called.
This way, complex call hierarchies can utilize cacheable properties."""
if not attr_name.startswith("get_"):
if "get_" + attr_name not in self.cls_dir or (len(args) == 0 and len(kwargs) == 0):
if isinstance(getattr(type(self), attr_name), property):
return getattr(self, attr_name)
return getattr(self, attr_name)(*args, **kwargs)
attr_name = "get_" + attr_name
if len(args) == 0:
naked_attr_name = attr_name[4:]
prop_name = naked_attr_name
_kwargs = dict(kwargs)
if "free" in _kwargs:
if _kwargs.pop("free"):
prop_name = "free_" + naked_attr_name
if "direction" in _kwargs:
direction = map_enum_fields(_kwargs.pop("direction"), enums.Direction)
if direction == enums.Direction.LongOnly:
prop_name = "long_" + naked_attr_name
elif direction == enums.Direction.ShortOnly:
prop_name = "short_" + naked_attr_name
if prop_name in self.cls_dir:
prop = getattr(type(self), prop_name)
options = getattr(prop, "options", {})
can_call_prop = True
if "group_by" in _kwargs:
group_by = _kwargs.pop("group_by")
group_aware = options.get("group_aware", True)
if group_aware:
if self.wrapper.grouper.is_grouping_modified(group_by=group_by):
can_call_prop = False
else:
group_by = _kwargs.pop("group_by")
if self.wrapper.grouper.is_grouping_enabled(group_by=group_by):
can_call_prop = False
if can_call_prop:
_kwargs.pop("jitted", None)
_kwargs.pop("chunked", None)
for k, v in get_func_kwargs(getattr(type(self), attr_name)).items():
if k in _kwargs and v is not _kwargs.pop(k):
can_call_prop = False
break
if can_call_prop:
if len(_kwargs) > 0:
can_call_prop = False
if can_call_prop:
return getattr(self, prop_name)
return getattr(self, attr_name)(*args, **kwargs)
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Portfolio.stats`.
Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
`stats` from `vectorbtpro._settings.portfolio`."""
from vectorbtpro._settings import settings
portfolio_stats_cfg = settings["portfolio"]["stats"]
return merge_dicts(
Analyzable.stats_defaults.__get__(self),
dict(settings=dict(trades_type=self.trades_type)),
portfolio_stats_cfg,
)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func="sim_start_index",
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func="sim_end_index",
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func="sim_duration",
apply_to_timedelta=True,
tags="wrapper",
),
start_value=dict(
title="Start Value",
calc_func="init_value",
tags="portfolio",
),
min_value=dict(
title="Min Value",
calc_func="value.vbt.min",
tags="portfolio",
),
max_value=dict(
title="Max Value",
calc_func="value.vbt.max",
tags="portfolio",
),
end_value=dict(
title="End Value",
calc_func="final_value",
tags="portfolio",
),
cash_deposits=dict(
title="Total Cash Deposits",
calc_func="total_cash_deposits",
check_has_cash_deposits=True,
tags="portfolio",
),
cash_earnings=dict(
title="Total Cash Earnings",
calc_func="total_cash_earnings",
check_has_cash_earnings=True,
tags="portfolio",
),
total_return=dict(
title="Total Return [%]",
calc_func="total_return",
post_calc_func=lambda self, out, settings: out * 100,
tags="portfolio",
),
bm_return=dict(
title="Benchmark Return [%]",
calc_func="bm_returns.vbt.returns.total",
post_calc_func=lambda self, out, settings: out * 100,
check_has_bm_returns=True,
tags="portfolio",
),
total_time_exposure=dict(
title="Position Coverage [%]",
calc_func="position_coverage",
post_calc_func=lambda self, out, settings: out * 100,
tags="portfolio",
),
max_gross_exposure=dict(
title="Max Gross Exposure [%]",
calc_func="gross_exposure.vbt.max",
post_calc_func=lambda self, out, settings: out * 100,
tags="portfolio",
),
max_dd=dict(
title="Max Drawdown [%]",
calc_func="drawdowns.max_drawdown",
post_calc_func=lambda self, out, settings: -out * 100,
tags=["portfolio", "drawdowns"],
),
max_dd_duration=dict(
title="Max Drawdown Duration",
calc_func="drawdowns.max_duration",
fill_wrap_kwargs=True,
tags=["portfolio", "drawdowns", "duration"],
),
total_orders=dict(
title="Total Orders",
calc_func="orders.count",
tags=["portfolio", "orders"],
),
total_fees_paid=dict(
title="Total Fees Paid",
calc_func="orders.fees.sum",
tags=["portfolio", "orders"],
),
total_trades=dict(
title="Total Trades",
calc_func="trades.count",
incl_open=True,
tags=["portfolio", "trades"],
),
win_rate=dict(
title="Win Rate [%]",
calc_func="trades.win_rate",
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['portfolio', 'trades', *incl_open_tags]"),
),
best_trade=dict(
title="Best Trade [%]",
calc_func="trades.returns.max",
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['portfolio', 'trades', *incl_open_tags]"),
),
worst_trade=dict(
title="Worst Trade [%]",
calc_func="trades.returns.min",
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['portfolio', 'trades', *incl_open_tags]"),
),
avg_winning_trade=dict(
title="Avg Winning Trade [%]",
calc_func="trades.winning.returns.mean",
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['portfolio', 'trades', *incl_open_tags, 'winning']"),
),
avg_losing_trade=dict(
title="Avg Losing Trade [%]",
calc_func="trades.losing.returns.mean",
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['portfolio', 'trades', *incl_open_tags, 'losing']"),
),
avg_winning_trade_duration=dict(
title="Avg Winning Trade Duration",
calc_func="trades.winning.duration.mean",
apply_to_timedelta=True,
tags=RepEval("['portfolio', 'trades', *incl_open_tags, 'winning', 'duration']"),
),
avg_losing_trade_duration=dict(
title="Avg Losing Trade Duration",
calc_func="trades.losing.duration.mean",
apply_to_timedelta=True,
tags=RepEval("['portfolio', 'trades', *incl_open_tags, 'losing', 'duration']"),
),
profit_factor=dict(
title="Profit Factor",
calc_func="trades.profit_factor",
tags=RepEval("['portfolio', 'trades', *incl_open_tags]"),
),
expectancy=dict(
title="Expectancy",
calc_func="trades.expectancy",
tags=RepEval("['portfolio', 'trades', *incl_open_tags]"),
),
sharpe_ratio=dict(
title="Sharpe Ratio",
calc_func="returns_acc.sharpe_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags=["portfolio", "returns"],
),
calmar_ratio=dict(
title="Calmar Ratio",
calc_func="returns_acc.calmar_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags=["portfolio", "returns"],
),
omega_ratio=dict(
title="Omega Ratio",
calc_func="returns_acc.omega_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags=["portfolio", "returns"],
),
sortino_ratio=dict(
title="Sortino Ratio",
calc_func="returns_acc.sortino_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags=["portfolio", "returns"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
def returns_stats(
self,
use_asset_returns: bool = False,
bm_returns: tp.Union[None, bool, tp.ArrayLike] = None,
log_returns: bool = False,
daily_returns: bool = False,
freq: tp.Optional[tp.FrequencyLike] = None,
year_freq: tp.Optional[tp.FrequencyLike] = None,
defaults: tp.KwargsLike = None,
chunked: tp.ChunkedOption = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Compute various statistics on returns of this portfolio.
See `Portfolio.returns_acc` and `vectorbtpro.returns.accessors.ReturnsAccessor.metrics`.
`kwargs` will be passed to `vectorbtpro.returns.accessors.ReturnsAccessor.stats` method.
If `bm_returns` is not set, uses `Portfolio.get_market_returns`."""
returns_acc = self.get_returns_acc(
use_asset_returns=use_asset_returns,
bm_returns=bm_returns,
log_returns=log_returns,
daily_returns=daily_returns,
freq=freq,
year_freq=year_freq,
defaults=defaults,
chunked=chunked,
group_by=group_by,
)
return getattr(returns_acc, "stats")(**kwargs)
# ############# Plotting ############# #
@hybrid_method
def plot_orders(
cls_or_self,
column: tp.Optional[tp.Label] = None,
orders: tp.Optional[Drawdowns] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column of orders.
`**kwargs` are passed to `vectorbtpro.portfolio.orders.Orders.plot`."""
if not isinstance(cls_or_self, type):
if orders is None:
orders = cls_or_self.resolve_shortcut_attr(
"orders",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
)
else:
checks.assert_not_none(orders, arg_name="orders")
fig = orders.plot(column=column, **kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=False,
xref=xref,
)
return fig
@hybrid_method
def plot_trades(
cls_or_self,
column: tp.Optional[tp.Label] = None,
trades: tp.Optional[Drawdowns] = None,
trades_type: tp.Optional[tp.Union[str, int]] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
xref: str = "x",
yref: str = "y",
**kwargs,
) -> tp.BaseFigure:
"""Plot one column of trades.
`**kwargs` are passed to `vectorbtpro.portfolio.trades.Trades.plot`."""
if not isinstance(cls_or_self, type):
if trades is None:
trades = cls_or_self.resolve_shortcut_attr(
"trades",
trades_type=trades_type,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
)
else:
checks.assert_not_none(trades, arg_name="trades")
fig = trades.plot(column=column, xref=xref, yref=yref, **kwargs)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=False,
xref=xref,
)
return fig
@hybrid_method
def plot_trade_pnl(
cls_or_self,
column: tp.Optional[tp.Label] = None,
trades: tp.Optional[Drawdowns] = None,
trades_type: tp.Optional[tp.Union[str, int]] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
pct_scale: bool = False,
xref: str = "x",
yref: str = "y",
**kwargs,
) -> tp.BaseFigure:
"""Plot one column of trade P&L.
`**kwargs` are passed to `vectorbtpro.portfolio.trades.Trades.plot_pnl`."""
if not isinstance(cls_or_self, type):
if trades is None:
trades = cls_or_self.resolve_shortcut_attr(
"trades",
trades_type=trades_type,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
)
else:
checks.assert_not_none(trades, arg_name="trades")
fig = trades.plot_pnl(column=column, pct_scale=pct_scale, xref=xref, yref=yref, **kwargs)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=False,
xref=xref,
)
return fig
@hybrid_method
def plot_trade_signals(
cls_or_self,
column: tp.Optional[tp.Label] = None,
entry_trades: tp.Optional[EntryTrades] = None,
exit_trades: tp.Optional[ExitTrades] = None,
positions: tp.Optional[Positions] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
plot_positions: tp.Union[bool, str] = "zones",
long_entry_trace_kwargs: tp.KwargsLike = None,
short_entry_trace_kwargs: tp.KwargsLike = None,
long_exit_trace_kwargs: tp.KwargsLike = None,
short_exit_trace_kwargs: tp.KwargsLike = None,
long_shape_kwargs: tp.KwargsLike = None,
short_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of trade signals.
Markers and shapes are colored by trade direction (green = long, red = short)."""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if entry_trades is None:
entry_trades = cls_or_self.resolve_shortcut_attr(
"entry_trades",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=False,
)
if exit_trades is None:
exit_trades = cls_or_self.resolve_shortcut_attr(
"exit_trades",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=False,
)
if positions is None:
positions = cls_or_self.resolve_shortcut_attr(
"positions",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=False,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(entry_trades, arg_name="entry_trades")
checks.assert_not_none(exit_trades, arg_name="exit_trades")
checks.assert_not_none(positions, arg_name="positions")
if wrapper is None:
wrapper = entry_trades.wrapper
fig = entry_trades.plot_signals(
column=column,
long_entry_trace_kwargs=long_entry_trace_kwargs,
short_entry_trace_kwargs=short_entry_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
**kwargs,
)
fig = exit_trades.plot_signals(
column=column,
plot_ohlc=False,
plot_close=False,
long_exit_trace_kwargs=long_exit_trace_kwargs,
short_exit_trace_kwargs=short_exit_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
if isinstance(plot_positions, bool):
if plot_positions:
plot_positions = "zones"
else:
plot_positions = None
if plot_positions is not None:
if plot_positions.lower() == "zones":
long_shape_kwargs = merge_dicts(
dict(fillcolor=plotting_cfg["contrast_color_schema"]["green"]),
long_shape_kwargs,
)
short_shape_kwargs = merge_dicts(
dict(fillcolor=plotting_cfg["contrast_color_schema"]["red"]),
short_shape_kwargs,
)
elif plot_positions.lower() == "lines":
base_shape_kwargs = dict(
type="line",
line=dict(dash="dot"),
xref=Rep("xref"),
yref=Rep("yref"),
x0=Rep("start_index"),
x1=Rep("end_index"),
y0=RepFunc(lambda record: record["entry_price"]),
y1=RepFunc(lambda record: record["exit_price"]),
opacity=0.75,
)
long_shape_kwargs = atomic_dict(
merge_dicts(
base_shape_kwargs,
dict(line=dict(color=plotting_cfg["contrast_color_schema"]["green"])),
long_shape_kwargs,
)
)
short_shape_kwargs = atomic_dict(
merge_dicts(
base_shape_kwargs,
dict(line=dict(color=plotting_cfg["contrast_color_schema"]["red"])),
short_shape_kwargs,
)
)
else:
raise ValueError(f"Invalid plot_positions: '{plot_positions}'")
fig = positions.direction_long.plot_shapes(
column=column,
plot_ohlc=False,
plot_close=False,
shape_kwargs=long_shape_kwargs,
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
fig = positions.direction_short.plot_shapes(
column=column,
plot_ohlc=False,
plot_close=False,
shape_kwargs=short_shape_kwargs,
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=False,
xref=xref,
)
return fig
@hybrid_method
def plot_cash_flow(
cls_or_self,
column: tp.Optional[tp.Label] = None,
free: bool = False,
cash_flow: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
line_shape: str = "hv",
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of cash flow.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.plot`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if cash_flow is None:
cash_flow = cls_or_self.resolve_shortcut_attr(
"cash_flow",
free=free,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(cash_flow, arg_name="cash_flow")
checks.assert_not_none(wrapper, arg_name="wrapper")
cash_flow = wrapper.select_col_from_obj(cash_flow, column=column, group_by=group_by)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["green"],
shape=line_shape,
),
name="Cash",
)
),
kwargs,
)
fig = cash_flow.vbt.lineplot(**kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=0.0,
x1=x_domain[1],
y1=0.0,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@hybrid_method
def plot_cash(
cls_or_self,
column: tp.Optional[tp.Label] = None,
free: bool = False,
init_cash: tp.Optional[tp.MaybeSeries] = None,
cash: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
line_shape: str = "hv",
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of cash balance.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if init_cash is None:
init_cash = cls_or_self.resolve_shortcut_attr(
"init_cash",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if cash is None:
cash = cls_or_self.resolve_shortcut_attr(
"cash",
free=free,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_cash, arg_name="init_cash")
checks.assert_not_none(cash, arg_name="cash")
checks.assert_not_none(wrapper, arg_name="wrapper")
init_cash = wrapper.select_col_from_obj(init_cash, column=column, group_by=group_by)
cash = wrapper.select_col_from_obj(cash, column=column, group_by=group_by)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["green"],
shape=line_shape,
),
name="Cash",
),
pos_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["green"], 0.3),
line=dict(shape=line_shape),
),
neg_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3),
line=dict(shape=line_shape),
),
other_trace_kwargs="hidden",
),
kwargs,
)
fig = cash.vbt.plot_against(init_cash, **kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=init_cash,
x1=x_domain[1],
y1=init_cash,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@hybrid_method
def plot_asset_flow(
cls_or_self,
column: tp.Optional[tp.Label] = None,
direction: tp.Union[str, int] = "both",
asset_flow: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
line_shape: str = "hv",
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column of asset flow.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.plot`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if asset_flow is None:
asset_flow = cls_or_self.resolve_shortcut_attr(
"asset_flow",
direction=direction,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(asset_flow, arg_name="asset_flow")
checks.assert_not_none(wrapper, arg_name="wrapper")
asset_flow = wrapper.select_col_from_obj(asset_flow, column=column, group_by=False)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["blue"],
shape=line_shape,
),
name="Assets",
)
),
kwargs,
)
fig = asset_flow.vbt.lineplot(**kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=0,
x1=x_domain[1],
y1=0,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=False,
xref=xref,
)
return fig
@hybrid_method
def plot_assets(
cls_or_self,
column: tp.Optional[tp.Label] = None,
direction: tp.Union[str, int] = "both",
assets: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
line_shape: str = "hv",
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column of assets.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if assets is None:
assets = cls_or_self.resolve_shortcut_attr(
"assets",
direction=direction,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(assets, arg_name="assets")
checks.assert_not_none(wrapper, arg_name="wrapper")
assets = wrapper.select_col_from_obj(assets, column=column, group_by=False)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["blue"],
shape=line_shape,
),
name="Assets",
),
pos_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["blue"], 0.3),
line=dict(shape=line_shape),
),
neg_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3),
line=dict(shape=line_shape),
),
other_trace_kwargs="hidden",
),
kwargs,
)
fig = assets.vbt.plot_against(0, **kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=0.0,
x1=x_domain[1],
y1=0.0,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=False,
xref=xref,
)
return fig
@hybrid_method
def plot_asset_value(
cls_or_self,
column: tp.Optional[tp.Label] = None,
direction: tp.Union[str, int] = "both",
asset_value: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
line_shape: str = "hv",
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of asset value.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if asset_value is None:
asset_value = cls_or_self.resolve_shortcut_attr(
"asset_value",
direction=direction,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(asset_value, arg_name="asset_value")
checks.assert_not_none(wrapper, arg_name="wrapper")
asset_value = wrapper.select_col_from_obj(asset_value, column=column, group_by=group_by)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["purple"],
shape=line_shape,
),
name="Value",
),
pos_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["purple"], 0.3),
line=dict(shape=line_shape),
),
neg_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3),
line=dict(shape=line_shape),
),
other_trace_kwargs="hidden",
),
kwargs,
)
fig = asset_value.vbt.plot_against(0, **kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=0.0,
x1=x_domain[1],
y1=0.0,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@hybrid_method
def plot_value(
cls_or_self,
column: tp.Optional[tp.Label] = None,
init_value: tp.Optional[tp.MaybeSeries] = None,
value: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of value.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if init_value is None:
init_value = cls_or_self.resolve_shortcut_attr(
"init_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if value is None:
value = cls_or_self.resolve_shortcut_attr(
"value",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_value, arg_name="init_value")
checks.assert_not_none(value, arg_name="value")
checks.assert_not_none(wrapper, arg_name="wrapper")
init_value = wrapper.select_col_from_obj(init_value, column=column, group_by=group_by)
value = wrapper.select_col_from_obj(value, column=column, group_by=group_by)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["purple"],
),
name="Value",
),
pos_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["purple"], 0.3),
),
neg_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["red"], 0.3),
),
other_trace_kwargs="hidden",
),
kwargs,
)
fig = value.vbt.plot_against(init_value, **kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=init_value,
x1=x_domain[1],
y1=init_value,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@hybrid_method
def plot_cumulative_returns(
cls_or_self,
column: tp.Optional[tp.Label] = None,
returns_acc: tp.Optional[ReturnsAccessor] = None,
use_asset_returns: bool = False,
bm_returns: tp.Union[None, bool, tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
pct_scale: bool = False,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of cumulative returns.
If `bm_returns` is None, will use `Portfolio.get_market_returns`.
`**kwargs` are passed to `vectorbtpro.returns.accessors.ReturnsSRAccessor.plot_cumulative`."""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if returns_acc is None:
returns_acc = cls_or_self.resolve_shortcut_attr(
"returns_acc",
use_asset_returns=use_asset_returns,
bm_returns=bm_returns,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
else:
checks.assert_not_none(returns_acc, arg_name="returns_acc")
kwargs = merge_dicts(
dict(
main_kwargs=dict(
trace_kwargs=dict(name="Value"),
pos_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["purple"], 0.3),
),
neg_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["red"], 0.3),
),
),
hline_shape_kwargs=dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
),
),
kwargs,
)
fig = returns_acc.plot_cumulative(
column=column,
fit_sim_range=fit_sim_range,
pct_scale=pct_scale,
**kwargs,
)
return fig
@hybrid_method
def plot_drawdowns(
cls_or_self,
column: tp.Optional[tp.Label] = None,
drawdowns: tp.Optional[Drawdowns] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
xref: str = "x",
yref: str = "y",
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of drawdowns.
`**kwargs` are passed to `vectorbtpro.generic.drawdowns.Drawdowns.plot`."""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if drawdowns is None:
drawdowns = cls_or_self.resolve_shortcut_attr(
"drawdowns",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
else:
checks.assert_not_none(drawdowns, arg_name="drawdowns")
kwargs = merge_dicts(
dict(
close_trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["purple"],
),
name="Value",
),
),
kwargs,
)
fig = drawdowns.plot(column=column, xref=xref, yref=yref, **kwargs)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@hybrid_method
def plot_underwater(
cls_or_self,
column: tp.Optional[tp.Label] = None,
init_value: tp.Optional[tp.MaybeSeries] = None,
returns_acc: tp.Optional[ReturnsAccessor] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
pct_scale: bool = True,
xref: str = "x",
yref: str = "y",
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of underwater.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericAccessor.plot`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if init_value is None:
init_value = cls_or_self.resolve_shortcut_attr(
"init_value",
sim_start=sim_start if rec_sim_range else None,
sim_end=sim_end if rec_sim_range else None,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if returns_acc is None:
returns_acc = cls_or_self.resolve_shortcut_attr(
"returns_acc",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(init_value, arg_name="init_value")
checks.assert_not_none(returns_acc, arg_name="returns_acc")
checks.assert_not_none(wrapper, arg_name="wrapper")
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
drawdown = returns_acc.drawdown()
drawdown = wrapper.select_col_from_obj(drawdown, column=column, group_by=group_by)
if not pct_scale:
cumulative_returns = returns_acc.cumulative()
cumret = wrapper.select_col_from_obj(cumulative_returns, column=column, group_by=group_by)
init_value = wrapper.select_col_from_obj(init_value, column=column, group_by=group_by)
drawdown = cumret * init_value * drawdown / (1 + drawdown)
default_kwargs = dict(
trace_kwargs=dict(
line=dict(color=plotting_cfg["color_schema"]["red"]),
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["red"], 0.3),
fill="tozeroy",
name="Drawdown",
)
)
if pct_scale:
yaxis = "yaxis" + yref[1:]
default_kwargs[yaxis] = dict(tickformat=".2%")
kwargs = merge_dicts(default_kwargs, kwargs)
fig = drawdown.vbt.lineplot(**kwargs)
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=0,
x1=x_domain[1],
y1=0,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@hybrid_method
def plot_gross_exposure(
cls_or_self,
column: tp.Optional[tp.Label] = None,
direction: tp.Union[str, int] = "both",
gross_exposure: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
line_shape: str = "hv",
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of gross exposure.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if gross_exposure is None:
gross_exposure = cls_or_self.resolve_shortcut_attr(
"gross_exposure",
direction=direction,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(gross_exposure, arg_name="gross_exposure")
checks.assert_not_none(wrapper, arg_name="wrapper")
gross_exposure = wrapper.select_col_from_obj(gross_exposure, column=column, group_by=group_by)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["pink"],
shape=line_shape,
),
name="Exposure",
),
pos_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3),
line=dict(shape=line_shape),
),
neg_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.3),
line=dict(shape=line_shape),
),
other_trace_kwargs="hidden",
),
kwargs,
)
fig = gross_exposure.vbt.plot_against(1, **kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=1,
x1=x_domain[1],
y1=1,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@hybrid_method
def plot_net_exposure(
cls_or_self,
column: tp.Optional[tp.Label] = None,
net_exposure: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
line_shape: str = "hv",
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column or group of net exposure.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`."""
from vectorbtpro.utils.figure import get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
if not isinstance(cls_or_self, type):
if net_exposure is None:
net_exposure = cls_or_self.resolve_shortcut_attr(
"net_exposure",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(net_exposure, arg_name="net_exposure")
checks.assert_not_none(wrapper, arg_name="wrapper")
net_exposure = wrapper.select_col_from_obj(net_exposure, column=column, group_by=group_by)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["pink"],
shape=line_shape,
),
name="Exposure",
),
pos_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["orange"], 0.3),
line=dict(shape=line_shape),
),
neg_trace_kwargs=dict(
fillcolor=adjust_opacity(plotting_cfg["color_schema"]["pink"], 0.3),
line=dict(shape=line_shape),
),
other_trace_kwargs="hidden",
),
kwargs,
)
fig = net_exposure.vbt.plot_against(1, **kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=1,
x1=x_domain[1],
y1=1,
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@hybrid_method
def plot_allocations(
cls_or_self,
column: tp.Optional[tp.Label] = None,
allocations: tp.Optional[tp.SeriesFrame] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
fit_sim_range: bool = True,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
line_shape: str = "hv",
line_visible: bool = True,
colorway: tp.Union[None, str, tp.Sequence[str]] = "Vivid",
xref: tp.Optional[str] = None,
yref: tp.Optional[str] = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one group of allocations."""
if not isinstance(cls_or_self, type):
if allocations is None:
allocations = cls_or_self.resolve_shortcut_attr(
"allocations",
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
wrapper=wrapper,
group_by=group_by,
)
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(allocations, arg_name="allocations")
checks.assert_not_none(wrapper, arg_name="wrapper")
allocations = wrapper.select_col_from_obj(allocations, column=column, obj_ungrouped=True, group_by=group_by)
if wrapper.grouper.is_grouped(group_by=group_by):
group_names = wrapper.grouper.get_index(group_by=group_by).names
allocations = allocations.vbt.drop_levels(group_names, strict=False)
if isinstance(allocations, pd.Series) and allocations.name is None:
allocations.name = "Allocation"
fig = allocations.vbt.areaplot(line_shape=line_shape, line_visible=line_visible, colorway=colorway, **kwargs)
if xref is None:
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
if yref is None:
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
if fit_sim_range:
fig = cls_or_self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
wrapper=wrapper,
group_by=group_by,
xref=xref,
)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Portfolio.plot`.
Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
`plots` from `vectorbtpro._settings.portfolio`."""
from vectorbtpro._settings import settings
portfolio_plots_cfg = settings["portfolio"]["plots"]
return merge_dicts(
Analyzable.plots_defaults.__get__(self),
dict(settings=dict(trades_type=self.trades_type)),
portfolio_plots_cfg,
)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
orders=dict(
title="Orders",
yaxis_kwargs=dict(title="Price"),
check_is_not_grouped=True,
plot_func="plot_orders",
pass_add_trace_kwargs=True,
tags=["portfolio", "orders"],
),
trades=dict(
title="Trades",
yaxis_kwargs=dict(title="Price"),
check_is_not_grouped=True,
plot_func="plot_trades",
pass_add_trace_kwargs=True,
tags=["portfolio", "trades"],
),
trade_pnl=dict(
title="Trade PnL",
yaxis_kwargs=dict(title="PnL"),
check_is_not_grouped=True,
plot_func="plot_trade_pnl",
pct_scale=True,
pass_add_trace_kwargs=True,
tags=["portfolio", "trades"],
),
trade_signals=dict(
title="Trade Signals",
yaxis_kwargs=dict(title="Price"),
check_is_not_grouped=True,
plot_func="plot_trade_signals",
tags=["portfolio", "trades"],
),
cash_flow=dict(
title="Cash Flow",
yaxis_kwargs=dict(title="Amount"),
plot_func="plot_cash_flow",
pass_add_trace_kwargs=True,
tags=["portfolio", "cash"],
),
cash=dict(
title="Cash",
yaxis_kwargs=dict(title="Amount"),
plot_func="plot_cash",
pass_add_trace_kwargs=True,
tags=["portfolio", "cash"],
),
asset_flow=dict(
title="Asset Flow",
yaxis_kwargs=dict(title="Amount"),
check_is_not_grouped=True,
plot_func="plot_asset_flow",
pass_add_trace_kwargs=True,
tags=["portfolio", "assets"],
),
assets=dict(
title="Assets",
yaxis_kwargs=dict(title="Amount"),
check_is_not_grouped=True,
plot_func="plot_assets",
pass_add_trace_kwargs=True,
tags=["portfolio", "assets"],
),
asset_value=dict(
title="Asset Value",
yaxis_kwargs=dict(title="Value"),
plot_func="plot_asset_value",
pass_add_trace_kwargs=True,
tags=["portfolio", "assets", "value"],
),
value=dict(
title="Value",
yaxis_kwargs=dict(title="Value"),
plot_func="plot_value",
pass_add_trace_kwargs=True,
tags=["portfolio", "value"],
),
cumulative_returns=dict(
title="Cumulative Returns",
yaxis_kwargs=dict(title="Cumulative return"),
plot_func="plot_cumulative_returns",
pass_hline_shape_kwargs=True,
pass_add_trace_kwargs=True,
pass_xref=True,
pass_yref=True,
tags=["portfolio", "returns"],
),
drawdowns=dict(
title="Drawdowns",
yaxis_kwargs=dict(title="Value"),
plot_func="plot_drawdowns",
pass_add_trace_kwargs=True,
pass_xref=True,
pass_yref=True,
tags=["portfolio", "value", "drawdowns"],
),
underwater=dict(
title="Underwater",
yaxis_kwargs=dict(title="Drawdown"),
plot_func="plot_underwater",
pass_add_trace_kwargs=True,
tags=["portfolio", "value", "drawdowns"],
),
gross_exposure=dict(
title="Gross Exposure",
yaxis_kwargs=dict(title="Exposure"),
plot_func="plot_gross_exposure",
pass_add_trace_kwargs=True,
tags=["portfolio", "exposure"],
),
net_exposure=dict(
title="Net Exposure",
yaxis_kwargs=dict(title="Exposure"),
plot_func="plot_net_exposure",
pass_add_trace_kwargs=True,
tags=["portfolio", "exposure"],
),
allocations=dict(
title="Allocations",
yaxis_kwargs=dict(title="Allocation"),
plot_func="plot_allocations",
pass_add_trace_kwargs=True,
tags=["portfolio", "allocations"],
),
)
)
plot = Analyzable.plots
@property
def subplots(self) -> Config:
return self._subplots
# ############# Docs ############# #
@classmethod
def build_in_output_config_doc(cls, source_cls: tp.Optional[type] = None) -> str:
"""Build in-output config documentation."""
if source_cls is None:
source_cls = Portfolio
return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "in_output_config").__doc__)).substitute(
{"in_output_config": cls.in_output_config.prettify(), "cls_name": cls.__name__},
)
@classmethod
def override_in_output_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None:
"""Call this method on each subclass that overrides `Portfolio.in_output_config`."""
__pdoc__[cls.__name__ + ".in_output_config"] = cls.build_in_output_config_doc(source_cls=source_cls)
Portfolio.override_in_output_config_doc(__pdoc__)
Portfolio.override_metrics_doc(__pdoc__)
Portfolio.override_subplots_doc(__pdoc__)
__pdoc__["Portfolio.plot"] = "See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots`."
PF = Portfolio
"""Shortcut for `Portfolio`."""
__pdoc__["PF"] = False
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Functions for working with call sequence arrays."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.portfolio.enums import CallSeqType
from vectorbtpro.registries.jit_registry import register_jitted
__all__ = []
@register_jitted(cache=True)
def shuffle_call_seq_nb(call_seq: tp.Array2d, group_lens: tp.GroupLens) -> None:
"""Shuffle the call sequence array."""
from_col = 0
for group in range(len(group_lens)):
to_col = from_col + group_lens[group]
for i in range(call_seq.shape[0]):
np.random.shuffle(call_seq[i, from_col:to_col])
from_col = to_col
@register_jitted(cache=True)
def build_call_seq_nb(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
call_seq_type: int = CallSeqType.Default,
) -> tp.Array2d:
"""Build a new call sequence array."""
if call_seq_type == CallSeqType.Reversed:
out = np.full(target_shape[1], 1, dtype=int_)
out[np.cumsum(group_lens)[1:] - group_lens[1:] - 1] -= group_lens[1:]
out = np.cumsum(out[::-1])[::-1] - 1
out = out * np.ones((target_shape[0], 1), dtype=int_)
return out
out = np.full(target_shape[1], 1, dtype=int_)
out[np.cumsum(group_lens)[:-1]] -= group_lens[:-1]
out = np.cumsum(out) - 1
out = out * np.ones((target_shape[0], 1), dtype=int_)
if call_seq_type == CallSeqType.Random:
shuffle_call_seq_nb(out, group_lens)
return out
def require_call_seq(call_seq: tp.Array2d) -> tp.Array2d:
"""Force the call sequence array to pass our requirements."""
return np.require(call_seq, dtype=int_, requirements=["A", "O", "W", "F"])
def build_call_seq(
target_shape: tp.Shape,
group_lens: tp.GroupLens,
call_seq_type: int = CallSeqType.Default,
) -> tp.Array2d:
"""Not compiled but faster version of `build_call_seq_nb`."""
call_seq = np.full(target_shape[1], 1, dtype=int_)
if call_seq_type == CallSeqType.Reversed:
call_seq[np.cumsum(group_lens)[1:] - group_lens[1:] - 1] -= group_lens[1:]
call_seq = np.cumsum(call_seq[::-1])[::-1] - 1
else:
call_seq[np.cumsum(group_lens[:-1])] -= group_lens[:-1]
call_seq = np.cumsum(call_seq) - 1
call_seq = np.broadcast_to(call_seq, target_shape)
if call_seq_type == CallSeqType.Random:
call_seq = require_call_seq(call_seq)
shuffle_call_seq_nb(call_seq, group_lens)
return require_call_seq(call_seq)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Extensions for chunking of portfolio."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.merging import concat_arrays, column_stack_arrays
from vectorbtpro.portfolio.enums import SimulationOutput
from vectorbtpro.records.chunking import merge_records
from vectorbtpro.utils.chunking import ChunkMeta, ArraySlicer
from vectorbtpro.utils.config import ReadonlyConfig
from vectorbtpro.utils.template import Rep
__all__ = []
def get_init_cash_slicer(ann_args: tp.AnnArgs) -> ArraySlicer:
"""Get slicer for `init_cash` based on cash sharing."""
cash_sharing = ann_args["cash_sharing"]["value"]
if cash_sharing:
return base_ch.FlexArraySlicer()
return base_ch.flex_1d_array_gl_slicer
def get_cash_deposits_slicer(ann_args: tp.AnnArgs) -> ArraySlicer:
"""Get slicer for `cash_deposits` based on cash sharing."""
cash_sharing = ann_args["cash_sharing"]["value"]
if cash_sharing:
return base_ch.FlexArraySlicer(axis=1)
return base_ch.flex_array_gl_slicer
def in_outputs_merge_func(
results: tp.List[SimulationOutput],
chunk_meta: tp.Iterable[ChunkMeta],
ann_args: tp.AnnArgs,
mapper: base_ch.GroupLensMapper,
):
"""Merge chunks of in-output objects.
Concatenates 1-dim arrays, stacks columns of 2-dim arrays, and fixes and concatenates record arrays
using `vectorbtpro.records.chunking.merge_records`. Other objects will throw an error."""
in_outputs = dict()
for k, v in results[0].in_outputs._asdict().items():
if v is None:
in_outputs[k] = None
continue
if not isinstance(v, np.ndarray):
raise TypeError(f"Cannot merge in-output object '{k}' of type {type(v)}")
if v.ndim == 2:
in_outputs[k] = column_stack_arrays([getattr(r.in_outputs, k) for r in results])
elif v.ndim == 1:
if v.dtype.fields is None:
in_outputs[k] = np.concatenate([getattr(r.in_outputs, k) for r in results])
else:
records = [getattr(r.in_outputs, k) for r in results]
in_outputs[k] = merge_records(records, chunk_meta, ann_args=ann_args, mapper=mapper)
else:
raise ValueError(f"Cannot merge in-output object '{k}' with number of dimensions {v.ndim}")
return type(results[0].in_outputs)(**in_outputs)
def merge_sim_outs(
results: tp.List[SimulationOutput],
chunk_meta: tp.Iterable[ChunkMeta],
ann_args: tp.AnnArgs,
mapper: base_ch.GroupLensMapper,
in_outputs_merge_func: tp.Callable = in_outputs_merge_func,
**kwargs,
) -> SimulationOutput:
"""Merge chunks of `vectorbtpro.portfolio.enums.SimulationOutput` instances.
If `SimulationOutput.in_outputs` is not None, must provide `in_outputs_merge_func` or similar."""
order_records = [r.order_records for r in results]
order_records = merge_records(order_records, chunk_meta, ann_args=ann_args, mapper=mapper)
log_records = [r.log_records for r in results]
log_records = merge_records(log_records, chunk_meta, ann_args=ann_args, mapper=mapper)
target_shape = ann_args["target_shape"]["value"]
if results[0].cash_deposits.shape == target_shape:
cash_deposits = column_stack_arrays([r.cash_deposits for r in results])
else:
cash_deposits = results[0].cash_deposits
if results[0].cash_earnings.shape == target_shape:
cash_earnings = column_stack_arrays([r.cash_earnings for r in results])
else:
cash_earnings = results[0].cash_earnings
if results[0].call_seq is not None:
call_seq = column_stack_arrays([r.call_seq for r in results])
else:
call_seq = None
if results[0].in_outputs is not None:
in_outputs = in_outputs_merge_func(results, chunk_meta, ann_args, mapper, **kwargs)
else:
in_outputs = None
if results[0].sim_start is not None:
sim_start = concat_arrays([r.sim_start for r in results])
else:
sim_start = None
if results[0].sim_end is not None:
sim_end = concat_arrays([r.sim_end for r in results])
else:
sim_end = None
return SimulationOutput(
order_records=order_records,
log_records=log_records,
cash_deposits=cash_deposits,
cash_earnings=cash_earnings,
call_seq=call_seq,
in_outputs=in_outputs,
sim_start=sim_start,
sim_end=sim_end,
)
merge_sim_outs_config = ReadonlyConfig(
dict(
merge_func=merge_sim_outs,
merge_kwargs=dict(
chunk_meta=Rep("chunk_meta"),
ann_args=Rep("ann_args"),
mapper=base_ch.group_lens_mapper,
),
)
)
"""Config for merging using `merge_sim_outs`."""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class decorators for portfolio."""
from vectorbtpro import _typing as tp
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import Config, resolve_dict
from vectorbtpro.utils.decorators import cacheable_property, cached_property
from vectorbtpro.utils.parsing import get_func_arg_names
__all__ = []
def attach_returns_acc_methods(config: Config) -> tp.ClassWrapper:
"""Class decorator to attach returns accessor methods.
`config` must contain target method names (keys) and settings (values) with the following keys:
* `source_name`: Name of the source method. Defaults to the target name.
* `docstring`: Method docstring.
The class must be a subclass of `vectorbtpro.portfolio.base.Portfolio`."""
def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
checks.assert_subclass_of(cls, "Portfolio")
for target_name, settings in config.items():
source_name = settings.get("source_name", target_name)
docstring = settings.get("docstring", f"See `vectorbtpro.returns.accessors.ReturnsAccessor.{source_name}`.")
def new_method(
self,
*,
returns: tp.Optional[tp.SeriesFrame] = None,
use_asset_returns: bool = False,
bm_returns: tp.Union[None, bool, tp.ArrayLike] = None,
log_returns: bool = False,
daily_returns: bool = False,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
rec_sim_range: bool = False,
freq: tp.Optional[tp.FrequencyLike] = None,
year_freq: tp.Optional[tp.FrequencyLike] = None,
defaults: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
_source_name: str = source_name,
**kwargs,
) -> tp.Any:
returns_acc = self.get_returns_acc(
returns=returns,
use_asset_returns=use_asset_returns,
bm_returns=bm_returns,
log_returns=log_returns,
daily_returns=daily_returns,
sim_start=sim_start,
sim_end=sim_end,
rec_sim_range=rec_sim_range,
freq=freq,
year_freq=year_freq,
defaults=defaults,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
group_by=group_by,
)
ret_method = getattr(returns_acc, _source_name)
if "jitted" in get_func_arg_names(ret_method):
kwargs["jitted"] = jitted
return ret_method(**kwargs)
new_method.__name__ = "get_" + target_name
new_method.__module__ = cls.__module__
new_method.__qualname__ = f"{cls.__name__}.{new_method.__name__}"
new_method.__doc__ = docstring
setattr(cls, new_method.__name__, new_method)
return cls
return wrapper
def attach_shortcut_properties(config: Config) -> tp.ClassWrapper:
"""Class decorator to attach shortcut properties.
`config` must contain target property names (keys) and settings (values) with the following keys:
* `method_name`: Name of the source method. Defaults to the target name prepended with the prefix `get_`.
* `use_in_outputs`: Whether the property can return an in-place output. Defaults to True.
* `method_kwargs`: Keyword arguments passed to the source method. Defaults to None.
* `decorator`: Defaults to `vectorbtpro.utils.decorators.cached_property` for object types
'records' and 'red_array'. Otherwise, to `vectorbtpro.utils.decorators.cacheable_property`.
* `docstring`: Method docstring.
* Other keyword arguments are passed to the decorator and can include settings for wrapping,
indexing, resampling, stacking, etc.
The class must be a subclass of `vectorbtpro.portfolio.base.Portfolio`."""
def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
checks.assert_subclass_of(cls, "Portfolio")
for target_name, settings in config.items():
settings = dict(settings)
if target_name.startswith("get_"):
raise ValueError(f"Property names cannot have prefix 'get_' ('{target_name}')")
method_name = settings.pop("method_name", "get_" + target_name)
use_in_outputs = settings.pop("use_in_outputs", True)
method_kwargs = settings.pop("method_kwargs", None)
method_kwargs = resolve_dict(method_kwargs)
decorator = settings.pop("decorator", None)
if decorator is None:
if settings.get("obj_type", "array") in ("red_array", "records"):
decorator = cached_property
else:
decorator = cacheable_property
docstring = settings.pop("docstring", None)
if docstring is None:
if len(method_kwargs) == 0:
docstring = f"`{cls.__name__}.{method_name}` with default arguments."
else:
docstring = f"`{cls.__name__}.{method_name}` with arguments `{method_kwargs}`."
def new_prop(
self,
_method_name: tp.Optional[str] = method_name,
_target_name: str = target_name,
_use_in_outputs: bool = use_in_outputs,
_method_kwargs: tp.Kwargs = method_kwargs,
) -> tp.Any:
if _use_in_outputs and self.use_in_outputs and self.in_outputs is not None:
try:
out = self.get_in_output(_target_name)
if out is not None:
return out
except AttributeError:
pass
if _method_name is None:
raise ValueError(f"Field '{_target_name}' must be prefilled")
return getattr(self, _method_name)(**_method_kwargs)
new_prop.__name__ = target_name
new_prop.__module__ = cls.__module__
new_prop.__qualname__ = f"{cls.__name__}.{new_prop.__name__}"
new_prop.__doc__ = docstring
setattr(cls, new_prop.__name__, decorator(new_prop, **settings))
return cls
return wrapper
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Named tuples and enumerated types for portfolio.
Defines enums and other schemas for `vectorbtpro.portfolio`."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.utils.formatting import prettify
__pdoc__all__ = __all__ = [
"RejectedOrderError",
"PriceType",
"ValPriceType",
"InitCashMode",
"CallSeqType",
"PendingConflictMode",
"AccumulationMode",
"ConflictMode",
"DirectionConflictMode",
"OppositeEntryMode",
"DeltaFormat",
"TimeDeltaFormat",
"StopLadderMode",
"StopEntryPrice",
"StopExitPrice",
"StopExitType",
"StopUpdateMode",
"SizeType",
"Direction",
"LeverageMode",
"PriceAreaVioMode",
"OrderStatus",
"OrderStatusInfo",
"status_info_desc",
"OrderSide",
"OrderType",
"LimitOrderPrice",
"TradeDirection",
"TradeStatus",
"TradesType",
"OrderPriceStatus",
"PositionFeature",
"PriceArea",
"NoPriceArea",
"AccountState",
"ExecState",
"SimulationOutput",
"SimulationContext",
"GroupContext",
"RowContext",
"SegmentContext",
"OrderContext",
"PostOrderContext",
"FlexOrderContext",
"Order",
"NoOrder",
"OrderResult",
"SignalSegmentContext",
"SignalContext",
"PostSignalContext",
"FSInOutputs",
"FOInOutputs",
"order_fields",
"order_dt",
"fs_order_fields",
"fs_order_dt",
"trade_fields",
"trade_dt",
"log_fields",
"log_dt",
"alloc_range_fields",
"alloc_range_dt",
"alloc_point_fields",
"alloc_point_dt",
"main_info_fields",
"main_info_dt",
"limit_info_fields",
"limit_info_dt",
"sl_info_fields",
"sl_info_dt",
"tsl_info_fields",
"tsl_info_dt",
"tp_info_fields",
"tp_info_dt",
"time_info_fields",
"time_info_dt",
]
__pdoc__ = {}
# ############# Errors ############# #
class RejectedOrderError(Exception):
"""Rejected order error."""
pass
# ############# Enums ############# #
class PriceTypeT(tp.NamedTuple):
Open: int = -np.inf
Close: int = np.inf
NextOpen: int = -1
NextClose: int = -2
NextValidOpen: int = -3
NextValidClose: int = -4
PriceType = PriceTypeT()
"""_"""
__pdoc__[
"PriceType"
] = f"""Price type.
```python
{prettify(PriceType)}
```
Attributes:
Open: Opening price.
Will be substituted by `-np.inf`.
Close: Closing price.
Will be substituted by `np.inf`.
NextOpen: Next opening price.
Will be substituted by `-np.inf` and `from_ago` will be set to 1.
NextClose: Next closing price.
Will be substituted by `np.inf` and `from_ago` will be set to 1.
NextValidOpen: Next valid (non-NA) opening price.
Will be substituted by `-np.inf` and `from_ago` will be set to the distance to the previous valid value.
NextValidClose: Next valid (non-NA) closing price.
Will be substituted by `np.inf` and `from_ago` will be set to the distance to the previous valid value.
"""
class ValPriceTypeT(tp.NamedTuple):
Latest: int = -np.inf
Price: int = np.inf
ValPriceType = ValPriceTypeT()
"""_"""
__pdoc__[
"ValPriceType"
] = f"""Asset valuation price type.
```python
{prettify(ValPriceType)}
```
Attributes:
Latest: Latest price. Will be substituted by `-np.inf`.
Price: Order price. Will be substituted by `np.inf`.
"""
class InitCashModeT(tp.NamedTuple):
Auto: int = -1
AutoAlign: int = -2
InitCashMode = InitCashModeT()
"""_"""
__pdoc__[
"InitCashMode"
] = f"""Initial cash mode.
```python
{prettify(InitCashMode)}
```
Attributes:
Auto: Initial cash is infinite within simulation, and then set to the total cash spent.
AutoAlign: Initial cash is set to the total cash spent across all columns.
"""
class CallSeqTypeT(tp.NamedTuple):
Default: int = 0
Reversed: int = 1
Random: int = 2
Auto: int = 3
CallSeqType = CallSeqTypeT()
"""_"""
__pdoc__[
"CallSeqType"
] = f"""Call sequence type.
```python
{prettify(CallSeqType)}
```
Attributes:
Default: Place calls from left to right.
Reversed: Place calls from right to left.
Random: Place calls randomly.
Auto: Place calls dynamically based on order value.
"""
class PendingConflictModeT(tp.NamedTuple):
KeepIgnore: int = 0
KeepExecute: int = 1
CancelIgnore: int = 2
CancelExecute: int = 3
PendingConflictMode = PendingConflictModeT()
"""_"""
__pdoc__[
"PendingConflictMode"
] = f"""Conflict mode for pending signals.
```python
{prettify(PendingConflictMode)}
```
What should happen if an executable signal occurs during a pending signal?
Attributes:
KeepIgnore: Keep the pending signal and cancel the user-defined signal.
KeepExecute: Keep the pending signal and execute the user-defined signal.
CancelIgnore: Cancel the pending signal and ignore the user-defined signal.
CancelExecute: Cancel the pending signal and execute the user-defined signal.
"""
class AccumulationModeT(tp.NamedTuple):
Disabled: int = 0
Both: int = 1
AddOnly: int = 2
RemoveOnly: int = 3
AccumulationMode = AccumulationModeT()
"""_"""
__pdoc__[
"AccumulationMode"
] = f"""Accumulation mode.
```python
{prettify(AccumulationMode)}
```
Accumulation allows gradually increasing and decreasing positions by a size.
Attributes:
Disabled: Disable accumulation.
Can also be provided as False.
Both: Allow both adding to and removing from the position.
Can also be provided as True.
AddOnly: Allow accumulation to only add to the position.
RemoveOnly: Allow accumulation to only remove from the position.
!!! note
Accumulation acts differently for exits and opposite entries: exits reduce the current position
but won't enter the opposite one, while opposite entries reduce the position by the same amount,
but as soon as this position is closed, they begin to increase the opposite position.
The behavior for opposite entries can be changed by `OppositeEntryMode` and for stop orders by `StopExitType`.
"""
class ConflictModeT(tp.NamedTuple):
Ignore: int = 0
Entry: int = 1
Exit: int = 2
Adjacent: int = 3
Opposite: int = 4
ConflictMode = ConflictModeT()
"""_"""
__pdoc__[
"ConflictMode"
] = f"""Conflict mode.
```python
{prettify(ConflictMode)}
```
What should happen if both an entry signal and an exit signal occur simultaneously?
Attributes:
Ignore: Ignore both signals.
Entry: Execute the entry signal.
Exit: Execute the exit signal.
Adjacent: Execute the signal adjacent to the current position.
Takes effect only when in position, otherwise ignores.
Opposite: Execute the signal opposite to the current position.
Takes effect only when in position, otherwise ignores.
"""
class DirectionConflictModeT(tp.NamedTuple):
Ignore: int = 0
Long: int = 1
Short: int = 2
Adjacent: int = 3
Opposite: int = 4
DirectionConflictMode = DirectionConflictModeT()
"""_"""
__pdoc__[
"DirectionConflictMode"
] = f"""Direction conflict mode.
```python
{prettify(DirectionConflictMode)}
```
What should happen if both a long entry signal and a short entry signals occur simultaneously?
Attributes:
Ignore: Ignore both entry signals.
Long: Execute the long entry signal.
Short: Execute the short entry signal.
Adjacent: Execute the adjacent entry signal.
Takes effect only when in position, otherwise ignores.
Opposite: Execute the opposite entry signal.
Takes effect only when in position, otherwise ignores.
"""
class OppositeEntryModeT(tp.NamedTuple):
Ignore: int = 0
Close: int = 1
CloseReduce: int = 2
Reverse: int = 3
ReverseReduce: int = 4
OppositeEntryMode = OppositeEntryModeT()
"""_"""
__pdoc__[
"OppositeEntryMode"
] = f"""Opposite entry mode.
```python
{prettify(OppositeEntryMode)}
```
What should happen if an entry signal of opposite direction occurs before an exit signal?
Attributes:
Ignore: Ignore the opposite entry signal.
Close: Close the current position.
CloseReduce: Close the current position or reduce it if accumulation is enabled.
Reverse: Reverse the current position.
ReverseReduce: Reverse the current position or reduce it if accumulation is enabled.
"""
class DeltaFormatT(tp.NamedTuple):
Absolute: int = 0
Percent: int = 1
Percent100: int = 2
Target: int = 3
DeltaFormat = DeltaFormatT()
"""_"""
__pdoc__[
"DeltaFormat"
] = f"""Delta format.
```python
{prettify(DeltaFormat)}
```
In which format a delta is provided?
Attributes:
Absolute: Absolute format where 0.1 is the absolute difference between the initial and target value
Percent: Percentage format where 0.1 is 10% applied to the initial value to get the target value
Percent100: Percentage format where 0.1 is 0.1% applied to the initial value to get the target value
Target: Target format where 0.1 is the target value
"""
class TimeDeltaFormatT(tp.NamedTuple):
Rows: int = 0
Index: int = 1
TimeDeltaFormat = TimeDeltaFormatT()
"""_"""
__pdoc__[
"TimeDeltaFormat"
] = f"""Time delta format.
```python
{prettify(TimeDeltaFormat)}
```
In which format a time delta is provided?
Attributes:
Rows: Row format where 1 means one row (simulation step) has passed.
Doesn't require the index to be provided.
Index: Index format where 1 means one value in index has passed.
If index is datetime-like, 1 means one nanosecond.
Requires the index to be provided.
"""
class StopLadderModeT(tp.NamedTuple):
Disabled: int = 0
Uniform: int = 1
Weighted: int = 2
AdaptUniform: int = 3
AdaptWeighted: int = 4
Dynamic: int = 5
StopLadderMode = StopLadderModeT()
"""_"""
__pdoc__[
"StopLadderMode"
] = f"""Stop ladder mode.
```python
{prettify(StopLadderMode)}
```
Attributes:
Disabled: Disable the stop ladder.
Can also be provided as False.
Uniform: Enable the stop ladder with a uniform exit size.
Can also be provided as True.
Weighted: Enable the stop ladder with a stop-weighted exit size.
AdaptUniform: Enable the stop ladder with a uniform exit size that adapts to the current position.
AdaptWeighted: Enable the stop ladder with a stop-weighted exit size that adapts to the current position.
Dynamic: Enable the stop ladder but do not use stop arrays.
!!! note
When disabled, make sure that stop arrays broadcast against the target shape.
When enabled, make sure that rows in stop arrays represent steps in the ladder.
"""
class StopEntryPriceT(tp.NamedTuple):
ValPrice: int = -1
Open: int = -2
Price: int = -3
FillPrice: int = -4
Close: int = -5
StopEntryPrice = StopEntryPriceT()
"""_"""
__pdoc__[
"StopEntryPrice"
] = f"""Stop entry price.
```python
{prettify(StopEntryPrice)}
```
Which price to use as an initial stop price?
Attributes:
ValPrice: Asset valuation price.
Open: Opening price.
Price: Order price.
FillPrice: Filled order price (that is, slippage is already applied).
Close: Closing price.
!!! note
Each flag is negative, thus if a positive value is provided, it's used directly as price.
"""
class StopExitPriceT(tp.NamedTuple):
Stop: int = -1
HardStop: int = -2
Close: int = -3
StopExitPrice = StopExitPriceT()
"""_"""
__pdoc__[
"StopExitPrice"
] = f"""Stop exit price.
```python
{prettify(StopExitPrice)}
```
Which price to use when exiting a position upon a stop signal?
Attributes:
Stop: Stop price.
If the target price is first hit by the opening price, the opening price is used.
HardStop: Hard stop price.
The stop price is used regardless of whether the target price is first hit by the opening price.
Close: Closing price.
!!! note
Each flag is negative, thus if a positive value is provided, it's used directly as price.
"""
class StopExitTypeT(tp.NamedTuple):
Close: int = 0
CloseReduce: int = 1
Reverse: int = 2
ReverseReduce: int = 3
StopExitType = StopExitTypeT()
"""_"""
__pdoc__[
"StopExitType"
] = f"""Stop exit type.
```python
{prettify(StopExitType)}
```
How to exit the current position upon a stop signal?
Attributes:
Close: Close the current position.
CloseReduce: Close the current position or reduce it if accumulation is enabled.
Reverse: Reverse the current position.
ReverseReduce: Reverse the current position or reduce it if accumulation is enabled.
"""
class StopUpdateModeT(tp.NamedTuple):
Keep: int = 0
Override: int = 1
OverrideNaN: int = 2
StopUpdateMode = StopUpdateModeT()
"""_"""
__pdoc__[
"StopUpdateMode"
] = f"""Stop update mode.
```python
{prettify(StopUpdateMode)}
```
What to do with the old stop upon a new entry/accumulation?
Attributes:
Keep: Keep the old stop.
Override: Override the old stop, but only if the new stop is not NaN.
OverrideNaN: Override the old stop, even if the new stop is NaN.
"""
class SizeTypeT(tp.NamedTuple):
Amount: int = 0
Value: int = 1
Percent: int = 2
Percent100: int = 3
ValuePercent: int = 4
ValuePercent100: int = 5
TargetAmount: int = 6
TargetValue: int = 7
TargetPercent: int = 8
TargetPercent100: int = 9
SizeType = SizeTypeT()
"""_"""
__pdoc__[
"SizeType"
] = f"""Size type.
```python
{prettify(SizeType)}
```
Attributes:
Amount: Amount of assets to trade.
Value: Asset value to trade.
Gets converted into `SizeType.Amount` using `ExecState.val_price`.
Percent: Percentage of available resources to use in either direction (not to be confused with
the percentage of position value!) where 0.01 means 1%
* When long buying, the percentage of `ExecState.free_cash`
* When long selling, the percentage of `ExecState.position`
* When short selling, the percentage of `ExecState.free_cash`
* When short buying, the percentage of `ExecState.free_cash`, `ExecState.debt`, and `ExecState.locked_cash`
* When reversing, the percentage is getting applied on the final position
Percent100: `SizeType.Percent` where 1.0 means 1%.
ValuePercent: Percentage of total value.
Uses `ExecState.value` to get the current total value.
Gets converted into `SizeType.Value`.
ValuePercent100: `SizeType.ValuePercent` where 1.0 means 1%.
TargetAmount: Target amount of assets to hold (= target position).
Uses `ExecState.position` to get the current position.
Gets converted into `SizeType.Amount`.
TargetValue: Target asset value.
Uses `ExecState.val_price` to get the current asset value.
Gets converted into `SizeType.TargetAmount`.
TargetPercent: Target percentage of total value.
Uses `ExecState.value_now` to get the current total value.
Gets converted into `SizeType.TargetValue`.
TargetPercent100: `SizeType.TargetPercent` where 1.0 means 1%.
"""
class DirectionT(tp.NamedTuple):
LongOnly: int = 0
ShortOnly: int = 1
Both: int = 2
Direction = DirectionT()
"""_"""
__pdoc__[
"Direction"
] = f"""Position direction.
```python
{prettify(Direction)}
```
Attributes:
LongOnly: Only long positions.
ShortOnly: Only short positions.
Both: Both long and short positions.
"""
class LeverageModeT(tp.NamedTuple):
Lazy: int = 0
Eager: int = 1
LeverageMode = LeverageModeT()
"""_"""
__pdoc__[
"LeverageMode"
] = f"""Leverage mode.
```python
{prettify(LeverageMode)}
```
Attributes:
Lazy: Applies leverage only if free cash has been exhausted.
Eager: Applies leverage to each order.
"""
class PriceAreaVioModeT(tp.NamedTuple):
Ignore: int = 0
Cap: int = 1
Error: int = 2
PriceAreaVioMode = PriceAreaVioModeT()
"""_"""
__pdoc__[
"PriceAreaVioMode"
] = f"""Price are violation mode.
```python
{prettify(PriceAreaVioMode)}
```
Attributes:
Ignore: Ignore any violation.
Cap: Cap price to prevent violation.
Error: Throw an error upon violation.
"""
class OrderStatusT(tp.NamedTuple):
Filled: int = 0
Ignored: int = 1
Rejected: int = 2
OrderStatus = OrderStatusT()
"""_"""
__pdoc__[
"OrderStatus"
] = f"""Order status.
```python
{prettify(OrderStatus)}
```
Attributes:
Filled: Order has been filled.
Ignored: Order has been ignored.
Rejected: Order has been rejected.
"""
class OrderStatusInfoT(tp.NamedTuple):
SizeNaN: int = 0
PriceNaN: int = 1
ValPriceNaN: int = 2
ValueNaN: int = 3
ValueZeroNeg: int = 4
SizeZero: int = 5
NoCash: int = 6
NoOpenPosition: int = 7
MaxSizeExceeded: int = 8
RandomEvent: int = 9
CantCoverFees: int = 10
MinSizeNotReached: int = 11
PartialFill: int = 12
OrderStatusInfo = OrderStatusInfoT()
"""_"""
__pdoc__[
"OrderStatusInfo"
] = f"""Order status information.
```python
{prettify(OrderStatusInfo)}
```
"""
status_info_desc = [
"Size is NaN",
"Price is NaN",
"Asset valuation price is NaN",
"Asset/group value is NaN",
"Asset/group value is zero or negative",
"Size is zero",
"Not enough cash",
"No open position to reduce/close",
"Size is greater than maximum allowed",
"Random event happened",
"Not enough cash to cover fees",
"Final size is less than minimum allowed",
"Final size is less than requested",
]
"""_"""
__pdoc__[
"status_info_desc"
] = f"""Order status description.
```python
{prettify(status_info_desc)}
```
"""
class OrderSideT(tp.NamedTuple):
Buy: int = 0
Sell: int = 1
OrderSide = OrderSideT()
"""_"""
__pdoc__[
"OrderSide"
] = f"""Order side.
```python
{prettify(OrderSide)}
```
"""
class OrderTypeT(tp.NamedTuple):
Market: int = 0
Limit: int = 1
OrderType = OrderTypeT()
"""_"""
__pdoc__[
"OrderType"
] = f"""Order type.
```python
{prettify(OrderType)}
```
"""
class LimitOrderPriceT(tp.NamedTuple):
Limit: int = -1
HardLimit: int = -2
Close: int = -3
LimitOrderPrice = LimitOrderPriceT()
"""_"""
__pdoc__[
"LimitOrderPrice"
] = f"""Limit order price.
```python
{prettify(LimitOrderPrice)}
```
Which price to use when executing a limit order?
Attributes:
Limit: Limit price.
If the target price is first hit by the opening price, the opening price is used.
HardLimit: Hard limit price.
The stop price is used regardless of whether the target price is first hit by the opening price.
Close: Closing price.
!!! note
Each flag is negative, thus if a positive value is provided, it's used directly as price.
"""
class TradeDirectionT(tp.NamedTuple):
Long: int = 0
Short: int = 1
TradeDirection = TradeDirectionT()
"""_"""
__pdoc__[
"TradeDirection"
] = f"""Trade direction.
```python
{prettify(TradeDirection)}
```
"""
class TradeStatusT(tp.NamedTuple):
Open: int = 0
Closed: int = 1
TradeStatus = TradeStatusT()
"""_"""
__pdoc__[
"TradeStatus"
] = f"""Event status.
```python
{prettify(TradeStatus)}
```
"""
class TradesTypeT(tp.NamedTuple):
Trades: int = 0
EntryTrades: int = 1
ExitTrades: int = 2
Positions: int = 3
TradesType = TradesTypeT()
"""_"""
__pdoc__[
"TradesType"
] = f"""Trades type.
```python
{prettify(TradesType)}
```
"""
class OrderPriceStatusT(tp.NamedTuple):
OK: int = 0
AboveHigh: int = 1
BelowLow: int = 2
Unknown: int = 3
OrderPriceStatus = OrderPriceStatusT()
"""_"""
__pdoc__[
"OrderPriceStatus"
] = f"""Order price error.
```python
{prettify(OrderPriceStatus)}
```
Attributes:
OK: Order price is within OHLC.
AboveHigh: Order price is above high.
BelowLow: Order price is below low.
Unknown: High and/or low are unknown.
"""
class PositionFeatureT(tp.NamedTuple):
EntryPrice: int = 0
ExitPrice: int = 1
PositionFeature = PositionFeatureT()
"""_"""
__pdoc__[
"PositionFeature"
] = f"""Position feature.
```python
{prettify(PositionFeature)}
```
"""
# ############# Named tuples ############# #
class PriceArea(tp.NamedTuple):
open: float
high: float
low: float
close: float
__pdoc__[
"PriceArea"
] = """Price area defined by four boundaries.
Used together with `PriceAreaVioMode`."""
__pdoc__["PriceArea.open"] = "Opening price of the time step."
__pdoc__[
"PriceArea.high"
] = """Highest price of the time step.
Violation takes place when adjusted price goes above this value.
"""
__pdoc__[
"PriceArea.low"
] = """Lowest price of the time step.
Violation takes place when adjusted price goes below this value.
"""
__pdoc__[
"PriceArea.close"
] = """Closing price of the time step.
Violation takes place when adjusted price goes beyond this value.
"""
NoPriceArea = PriceArea(open=np.nan, high=np.nan, low=np.nan, close=np.nan)
"""_"""
__pdoc__["NoPriceArea"] = "No price area."
class AccountState(tp.NamedTuple):
cash: float
position: float
debt: float
locked_cash: float
free_cash: float
__pdoc__["AccountState"] = "State of the account."
__pdoc__[
"AccountState.cash"
] = """Cash.
Per group with cash sharing, otherwise per column."""
__pdoc__[
"AccountState.position"
] = """Position.
Per column."""
__pdoc__[
"AccountState.debt"
] = """Debt.
Per column."""
__pdoc__[
"AccountState.locked_cash"
] = """Locked cash.
Per column."""
__pdoc__[
"AccountState.free_cash"
] = """Free cash.
Per group with cash sharing, otherwise per column."""
class ExecState(tp.NamedTuple):
cash: float
position: float
debt: float
locked_cash: float
free_cash: float
val_price: float
value: float
__pdoc__["ExecState"] = "State before or after order execution."
__pdoc__["ExecState.cash"] = "See `AccountState.cash`."
__pdoc__["ExecState.position"] = "See `AccountState.position`."
__pdoc__["ExecState.debt"] = "See `AccountState.debt`."
__pdoc__["ExecState.locked_cash"] = "See `AccountState.locked_cash`."
__pdoc__["ExecState.free_cash"] = "See `AccountState.free_cash`."
__pdoc__["ExecState.val_price"] = "Valuation price in the current column."
__pdoc__["ExecState.value"] = "Value in the current column (or group with cash sharing)."
class SimulationOutput(tp.NamedTuple):
order_records: tp.RecordArray2d
log_records: tp.RecordArray2d
cash_deposits: tp.Array2d
cash_earnings: tp.Array2d
call_seq: tp.Optional[tp.Array2d]
in_outputs: tp.Optional[tp.NamedTuple]
sim_start: tp.Optional[tp.Array1d]
sim_end: tp.Optional[tp.Array1d]
__pdoc__["SimulationOutput"] = "A named tuple representing the output of a simulation."
__pdoc__["SimulationOutput.order_records"] = "Order records (flattened)."
__pdoc__["SimulationOutput.log_records"] = "Log records (flattened)."
__pdoc__[
"SimulationOutput.cash_deposits"
] = """Cash deposited/withdrawn at each timestamp.
If not tracked, becomes zero of shape `(1, 1)`."""
__pdoc__[
"SimulationOutput.cash_earnings"
] = """Cash earnings added/removed at each timestamp.
If not tracked, becomes zero of shape `(1, 1)`."""
__pdoc__[
"SimulationOutput.call_seq"
] = """Call sequence.
If not tracked, becomes None."""
__pdoc__[
"SimulationOutput.in_outputs"
] = """Named tuple with in-output objects.
If not tracked, becomes None."""
__pdoc__[
"SimulationOutput.sim_start"
] = """Start of the simulation per column.
Use `vectorbtpro.generic.nb.sim_range.prepare_ungrouped_sim_range_nb` to ungroup the array.
If not tracked, becomes None."""
__pdoc__[
"SimulationOutput.sim_end"
] = """End of the simulation per column.
Use `vectorbtpro.generic.nb.sim_range.prepare_ungrouped_sim_range_nb` to ungroup the array.
If not tracked, becomes None."""
class SimulationContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
call_seq: tp.Optional[tp.Array2d]
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
cash_deposits: tp.FlexArray2d
cash_earnings: tp.FlexArray2d
segment_mask: tp.FlexArray2d
call_pre_segment: bool
call_post_segment: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
bm_close: tp.FlexArray2d
ffill_val_price: bool
update_value: bool
fill_pos_info: bool
track_value: bool
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.RecordArray
sim_start: tp.Array1d
sim_end: tp.Array1d
__pdoc__[
"SimulationContext"
] = """A named tuple representing the context of a simulation.
Contains general information available to all other contexts.
Passed to `pre_sim_func_nb` and `post_sim_func_nb`."""
__pdoc__[
"SimulationContext.target_shape"
] = """Target shape of the simulation.
A tuple with exactly two elements: the number of rows and columns.
Example:
One day of minute data for three assets would yield a `target_shape` of `(1440, 3)`,
where the first axis are rows (minutes) and the second axis are columns (assets).
"""
__pdoc__[
"SimulationContext.group_lens"
] = """Number of columns in each group.
Even if columns are not grouped, `group_lens` contains ones - one column per group.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
Example:
In pairs trading, `group_lens` would be `np.array([2])`, while three independent
columns would be represented by `group_lens` of `np.array([1, 1, 1])`.
"""
__pdoc__["SimulationContext.cash_sharing"] = "Whether cash sharing is enabled."
__pdoc__[
"SimulationContext.call_seq"
] = """Default sequence of calls per segment.
Controls the sequence in which `order_func_nb` is executed within each segment.
Has shape `SimulationContext.target_shape` and each value must exist in the range `[0, group_len)`.
Can also be None if not provided.
!!! note
To use `sort_call_seq_1d_nb`, must be generated using `CallSeqType.Default`.
To change the call sequence dynamically, better change `SegmentContext.call_seq_now` in-place.
Example:
The default call sequence for three data points and two groups with three columns each:
```python
np.array([
[0, 1, 2, 0, 1, 2],
[0, 1, 2, 0, 1, 2],
[0, 1, 2, 0, 1, 2]
])
```
"""
__pdoc__[
"SimulationContext.init_cash"
] = """Initial capital per column (or per group with cash sharing).
Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_1d_pc_nb`.
Must broadcast to shape `(group_lens.shape[0],)` with cash sharing, otherwise `(target_shape[1],)`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
Example:
Consider three columns, each having $100 of starting capital. If we built one group of two columns
and one group of one column, the `init_cash` would be `np.array([200, 100])` with cash sharing
and `np.array([100, 100, 100])` without cash sharing.
"""
__pdoc__[
"SimulationContext.init_position"
] = """Initial position per column.
Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_1d_pc_nb`.
Must broadcast to shape `(target_shape[1],)`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.init_price"
] = """Initial position price per column.
Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_1d_pc_nb`.
Must broadcast to shape `(target_shape[1],)`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.cash_deposits"
] = """Cash to be deposited/withdrawn per column
(or per group with cash sharing).
Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_nb`.
Must broadcast to shape `(target_shape[0], group_lens.shape[0])`.
Cash is deposited/withdrawn right after `pre_segment_func_nb`.
You can modify this array in `pre_segment_func_nb`.
!!! note
To modify the array in place, make sure to build an array of the full shape.
"""
__pdoc__[
"SimulationContext.cash_earnings"
] = """Earnings to be added per column.
Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_nb`.
Must broadcast to shape `SimulationContext.target_shape`.
Earnings are added right before `post_segment_func_nb` and are already included
in the value of each group. You can modify this array in `pre_segment_func_nb` or `post_order_func_nb`.
!!! note
To modify the array in place, make sure to build an array of the full shape.
"""
__pdoc__[
"SimulationContext.segment_mask"
] = """Mask of whether a particular segment should be executed.
A segment is simply a sequence of `order_func_nb` calls under the same group and row.
If a segment is inactive, any callback function inside of it will not be executed.
You can still execute the segment's pre- and postprocessing function by enabling
`SimulationContext.call_pre_segment` and `SimulationContext.call_post_segment` respectively.
Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_nb`.
Must broadcast to shape `(target_shape[0], group_lens.shape[0])`.
!!! note
To modify the array in place, make sure to build an array of the full shape.
Example:
Consider two groups with two columns each and the following activity mask:
```python
np.array([[ True, False],
[False, True]])
```
The first group is only executed in the first row and the second group is only executed in the second row.
"""
__pdoc__[
"SimulationContext.call_pre_segment"
] = """Whether to call `pre_segment_func_nb` regardless of
`SimulationContext.segment_mask`."""
__pdoc__[
"SimulationContext.call_post_segment"
] = """Whether to call `post_segment_func_nb` regardless of
`SimulationContext.segment_mask`.
Allows, for example, to write user-defined arrays such as returns at the end of each segment."""
__pdoc__[
"SimulationContext.index"
] = """Index in integer (nanosecond) format.
If datetime-like, assumed to have the UTC timezone. Preset simulation methods
will automatically format any index as UTC without actually converting it to UTC, that is,
`12:00 +02:00` will become `12:00 +00:00` to avoid timezone conversion issues."""
__pdoc__["SimulationContext.freq"] = """Frequency of index in integer (nanosecond) format."""
__pdoc__[
"SimulationContext.open"
] = """Opening price.
Replaces `Order.price` in case it's `-np.inf`.
Similar behavior to that of `SimulationContext.close`."""
__pdoc__[
"SimulationContext.high"
] = """Highest price.
Similar behavior to that of `SimulationContext.close`."""
__pdoc__[
"SimulationContext.low"
] = """Lowest price.
Similar behavior to that of `SimulationContext.close`."""
__pdoc__[
"SimulationContext.close"
] = """Closing price at each time step.
Replaces `Order.price` in case it's `np.inf`.
Acts as a boundary - see `PriceArea.close`.
Utilizes flexible indexing using `vectorbtpro.base.flex_indexing.flex_select_nb`.
Must broadcast to shape `SimulationContext.target_shape`.
!!! note
To modify the array in place, make sure to build an array of the full shape.
"""
__pdoc__[
"SimulationContext.bm_close"
] = """Benchmark closing price at each time step.
Must broadcast to shape `SimulationContext.target_shape`."""
__pdoc__[
"SimulationContext.ffill_val_price"
] = """Whether to track valuation price only if it's known.
Otherwise, unknown `SimulationContext.close` will lead to NaN in valuation price at the next timestamp."""
__pdoc__[
"SimulationContext.update_value"
] = """Whether to update group value after each filled order.
Otherwise, stays the same for all columns in the group (the value is calculated
only once, before executing any order).
The change is marginal and mostly driven by transaction costs and slippage."""
__pdoc__[
"SimulationContext.fill_pos_info"
] = """Whether to fill position record.
Disable this to make simulation faster for simple use cases."""
__pdoc__[
"SimulationContext.track_value"
] = """Whether to track value metrics such as
the current valuation price, value, and return.
If False, 'SimulationContext.last_val_price', 'SimulationContext.last_value', and
'SimulationContext.last_return' will stay NaN and the statistics of any open position won't be updated.
You won't be able to use `SizeType.Value`, `SizeType.TargetValue`, and `SizeType.TargetPercent`.
Disable this to make simulation faster for simple use cases."""
__pdoc__[
"SimulationContext.order_records"
] = """Order records per column.
It's a 2-dimensional array with records of type `order_dt`.
The array is initialized with empty records first (they contain random data), and then
gradually filled with order data. The number of empty records depends upon `max_order_records`,
but usually it matches the number of rows, meaning there is maximal one order record per element.
`max_order_records` can be chosen lower if not every `order_func_nb` leads to a filled order, to save memory.
It can also be chosen higher if more than one order per element is expected.
You can use `SimulationContext.order_counts` to get the number of filled orders in each column.
To get all order records filled up to this point in a column, do `order_records[:order_counts[col], col]`.
Example:
Before filling, each order record looks like this:
```python
np.array([(-8070450532247928832, -8070450532247928832, 4, 0., 0., 0., 5764616306889786413)]
```
After filling, it becomes like this:
```python
np.array([(0, 0, 1, 50., 1., 0., 1)]
```
"""
__pdoc__[
"SimulationContext.order_counts"
] = """Number of filled order records in each column.
Points to `SimulationContext.order_records` and has shape `(target_shape[1],)`.
Example:
`order_counts` of `np.array([2, 100, 0])` means the latest filled order is `order_records[1, 0]` in the
first column, `order_records[99, 1]` in the second column, and no orders have been filled yet
in the third column (`order_records[0, 2]` is empty).
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.log_records"
] = """Log records per column.
Similar to `SimulationContext.order_records` but of type `log_dt` and count `SimulationContext.log_counts`."""
__pdoc__[
"SimulationContext.log_counts"
] = """Number of filled log records in each column.
Similar to `SimulationContext.log_counts` but for log records.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.in_outputs"
] = """Named tuple with in-output objects.
Can contain objects of arbitrary shape and type. Will be returned as part of `SimulationOutput`."""
__pdoc__[
"SimulationContext.last_cash"
] = """Latest cash per column (or per group with cash sharing).
At the very first timestamp, contains initial capital.
Gets updated right after `order_func_nb`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.last_position"
] = """Latest position per column.
At the very first timestamp, contains initial position.
Has shape `(target_shape[1],)`.
Gets updated right after `order_func_nb`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.last_debt"
] = """Latest debt from leverage or shorting per column.
Has shape `(target_shape[1],)`.
Gets updated right after `order_func_nb`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.last_locked_cash"
] = """Latest locked cash from leverage or shorting per column.
Has shape `(target_shape[1],)`.
Gets updated right after `order_func_nb`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.last_free_cash"
] = """Latest free cash per column (or per group with cash sharing).
Free cash never goes above the initial level, because an operation always costs money.
Has shape `(target_shape[1],)`.
Gets updated right after `order_func_nb`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.last_val_price"
] = """Latest valuation price per column.
Has shape `(target_shape[1],)`.
Enables `SizeType.Value`, `SizeType.TargetValue`, and `SizeType.TargetPercent`.
Gets multiplied by the current position to get the value of the column (see `SimulationContext.last_value`).
Gets updated right before `pre_segment_func_nb` using `SimulationContext.open`.
Then, gets updated right after `pre_segment_func_nb` - you can use `pre_segment_func_nb` to
override `last_val_price` in-place, such that `order_func_nb` can use the new group value.
If `SimulationContext.update_value`, gets also updated right after `order_func_nb` using
filled order price as the latest known price. Finally, gets updated right before
`post_segment_func_nb` using `SimulationContext.close`.
If `SimulationContext.ffill_val_price`, gets updated only if the value is not NaN.
For example, close of `[1, 2, np.nan, np.nan, 5]` yields valuation price of `[1, 2, 2, 2, 5]`.
!!! note
You are not allowed to use `-np.inf` or `np.inf` - only finite values.
If `SimulationContext.open` is NaN in the first row, the `last_val_price` is also NaN.
Example:
Consider 10 units in column 1 and 20 units in column 2. The current opening price of them is
$40 and $50 respectively, which is also the default valuation price in the current row,
available as `last_val_price` in `pre_segment_func_nb`. If both columns are in the same group
with cash sharing, the group is valued at $1400 before any `order_func_nb` is called, and can
be later accessed via `OrderContext.value_now`.
"""
__pdoc__[
"SimulationContext.last_value"
] = """Latest value per column (or per group with cash sharing).
Calculated by multiplying the valuation price by the current position and adding the cash.
The value in each column in a group with cash sharing is summed to get the value of the entire group.
Gets updated right before `pre_segment_func_nb`. Then, gets updated right after `pre_segment_func_nb`.
If `SimulationContext.update_value`, gets also updated right after `order_func_nb` using
filled order price as the latest known price (the difference will be minimal,
only affected by costs). Finally, gets updated right before `post_segment_func_nb`.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.last_return"
] = """Latest return per column (or per group with cash sharing).
Has the same shape as `SimulationContext.last_value`.
Calculated by comparing the current `SimulationContext.last_value` to the last one of the previous row.
Gets updated each time `SimulationContext.last_value` is updated.
!!! note
Changing this array may produce results inconsistent with those of `vectorbtpro.portfolio.base.Portfolio`.
"""
__pdoc__[
"SimulationContext.last_pos_info"
] = """Latest position record in each column.
It's a 1-dimensional array with records of type `trade_dt`.
Has shape `(target_shape[1],)`.
If `SimulationContext.init_position` is not zero in a column, that column's position record
is automatically filled before the simulation with `entry_price` set to `SimulationContext.init_price` and
`entry_idx` of -1.
The fields `entry_price` and `exit_price` are average entry and exit price respectively.
The average exit price does **not** contain open statistics, as opposed to `vectorbtpro.portfolio.trades.Positions`.
On the other hand, fields `pnl` and `return` contain statistics as if the position has been closed and are
re-calculated using `SimulationContext.last_val_price` right before and after `pre_segment_func_nb`,
right after `order_func_nb`, and right before `post_segment_func_nb`.
!!! note
In an open position record, the field `exit_price` doesn't reflect the latest valuation price,
but keeps the average price at which the position has been reduced.
"""
__pdoc__[
"SimulationContext.sim_start"
] = """Start of the simulation per column or group (also without cash sharing).
Changing in-place won't apply to the current simulation."""
__pdoc__[
"SimulationContext.sim_start"
] = """End of the simulation per column or group (also without cash sharing).
Changing in-place will apply to the current simulation if it's lower than the initial value."""
class GroupContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
call_seq: tp.Optional[tp.Array2d]
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
cash_deposits: tp.FlexArray2d
cash_earnings: tp.FlexArray2d
segment_mask: tp.FlexArray2d
call_pre_segment: bool
call_post_segment: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
bm_close: tp.FlexArray2d
ffill_val_price: bool
update_value: bool
fill_pos_info: bool
track_value: bool
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.RecordArray
sim_start: tp.Array1d
sim_end: tp.Array1d
group: int
group_len: int
from_col: int
to_col: int
__pdoc__[
"GroupContext"
] = """A named tuple representing the context of a group.
A group is a set of nearby columns that are somehow related (for example, by sharing the same capital).
In each row, the columns under the same group are bound to the same segment.
Contains all fields from `SimulationContext` plus fields describing the current group.
Passed to `pre_group_func_nb` and `post_group_func_nb`.
Example:
Consider a group of three columns, a group of two columns, and one more column:
| group | group_len | from_col | to_col |
| ----- | --------- | -------- | ------ |
| 0 | 3 | 0 | 3 |
| 1 | 2 | 3 | 5 |
| 2 | 1 | 5 | 6 |
"""
for field in GroupContext._fields:
if field in SimulationContext._fields:
__pdoc__["GroupContext." + field] = f"See `SimulationContext.{field}`."
__pdoc__[
"GroupContext.group"
] = """Index of the current group.
Has range `[0, group_lens.shape[0])`.
"""
__pdoc__[
"GroupContext.group_len"
] = """Number of columns in the current group.
Scalar value. Same as `group_lens[group]`.
"""
__pdoc__[
"GroupContext.from_col"
] = """Index of the first column in the current group.
Has range `[0, target_shape[1])`.
"""
__pdoc__[
"GroupContext.to_col"
] = """Index of the last column in the current group plus one.
Has range `[1, target_shape[1] + 1)`.
If columns are not grouped, equals to `from_col + 1`.
!!! warning
In the last group, `to_col` points at a column that doesn't exist.
"""
class RowContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
call_seq: tp.Optional[tp.Array2d]
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
cash_deposits: tp.FlexArray2d
cash_earnings: tp.FlexArray2d
segment_mask: tp.FlexArray2d
call_pre_segment: bool
call_post_segment: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
bm_close: tp.FlexArray2d
ffill_val_price: bool
update_value: bool
fill_pos_info: bool
track_value: bool
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.RecordArray
sim_start: tp.Array1d
sim_end: tp.Array1d
i: int
__pdoc__[
"RowContext"
] = """A named tuple representing the context of a row.
A row is a time step in which segments are executed.
Contains all fields from `SimulationContext` plus fields describing the current row.
Passed to `pre_row_func_nb` and `post_row_func_nb`.
"""
for field in RowContext._fields:
if field in SimulationContext._fields:
__pdoc__["RowContext." + field] = f"See `SimulationContext.{field}`."
__pdoc__[
"RowContext.i"
] = """Index of the current row.
Has range `[0, target_shape[0])`.
"""
class SegmentContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
call_seq: tp.Optional[tp.Array2d]
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
cash_deposits: tp.FlexArray2d
cash_earnings: tp.FlexArray2d
segment_mask: tp.FlexArray2d
call_pre_segment: bool
call_post_segment: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
bm_close: tp.FlexArray2d
ffill_val_price: bool
update_value: bool
fill_pos_info: bool
track_value: bool
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.RecordArray
sim_start: tp.Array1d
sim_end: tp.Array1d
group: int
group_len: int
from_col: int
to_col: int
i: int
call_seq_now: tp.Optional[tp.Array1d]
__pdoc__[
"SegmentContext"
] = """A named tuple representing the context of a segment.
A segment is an intersection between groups and rows. It's an entity that defines
how and in which order elements within the same group and row are processed.
Contains all fields from `SimulationContext`, `GroupContext`, and `RowContext`, plus fields
describing the current segment.
Passed to `pre_segment_func_nb` and `post_segment_func_nb`.
"""
for field in SegmentContext._fields:
if field in SimulationContext._fields:
__pdoc__["SegmentContext." + field] = f"See `SimulationContext.{field}`."
elif field in GroupContext._fields:
__pdoc__["SegmentContext." + field] = f"See `GroupContext.{field}`."
elif field in RowContext._fields:
__pdoc__["SegmentContext." + field] = f"See `RowContext.{field}`."
__pdoc__[
"SegmentContext.call_seq_now"
] = """Sequence of calls within the current segment.
Has shape `(group_len,)`.
Each value in this sequence must indicate the position of column in the group to
call next. Processing goes always from left to right.
You can use `pre_segment_func_nb` to override `call_seq_now`.
Example:
`[2, 0, 1]` would first call column 2, then 0, and finally 1.
"""
class OrderContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
call_seq: tp.Optional[tp.Array2d]
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
cash_deposits: tp.FlexArray2d
cash_earnings: tp.FlexArray2d
segment_mask: tp.FlexArray2d
call_pre_segment: bool
call_post_segment: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
bm_close: tp.FlexArray2d
ffill_val_price: bool
update_value: bool
fill_pos_info: bool
track_value: bool
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.RecordArray
sim_start: tp.Array1d
sim_end: tp.Array1d
group: int
group_len: int
from_col: int
to_col: int
i: int
call_seq_now: tp.Optional[tp.Array1d]
col: int
call_idx: int
cash_now: float
position_now: float
debt_now: float
locked_cash_now: float
free_cash_now: float
val_price_now: float
value_now: float
return_now: float
pos_info_now: tp.Record
__pdoc__[
"OrderContext"
] = """A named tuple representing the context of an order.
Contains all fields from `SegmentContext` plus fields describing the current state.
Passed to `order_func_nb`.
"""
for field in OrderContext._fields:
if field in SimulationContext._fields:
__pdoc__["OrderContext." + field] = f"See `SimulationContext.{field}`."
elif field in GroupContext._fields:
__pdoc__["OrderContext." + field] = f"See `GroupContext.{field}`."
elif field in RowContext._fields:
__pdoc__["OrderContext." + field] = f"See `RowContext.{field}`."
elif field in SegmentContext._fields:
__pdoc__["OrderContext." + field] = f"See `SegmentContext.{field}`."
__pdoc__[
"OrderContext.col"
] = """Current column.
Has range `[0, target_shape[1])` and is always within `[from_col, to_col)`.
"""
__pdoc__[
"OrderContext.call_idx"
] = """Index of the current call in `SegmentContext.call_seq_now`.
Has range `[0, group_len)`.
"""
__pdoc__["OrderContext.cash_now"] = "`SimulationContext.last_cash` for the current column/group."
__pdoc__["OrderContext.position_now"] = "`SimulationContext.last_position` for the current column."
__pdoc__["OrderContext.debt_now"] = "`SimulationContext.last_debt` for the current column."
__pdoc__["OrderContext.locked_cash_now"] = "`SimulationContext.last_locked_cash` for the current column."
__pdoc__["OrderContext.free_cash_now"] = "`SimulationContext.last_free_cash` for the current column/group."
__pdoc__["OrderContext.val_price_now"] = "`SimulationContext.last_val_price` for the current column."
__pdoc__["OrderContext.value_now"] = "`SimulationContext.last_value` for the current column/group."
__pdoc__["OrderContext.return_now"] = "`SimulationContext.last_return` for the current column/group."
__pdoc__["OrderContext.pos_info_now"] = "`SimulationContext.last_pos_info` for the current column."
class PostOrderContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
call_seq: tp.Optional[tp.Array2d]
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
cash_deposits: tp.FlexArray2d
cash_earnings: tp.FlexArray2d
segment_mask: tp.FlexArray2d
call_pre_segment: bool
call_post_segment: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
bm_close: tp.FlexArray2d
ffill_val_price: bool
update_value: bool
fill_pos_info: bool
track_value: bool
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.RecordArray
sim_start: tp.Array1d
sim_end: tp.Array1d
group: int
group_len: int
from_col: int
to_col: int
i: int
call_seq_now: tp.Optional[tp.Array1d]
col: int
call_idx: int
cash_before: float
position_before: float
debt_before: float
locked_cash_before: float
free_cash_before: float
val_price_before: float
value_before: float
order_result: "OrderResult"
cash_now: float
position_now: float
debt_now: float
locked_cash_now: float
free_cash_now: float
val_price_now: float
value_now: float
return_now: float
pos_info_now: tp.Record
__pdoc__[
"PostOrderContext"
] = """A named tuple representing the context after an order has been processed.
Contains all fields from `OrderContext` plus fields describing the order result and the previous state.
Passed to `post_order_func_nb`.
"""
for field in PostOrderContext._fields:
if field in SimulationContext._fields:
__pdoc__["PostOrderContext." + field] = f"See `SimulationContext.{field}`."
elif field in GroupContext._fields:
__pdoc__["PostOrderContext." + field] = f"See `GroupContext.{field}`."
elif field in RowContext._fields:
__pdoc__["PostOrderContext." + field] = f"See `RowContext.{field}`."
elif field in SegmentContext._fields:
__pdoc__["PostOrderContext." + field] = f"See `SegmentContext.{field}`."
elif field in OrderContext._fields:
__pdoc__["PostOrderContext." + field] = f"See `OrderContext.{field}`."
__pdoc__["PostOrderContext.cash_before"] = "`OrderContext.cash_now` before execution."
__pdoc__["PostOrderContext.position_before"] = "`OrderContext.position_now` before execution."
__pdoc__["PostOrderContext.debt_before"] = "`OrderContext.debt_now` before execution."
__pdoc__["PostOrderContext.locked_cash_before"] = "`OrderContext.locked_cash_now` before execution."
__pdoc__["PostOrderContext.free_cash_before"] = "`OrderContext.free_cash_now` before execution."
__pdoc__["PostOrderContext.val_price_before"] = "`OrderContext.val_price_now` before execution."
__pdoc__["PostOrderContext.value_before"] = "`OrderContext.value_now` before execution."
__pdoc__[
"PostOrderContext.order_result"
] = """Order result of type `OrderResult`.
Can be used to check whether the order has been filled, ignored, or rejected.
"""
__pdoc__["PostOrderContext.cash_now"] = "`OrderContext.cash_now` after execution."
__pdoc__["PostOrderContext.position_now"] = "`OrderContext.position_now` after execution."
__pdoc__["PostOrderContext.debt_now"] = "`OrderContext.debt_now` after execution."
__pdoc__["PostOrderContext.locked_cash_now"] = "`OrderContext.locked_cash_now` after execution."
__pdoc__["PostOrderContext.free_cash_now"] = "`OrderContext.free_cash_now` after execution."
__pdoc__[
"PostOrderContext.val_price_now"
] = """`OrderContext.val_price_now` after execution.
If `SimulationContext.update_value`, gets replaced with the fill price,
as it becomes the most recently known price. Otherwise, stays the same.
"""
__pdoc__[
"PostOrderContext.value_now"
] = """`OrderContext.value_now` after execution.
If `SimulationContext.update_value`, gets updated with the new cash and value of the column. Otherwise, stays the same.
"""
__pdoc__["PostOrderContext.return_now"] = "`OrderContext.return_now` after execution."
__pdoc__["PostOrderContext.pos_info_now"] = "`OrderContext.pos_info_now` after execution."
class FlexOrderContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
call_seq: tp.Optional[tp.Array2d]
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
cash_deposits: tp.FlexArray2d
cash_earnings: tp.FlexArray2d
segment_mask: tp.FlexArray2d
call_pre_segment: bool
call_post_segment: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
bm_close: tp.FlexArray2d
ffill_val_price: bool
update_value: bool
fill_pos_info: bool
track_value: bool
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.RecordArray
sim_start: tp.Array1d
sim_end: tp.Array1d
group: int
group_len: int
from_col: int
to_col: int
i: int
call_seq_now: None
call_idx: int
__pdoc__[
"FlexOrderContext"
] = """A named tuple representing the context of a flexible order.
Contains all fields from `SegmentContext` plus the current call index.
Passed to `flex_order_func_nb`.
"""
for field in FlexOrderContext._fields:
if field in SimulationContext._fields:
__pdoc__["FlexOrderContext." + field] = f"See `SimulationContext.{field}`."
elif field in GroupContext._fields:
__pdoc__["FlexOrderContext." + field] = f"See `GroupContext.{field}`."
elif field in RowContext._fields:
__pdoc__["FlexOrderContext." + field] = f"See `RowContext.{field}`."
elif field in SegmentContext._fields:
__pdoc__["FlexOrderContext." + field] = f"See `SegmentContext.{field}`."
__pdoc__["FlexOrderContext.call_idx"] = "Index of the current call."
class Order(tp.NamedTuple):
size: float = np.inf
price: float = np.inf
size_type: int = SizeType.Amount
direction: int = Direction.Both
fees: float = 0.0
fixed_fees: float = 0.0
slippage: float = 0.0
min_size: float = np.nan
max_size: float = np.nan
size_granularity: float = np.nan
leverage: float = 1.0
leverage_mode: int = LeverageMode.Lazy
reject_prob: float = 0.0
price_area_vio_mode: int = PriceAreaVioMode.Ignore
allow_partial: bool = True
raise_reject: bool = False
log: bool = False
__pdoc__[
"Order"
] = """A named tuple representing an order.
!!! note
Currently, Numba has issues with using defaults when filling named tuples.
Use `vectorbtpro.portfolio.nb.core.order_nb` to create an order."""
__pdoc__[
"Order.size"
] = """Size in units.
Behavior depends upon `Order.size_type` and `Order.direction`.
For any fixed size:
* Set to any number to buy/sell some fixed amount or value.
* Set to `np.inf` to buy for all cash, or `-np.inf` to sell for all free cash.
If `Order.direction` is not `Direction.Both`, `-np.inf` will close the position.
* Set to `np.nan` or 0 to skip.
For any target size:
* Set to any number to buy/sell an amount relative to the current position or value.
* Set to 0 to close the current position.
* Set to `np.nan` to skip.
"""
__pdoc__[
"Order.price"
] = """Price per unit.
Final price will depend upon slippage.
* If `-np.inf`, gets replaced by the current open.
* If `np.inf`, gets replaced by the current close.
!!! note
Make sure to use timestamps that come between (and ideally not including) the current open and close."""
__pdoc__["Order.size_type"] = "See `SizeType`."
__pdoc__["Order.direction"] = "See `Direction`."
__pdoc__[
"Order.fees"
] = """Fees in percentage of the order value.
Negative trading fees like -0.05 mean earning 5% per trade instead of paying a fee.
!!! note
0.01 = 1%."""
__pdoc__[
"Order.fixed_fees"
] = """Fixed amount of fees to pay for this order.
Similar to `Order.fees`, can be negative."""
__pdoc__[
"Order.slippage"
] = """Slippage in percentage of `Order.price`.
Slippage is a penalty applied on the price.
!!! note
0.01 = 1%."""
__pdoc__[
"Order.min_size"
] = """Minimum size in both directions.
Depends on `Order.size_type`. Lower than that will be rejected."""
__pdoc__[
"Order.max_size"
] = """Maximum size in both directions.
Depends on `Order.size_type`. Higher than that will be partly filled."""
__pdoc__[
"Order.size_granularity"
] = """Granularity of the size.
For example, granularity of 1.0 makes the quantity to behave like an integer.
Placing an order of 12.5 shares (in any direction) will order exactly 12.0 shares.
!!! note
The filled size remains a floating number."""
__pdoc__["Order.leverage"] = "Leverage."
__pdoc__["Order.leverage_mode"] = "See `LeverageMode`."
__pdoc__[
"Order.reject_prob"
] = """Probability of rejecting this order to simulate a random rejection event.
Not everything goes smoothly in real life. Use random rejections to test your order management for robustness."""
__pdoc__["Order.price_area_vio_mode"] = "See `PriceAreaVioMode`."
__pdoc__[
"Order.allow_partial"
] = """Whether to allow partial fill.
Otherwise, the order gets rejected.
Does not apply when `Order.size` is `np.inf`."""
__pdoc__[
"Order.raise_reject"
] = """Whether to raise exception if order has been rejected.
Terminates the simulation."""
__pdoc__[
"Order.log"
] = """Whether to log this order by filling a log record.
Remember to increase `max_log_records`."""
NoOrder = Order(
size=np.nan,
price=np.nan,
size_type=-1,
direction=-1,
fees=np.nan,
fixed_fees=np.nan,
slippage=np.nan,
min_size=np.nan,
max_size=np.nan,
size_granularity=np.nan,
leverage=1.0,
leverage_mode=LeverageMode.Lazy,
reject_prob=np.nan,
price_area_vio_mode=-1,
allow_partial=False,
raise_reject=False,
log=False,
)
"""_"""
__pdoc__["NoOrder"] = "Order that should not be processed."
class OrderResult(tp.NamedTuple):
size: float
price: float
fees: float
side: int
status: int
status_info: int
__pdoc__["OrderResult"] = "A named tuple representing an order result."
__pdoc__["OrderResult.size"] = "Filled size."
__pdoc__["OrderResult.price"] = "Filled price per unit, adjusted with slippage."
__pdoc__["OrderResult.fees"] = "Total fees paid for this order."
__pdoc__["OrderResult.side"] = "See `OrderSide`."
__pdoc__["OrderResult.status"] = "See `OrderStatus`."
__pdoc__["OrderResult.status_info"] = "See `OrderStatusInfo`."
class SignalSegmentContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
track_cash_deposits: bool
cash_deposits_out: tp.Array2d
track_cash_earnings: bool
cash_earnings_out: tp.Array2d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.Array1d
last_limit_info: tp.Array1d
last_sl_info: tp.Array1d
last_tsl_info: tp.Array1d
last_tp_info: tp.Array1d
last_td_info: tp.Array1d
last_dt_info: tp.Array1d
sim_start: tp.Array1d
sim_end: tp.Array1d
group: int
group_len: int
from_col: int
to_col: int
i: int
__pdoc__[
"SignalSegmentContext"
] = """A named tuple representing the context of a segment in a from-signals simulation.
Contains information related to the cascade of the simulation, such as OHLC, but also
internal information that is not passed by the user but created at the beginning of the simulation.
To make use of other information, such as order size, use templates.
Passed to `post_segment_func_nb`."""
for field in SignalSegmentContext._fields:
if field in SimulationContext._fields:
__pdoc__["SignalSegmentContext." + field] = f"See `SimulationContext.{field}`."
for field in SignalSegmentContext._fields:
if field in GroupContext._fields:
__pdoc__["SignalSegmentContext." + field] = f"See `GroupContext.{field}`."
for field in SignalSegmentContext._fields:
if field in RowContext._fields:
__pdoc__["SignalSegmentContext." + field] = f"See `RowContext.{field}`."
__pdoc__[
"SignalSegmentContext.track_cash_deposits"
] = """Whether to track cash deposits.
Becomes True if any value in `cash_deposits` is not zero."""
__pdoc__["SignalSegmentContext.cash_deposits_out"] = "See `SimulationOutput.cash_deposits`."
__pdoc__[
"SignalSegmentContext.track_cash_earnings"
] = """Whether to track cash earnings.
Becomes True if any value in `cash_earnings` is not zero."""
__pdoc__["SignalSegmentContext.cash_earnings_out"] = "See `SimulationOutput.cash_earnings`."
__pdoc__["SignalSegmentContext.in_outputs"] = "See `FSInOutputs`."
__pdoc__[
"SignalSegmentContext.last_limit_info"
] = """Record of type `limit_info_dt` per column.
Accessible via `c.limit_info_dt[field][col]`."""
__pdoc__[
"SignalSegmentContext.last_sl_info"
] = """Record of type `sl_info_dt` per column.
Accessible via `c.last_sl_info[field][col]`."""
__pdoc__[
"SignalSegmentContext.last_tsl_info"
] = """Record of type `tsl_info_dt` per column.
Accessible via `c.last_tsl_info[field][col]`."""
__pdoc__[
"SignalSegmentContext.last_tp_info"
] = """Record of type `tp_info_dt` per column.
Accessible via `c.last_tp_info[field][col]`."""
__pdoc__[
"SignalSegmentContext.last_td_info"
] = """Record of type `time_info_dt` per column.
Accessible via `c.last_td_info[field][col]`."""
__pdoc__[
"SignalSegmentContext.last_dt_info"
] = """Record of type `time_info_dt` per column.
Accessible via `c.last_dt_info[field][col]`."""
class SignalContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
track_cash_deposits: bool
cash_deposits_out: tp.Array2d
track_cash_earnings: bool
cash_earnings_out: tp.Array2d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.Array1d
last_limit_info: tp.Array1d
last_sl_info: tp.Array1d
last_tsl_info: tp.Array1d
last_tp_info: tp.Array1d
last_td_info: tp.Array1d
last_dt_info: tp.Array1d
sim_start: tp.Array1d
sim_end: tp.Array1d
group: int
group_len: int
from_col: int
to_col: int
i: int
col: int
__pdoc__[
"SignalContext"
] = """A named tuple representing the context of an element in a from-signals simulation.
Contains all fields from `SignalSegmentContext` plus the column field.
Passed to `signal_func_nb` and `adjust_func_nb`.
"""
for field in SignalContext._fields:
if field in SignalSegmentContext._fields:
__pdoc__["SignalContext." + field] = f"See `SignalSegmentContext.{field}`."
__pdoc__["SignalContext.col"] = "See `OrderContext.col`."
class PostSignalContext(tp.NamedTuple):
target_shape: tp.Shape
group_lens: tp.GroupLens
cash_sharing: bool
index: tp.Optional[tp.Array1d]
freq: tp.Optional[int]
open: tp.FlexArray2d
high: tp.FlexArray2d
low: tp.FlexArray2d
close: tp.FlexArray2d
init_cash: tp.FlexArray1d
init_position: tp.FlexArray1d
init_price: tp.FlexArray1d
order_records: tp.RecordArray2d
order_counts: tp.Array1d
log_records: tp.RecordArray2d
log_counts: tp.Array1d
track_cash_deposits: bool
cash_deposits_out: tp.Array2d
track_cash_earnings: bool
cash_earnings_out: tp.Array2d
in_outputs: tp.Optional[tp.NamedTuple]
last_cash: tp.Array1d
last_position: tp.Array1d
last_debt: tp.Array1d
last_locked_cash: tp.Array1d
last_free_cash: tp.Array1d
last_val_price: tp.Array1d
last_value: tp.Array1d
last_return: tp.Array1d
last_pos_info: tp.Array1d
last_limit_info: tp.Array1d
last_sl_info: tp.Array1d
last_tsl_info: tp.Array1d
last_tp_info: tp.Array1d
last_td_info: tp.Array1d
last_dt_info: tp.Array1d
sim_start: tp.Array1d
sim_end: tp.Array1d
group: int
group_len: int
from_col: int
to_col: int
i: int
col: int
cash_before: float
position_before: float
debt_before: float
locked_cash_before: float
free_cash_before: float
val_price_before: float
value_before: float
order_result: "OrderResult"
__pdoc__[
"PostSignalContext"
] = """A named tuple representing the context after an order has been processed in a from-signals simulation.
Contains all fields from `SignalContext` plus the previous balances and order result.
Passed to `post_signal_func_nb`.
"""
for field in PostSignalContext._fields:
if field in SignalContext._fields:
__pdoc__["PostSignalContext." + field] = f"See `SignalContext.{field}`."
__pdoc__["PostSignalContext.cash_before"] = "`ExecState.cash` before execution."
__pdoc__["PostSignalContext.position_before"] = "`ExecState.position` before execution."
__pdoc__["PostSignalContext.debt_before"] = "`ExecState.debt` before execution."
__pdoc__["PostSignalContext.locked_cash_before"] = "`ExecState.free_cash` before execution."
__pdoc__["PostSignalContext.free_cash_before"] = "`ExecState.val_price` before execution."
__pdoc__["PostSignalContext.val_price_before"] = "`ExecState.value` before execution."
__pdoc__["PostSignalContext.order_result"] = "`PostOrderContext.order_result`."
# ############# In-outputs ############# #
class FOInOutputs(tp.NamedTuple):
cash: tp.Array2d
position: tp.Array2d
debt: tp.Array2d
locked_cash: tp.Array2d
free_cash: tp.Array2d
value: tp.Array2d
returns: tp.Array2d
__pdoc__["FOInOutputs"] = "A named tuple representing the in-outputs for simulation based on orders."
__pdoc__[
"FOInOutputs.cash"
] = """See `AccountState.cash`.
Follows groups if cash sharing is enabled, otherwise columns.
Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`."""
__pdoc__[
"FOInOutputs.position"
] = """See `AccountState.position`.
Follows columns.
Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`."""
__pdoc__[
"FOInOutputs.debt"
] = """See `AccountState.debt`.
Follows columns.
Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`."""
__pdoc__[
"FOInOutputs.locked_cash"
] = """See `AccountState.locked_cash`.
Follows columns.
Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`."""
__pdoc__[
"FOInOutputs.free_cash"
] = """See `AccountState.free_cash`.
Follows groups if cash sharing is enabled, otherwise columns.
Gets filled if `save_state` is True, otherwise has the shape `(0, 0)`."""
__pdoc__[
"FOInOutputs.value"
] = """Value.
Follows groups if cash sharing is enabled, otherwise columns.
Gets filled if `fill_value` is True, otherwise has the shape `(0, 0)`."""
__pdoc__[
"FOInOutputs.returns"
] = """Returns.
Follows groups if cash sharing is enabled, otherwise columns.
Gets filled if `save_returns` is True, otherwise has the shape `(0, 0)`."""
class FSInOutputs(tp.NamedTuple):
cash: tp.Array2d
position: tp.Array2d
debt: tp.Array2d
locked_cash: tp.Array2d
free_cash: tp.Array2d
value: tp.Array2d
returns: tp.Array2d
__pdoc__["FSInOutputs"] = "A named tuple representing the in-outputs for simulation based on signals."
__pdoc__["FSInOutputs.cash"] = "See `FOInOutputs.cash`."
__pdoc__["FSInOutputs.position"] = "See `FOInOutputs.position`."
__pdoc__["FSInOutputs.debt"] = "See `FOInOutputs.debt`."
__pdoc__["FSInOutputs.locked_cash"] = "See `FOInOutputs.locked_cash`."
__pdoc__["FSInOutputs.free_cash"] = "See `FOInOutputs.free_cash`."
__pdoc__["FSInOutputs.value"] = "See `FOInOutputs.value`."
__pdoc__["FSInOutputs.returns"] = "See `FOInOutputs.returns`."
# ############# Records ############# #
order_fields = [
("id", int_),
("col", int_),
("idx", int_),
("size", float_),
("price", float_),
("fees", float_),
("side", int_),
]
"""Fields for `order_dt`."""
order_dt = np.dtype(order_fields, align=True)
"""_"""
__pdoc__[
"order_dt"
] = f"""`np.dtype` of order records.
```python
{prettify(order_dt)}
```
"""
fs_order_fields = [
("id", int_),
("col", int_),
("signal_idx", int_),
("creation_idx", int_),
("idx", int_),
("size", float_),
("price", float_),
("fees", float_),
("side", int_),
("type", int_),
("stop_type", int_),
]
"""Fields for `fs_order_dt`."""
fs_order_dt = np.dtype(fs_order_fields, align=True)
"""_"""
__pdoc__[
"fs_order_dt"
] = f"""`np.dtype` of order records generated from signals.
```python
{prettify(fs_order_dt)}
```
"""
trade_fields = [
("id", int_),
("col", int_),
("size", float_),
("entry_order_id", int_),
("entry_idx", int_),
("entry_price", float_),
("entry_fees", float_),
("exit_order_id", int_),
("exit_idx", int_),
("exit_price", float_),
("exit_fees", float_),
("pnl", float_),
("return", float_),
("direction", int_),
("status", int_),
("parent_id", int_),
]
"""Fields for `trade_dt`."""
trade_dt = np.dtype(trade_fields, align=True)
"""_"""
__pdoc__[
"trade_dt"
] = f"""`np.dtype` of trade records.
```python
{prettify(trade_dt)}
```
"""
log_fields = [
("id", int_),
("group", int_),
("col", int_),
("idx", int_),
("price_area_open", float_),
("price_area_high", float_),
("price_area_low", float_),
("price_area_close", float_),
("st0_cash", float_),
("st0_position", float_),
("st0_debt", float_),
("st0_locked_cash", float_),
("st0_free_cash", float_),
("st0_val_price", float_),
("st0_value", float_),
("req_size", float_),
("req_price", float_),
("req_size_type", int_),
("req_direction", int_),
("req_fees", float_),
("req_fixed_fees", float_),
("req_slippage", float_),
("req_min_size", float_),
("req_max_size", float_),
("req_size_granularity", float_),
("req_leverage", float_),
("req_leverage_mode", int_),
("req_reject_prob", float_),
("req_price_area_vio_mode", int_),
("req_allow_partial", np.bool_),
("req_raise_reject", np.bool_),
("req_log", np.bool_),
("res_size", float_),
("res_price", float_),
("res_fees", float_),
("res_side", int_),
("res_status", int_),
("res_status_info", int_),
("st1_cash", float_),
("st1_position", float_),
("st1_debt", float_),
("st1_locked_cash", float_),
("st1_free_cash", float_),
("st1_val_price", float_),
("st1_value", float_),
("order_id", int_),
]
"""Fields for `log_fields`."""
log_dt = np.dtype(log_fields, align=True)
"""_"""
__pdoc__[
"log_dt"
] = f"""`np.dtype` of log records.
```python
{prettify(log_dt)}
```
"""
alloc_range_fields = [
("id", int_),
("col", int_),
("start_idx", int_),
("end_idx", int_),
("alloc_idx", int_),
("status", int_),
]
"""Fields for `alloc_range_dt`."""
alloc_range_dt = np.dtype(alloc_range_fields, align=True)
"""_"""
__pdoc__[
"alloc_range_dt"
] = f"""`np.dtype` of allocation range records.
```python
{prettify(alloc_range_dt)}
```
"""
alloc_point_fields = [
("id", int_),
("col", int_),
("alloc_idx", int_),
]
"""Fields for `alloc_point_dt`."""
alloc_point_dt = np.dtype(alloc_point_fields, align=True)
"""_"""
__pdoc__[
"alloc_point_dt"
] = f"""`np.dtype` of allocation point records.
```python
{prettify(alloc_point_dt)}
```
"""
# ############# Info records ############# #
main_info_fields = [
("bar_zone", int_),
("signal_idx", int_),
("creation_idx", int_),
("idx", int_),
("val_price", float_),
("price", float_),
("size", float_),
("size_type", int_),
("direction", int_),
("type", int_),
("stop_type", int_),
]
"""Fields for `main_info_dt`."""
main_info_dt = np.dtype(main_info_fields, align=True)
"""_"""
__pdoc__[
"main_info_dt"
] = f"""`np.dtype` of main information records.
```python
{prettify(main_info_dt)}
```
Attributes:
bar_zone: See `vectorbtpro.generic.enums.BarZone`.
signal_idx: Row where signal was placed.
creation_idx: Row where order was created.
i: Row from where order information was taken.
val_price: Valuation price.
price: Requested price.
size: Order size.
size_type: See `SizeType`.
direction: See `Direction`.
type: See `OrderType`.
stop_type: See `vectorbtpro.signals.enums.StopType`.
"""
limit_info_fields = [
("signal_idx", int_),
("creation_idx", int_),
("init_idx", int_),
("init_price", float_),
("init_size", float_),
("init_size_type", int_),
("init_direction", int_),
("init_stop_type", int_),
("delta", float_),
("delta_format", int_),
("tif", np.int64),
("expiry", np.int64),
("time_delta_format", int_),
("reverse", float_),
("order_price", float_),
]
"""Fields for `limit_info_dt`."""
limit_info_dt = np.dtype(limit_info_fields, align=True)
"""_"""
__pdoc__[
"limit_info_dt"
] = f"""`np.dtype` of limit information records.
```python
{prettify(limit_info_dt)}
```
Attributes:
signal_idx: Signal row.
creation_idx: Limit creation row.
init_idx: Initial row from where order information is taken.
init_price: Initial price.
init_size: Order size.
init_size_type: See `SizeType`.
init_direction: See `Direction`.
init_stop_type: See `vectorbtpro.signals.enums.StopType`.
delta: Delta from the initial price.
delta_format: See `DeltaFormat`.
tif: Time in force in integer format. Set to `-1` to disable.
expiry: Expiry time in integer format. Set to `-1` to disable.
time_delta_format: See `TimeDeltaFormat`.
reverse: Whether to reverse the price hit detection.
order_price: See `LimitOrderPrice`.
"""
sl_info_fields = [
("init_idx", int_),
("init_price", float_),
("init_position", float_),
("stop", float_),
("exit_price", float_),
("exit_size", float_),
("exit_size_type", int_),
("exit_type", int_),
("order_type", int_),
("limit_delta", float_),
("delta_format", int_),
("ladder", int_),
("step", int_),
("step_idx", int_),
]
"""Fields for `sl_info_dt`."""
sl_info_dt = np.dtype(sl_info_fields, align=True)
"""_"""
__pdoc__[
"sl_info_dt"
] = f"""`np.dtype` of SL information records.
```python
{prettify(sl_info_dt)}
```
Attributes:
init_idx: Initial row.
init_price: Initial price.
init_position: Initial position.
stop: Latest updated stop value.
exit_price: See `StopExitPrice`.
exit_size: Order size.
exit_size_type: See `SizeType`.
exit_type: See `StopExitType`.
order_type: See `OrderType`.
limit_delta: Delta from the hit price. Only for `StopType.Limit`.
delta_format: See `DeltaFormat`.
ladder: See `StopLadderMode`.
step: Step in the ladder (i.e., the number of times the stop was executed)
step_idx: Step row.
"""
tsl_info_fields = [
("init_idx", int_),
("init_price", float_),
("init_position", float_),
("peak_idx", int_),
("peak_price", float_),
("stop", float_),
("th", float_),
("exit_price", float_),
("exit_size", float_),
("exit_size_type", int_),
("exit_type", int_),
("order_type", int_),
("limit_delta", float_),
("delta_format", int_),
("ladder", int_),
("step", int_),
("step_idx", int_),
]
"""Fields for `tsl_info_dt`."""
tsl_info_dt = np.dtype(tsl_info_fields, align=True)
"""_"""
__pdoc__[
"tsl_info_dt"
] = f"""`np.dtype` of TSL information records.
```python
{prettify(tsl_info_dt)}
```
Attributes:
init_idx: Initial row.
init_price: Initial price.
init_position: Initial position.
peak_idx: Row of the highest/lowest price.
peak_price: Highest/lowest price.
stop: Latest updated stop value.
th: Latest updated threshold value.
exit_price: See `StopExitPrice`.
exit_size: Order size.
exit_size_type: See `SizeType`.
exit_type: See `StopExitType`.
order_type: See `OrderType`.
limit_delta: Delta from the hit price. Only for `StopType.Limit`.
delta_format: See `DeltaFormat`.
ladder: See `StopLadderMode`.
step: Step in the ladder (i.e., the number of times the stop was executed)
step_idx: Step row.
"""
tp_info_fields = [
("init_idx", int_),
("init_price", float_),
("init_position", float_),
("stop", float_),
("exit_price", float_),
("exit_size", float_),
("exit_size_type", int_),
("exit_type", int_),
("order_type", int_),
("limit_delta", float_),
("delta_format", int_),
("ladder", int_),
("step", int_),
("step_idx", int_),
]
"""Fields for `tp_info_dt`."""
tp_info_dt = np.dtype(tp_info_fields, align=True)
"""_"""
__pdoc__[
"tp_info_dt"
] = f"""`np.dtype` of TP information records.
```python
{prettify(tp_info_dt)}
```
Attributes:
init_idx: Initial row.
init_price: Initial price.
init_position: Initial position.
stop: Latest updated stop value.
exit_price: See `StopExitPrice`.
exit_size: Order size.
exit_size_type: See `SizeType`.
exit_type: See `StopExitType`.
order_type: See `OrderType`.
limit_delta: Delta from the hit price. Only for `StopType.Limit`.
delta_format: See `DeltaFormat`.
ladder: See `StopLadderMode`.
step: Step in the ladder (i.e., the number of times the stop was executed)
step_idx: Step row.
"""
time_info_fields = [
("init_idx", int_),
("init_position", float_),
("stop", np.int64),
("exit_price", float_),
("exit_size", float_),
("exit_size_type", int_),
("exit_type", int_),
("order_type", int_),
("limit_delta", float_),
("delta_format", int_),
("time_delta_format", int_),
("ladder", int_),
("step", int_),
("step_idx", int_),
]
"""Fields for `time_info_dt`."""
time_info_dt = np.dtype(time_info_fields, align=True)
"""_"""
__pdoc__[
"time_info_dt"
] = f"""`np.dtype` of time information records.
```python
{prettify(time_info_dt)}
```
Attributes:
init_idx: Initial row.
init_position: Initial position.
stop: Latest updated stop value.
exit_price: See `StopExitPrice`.
exit_size: Order size.
exit_size_type: See `SizeType`.
exit_type: See `StopExitType`.
order_type: See `OrderType`.
limit_delta: Delta from the hit price. Only for `StopType.Limit`.
delta_format: See `DeltaFormat`. Only for `StopType.Limit`.
time_delta_format: See `TimeDeltaFormat`.
ladder: See `StopLadderMode`.
step: Step in the ladder (i.e., the number of times the stop was executed)
step_idx: Step row.
"""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with log records.
Order records capture information on simulation logs. Logs are populated when
simulating a portfolio and can be accessed as `vectorbtpro.portfolio.base.Portfolio.logs`.
```pycon
>>> from vectorbtpro import *
>>> np.random.seed(42)
>>> price = pd.DataFrame({
... 'a': np.random.uniform(1, 2, size=100),
... 'b': np.random.uniform(1, 2, size=100)
... }, index=[datetime(2020, 1, 1) + timedelta(days=i) for i in range(100)])
>>> size = pd.DataFrame({
... 'a': np.random.uniform(-100, 100, size=100),
... 'b': np.random.uniform(-100, 100, size=100),
... }, index=[datetime(2020, 1, 1) + timedelta(days=i) for i in range(100)])
>>> pf = vbt.Portfolio.from_orders(price, size, fees=0.01, freq='d', log=True)
>>> logs = pf.logs
>>> logs.filled.count()
a 88
b 99
Name: count, dtype: int64
>>> logs.ignored.count()
a 0
b 0
Name: count, dtype: int64
>>> logs.rejected.count()
a 12
b 1
Name: count, dtype: int64
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Logs.metrics`.
```pycon
>>> logs['a'].stats()
Start 2020-01-01 00:00:00
End 2020-04-09 00:00:00
Period 100 days 00:00:00
Total Records 100
Status Counts: None 0
Status Counts: Filled 88
Status Counts: Ignored 0
Status Counts: Rejected 12
Status Info Counts: None 88
Status Info Counts: NoCashLong 12
Name: a, dtype: object
```
`Logs.stats` also supports (re-)grouping:
```pycon
>>> logs.stats(group_by=True)
Start 2020-01-01 00:00:00
End 2020-04-09 00:00:00
Period 100 days 00:00:00
Total Records 200
Status Counts: None 0
Status Counts: Filled 187
Status Counts: Ignored 0
Status Counts: Rejected 13
Status Info Counts: None 187
Status Info Counts: NoCashLong 13
Name: group, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Logs.subplots`.
This class does not have any subplots.
"""
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_dict
from vectorbtpro.generic.price_records import PriceRecords
from vectorbtpro.portfolio.enums import (
log_dt,
SizeType,
LeverageMode,
PriceAreaVioMode,
Direction,
OrderSide,
OrderStatus,
OrderStatusInfo,
)
from vectorbtpro.records.decorators import attach_fields, override_field_config
from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig
__all__ = [
"Logs",
]
__pdoc__ = {}
logs_field_config = ReadonlyConfig(
dict(
dtype=log_dt,
settings=dict(
id=dict(title="Log Id"),
col=dict(title="Column"),
idx=dict(title="Index"),
group=dict(title="Group"),
price_area_open=dict(title="[PA] Open"),
price_area_high=dict(title="[PA] High"),
price_area_low=dict(title="[PA] Low"),
price_area_close=dict(title="[PA] Close"),
st0_cash=dict(title="[ST0] Cash"),
st0_position=dict(title="[ST0] Position"),
st0_debt=dict(title="[ST0] Debt"),
st0_locked_cash=dict(title="[ST0] Locked Cash"),
st0_free_cash=dict(title="[ST0] Free Cash"),
st0_val_price=dict(title="[ST0] Valuation Price"),
st0_value=dict(title="[ST0] Value"),
req_size=dict(title="[REQ] Size"),
req_price=dict(title="[REQ] Price"),
req_size_type=dict(title="[REQ] Size Type", mapping=SizeType),
req_direction=dict(title="[REQ] Direction", mapping=Direction),
req_fees=dict(title="[REQ] Fees"),
req_fixed_fees=dict(title="[REQ] Fixed Fees"),
req_slippage=dict(title="[REQ] Slippage"),
req_min_size=dict(title="[REQ] Min Size"),
req_max_size=dict(title="[REQ] Max Size"),
req_size_granularity=dict(title="[REQ] Size Granularity"),
req_leverage=dict(title="[REQ] Leverage"),
req_leverage_mode=dict(title="[REQ] Leverage Mode", mapping=LeverageMode),
req_reject_prob=dict(title="[REQ] Rejection Prob"),
req_price_area_vio_mode=dict(title="[REQ] Price Area Violation Mode", mapping=PriceAreaVioMode),
req_allow_partial=dict(title="[REQ] Allow Partial"),
req_raise_reject=dict(title="[REQ] Raise Rejection"),
req_log=dict(title="[REQ] Log"),
res_size=dict(title="[RES] Size"),
res_price=dict(title="[RES] Price"),
res_fees=dict(title="[RES] Fees"),
res_side=dict(title="[RES] Side", mapping=OrderSide),
res_status=dict(title="[RES] Status", mapping=OrderStatus),
res_status_info=dict(title="[RES] Status Info", mapping=OrderStatusInfo),
st1_cash=dict(title="[ST1] Cash"),
st1_position=dict(title="[ST1] Position"),
st1_debt=dict(title="[ST1] Debt"),
st1_locked_cash=dict(title="[ST1] Locked Cash"),
st1_free_cash=dict(title="[ST1] Free Cash"),
st1_val_price=dict(title="[ST1] Valuation Price"),
st1_value=dict(title="[ST1] Value"),
order_id=dict(title="Order Id", mapping="ids"),
),
)
)
"""_"""
__pdoc__[
"logs_field_config"
] = f"""Field config for `Logs`.
```python
{logs_field_config.prettify()}
```
"""
logs_attach_field_config = ReadonlyConfig(
dict(
res_side=dict(attach_filters=True),
res_status=dict(attach_filters=True),
res_status_info=dict(attach_filters=True),
)
)
"""_"""
__pdoc__[
"logs_attach_field_config"
] = f"""Config of fields to be attached to `Logs`.
```python
{logs_attach_field_config.prettify()}
```
"""
LogsT = tp.TypeVar("LogsT", bound="Logs")
@attach_fields(logs_attach_field_config)
@override_field_config(logs_field_config)
class Logs(PriceRecords):
"""Extends `vectorbtpro.generic.price_records.PriceRecords` for working with log records."""
@property
def field_config(self) -> Config:
return self._field_config
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Logs.stats`.
Merges `vectorbtpro.generic.price_records.PriceRecords.stats_defaults` and
`stats` from `vectorbtpro._settings.logs`."""
from vectorbtpro._settings import settings
logs_stats_cfg = settings["logs"]["stats"]
return merge_dicts(PriceRecords.stats_defaults.__get__(self), logs_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
total_records=dict(title="Total Records", calc_func="count", tags="records"),
res_status_counts=dict(
title="Status Counts",
calc_func="res_status.value_counts",
incl_all_keys=True,
post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"),
tags=["logs", "res_status", "value_counts"],
),
res_status_info_counts=dict(
title="Status Info Counts",
calc_func="res_status_info.value_counts",
post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"),
tags=["logs", "res_status_info", "value_counts"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Logs.plots`.
Merges `vectorbtpro.generic.price_records.PriceRecords.plots_defaults` and
`plots` from `vectorbtpro._settings.logs`."""
from vectorbtpro._settings import settings
logs_plots_cfg = settings["logs"]["plots"]
return merge_dicts(PriceRecords.plots_defaults.__get__(self), logs_plots_cfg)
@property
def subplots(self) -> Config:
return self._subplots
Logs.override_field_config_doc(__pdoc__)
Logs.override_metrics_doc(__pdoc__)
Logs.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with order records.
Order records capture information on filled orders. Orders are mainly populated when simulating
a portfolio and can be accessed as `vectorbtpro.portfolio.base.Portfolio.orders`.
```pycon
>>> from vectorbtpro import *
>>> price = vbt.RandomData.pull(
... ['a', 'b'],
... start=datetime(2020, 1, 1),
... end=datetime(2020, 3, 1),
... seed=vbt.key_dict(a=42, b=43)
... ).get()
```
[=100% "100%"]{: .candystripe .candystripe-animate }
```pycon
>>> size = pd.DataFrame({
... 'a': np.random.randint(-1, 2, size=len(price.index)),
... 'b': np.random.randint(-1, 2, size=len(price.index)),
... }, index=price.index, columns=price.columns)
>>> pf = vbt.Portfolio.from_orders(price, size, fees=0.01, freq='d')
>>> pf.orders.side_buy.count()
symbol
a 17
b 15
Name: count, dtype: int64
>>> pf.orders.side_sell.count()
symbol
a 24
b 26
Name: count, dtype: int64
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Orders.metrics`.
```pycon
>>> pf.orders['a'].stats()
Start 2019-12-31 22:00:00+00:00
End 2020-02-29 22:00:00+00:00
Period 61 days 00:00:00
Total Records 41
Side Counts: Buy 17
Side Counts: Sell 24
Size: Min 0 days 19:33:05.006182372
Size: Median 1 days 00:00:00
Size: Max 1 days 00:00:00
Fees: Min 0 days 20:26:25.905776572
Fees: Median 0 days 22:46:22.693324744
Fees: Max 1 days 01:04:25.541681491
Weighted Buy Price 94.69917
Weighted Sell Price 95.742148
Name: a, dtype: object
```
`Orders.stats` also supports (re-)grouping:
```pycon
>>> pf.orders.stats(group_by=True)
Start 2019-12-31 22:00:00+00:00
End 2020-02-29 22:00:00+00:00
Period 61 days 00:00:00
Total Records 82
Side Counts: Buy 32
Side Counts: Sell 50
Size: Min 0 days 19:33:05.006182372
Size: Median 1 days 00:00:00
Size: Max 1 days 00:00:00
Fees: Min 0 days 20:26:25.905776572
Fees: Median 0 days 23:58:29.773897679
Fees: Max 1 days 02:29:08.904770159
Weighted Buy Price 98.804452
Weighted Sell Price 99.969934
Name: group, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Orders.subplots`.
`Orders` class has a single subplot based on `Orders.plot`:
```pycon
>>> pf.orders['a'].plots().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_1d_array, to_2d_array, to_dict
from vectorbtpro.generic.enums import RangeStatus, range_dt
from vectorbtpro.generic.price_records import PriceRecords
from vectorbtpro.generic.ranges import Ranges
from vectorbtpro.portfolio import nb
from vectorbtpro.portfolio.enums import order_dt, OrderSide, fs_order_dt, OrderType, OrderPriceStatus
from vectorbtpro.records.decorators import attach_fields, override_field_config, attach_shortcut_properties
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.signals.enums import StopType
from vectorbtpro.utils.colors import adjust_lightness
from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig
__all__ = [
"Orders",
"FSOrders",
]
__pdoc__ = {}
orders_field_config = ReadonlyConfig(
dict(
dtype=order_dt,
settings=dict(
id=dict(title="Order Id"),
idx=dict(),
size=dict(title="Size"),
price=dict(title="Price"),
fees=dict(title="Fees"),
side=dict(title="Side", mapping=OrderSide, as_customdata=False),
),
)
)
"""_"""
__pdoc__[
"orders_field_config"
] = f"""Field config for `Orders`.
```python
{orders_field_config.prettify()}
```
"""
orders_attach_field_config = ReadonlyConfig(dict(side=dict(attach_filters=True)))
"""_"""
__pdoc__[
"orders_attach_field_config"
] = f"""Config of fields to be attached to `Orders`.
```python
{orders_attach_field_config.prettify()}
```
"""
orders_shortcut_config = ReadonlyConfig(
dict(
long_view=dict(),
short_view=dict(),
signed_size=dict(obj_type="mapped"),
value=dict(obj_type="mapped"),
weighted_price=dict(obj_type="red_array"),
price_status=dict(obj_type="mapped"),
)
)
"""_"""
__pdoc__[
"orders_shortcut_config"
] = f"""Config of shortcut properties to be attached to `Orders`.
```python
{orders_shortcut_config.prettify()}
```
"""
OrdersT = tp.TypeVar("OrdersT", bound="Orders")
@attach_shortcut_properties(orders_shortcut_config)
@attach_fields(orders_attach_field_config)
@override_field_config(orders_field_config)
class Orders(PriceRecords):
"""Extends `vectorbtpro.generic.price_records.PriceRecords` for working with order records."""
@property
def field_config(self) -> Config:
return self._field_config
# ############# Views ############# #
def get_long_view(
self: OrdersT,
init_position: tp.ArrayLike = 0.0,
init_price: tp.ArrayLike = np.nan,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
) -> OrdersT:
"""See `vectorbtpro.portfolio.nb.records.get_long_view_orders_nb`."""
func = jit_reg.resolve_option(nb.get_long_view_orders_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
new_records_arr = func(
self.records_arr,
to_2d_array(self.close),
self.col_mapper.col_map,
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
)
return self.replace(records_arr=new_records_arr)
def get_short_view(
self: OrdersT,
init_position: tp.ArrayLike = 0.0,
init_price: tp.ArrayLike = np.nan,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
) -> OrdersT:
"""See `vectorbtpro.portfolio.nb.records.get_short_view_orders_nb`."""
func = jit_reg.resolve_option(nb.get_short_view_orders_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
new_records_arr = func(
self.records_arr,
to_2d_array(self.close),
self.col_mapper.col_map,
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
)
return self.replace(records_arr=new_records_arr)
# ############# Stats ############# #
def get_signed_size(self, **kwargs) -> tp.MaybeSeries:
"""Get signed size."""
size = self.get_field_arr("size").copy()
size[self.get_field_arr("side") == OrderSide.Sell] *= -1
return self.map_array(size, **kwargs)
def get_value(self, **kwargs) -> tp.MaybeSeries:
"""Get value."""
return self.map_array(self.signed_size.values * self.price.values, **kwargs)
def get_weighted_price(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get size-weighted price average."""
wrap_kwargs = merge_dicts(dict(name_or_index="weighted_price"), wrap_kwargs)
return MappedArray.reduce(
nb.weighted_price_reduce_meta_nb,
self.get_field_arr("size"),
self.get_field_arr("price"),
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
col_mapper=self.col_mapper,
**kwargs,
)
def get_price_status(self, **kwargs) -> MappedArray:
"""See `vectorbtpro.portfolio.nb.records.price_status_nb`."""
return self.apply(
nb.price_status_nb,
self._high,
self._low,
mapping=OrderPriceStatus,
**kwargs,
)
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Orders.stats`.
Merges `vectorbtpro.generic.price_records.PriceRecords.stats_defaults` and
`stats` from `vectorbtpro._settings.orders`."""
from vectorbtpro._settings import settings
orders_stats_cfg = settings["orders"]["stats"]
return merge_dicts(PriceRecords.stats_defaults.__get__(self), orders_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
total_records=dict(title="Total Records", calc_func="count", tags="records"),
side_counts=dict(
title="Side Counts",
calc_func="side.value_counts",
incl_all_keys=True,
post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"),
tags=["orders", "side"],
),
size=dict(
title="Size",
calc_func="size.describe",
post_calc_func=lambda self, out, settings: {
"Min": out.loc["min"],
"Median": out.loc["50%"],
"Max": out.loc["max"],
},
tags=["orders", "size"],
),
fees=dict(
title="Fees",
calc_func="fees.describe",
post_calc_func=lambda self, out, settings: {
"Min": out.loc["min"],
"Median": out.loc["50%"],
"Max": out.loc["max"],
},
tags=["orders", "fees"],
),
weighted_buy_price=dict(
title="Weighted Buy Price",
calc_func="side_buy.get_weighted_price",
tags=["orders", "buy", "price"],
),
weighted_sell_price=dict(
title="Weighted Sell Price",
calc_func="side_sell.get_weighted_price",
tags=["orders", "sell", "price"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_ohlc: bool = True,
plot_close: bool = True,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
close_trace_kwargs: tp.KwargsLike = None,
buy_trace_kwargs: tp.KwargsLike = None,
sell_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot orders.
Args:
column (str): Name of the column to plot.
plot_ohlc (bool): Whether to plot OHLC.
plot_close (bool): Whether to plot close.
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Orders.close`.
buy_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Buy" markers.
sell_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Sell" markers.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> index = pd.date_range("2020", periods=5)
>>> price = pd.Series([1., 2., 3., 2., 1.], index=index)
>>> size = pd.Series([1., 1., 1., 1., -1.], index=index)
>>> orders = vbt.Portfolio.from_orders(price, size).orders
>>> orders.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
if ohlc_trace_kwargs is None:
ohlc_trace_kwargs = {}
if close_trace_kwargs is None:
close_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"),
close_trace_kwargs,
)
if buy_trace_kwargs is None:
buy_trace_kwargs = {}
if sell_trace_kwargs is None:
sell_trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
# Plot price
if (
plot_ohlc
and self_col._open is not None
and self_col._high is not None
and self_col._low is not None
and self_col._close is not None
):
ohlc_df = pd.DataFrame(
{
"open": self_col.open,
"high": self_col.high,
"low": self_col.low,
"close": self_col.close,
}
)
if "opacity" not in ohlc_trace_kwargs:
ohlc_trace_kwargs["opacity"] = 0.5
fig = ohlc_df.vbt.ohlcv.plot(
ohlc_type=ohlc_type,
plot_volume=False,
ohlc_trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
elif plot_close and self_col._close is not None:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if self_col.count() > 0:
# Extract information
idx = self_col.get_map_field_to_index("idx")
price = self_col.get_field_arr("price")
side = self_col.get_field_arr("side")
buy_mask = side == OrderSide.Buy
if buy_mask.any():
# Plot buy markers
buy_customdata, buy_hovertemplate = self_col.prepare_customdata(mask=buy_mask)
_buy_trace_kwargs = merge_dicts(
dict(
x=idx[buy_mask],
y=price[buy_mask],
mode="markers",
marker=dict(
symbol="triangle-up",
color=plotting_cfg["contrast_color_schema"]["green"],
size=8,
line=dict(width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["green"])),
),
name="Buy",
customdata=buy_customdata,
hovertemplate=buy_hovertemplate,
),
buy_trace_kwargs,
)
buy_scatter = go.Scatter(**_buy_trace_kwargs)
fig.add_trace(buy_scatter, **add_trace_kwargs)
sell_mask = side == OrderSide.Sell
if sell_mask.any():
# Plot sell markers
sell_customdata, sell_hovertemplate = self_col.prepare_customdata(mask=sell_mask)
_sell_trace_kwargs = merge_dicts(
dict(
x=idx[sell_mask],
y=price[sell_mask],
mode="markers",
marker=dict(
symbol="triangle-down",
color=plotting_cfg["contrast_color_schema"]["red"],
size=8,
line=dict(width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["red"])),
),
name="Sell",
customdata=sell_customdata,
hovertemplate=sell_hovertemplate,
),
sell_trace_kwargs,
)
sell_scatter = go.Scatter(**_sell_trace_kwargs)
fig.add_trace(sell_scatter, **add_trace_kwargs)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Orders.plots`.
Merges `vectorbtpro.generic.price_records.PriceRecords.plots_defaults` and
`plots` from `vectorbtpro._settings.orders`."""
from vectorbtpro._settings import settings
orders_plots_cfg = settings["orders"]["plots"]
return merge_dicts(PriceRecords.plots_defaults.__get__(self), orders_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot=dict(
title="Orders",
yaxis_kwargs=dict(title="Price"),
check_is_not_grouped=True,
plot_func="plot",
tags="orders",
)
)
)
@property
def subplots(self) -> Config:
return self._subplots
Orders.override_field_config_doc(__pdoc__)
Orders.override_metrics_doc(__pdoc__)
Orders.override_subplots_doc(__pdoc__)
fs_orders_field_config = ReadonlyConfig(
dict(
dtype=fs_order_dt,
settings=dict(
idx=dict(title="Fill Index"),
signal_idx=dict(
title="Signal Index",
mapping="index",
noindex=True,
),
creation_idx=dict(
title="Creation Index",
mapping="index",
noindex=True,
),
type=dict(
title="Type",
mapping=OrderType,
),
stop_type=dict(
title="Stop Type",
mapping=StopType,
),
),
)
)
"""_"""
__pdoc__[
"fs_orders_field_config"
] = f"""Field config for `FSOrders`.
```python
{fs_orders_field_config.prettify()}
```
"""
fs_orders_attach_field_config = ReadonlyConfig(
dict(
type=dict(attach_filters=True),
stop_type=dict(attach_filters=True),
)
)
"""_"""
__pdoc__[
"fs_orders_attach_field_config"
] = f"""Config of fields to be attached to `FSOrders`.
```python
{fs_orders_attach_field_config.prettify()}
```
"""
fs_orders_shortcut_config = ReadonlyConfig(
dict(
stop_orders=dict(),
ranges=dict(),
creation_ranges=dict(),
fill_ranges=dict(),
signal_to_creation_duration=dict(obj_type="mapped_array"),
creation_to_fill_duration=dict(obj_type="mapped_array"),
signal_to_fill_duration=dict(obj_type="mapped_array"),
)
)
"""_"""
__pdoc__[
"fs_orders_shortcut_config"
] = f"""Config of shortcut properties to be attached to `FSOrders`.
```python
{fs_orders_shortcut_config.prettify()}
```
"""
FSOrdersT = tp.TypeVar("FSOrdersT", bound="FSOrders")
@attach_shortcut_properties(fs_orders_shortcut_config)
@attach_fields(fs_orders_attach_field_config)
@override_field_config(fs_orders_field_config)
class FSOrders(Orders):
"""Extends `Orders` for working with order records generated from signals."""
@property
def field_config(self) -> Config:
return self._field_config
def get_stop_orders(self, **kwargs):
"""Get stop orders."""
return self.apply_mask(self.get_field_arr("stop_type") != -1, **kwargs)
def get_ranges(self, **kwargs) -> Ranges:
"""Get records of type `vectorbtpro.generic.ranges.Ranges` for signal-to-fill ranges."""
new_records_arr = np.empty(self.values.shape, dtype=range_dt)
new_records_arr["id"][:] = self.get_field_arr("id").copy()
new_records_arr["col"][:] = self.get_field_arr("col").copy()
new_records_arr["start_idx"][:] = self.get_field_arr("signal_idx").copy()
new_records_arr["end_idx"][:] = self.get_field_arr("idx").copy()
new_records_arr["status"][:] = RangeStatus.Closed
return Ranges.from_records(
self.wrapper,
new_records_arr,
open=self._open,
high=self._high,
low=self._low,
close=self._close,
**kwargs,
)
def get_creation_ranges(self, **kwargs) -> Ranges:
"""Get records of type `vectorbtpro.generic.ranges.Ranges` for signal-to-creation ranges."""
new_records_arr = np.empty(self.values.shape, dtype=range_dt)
new_records_arr["id"][:] = self.get_field_arr("id").copy()
new_records_arr["col"][:] = self.get_field_arr("col").copy()
new_records_arr["start_idx"][:] = self.get_field_arr("signal_idx").copy()
new_records_arr["end_idx"][:] = self.get_field_arr("creation_idx").copy()
new_records_arr["status"][:] = RangeStatus.Closed
return Ranges.from_records(
self.wrapper,
new_records_arr,
open=self._open,
high=self._high,
low=self._low,
close=self._close,
**kwargs,
)
def get_fill_ranges(self, **kwargs) -> Ranges:
"""Get records of type `vectorbtpro.generic.ranges.Ranges` for creation-to-fill ranges."""
new_records_arr = np.empty(self.values.shape, dtype=range_dt)
new_records_arr["id"][:] = self.get_field_arr("id").copy()
new_records_arr["col"][:] = self.get_field_arr("col").copy()
new_records_arr["start_idx"][:] = self.get_field_arr("creation_idx").copy()
new_records_arr["end_idx"][:] = self.get_field_arr("idx").copy()
new_records_arr["status"][:] = RangeStatus.Closed
return Ranges.from_records(
self.wrapper,
new_records_arr,
open=self._open,
high=self._high,
low=self._low,
close=self._close,
**kwargs,
)
def get_signal_to_creation_duration(self, **kwargs) -> MappedArray:
"""Get duration between signal and creation."""
duration = self.get_field_arr("creation_idx") - self.get_field_arr("signal_idx")
return self.map_array(duration, **kwargs)
def get_creation_to_fill_duration(self, **kwargs) -> MappedArray:
"""Get duration between creation and fill."""
duration = self.get_field_arr("idx") - self.get_field_arr("creation_idx")
return self.map_array(duration, **kwargs)
def get_signal_to_fill_duration(self, **kwargs) -> MappedArray:
"""Get duration between signal and fill."""
duration = self.get_field_arr("idx") - self.get_field_arr("signal_idx")
return self.map_array(duration, **kwargs)
_metrics: tp.ClassVar[Config] = HybridConfig(
start_index=Orders.metrics["start_index"],
end_index=Orders.metrics["end_index"],
total_duration=Orders.metrics["total_duration"],
total_records=Orders.metrics["total_records"],
side_counts=Orders.metrics["side_counts"],
type_counts=dict(
title="Type Counts",
calc_func="type.value_counts",
incl_all_keys=True,
post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"),
tags=["orders", "type"],
),
stop_type_counts=dict(
title="Stop Type Counts",
calc_func="stop_type.value_counts",
incl_all_keys=True,
post_calc_func=lambda self, out, settings: to_dict(out, orient="index_series"),
tags=["orders", "stop_type"],
),
size=Orders.metrics["size"],
fees=Orders.metrics["fees"],
weighted_buy_price=Orders.metrics["weighted_buy_price"],
weighted_sell_price=Orders.metrics["weighted_sell_price"],
avg_signal_to_creation_duration=dict(
title="Avg Signal-Creation Duration",
calc_func="signal_to_creation_duration.mean",
apply_to_timedelta=True,
tags=["orders", "duration"],
),
avg_creation_to_fill_duration=dict(
title="Avg Creation-Fill Duration",
calc_func="creation_to_fill_duration.mean",
apply_to_timedelta=True,
tags=["orders", "duration"],
),
avg_signal_to_fill_duration=dict(
title="Avg Signal-Fill Duration",
calc_func="signal_to_fill_duration.mean",
apply_to_timedelta=True,
tags=["orders", "duration"],
),
)
@property
def metrics(self) -> Config:
return self._metrics
FSOrders.override_field_config_doc(__pdoc__)
FSOrders.override_metrics_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for preparing portfolio simulations."""
from collections import namedtuple
from functools import cached_property as cachedproperty
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.decorators import override_arg_config, attach_arg_properties
from vectorbtpro.base.preparing import BasePreparer
from vectorbtpro.base.reshaping import to_2d_array, broadcast_array_to, broadcast
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.data.base import OHLCDataMixin, Data
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.sim_range import SimRangeMixin
from vectorbtpro.portfolio import nb, enums
from vectorbtpro.portfolio.call_seq import require_call_seq, build_call_seq
from vectorbtpro.portfolio.orders import FSOrders
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg, register_jitted
from vectorbtpro.signals import nb as signals_nb
from vectorbtpro.utils import checks, chunking as ch
from vectorbtpro.utils.config import Configured, merge_dicts, ReadonlyConfig
from vectorbtpro.utils.mapping import to_field_mapping
from vectorbtpro.utils.template import CustomTemplate, substitute_templates, RepFunc
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"PFPrepResult",
"BasePFPreparer",
"FOPreparer",
"FSPreparer",
"FOFPreparer",
"FDOFPreparer",
]
__pdoc__ = {}
@register_jitted(cache=True)
def valid_price_from_ago_1d_nb(price: tp.Array1d) -> tp.Array1d:
"""Parse from_ago from a valid price."""
from_ago = np.empty(price.shape, dtype=int_)
for i in range(price.shape[0] - 1, -1, -1):
if i > 0 and not np.isnan(price[i]):
for j in range(i - 1, -1, -1):
if not np.isnan(price[j]):
break
from_ago[i] = i - j
else:
from_ago[i] = 1
return from_ago
PFPrepResultT = tp.TypeVar("PFPrepResultT", bound="PFPrepResult")
class PFPrepResult(Configured):
"""Result of preparation."""
def __init__(
self,
target_func: tp.Optional[tp.Callable] = None,
target_args: tp.Optional[tp.Kwargs] = None,
pf_args: tp.Optional[tp.Kwargs] = None,
**kwargs,
) -> None:
Configured.__init__(
self,
target_func=target_func,
target_args=target_args,
pf_args=pf_args,
**kwargs,
)
@cachedproperty
def target_func(self) -> tp.Optional[tp.Callable]:
"""Target function."""
return self.config["target_func"]
@cachedproperty
def target_args(self) -> tp.Kwargs:
"""Target arguments."""
return self.config["target_args"]
@cachedproperty
def pf_args(self) -> tp.Optional[tp.Kwargs]:
"""Portfolio arguments."""
return self.config["pf_args"]
base_arg_config = ReadonlyConfig(
dict(
data=dict(),
open=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
high=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
low=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
close=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
bm_close=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
cash_earnings=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)),
),
init_cash=dict(map_enum_kwargs=dict(enum=enums.InitCashMode, look_for_type=str)),
init_position=dict(),
init_price=dict(),
cash_deposits=dict(),
group_by=dict(),
cash_sharing=dict(),
freq=dict(),
sim_start=dict(),
sim_end=dict(),
call_seq=dict(map_enum_kwargs=dict(enum=enums.CallSeqType, look_for_type=str)),
attach_call_seq=dict(),
keep_inout_flex=dict(),
in_outputs=dict(has_default=False),
)
)
"""_"""
__pdoc__[
"base_arg_config"
] = f"""Argument config for `BasePFPreparer`.
```python
{base_arg_config.prettify()}
```
"""
@attach_arg_properties
@override_arg_config(base_arg_config)
class BasePFPreparer(BasePreparer):
"""Base class for preparing portfolio simulations."""
_settings_path: tp.SettingsPath = "portfolio"
@classmethod
def find_target_func(cls, target_func_name: str) -> tp.Callable:
return getattr(nb, target_func_name)
# ############# Ready arguments ############# #
@cachedproperty
def init_cash_mode(self) -> tp.Optional[int]:
"""Initial cash mode."""
init_cash = self["init_cash"]
if checks.is_int(init_cash) and init_cash in enums.InitCashMode:
return init_cash
return None
@cachedproperty
def group_by(self) -> tp.GroupByLike:
"""Argument `group_by`."""
group_by = self["group_by"]
if group_by is None and self.cash_sharing:
return True
return group_by
@cachedproperty
def auto_call_seq(self) -> bool:
"""Whether automatic call sequence is enabled."""
call_seq = self["call_seq"]
return checks.is_int(call_seq) and call_seq == enums.CallSeqType.Auto
@classmethod
def parse_data(
cls,
data: tp.Union[None, OHLCDataMixin, str, tp.ArrayLike],
all_ohlc: bool = False,
) -> tp.Optional[OHLCDataMixin]:
"""Parse an instance with OHLC features."""
if data is None:
return None
if isinstance(data, OHLCDataMixin):
return data
if isinstance(data, str):
return Data.from_data_str(data)
if isinstance(data, pd.DataFrame):
ohlcv_acc = data.vbt.ohlcv
if all_ohlc and ohlcv_acc.has_ohlc:
return ohlcv_acc
if not all_ohlc and ohlcv_acc.has_any_ohlc:
return ohlcv_acc
return None
@cachedproperty
def data(self) -> tp.Optional[OHLCDataMixin]:
"""Argument `data`."""
return self.parse_data(self["data"])
# ############# Before broadcasting ############# #
@cachedproperty
def _pre_open(self) -> tp.ArrayLike:
"""Argument `open` before broadcasting."""
open = self["open"]
if open is None:
if self.data is not None:
open = self.data.open
if open is None:
return np.nan
return open
@cachedproperty
def _pre_high(self) -> tp.ArrayLike:
"""Argument `high` before broadcasting."""
high = self["high"]
if high is None:
if self.data is not None:
high = self.data.high
if high is None:
return np.nan
return high
@cachedproperty
def _pre_low(self) -> tp.ArrayLike:
"""Argument `low` before broadcasting."""
low = self["low"]
if low is None:
if self.data is not None:
low = self.data.low
if low is None:
return np.nan
return low
@cachedproperty
def _pre_close(self) -> tp.ArrayLike:
"""Argument `close` before broadcasting."""
close = self["close"]
if close is None:
if self.data is not None:
close = self.data.close
if close is None:
return np.nan
return close
@cachedproperty
def _pre_bm_close(self) -> tp.Optional[tp.ArrayLike]:
"""Argument `bm_close` before broadcasting."""
bm_close = self["bm_close"]
if bm_close is not None and not isinstance(bm_close, bool):
return bm_close
return np.nan
@cachedproperty
def _pre_init_cash(self) -> tp.ArrayLike:
"""Argument `init_cash` before broadcasting."""
if self.init_cash_mode is not None:
return np.inf
return self["init_cash"]
@cachedproperty
def _pre_init_position(self) -> tp.ArrayLike:
"""Argument `init_position` before broadcasting."""
return self["init_position"]
@cachedproperty
def _pre_init_price(self) -> tp.ArrayLike:
"""Argument `init_price` before broadcasting."""
return self["init_price"]
@cachedproperty
def _pre_cash_deposits(self) -> tp.ArrayLike:
"""Argument `cash_deposits` before broadcasting."""
return self["cash_deposits"]
@cachedproperty
def _pre_freq(self) -> tp.Optional[tp.FrequencyLike]:
"""Argument `freq` before casting to nanosecond format."""
freq = self["freq"]
if freq is None and self.data is not None:
return self.data.symbol_wrapper.freq
return freq
@cachedproperty
def _pre_call_seq(self) -> tp.Optional[tp.ArrayLike]:
"""Argument `call_seq` before broadcasting."""
if self.auto_call_seq:
return None
return self["call_seq"]
@cachedproperty
def _pre_in_outputs(self) -> tp.Union[None, tp.NamedTuple, CustomTemplate]:
"""Argument `in_outputs` before broadcasting."""
in_outputs = self["in_outputs"]
if (
in_outputs is not None
and not isinstance(in_outputs, CustomTemplate)
and not checks.is_namedtuple(in_outputs)
):
in_outputs = to_field_mapping(in_outputs)
in_outputs = namedtuple("InOutputs", in_outputs)(**in_outputs)
return in_outputs
# ############# After broadcasting ############# #
@cachedproperty
def cs_group_lens(self) -> tp.GroupLens:
"""Cash sharing aware group lengths."""
cs_group_lens = self.wrapper.grouper.get_group_lens(group_by=None if self.cash_sharing else False)
checks.assert_subdtype(cs_group_lens, np.integer, arg_name="cs_group_lens")
return cs_group_lens
@cachedproperty
def group_lens(self) -> tp.GroupLens:
"""Group lengths."""
return self.wrapper.grouper.get_group_lens(group_by=self.group_by)
@cachedproperty
def sim_group_lens(self) -> tp.GroupLens:
"""Simulation group lengths."""
return self.group_lens
def align_pc_arr(
self,
arr: tp.ArrayLike,
group_lens: tp.Optional[tp.GroupLens] = None,
check_dtype: tp.Optional[tp.DTypeLike] = None,
cast_to_dtype: tp.Optional[tp.DTypeLike] = None,
reduce_func: tp.Union[None, str, tp.Callable] = None,
arg_name: tp.Optional[str] = None,
) -> tp.Array1d:
"""Align a per-column array."""
arr = np.asarray(arr)
if check_dtype is not None:
checks.assert_subdtype(arr, check_dtype, arg_name=arg_name)
if cast_to_dtype is not None:
arr = np.require(arr, dtype=cast_to_dtype)
if arr.size > 1 and group_lens is not None and reduce_func is not None:
if len(self.group_lens) == len(arr) != len(group_lens) == len(self.wrapper.columns):
new_arr = np.empty(len(self.wrapper.columns), dtype=int_)
col_generator = self.wrapper.grouper.iter_group_idxs()
for i, cols in enumerate(col_generator):
new_arr[cols] = arr[i]
arr = new_arr
if len(self.wrapper.columns) == len(arr) != len(group_lens):
new_arr = np.empty(len(group_lens), dtype=int_)
col_generator = self.wrapper.grouper.iter_group_lens(group_lens)
for i, cols in enumerate(col_generator):
if isinstance(reduce_func, str):
new_arr[i] = getattr(arr[cols], reduce_func)()
else:
new_arr[i] = reduce_func(arr[cols])
arr = new_arr
if group_lens is not None:
return broadcast_array_to(arr, len(group_lens))
return broadcast_array_to(arr, len(self.wrapper.columns))
@cachedproperty
def init_cash(self) -> tp.Array1d:
"""Argument `init_cash`."""
return self.align_pc_arr(
self._pre_init_cash,
group_lens=self.cs_group_lens,
check_dtype=np.number,
cast_to_dtype=float_,
reduce_func="sum",
arg_name="init_cash",
)
@cachedproperty
def init_position(self) -> tp.Array1d:
"""Argument `init_position`."""
init_position = self.align_pc_arr(
self._pre_init_position,
check_dtype=np.number,
cast_to_dtype=float_,
arg_name="init_position",
)
if (((init_position > 0) | (init_position < 0)) & np.isnan(self.init_price)).any():
warn(f"Initial position has undefined price. Set init_price.")
return init_position
@cachedproperty
def init_price(self) -> tp.Array1d:
"""Argument `init_price`."""
return self.align_pc_arr(
self._pre_init_price,
check_dtype=np.number,
cast_to_dtype=float_,
arg_name="init_price",
)
@cachedproperty
def cash_deposits(self) -> tp.ArrayLike:
"""Argument `cash_deposits`."""
cash_deposits = self["cash_deposits"]
checks.assert_subdtype(cash_deposits, np.number, arg_name="cash_deposits")
return broadcast(
cash_deposits,
to_shape=(self.target_shape[0], len(self.cs_group_lens)),
to_pd=False,
keep_flex=self.keep_inout_flex,
reindex_kwargs=dict(fill_value=0.0),
require_kwargs=self.broadcast_kwargs.get("require_kwargs", {}),
)
@cachedproperty
def auto_sim_start(self) -> tp.Optional[tp.ArrayLike]:
"""Get automatic `sim_start`"""
return None
@cachedproperty
def auto_sim_end(self) -> tp.Optional[tp.ArrayLike]:
"""Get automatic `sim_end`"""
return None
@cachedproperty
def sim_start(self) -> tp.Optional[tp.ArrayLike]:
"""Argument `sim_start`."""
sim_start = self["sim_start"]
if sim_start is None:
return None
if isinstance(sim_start, str) and sim_start.lower() == "auto":
sim_start = self.auto_sim_start
if sim_start is None:
return None
sim_start_arr = np.asarray(sim_start)
if np.issubdtype(sim_start_arr.dtype, np.integer):
if sim_start_arr.ndim == 0:
return sim_start
new_sim_start = sim_start_arr
else:
if sim_start_arr.ndim == 0:
return SimRangeMixin.resolve_sim_start_value(sim_start, wrapper=self.wrapper)
new_sim_start = np.empty(len(sim_start), dtype=int_)
for i in range(len(sim_start)):
new_sim_start[i] = SimRangeMixin.resolve_sim_start_value(sim_start[i], wrapper=self.wrapper)
return self.align_pc_arr(
new_sim_start,
group_lens=self.sim_group_lens,
check_dtype=np.integer,
cast_to_dtype=int_,
reduce_func="min",
arg_name="sim_start",
)
@cachedproperty
def sim_end(self) -> tp.Optional[tp.ArrayLike]:
"""Argument `sim_end`."""
sim_end = self["sim_end"]
if sim_end is None:
return None
if isinstance(sim_end, str) and sim_end.lower() == "auto":
sim_end = self.auto_sim_end
if sim_end is None:
return None
sim_end_arr = np.asarray(sim_end)
if np.issubdtype(sim_end_arr.dtype, np.integer):
if sim_end_arr.ndim == 0:
return sim_end
new_sim_end = sim_end_arr
else:
if sim_end_arr.ndim == 0:
return SimRangeMixin.resolve_sim_end_value(sim_end, wrapper=self.wrapper)
new_sim_end = np.empty(len(sim_end), dtype=int_)
for i in range(len(sim_end)):
new_sim_end[i] = SimRangeMixin.resolve_sim_end_value(sim_end[i], wrapper=self.wrapper)
return self.align_pc_arr(
new_sim_end,
group_lens=self.sim_group_lens,
check_dtype=np.integer,
cast_to_dtype=int_,
reduce_func="max",
arg_name="sim_end",
)
@cachedproperty
def call_seq(self) -> tp.Optional[tp.ArrayLike]:
"""Argument `call_seq`."""
call_seq = self._pre_call_seq
if call_seq is None and self.attach_call_seq:
call_seq = enums.CallSeqType.Default
if call_seq is not None:
if checks.is_any_array(call_seq):
call_seq = require_call_seq(broadcast(call_seq, to_shape=self.target_shape, to_pd=False))
else:
call_seq = build_call_seq(self.target_shape, self.group_lens, call_seq_type=call_seq)
if call_seq is not None:
checks.assert_subdtype(call_seq, np.integer, arg_name="call_seq")
return call_seq
# ############# Template substitution ############# #
@cachedproperty
def template_context(self) -> tp.Kwargs:
return merge_dicts(
dict(
group_lens=self.group_lens,
cs_group_lens=self.cs_group_lens,
cash_sharing=self.cash_sharing,
init_cash=self.init_cash,
init_position=self.init_position,
init_price=self.init_price,
cash_deposits=self.cash_deposits,
sim_start=self.sim_start,
sim_end=self.sim_end,
call_seq=self.call_seq,
auto_call_seq=self.auto_call_seq,
attach_call_seq=self.attach_call_seq,
in_outputs=self._pre_in_outputs,
),
BasePreparer.template_context.func(self),
)
@cachedproperty
def in_outputs(self) -> tp.Optional[tp.NamedTuple]:
"""Argument `in_outputs`."""
return substitute_templates(self._pre_in_outputs, self.template_context, eval_id="in_outputs")
# ############# Result ############# #
@cachedproperty
def pf_args(self) -> tp.Optional[tp.Kwargs]:
"""Arguments to be passed to the portfolio."""
kwargs = dict()
for k, v in self.config.items():
if k not in self.arg_config and k != "arg_config":
kwargs[k] = v
return dict(
wrapper=self.wrapper,
open=self.open if self._pre_open is not np.nan else None,
high=self.high if self._pre_high is not np.nan else None,
low=self.low if self._pre_low is not np.nan else None,
close=self.close,
cash_sharing=self.cash_sharing,
init_cash=self.init_cash if self.init_cash_mode is None else self.init_cash_mode,
init_position=self.init_position,
init_price=self.init_price,
bm_close=(
self.bm_close
if (self["bm_close"] is not None and not isinstance(self["bm_close"], bool))
else self["bm_close"]
),
**kwargs,
)
@cachedproperty
def result(self) -> PFPrepResult:
"""Result as an instance of `PFPrepResult`."""
return PFPrepResult(target_func=self.target_func, target_args=self.target_args, pf_args=self.pf_args)
BasePFPreparer.override_arg_config_doc(__pdoc__)
order_arg_config = ReadonlyConfig(
dict(
size=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
fill_default=False,
),
price=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.PriceType, ignore_type=(int, float)),
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PriceType.Close)),
),
size_type=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.SizeType),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.SizeType.Amount)),
),
direction=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.Direction),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.Direction.Both)),
),
fees=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)),
),
fixed_fees=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)),
),
slippage=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)),
),
min_size=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
max_size=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
size_granularity=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
leverage=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=1.0)),
),
leverage_mode=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.LeverageMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.LeverageMode.Lazy)),
),
reject_prob=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)),
),
price_area_vio_mode=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.PriceAreaVioMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PriceAreaVioMode.Ignore)),
),
allow_partial=dict(
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=True)),
),
raise_reject=dict(
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
log=dict(
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
)
)
"""_"""
__pdoc__[
"order_arg_config"
] = f"""Argument config for order-related information.
```python
{order_arg_config.prettify()}
```
"""
fo_arg_config = ReadonlyConfig(
dict(
cash_dividends=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)),
),
val_price=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.ValPriceType, ignore_type=(int, float)),
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
from_ago=dict(
broadcast=True,
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0)),
),
ffill_val_price=dict(),
update_value=dict(),
save_state=dict(),
save_value=dict(),
save_returns=dict(),
skip_empty=dict(),
max_order_records=dict(),
max_log_records=dict(),
)
)
"""_"""
__pdoc__[
"fo_arg_config"
] = f"""Argument config for `FOPreparer`.
```python
{fo_arg_config.prettify()}
```
"""
@attach_arg_properties
@override_arg_config(fo_arg_config)
@override_arg_config(order_arg_config)
class FOPreparer(BasePFPreparer):
"""Class for preparing `vectorbtpro.portfolio.base.Portfolio.from_orders`."""
_settings_path: tp.SettingsPath = "portfolio.from_orders"
# ############# Ready arguments ############# #
@cachedproperty
def staticized(self) -> tp.StaticizedOption:
"""Argument `staticized`."""
raise ValueError("This method doesn't support staticization")
# ############# Before broadcasting ############# #
@cachedproperty
def _pre_from_ago(self) -> tp.ArrayLike:
"""Argument `from_ago` before broadcasting."""
from_ago = self["from_ago"]
if from_ago is not None:
return from_ago
return 0
@cachedproperty
def _pre_max_order_records(self) -> tp.Optional[int]:
"""Argument `max_order_records` before broadcasting."""
return self["max_order_records"]
@cachedproperty
def _pre_max_log_records(self) -> tp.Optional[int]:
"""Argument `max_log_records` before broadcasting."""
return self["max_log_records"]
# ############# After broadcasting ############# #
@cachedproperty
def sim_group_lens(self) -> tp.GroupLens:
return self.cs_group_lens
@cachedproperty
def auto_sim_start(self) -> tp.Optional[tp.ArrayLike]:
size = to_2d_array(self.size)
if size.shape[0] == 1:
return None
first_valid_idx = generic_nb.first_valid_index_nb(size, check_inf=False)
first_valid_idx = np.where(first_valid_idx == -1, 0, first_valid_idx)
return first_valid_idx
@cachedproperty
def auto_sim_end(self) -> tp.Optional[tp.ArrayLike]:
size = to_2d_array(self.size)
if size.shape[0] == 1:
return None
last_valid_idx = generic_nb.last_valid_index_nb(size, check_inf=False)
last_valid_idx = np.where(last_valid_idx == -1, len(self.wrapper.index), last_valid_idx + 1)
return last_valid_idx
@cachedproperty
def price_and_from_ago(self) -> tp.Tuple[tp.ArrayLike, tp.ArrayLike]:
"""Arguments `price` and `from_ago` after broadcasting."""
price = self._post_price
from_ago = self._post_from_ago
if self["from_ago"] is None:
if price.size == 1 or price.shape[0] == 1:
next_open_mask = price == enums.PriceType.NextOpen
next_close_mask = price == enums.PriceType.NextClose
next_valid_open_mask = price == enums.PriceType.NextValidOpen
next_valid_close_mask = price == enums.PriceType.NextValidClose
if next_valid_open_mask.any() or next_valid_close_mask.any():
new_price = np.empty(self.wrapper.shape_2d, float_)
new_from_ago = np.empty(self.wrapper.shape_2d, int_)
if next_valid_open_mask.any():
open = broadcast_array_to(self.open, self.wrapper.shape_2d)
if next_valid_close_mask.any():
close = broadcast_array_to(self.close, self.wrapper.shape_2d)
for i in range(price.size):
price_item = price.item(i)
if price_item == enums.PriceType.NextOpen:
new_price[:, i] = enums.PriceType.Open
new_from_ago[:, i] = 1
elif price_item == enums.PriceType.NextClose:
new_price[:, i] = enums.PriceType.Close
new_from_ago[:, i] = 1
elif price_item == enums.PriceType.NextValidOpen:
new_price[:, i] = enums.PriceType.Open
new_from_ago[:, i] = valid_price_from_ago_1d_nb(open[:, i])
elif price_item == enums.PriceType.NextValidClose:
new_price[:, i] = enums.PriceType.Close
new_from_ago[:, i] = valid_price_from_ago_1d_nb(close[:, i])
price = new_price
from_ago = new_from_ago
elif next_open_mask.any() or next_close_mask.any():
price = price.astype(float_)
price[next_open_mask] = enums.PriceType.Open
price[next_close_mask] = enums.PriceType.Close
from_ago = np.full(price.shape, 0, dtype=int_)
from_ago[next_open_mask] = 1
from_ago[next_close_mask] = 1
return price, from_ago
@cachedproperty
def price(self) -> tp.ArrayLike:
"""Argument `price`."""
return self.price_and_from_ago[0]
@cachedproperty
def from_ago(self) -> tp.ArrayLike:
"""Argument `from_ago`."""
return self.price_and_from_ago[1]
@cachedproperty
def max_order_records(self) -> tp.Optional[int]:
"""Argument `max_order_records`."""
max_order_records = self._pre_max_order_records
if max_order_records is None:
_size = self._post_size
if _size.size == 1:
max_order_records = self.target_shape[0] * int(not np.isnan(_size.item(0)))
else:
if _size.shape[0] == 1 and self.target_shape[0] > 1:
max_order_records = self.target_shape[0] * int(np.any(~np.isnan(_size)))
else:
max_order_records = int(np.max(np.sum(~np.isnan(_size), axis=0)))
return max_order_records
@cachedproperty
def max_log_records(self) -> tp.Optional[int]:
"""Argument `max_log_records`."""
max_log_records = self._pre_max_log_records
if max_log_records is None:
_log = self._post_log
if _log.size == 1:
max_log_records = self.target_shape[0] * int(_log.item(0))
else:
if _log.shape[0] == 1 and self.target_shape[0] > 1:
max_log_records = self.target_shape[0] * int(np.any(_log))
else:
max_log_records = int(np.max(np.sum(_log, axis=0)))
return max_log_records
# ############# Template substitution ############# #
@cachedproperty
def template_context(self) -> tp.Kwargs:
return merge_dicts(
dict(
group_lens=self.group_lens if self.dynamic_mode else self.cs_group_lens,
ffill_val_price=self.ffill_val_price,
update_value=self.update_value,
save_state=self.save_state,
save_value=self.save_value,
save_returns=self.save_returns,
max_order_records=self.max_order_records,
max_log_records=self.max_log_records,
),
BasePFPreparer.template_context.func(self),
)
# ############# Result ############# #
@cachedproperty
def target_func(self) -> tp.Optional[tp.Callable]:
func = jit_reg.resolve_option(nb.from_orders_nb, self.jitted)
func = ch_reg.resolve_option(func, self.chunked)
return func
@cachedproperty
def target_arg_map(self) -> tp.Kwargs:
target_arg_map = dict(BasePFPreparer.target_arg_map.func(self))
target_arg_map["group_lens"] = "cs_group_lens"
return target_arg_map
FOPreparer.override_arg_config_doc(__pdoc__)
fs_arg_config = ReadonlyConfig(
dict(
size=dict(
fill_default=True,
),
cash_dividends=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0.0)),
),
entries=dict(
has_default=False,
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
exits=dict(
has_default=False,
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
long_entries=dict(
has_default=False,
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
long_exits=dict(
has_default=False,
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
short_entries=dict(
has_default=False,
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
short_exits=dict(
has_default=False,
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
adjust_func_nb=dict(),
adjust_args=dict(type="args", substitute_templates=True),
signal_func_nb=dict(),
signal_args=dict(type="args", substitute_templates=True),
post_signal_func_nb=dict(),
post_signal_args=dict(type="args", substitute_templates=True),
post_segment_func_nb=dict(),
post_segment_args=dict(type="args", substitute_templates=True),
order_mode=dict(),
val_price=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.ValPriceType, ignore_type=(int, float)),
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
accumulate=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.AccumulationMode, ignore_type=(int, bool)),
subdtype=(np.integer, np.bool_),
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.AccumulationMode.Disabled)),
),
upon_long_conflict=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.ConflictMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.ConflictMode.Ignore)),
),
upon_short_conflict=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.ConflictMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.ConflictMode.Ignore)),
),
upon_dir_conflict=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.DirectionConflictMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.DirectionConflictMode.Ignore)),
),
upon_opposite_entry=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.OppositeEntryMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.OppositeEntryMode.ReverseReduce)),
),
order_type=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.OrderType),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.OrderType.Market)),
),
limit_delta=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
limit_tif=dict(
broadcast=True,
is_td=True,
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=-1)),
),
limit_expiry=dict(
broadcast=True,
is_dt=True,
last_before=False,
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=-1)),
),
limit_reverse=dict(
broadcast=True,
subdtype=np.bool_,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=False)),
),
limit_order_price=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.LimitOrderPrice, ignore_type=(int, float)),
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.LimitOrderPrice.Limit)),
),
upon_adj_limit_conflict=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.PendingConflictMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PendingConflictMode.KeepIgnore)),
),
upon_opp_limit_conflict=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.PendingConflictMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PendingConflictMode.CancelExecute)),
),
use_stops=dict(),
stop_ladder=dict(map_enum_kwargs=dict(enum=enums.StopLadderMode, look_for_type=str)),
sl_stop=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
tsl_stop=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
tsl_th=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
tp_stop=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
td_stop=dict(
broadcast=True,
is_td=True,
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=-1)),
),
dt_stop=dict(
broadcast=True,
is_dt=True,
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=-1)),
),
stop_entry_price=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.StopEntryPrice, ignore_type=(int, float)),
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.StopEntryPrice.Close)),
),
stop_exit_price=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.StopExitPrice, ignore_type=(int, float)),
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.StopExitPrice.Stop)),
),
stop_exit_type=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.StopExitType),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.StopExitType.Close)),
),
stop_order_type=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.OrderType),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.OrderType.Market)),
),
stop_limit_delta=dict(
broadcast=True,
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
upon_stop_update=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.StopUpdateMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.StopUpdateMode.Override)),
),
upon_adj_stop_conflict=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.PendingConflictMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PendingConflictMode.KeepExecute)),
),
upon_opp_stop_conflict=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.PendingConflictMode),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.PendingConflictMode.KeepExecute)),
),
delta_format=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.DeltaFormat),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.DeltaFormat.Percent)),
),
time_delta_format=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.TimeDeltaFormat),
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=enums.TimeDeltaFormat.Index)),
),
from_ago=dict(
broadcast=True,
subdtype=np.integer,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=0)),
),
ffill_val_price=dict(),
update_value=dict(),
fill_pos_info=dict(),
save_state=dict(),
save_value=dict(),
save_returns=dict(),
skip_empty=dict(),
max_order_records=dict(),
max_log_records=dict(),
records=dict(
rename_fields=dict(
entry="entries",
exit="exits",
long_entry="long_entries",
long_exit="long_exits",
short_entry="short_entries",
short_exit="short_exits",
)
),
)
)
"""_"""
__pdoc__[
"fs_arg_config"
] = f"""Argument config for `FSPreparer`.
```python
{fs_arg_config.prettify()}
```
"""
@attach_arg_properties
@override_arg_config(fs_arg_config)
@override_arg_config(order_arg_config)
class FSPreparer(BasePFPreparer):
"""Class for preparing `vectorbtpro.portfolio.base.Portfolio.from_signals`."""
_settings_path: tp.SettingsPath = "portfolio.from_signals"
# ############# Mode resolution ############# #
@cachedproperty
def _pre_staticized(self) -> tp.StaticizedOption:
"""Argument `staticized` before its resolution."""
staticized = self["staticized"]
if isinstance(staticized, bool):
if staticized:
staticized = dict()
else:
staticized = None
if isinstance(staticized, dict):
staticized = dict(staticized)
if "func" not in staticized:
staticized["func"] = nb.from_signal_func_nb
return staticized
@cachedproperty
def order_mode(self) -> bool:
"""Argument `order_mode`."""
order_mode = self["order_mode"]
if order_mode is None:
order_mode = False
return order_mode
@cachedproperty
def dynamic_mode(self) -> tp.StaticizedOption:
"""Whether the dynamic mode is enabled."""
return (
self["adjust_func_nb"] is not None
or self["signal_func_nb"] is not None
or self["post_signal_func_nb"] is not None
or self["post_segment_func_nb"] is not None
or self.order_mode
or self._pre_staticized is not None
)
@cachedproperty
def implicit_mode(self) -> bool:
"""Whether the explicit mode is enabled."""
return self["entries"] is not None or self["exits"] is not None
@cachedproperty
def explicit_mode(self) -> bool:
"""Whether the explicit mode is enabled."""
return self["long_entries"] is not None or self["long_exits"] is not None
@cachedproperty
def _pre_ls_mode(self) -> bool:
"""Whether direction-aware mode is enabled before resolution."""
return self.explicit_mode or self["short_entries"] is not None or self["short_exits"] is not None
@cachedproperty
def _pre_signals_mode(self) -> bool:
"""Whether signals mode is enabled before resolution."""
return self.implicit_mode or self._pre_ls_mode
@cachedproperty
def ls_mode(self) -> bool:
"""Whether direction-aware mode is enabled."""
if not self._pre_signals_mode and not self.order_mode and self["signal_func_nb"] is None:
return True
ls_mode = self._pre_ls_mode
if self.config.get("direction", None) is not None and ls_mode:
raise ValueError("Direction and short signal arrays cannot be used together")
return ls_mode
@cachedproperty
def signals_mode(self) -> bool:
"""Whether signals mode is enabled."""
if not self._pre_signals_mode and not self.order_mode and self["signal_func_nb"] is None:
return True
signals_mode = self._pre_signals_mode
if signals_mode and self.order_mode:
raise ValueError("Signal arrays and order mode cannot be used together")
return signals_mode
@cachedproperty
def signal_func_mode(self) -> bool:
"""Whether signal function mode is enabled."""
return self.dynamic_mode and not self.signals_mode and not self.order_mode
@cachedproperty
def adjust_func_nb(self) -> tp.Optional[tp.Callable]:
"""Argument `adjust_func_nb`."""
if self.dynamic_mode:
if self["adjust_func_nb"] is None:
return nb.no_adjust_func_nb
return self["adjust_func_nb"]
return None
@cachedproperty
def signal_func_nb(self) -> tp.Optional[tp.Callable]:
"""Argument `signal_func_nb`."""
if self.dynamic_mode:
if self["signal_func_nb"] is None:
if self.ls_mode:
return nb.ls_signal_func_nb
if self.signals_mode:
return nb.dir_signal_func_nb
if self.order_mode:
return nb.order_signal_func_nb
return None
return self["signal_func_nb"]
return None
@cachedproperty
def post_signal_func_nb(self) -> tp.Optional[tp.Callable]:
"""Argument `post_signal_func_nb`."""
if self.dynamic_mode:
if self["post_signal_func_nb"] is None:
return nb.no_post_func_nb
return self["post_signal_func_nb"]
return None
@cachedproperty
def post_segment_func_nb(self) -> tp.Optional[tp.Callable]:
"""Argument `post_segment_func_nb`."""
if self.dynamic_mode:
if self["post_segment_func_nb"] is None:
if self.save_state or self.save_value or self.save_returns:
return nb.save_post_segment_func_nb
return nb.no_post_func_nb
return self["post_segment_func_nb"]
return None
@cachedproperty
def staticized(self) -> tp.StaticizedOption:
"""Argument `staticized`."""
staticized = self._pre_staticized
if isinstance(staticized, dict):
staticized = dict(staticized)
if self.dynamic_mode:
if self["signal_func_nb"] is None:
if self.ls_mode:
self.adapt_staticized_to_udf(staticized, "ls_signal_func_nb", "signal_func_nb")
staticized["suggest_fname"] = "from_ls_signal_func_nb"
elif self.signals_mode:
self.adapt_staticized_to_udf(staticized, "dir_signal_func_nb", "signal_func_nb")
staticized["suggest_fname"] = "from_dir_signal_func_nb"
elif self.order_mode:
self.adapt_staticized_to_udf(staticized, "order_signal_func_nb", "signal_func_nb")
staticized["suggest_fname"] = "from_order_signal_func_nb"
else:
self.adapt_staticized_to_udf(staticized, self["signal_func_nb"], "signal_func_nb")
if self["adjust_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["adjust_func_nb"], "adjust_func_nb")
if self["post_signal_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["post_signal_func_nb"], "post_signal_func_nb")
if self["post_segment_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["post_segment_func_nb"], "post_segment_func_nb")
elif self.save_state or self.save_value or self.save_returns:
self.adapt_staticized_to_udf(staticized, "save_post_segment_func_nb", "post_segment_func_nb")
return staticized
@cachedproperty
def _pre_chunked(self) -> tp.ChunkedOption:
"""Argument `chunked` before template substitution."""
return self["chunked"]
# ############# Before broadcasting ############# #
@cachedproperty
def _pre_entries(self) -> tp.ArrayLike:
"""Argument `entries` before broadcasting."""
return self["entries"] if self["entries"] is not None else False
@cachedproperty
def _pre_exits(self) -> tp.ArrayLike:
"""Argument `exits` before broadcasting."""
return self["exits"] if self["exits"] is not None else False
@cachedproperty
def _pre_long_entries(self) -> tp.ArrayLike:
"""Argument `long_entries` before broadcasting."""
return self["long_entries"] if self["long_entries"] is not None else False
@cachedproperty
def _pre_long_exits(self) -> tp.ArrayLike:
"""Argument `long_exits` before broadcasting."""
return self["long_exits"] if self["long_exits"] is not None else False
@cachedproperty
def _pre_short_entries(self) -> tp.ArrayLike:
"""Argument `short_entries` before broadcasting."""
return self["short_entries"] if self["short_entries"] is not None else False
@cachedproperty
def _pre_short_exits(self) -> tp.ArrayLike:
"""Argument `short_exits` before broadcasting."""
return self["short_exits"] if self["short_exits"] is not None else False
@cachedproperty
def _pre_from_ago(self) -> tp.ArrayLike:
"""Argument `from_ago` before broadcasting."""
from_ago = self["from_ago"]
if from_ago is not None:
return from_ago
return 0
@cachedproperty
def _pre_max_log_records(self) -> tp.Optional[int]:
"""Argument `max_log_records` before broadcasting."""
return self["max_log_records"]
@classmethod
def init_in_outputs(
cls,
wrapper: ArrayWrapper,
group_lens: tp.Optional[tp.GroupLens] = None,
cash_sharing: bool = False,
save_state: bool = True,
save_value: bool = True,
save_returns: bool = True,
) -> enums.FSInOutputs:
"""Initialize `vectorbtpro.portfolio.enums.FSInOutputs`."""
if cash_sharing:
if group_lens is None:
group_lens = wrapper.grouper.get_group_lens()
return nb.init_FSInOutputs_nb(
wrapper.shape_2d,
group_lens,
cash_sharing=cash_sharing,
save_state=save_state,
save_value=save_value,
save_returns=save_returns,
)
@cachedproperty
def _pre_in_outputs(self) -> tp.Union[None, tp.NamedTuple, CustomTemplate]:
if self.dynamic_mode:
if self["post_segment_func_nb"] is None:
if self.save_state or self.save_value or self.save_returns:
return RepFunc(self.init_in_outputs)
return BasePFPreparer._pre_in_outputs.func(self)
if self["in_outputs"] is not None:
raise ValueError("Argument in_outputs cannot be used in fixed mode")
return None
# ############# Broadcasting ############# #
@cachedproperty
def def_broadcast_kwargs(self) -> tp.Kwargs:
def_broadcast_kwargs = dict(BasePFPreparer.def_broadcast_kwargs.func(self))
new_def_broadcast_kwargs = dict()
if self.order_mode:
new_def_broadcast_kwargs["keep_flex"] = dict(
size=False,
size_type=False,
min_size=False,
max_size=False,
)
new_def_broadcast_kwargs["min_ndim"] = dict(
size=2,
size_type=2,
min_size=2,
max_size=2,
)
new_def_broadcast_kwargs["require_kwargs"] = dict(
size=dict(requirements="O"),
size_type=dict(requirements="O"),
min_size=dict(requirements="O"),
max_size=dict(requirements="O"),
)
if self.stop_ladder:
new_def_broadcast_kwargs["axis"] = dict(
sl_stop=1,
tsl_stop=1,
tp_stop=1,
td_stop=1,
dt_stop=1,
)
new_def_broadcast_kwargs["merge_kwargs"] = dict(
sl_stop=dict(reset_index="from_start", fill_value=np.nan),
tsl_stop=dict(reset_index="from_start", fill_value=np.nan),
tp_stop=dict(reset_index="from_start", fill_value=np.nan),
td_stop=dict(reset_index="from_start", fill_value=-1),
dt_stop=dict(reset_index="from_start", fill_value=-1),
)
return merge_dicts(def_broadcast_kwargs, new_def_broadcast_kwargs)
# ############# After broadcasting ############# #
@cachedproperty
def sim_group_lens(self) -> tp.GroupLens:
if not self.dynamic_mode:
return self.cs_group_lens
return self.group_lens
@cachedproperty
def signals(self) -> tp.Tuple[tp.ArrayLike, tp.ArrayLike, tp.ArrayLike, tp.ArrayLike]:
"""Arguments `entries`, `exits`, `short_entries`, and `short_exits` after broadcasting."""
if not self.dynamic_mode and not self.ls_mode:
entries = self._post_entries
exits = self._post_exits
direction = self._post_direction
if direction.size == 1:
_direction = direction.item(0)
if _direction == enums.Direction.LongOnly:
long_entries = entries
long_exits = exits
short_entries = np.array([[False]])
short_exits = np.array([[False]])
elif _direction == enums.Direction.ShortOnly:
long_entries = np.array([[False]])
long_exits = np.array([[False]])
short_entries = entries
short_exits = exits
else:
long_entries = entries
long_exits = np.array([[False]])
short_entries = exits
short_exits = np.array([[False]])
else:
return nb.dir_to_ls_signals_nb(
target_shape=self.target_shape,
entries=entries,
exits=exits,
direction=direction,
)
else:
if self.explicit_mode and self.implicit_mode:
long_entries = self._post_entries | self._post_long_entries
long_exits = self._post_exits | self._post_long_exits
short_entries = self._post_entries | self._post_short_entries
short_exits = self._post_exits | self._post_short_exits
elif self.explicit_mode:
long_entries = self._post_long_entries
long_exits = self._post_long_exits
short_entries = self._post_short_entries
short_exits = self._post_short_exits
else:
long_entries = self._post_entries
long_exits = self._post_exits
short_entries = self._post_short_entries
short_exits = self._post_short_exits
return long_entries, long_exits, short_entries, short_exits
@cachedproperty
def long_entries(self) -> tp.ArrayLike:
"""Argument `long_entries`."""
return self.signals[0]
@cachedproperty
def long_exits(self) -> tp.ArrayLike:
"""Argument `long_exits`."""
return self.signals[1]
@cachedproperty
def short_entries(self) -> tp.ArrayLike:
"""Argument `short_entries`."""
return self.signals[2]
@cachedproperty
def short_exits(self) -> tp.ArrayLike:
"""Argument `short_exits`."""
return self.signals[3]
@cachedproperty
def combined_mask(self) -> tp.Array2d:
"""Signals combined using the OR rule into a mask."""
long_entries = to_2d_array(self.long_entries)
long_exits = to_2d_array(self.long_exits)
short_entries = to_2d_array(self.short_entries)
short_exits = to_2d_array(self.short_exits)
return long_entries | long_exits | short_entries | short_exits
@cachedproperty
def auto_sim_start(self) -> tp.Optional[tp.ArrayLike]:
if self.combined_mask.shape[0] == 1:
return None
first_signal_idx = signals_nb.nth_index_nb(self.combined_mask, 0)
return np.where(first_signal_idx == -1, 0, first_signal_idx)
@cachedproperty
def auto_sim_end(self) -> tp.Optional[tp.ArrayLike]:
if self.combined_mask.shape[0] == 1:
return None
last_signal_idx = signals_nb.nth_index_nb(self.combined_mask, -1)
return np.where(last_signal_idx == -1, len(self.wrapper.index), last_signal_idx + 1)
@cachedproperty
def price_and_from_ago(self) -> tp.Tuple[tp.ArrayLike, tp.ArrayLike]:
"""Arguments `price` and `from_ago` after broadcasting."""
price = self._post_price
from_ago = self._post_from_ago
if self["from_ago"] is None:
if price.size == 1 or price.shape[0] == 1:
next_open_mask = price == enums.PriceType.NextOpen
next_close_mask = price == enums.PriceType.NextClose
if next_open_mask.any() or next_close_mask.any():
price = price.astype(float_)
price[next_open_mask] = enums.PriceType.Open
price[next_close_mask] = enums.PriceType.Close
from_ago = np.full(price.shape, 0, dtype=int_)
from_ago[next_open_mask] = 1
from_ago[next_close_mask] = 1
return price, from_ago
@cachedproperty
def price(self) -> tp.ArrayLike:
"""Argument `price`."""
return self.price_and_from_ago[0]
@cachedproperty
def from_ago(self) -> tp.ArrayLike:
"""Argument `from_ago`."""
return self.price_and_from_ago[1]
@cachedproperty
def max_log_records(self) -> tp.Optional[int]:
"""Argument `max_log_records`."""
max_log_records = self._pre_max_log_records
if max_log_records is None:
_log = self._post_log
if _log.size == 1:
max_log_records = self.target_shape[0] * int(_log.item(0))
else:
if _log.shape[0] == 1 and self.target_shape[0] > 1:
max_log_records = self.target_shape[0] * int(np.any(_log))
else:
max_log_records = int(np.max(np.sum(_log, axis=0)))
return max_log_records
@cachedproperty
def use_stops(self) -> bool:
"""Argument `use_stops`."""
if self["use_stops"] is None:
if self.stop_ladder:
use_stops = True
else:
if self.dynamic_mode:
use_stops = True
else:
if (
not np.all(np.isnan(self.sl_stop))
or not np.all(np.isnan(self.tsl_stop))
or not np.all(np.isnan(self.tp_stop))
or np.any(self.td_stop != -1)
or np.any(self.dt_stop != -1)
):
use_stops = True
else:
use_stops = False
else:
use_stops = self["use_stops"]
return use_stops
@cachedproperty
def use_limit_orders(self) -> bool:
"""Whether to use limit orders."""
if np.any(self.order_type == enums.OrderType.Limit):
return True
if self.use_stops and np.any(self.stop_order_type == enums.OrderType.Limit):
return True
return False
@cachedproperty
def basic_mode(self) -> bool:
"""Whether the basic mode is enabled."""
return not self.use_stops and not self.use_limit_orders
# ############# Template substitution ############# #
@cachedproperty
def template_context(self) -> tp.Kwargs:
return merge_dicts(
dict(
order_mode=self.order_mode,
use_stops=self.use_stops,
stop_ladder=self.stop_ladder,
adjust_func_nb=self.adjust_func_nb,
adjust_args=self._pre_adjust_args,
signal_func_nb=self.signal_func_nb,
signal_args=self._pre_signal_args,
post_signal_func_nb=self.post_signal_func_nb,
post_signal_args=self._pre_post_signal_args,
post_segment_func_nb=self.post_segment_func_nb,
post_segment_args=self._pre_post_segment_args,
ffill_val_price=self.ffill_val_price,
update_value=self.update_value,
fill_pos_info=self.fill_pos_info,
save_state=self.save_state,
save_value=self.save_value,
save_returns=self.save_returns,
max_order_records=self.max_order_records,
max_log_records=self.max_log_records,
),
BasePFPreparer.template_context.func(self),
)
@cachedproperty
def signal_args(self) -> tp.Args:
"""Argument `signal_args`."""
if self.dynamic_mode:
if self["signal_func_nb"] is None:
if self.ls_mode:
return (
self.long_entries,
self.long_exits,
self.short_entries,
self.short_exits,
self.from_ago,
*((self.adjust_func_nb,) if self.staticized is None else ()),
self.adjust_args,
)
if self.signals_mode:
return (
self.entries,
self.exits,
self.direction,
self.from_ago,
*((self.adjust_func_nb,) if self.staticized is None else ()),
self.adjust_args,
)
if self.order_mode:
return (
self.size,
self.price,
self.size_type,
self.direction,
self.min_size,
self.max_size,
self.val_price,
self.from_ago,
*((self.adjust_func_nb,) if self.staticized is None else ()),
self.adjust_args,
)
return self._post_signal_args
@cachedproperty
def post_segment_args(self) -> tp.Args:
"""Argument `post_segment_args`."""
if self.dynamic_mode:
if self["post_segment_func_nb"] is None:
if self.save_state or self.save_value or self.save_returns:
return (
self.save_state,
self.save_value,
self.save_returns,
)
return self._post_post_segment_args
@cachedproperty
def chunked(self) -> tp.ChunkedOption:
if self.dynamic_mode:
if self["signal_func_nb"] is None:
if self.ls_mode:
return ch.specialize_chunked_option(
self._pre_chunked,
arg_take_spec=dict(
signal_args=ch.ArgsTaker(
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
*((None,) if self.staticized is None else ()),
ch.ArgsTaker(),
)
),
)
if self.signals_mode:
return ch.specialize_chunked_option(
self._pre_chunked,
arg_take_spec=dict(
signal_args=ch.ArgsTaker(
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
*((None,) if self.staticized is None else ()),
ch.ArgsTaker(),
)
),
)
if self.order_mode:
return ch.specialize_chunked_option(
self._pre_chunked,
arg_take_spec=dict(
signal_args=ch.ArgsTaker(
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
base_ch.flex_array_gl_slicer,
*((None,) if self.staticized is None else ()),
ch.ArgsTaker(),
)
),
)
return self._pre_chunked
# ############# Result ############# #
@cachedproperty
def target_func(self) -> tp.Optional[tp.Callable]:
if self.dynamic_mode:
func = self.resolve_dynamic_target_func("from_signal_func_nb", self.staticized)
elif not self.basic_mode:
func = nb.from_signals_nb
else:
func = nb.from_basic_signals_nb
func = jit_reg.resolve_option(func, self.jitted)
func = ch_reg.resolve_option(func, self.chunked)
return func
@cachedproperty
def target_arg_map(self) -> tp.Kwargs:
target_arg_map = dict(BasePFPreparer.target_arg_map.func(self))
if self.dynamic_mode:
if self.staticized is not None:
target_arg_map["signal_func_nb"] = None
target_arg_map["post_signal_func_nb"] = None
target_arg_map["post_segment_func_nb"] = None
else:
target_arg_map["group_lens"] = "cs_group_lens"
return target_arg_map
@cachedproperty
def pf_args(self) -> tp.Optional[tp.Kwargs]:
pf_args = dict(BasePFPreparer.pf_args.func(self))
pf_args["orders_cls"] = FSOrders
return pf_args
FSPreparer.override_arg_config_doc(__pdoc__)
fof_arg_config = ReadonlyConfig(
dict(
segment_mask=dict(),
call_pre_segment=dict(),
call_post_segment=dict(),
pre_sim_func_nb=dict(),
pre_sim_args=dict(type="args", substitute_templates=True),
post_sim_func_nb=dict(),
post_sim_args=dict(type="args", substitute_templates=True),
pre_group_func_nb=dict(),
pre_group_args=dict(type="args", substitute_templates=True),
post_group_func_nb=dict(),
post_group_args=dict(type="args", substitute_templates=True),
pre_row_func_nb=dict(),
pre_row_args=dict(type="args", substitute_templates=True),
post_row_func_nb=dict(),
post_row_args=dict(type="args", substitute_templates=True),
pre_segment_func_nb=dict(),
pre_segment_args=dict(type="args", substitute_templates=True),
post_segment_func_nb=dict(),
post_segment_args=dict(type="args", substitute_templates=True),
order_func_nb=dict(),
order_args=dict(type="args", substitute_templates=True),
flex_order_func_nb=dict(),
flex_order_args=dict(type="args", substitute_templates=True),
post_order_func_nb=dict(),
post_order_args=dict(type="args", substitute_templates=True),
ffill_val_price=dict(),
update_value=dict(),
fill_pos_info=dict(),
track_value=dict(),
row_wise=dict(),
max_order_records=dict(),
max_log_records=dict(),
)
)
"""_"""
__pdoc__[
"fof_arg_config"
] = f"""Argument config for `FOFPreparer`.
```python
{fof_arg_config.prettify()}
```
"""
@attach_arg_properties
@override_arg_config(fof_arg_config)
class FOFPreparer(BasePFPreparer):
"""Class for preparing `vectorbtpro.portfolio.base.Portfolio.from_order_func`."""
_settings_path: tp.SettingsPath = "portfolio.from_order_func"
# ############# Mode resolution ############# #
@cachedproperty
def _pre_staticized(self) -> tp.StaticizedOption:
"""Argument `staticized` before its resolution."""
staticized = self["staticized"]
if isinstance(staticized, bool):
if staticized:
staticized = dict()
else:
staticized = None
if isinstance(staticized, dict):
staticized = dict(staticized)
if "func" not in staticized:
if not self.flexible and not self.row_wise:
staticized["func"] = nb.from_order_func_nb
elif not self.flexible and self.row_wise:
staticized["func"] = nb.from_order_func_rw_nb
elif self.flexible and not self.row_wise:
staticized["func"] = nb.from_flex_order_func_nb
else:
staticized["func"] = nb.from_flex_order_func_rw_nb
return staticized
@cachedproperty
def flexible(self) -> bool:
"""Whether the flexible mode is enabled."""
return self["flex_order_func_nb"] is not None
@cachedproperty
def pre_sim_func_nb(self) -> tp.Callable:
"""Argument `pre_sim_func_nb`."""
pre_sim_func_nb = self["pre_sim_func_nb"]
if pre_sim_func_nb is None:
pre_sim_func_nb = nb.no_pre_func_nb
return pre_sim_func_nb
@cachedproperty
def post_sim_func_nb(self) -> tp.Callable:
"""Argument `post_sim_func_nb`."""
post_sim_func_nb = self["post_sim_func_nb"]
if post_sim_func_nb is None:
post_sim_func_nb = nb.no_post_func_nb
return post_sim_func_nb
@cachedproperty
def pre_group_func_nb(self) -> tp.Callable:
"""Argument `pre_group_func_nb`."""
pre_group_func_nb = self["pre_group_func_nb"]
if self.row_wise and pre_group_func_nb is not None:
raise ValueError("Cannot use pre_group_func_nb in a row-wise simulation")
if pre_group_func_nb is None:
pre_group_func_nb = nb.no_pre_func_nb
return pre_group_func_nb
@cachedproperty
def post_group_func_nb(self) -> tp.Callable:
"""Argument `post_group_func_nb`."""
post_group_func_nb = self["post_group_func_nb"]
if self.row_wise and post_group_func_nb is not None:
raise ValueError("Cannot use post_group_func_nb in a row-wise simulation")
if post_group_func_nb is None:
post_group_func_nb = nb.no_post_func_nb
return post_group_func_nb
@cachedproperty
def pre_row_func_nb(self) -> tp.Callable:
"""Argument `pre_row_func_nb`."""
pre_row_func_nb = self["pre_row_func_nb"]
if not self.row_wise and pre_row_func_nb is not None:
raise ValueError("Cannot use pre_row_func_nb in a column-wise simulation")
if pre_row_func_nb is None:
pre_row_func_nb = nb.no_pre_func_nb
return pre_row_func_nb
@cachedproperty
def post_row_func_nb(self) -> tp.Callable:
"""Argument `post_row_func_nb`."""
post_row_func_nb = self["post_row_func_nb"]
if not self.row_wise and post_row_func_nb is not None:
raise ValueError("Cannot use post_row_func_nb in a column-wise simulation")
if post_row_func_nb is None:
post_row_func_nb = nb.no_post_func_nb
return post_row_func_nb
@cachedproperty
def pre_segment_func_nb(self) -> tp.Callable:
"""Argument `pre_segment_func_nb`."""
pre_segment_func_nb = self["pre_segment_func_nb"]
if pre_segment_func_nb is None:
pre_segment_func_nb = nb.no_pre_func_nb
return pre_segment_func_nb
@cachedproperty
def post_segment_func_nb(self) -> tp.Callable:
"""Argument `post_segment_func_nb`."""
post_segment_func_nb = self["post_segment_func_nb"]
if post_segment_func_nb is None:
post_segment_func_nb = nb.no_post_func_nb
return post_segment_func_nb
@cachedproperty
def order_func_nb(self) -> tp.Callable:
"""Argument `order_func_nb`."""
order_func_nb = self["order_func_nb"]
if self.flexible and order_func_nb is not None:
raise ValueError("Must provide either order_func_nb or flex_order_func_nb")
if not self.flexible and order_func_nb is None:
raise ValueError("Must provide either order_func_nb or flex_order_func_nb")
if order_func_nb is None:
order_func_nb = nb.no_order_func_nb
return order_func_nb
@cachedproperty
def flex_order_func_nb(self) -> tp.Callable:
"""Argument `flex_order_func_nb`."""
flex_order_func_nb = self["flex_order_func_nb"]
if flex_order_func_nb is None:
flex_order_func_nb = nb.no_flex_order_func_nb
return flex_order_func_nb
@cachedproperty
def post_order_func_nb(self) -> tp.Callable:
"""Argument `post_order_func_nb`."""
post_order_func_nb = self["post_order_func_nb"]
if post_order_func_nb is None:
post_order_func_nb = nb.no_post_func_nb
return post_order_func_nb
@cachedproperty
def staticized(self) -> tp.StaticizedOption:
"""Argument `staticized`."""
staticized = self._pre_staticized
if isinstance(staticized, dict):
staticized = dict(staticized)
if self["pre_sim_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["pre_sim_func_nb"], "pre_sim_func_nb")
if self["post_sim_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["post_sim_func_nb"], "post_sim_func_nb")
if self["pre_group_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["pre_group_func_nb"], "pre_group_func_nb")
if self["post_group_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["post_group_func_nb"], "post_group_func_nb")
if self["pre_row_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["pre_row_func_nb"], "pre_row_func_nb")
if self["post_row_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["post_row_func_nb"], "post_row_func_nb")
if self["pre_segment_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["pre_segment_func_nb"], "pre_segment_func_nb")
if self["post_segment_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["post_segment_func_nb"], "post_segment_func_nb")
if self["order_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["order_func_nb"], "order_func_nb")
if self["flex_order_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["flex_order_func_nb"], "flex_order_func_nb")
if self["post_order_func_nb"] is not None:
self.adapt_staticized_to_udf(staticized, self["post_order_func_nb"], "post_order_func_nb")
return staticized
# ############# Before broadcasting ############# #
@cachedproperty
def _pre_call_seq(self) -> tp.Optional[tp.ArrayLike]:
if self.auto_call_seq:
raise ValueError(
"CallSeqType.Auto must be implemented manually. Use sort_call_seq_nb in pre_segment_func_nb."
)
return self["call_seq"]
@cachedproperty
def _pre_segment_mask(self) -> tp.ArrayLike:
"""Argument `segment_mask` before broadcasting."""
return self["segment_mask"]
# ############# After broadcasting ############# #
@cachedproperty
def sim_start(self) -> tp.Optional[tp.ArrayLike]:
sim_start = self["sim_start"]
if sim_start is None:
return None
return BasePFPreparer.sim_start.func(self)
@cachedproperty
def sim_end(self) -> tp.Optional[tp.ArrayLike]:
sim_end = self["sim_end"]
if sim_end is None:
return None
return BasePFPreparer.sim_end.func(self)
@cachedproperty
def segment_mask(self) -> tp.ArrayLike:
"""Argument `segment_mask`."""
segment_mask = self._pre_segment_mask
if checks.is_int(segment_mask):
if self.keep_inout_flex:
_segment_mask = np.full((self.target_shape[0], 1), False)
else:
_segment_mask = np.full((self.target_shape[0], len(self.group_lens)), False)
_segment_mask[0::segment_mask] = True
segment_mask = _segment_mask
else:
segment_mask = broadcast(
segment_mask,
to_shape=(self.target_shape[0], len(self.group_lens)),
to_pd=False,
keep_flex=self.keep_inout_flex,
reindex_kwargs=dict(fill_value=False),
require_kwargs=self.broadcast_kwargs.get("require_kwargs", {}),
)
checks.assert_subdtype(segment_mask, np.bool_, arg_name="segment_mask")
return segment_mask
# ############# Template substitution ############# #
@cachedproperty
def template_context(self) -> tp.Kwargs:
return merge_dicts(
dict(
segment_mask=self.segment_mask,
call_pre_segment=self.call_pre_segment,
call_post_segment=self.call_post_segment,
pre_sim_func_nb=self.pre_sim_func_nb,
pre_sim_args=self._pre_pre_sim_args,
post_sim_func_nb=self.post_sim_func_nb,
post_sim_args=self._pre_post_sim_args,
pre_group_func_nb=self.pre_group_func_nb,
pre_group_args=self._pre_pre_group_args,
post_group_func_nb=self.post_group_func_nb,
post_group_args=self._pre_post_group_args,
pre_row_func_nb=self.pre_row_func_nb,
pre_row_args=self._pre_pre_row_args,
post_row_func_nb=self.post_row_func_nb,
post_row_args=self._pre_post_row_args,
pre_segment_func_nb=self.pre_segment_func_nb,
pre_segment_args=self._pre_pre_segment_args,
post_segment_func_nb=self.post_segment_func_nb,
post_segment_args=self._pre_post_segment_args,
order_func_nb=self.order_func_nb,
order_args=self._pre_order_args,
flex_order_func_nb=self.flex_order_func_nb,
flex_order_args=self._pre_flex_order_args,
post_order_func_nb=self.post_order_func_nb,
post_order_args=self._pre_post_order_args,
ffill_val_price=self.ffill_val_price,
update_value=self.update_value,
fill_pos_info=self.fill_pos_info,
track_value=self.track_value,
max_order_records=self.max_order_records,
max_log_records=self.max_log_records,
),
BasePFPreparer.template_context.func(self),
)
# ############# Result ############# #
@cachedproperty
def target_func(self) -> tp.Optional[tp.Callable]:
if not self.row_wise and not self.flexible:
func = self.resolve_dynamic_target_func("from_order_func_nb", self.staticized)
elif not self.row_wise and self.flexible:
func = self.resolve_dynamic_target_func("from_flex_order_func_nb", self.staticized)
elif self.row_wise and not self.flexible:
func = self.resolve_dynamic_target_func("from_order_func_rw_nb", self.staticized)
else:
func = self.resolve_dynamic_target_func("from_flex_order_func_rw_nb", self.staticized)
func = jit_reg.resolve_option(func, self.jitted)
func = ch_reg.resolve_option(func, self.chunked)
return func
@cachedproperty
def target_arg_map(self) -> tp.Kwargs:
target_arg_map = dict(BasePFPreparer.target_arg_map.func(self))
if self.staticized is not None:
target_arg_map["pre_sim_func_nb"] = None
target_arg_map["post_sim_func_nb"] = None
target_arg_map["pre_group_func_nb"] = None
target_arg_map["post_group_func_nb"] = None
target_arg_map["pre_row_func_nb"] = None
target_arg_map["post_row_func_nb"] = None
target_arg_map["pre_segment_func_nb"] = None
target_arg_map["post_segment_func_nb"] = None
target_arg_map["order_func_nb"] = None
target_arg_map["flex_order_func_nb"] = None
target_arg_map["post_order_func_nb"] = None
return target_arg_map
fdof_arg_config = ReadonlyConfig(
dict(
val_price=dict(
broadcast=True,
map_enum_kwargs=dict(enum=enums.ValPriceType, ignore_type=(int, float)),
subdtype=np.number,
broadcast_kwargs=dict(reindex_kwargs=dict(fill_value=np.nan)),
),
flexible=dict(),
)
)
"""_"""
__pdoc__[
"fdof_arg_config"
] = f"""Argument config for `FDOFPreparer`.
```python
{fdof_arg_config.prettify()}
```
"""
@attach_arg_properties
@override_arg_config(fdof_arg_config)
@override_arg_config(order_arg_config)
class FDOFPreparer(FOFPreparer):
"""Class for preparing `vectorbtpro.portfolio.base.Portfolio.from_def_order_func`."""
_settings_path: tp.SettingsPath = "portfolio.from_def_order_func"
# ############# Mode resolution ############# #
@cachedproperty
def flexible(self) -> bool:
return self["flexible"]
@cachedproperty
def pre_segment_func_nb(self) -> tp.Callable:
"""Argument `pre_segment_func_nb`."""
pre_segment_func_nb = self["pre_segment_func_nb"]
if pre_segment_func_nb is None:
if self.flexible:
pre_segment_func_nb = nb.def_flex_pre_segment_func_nb
else:
pre_segment_func_nb = nb.def_pre_segment_func_nb
return pre_segment_func_nb
@cachedproperty
def order_func_nb(self) -> tp.Callable:
"""Argument `order_func_nb`."""
order_func_nb = self["order_func_nb"]
if self.flexible and order_func_nb is not None:
raise ValueError("Argument order_func_nb cannot be provided when flexible=True")
if order_func_nb is None:
order_func_nb = nb.def_order_func_nb
return order_func_nb
@cachedproperty
def flex_order_func_nb(self) -> tp.Callable:
"""Argument `flex_order_func_nb`."""
flex_order_func_nb = self["flex_order_func_nb"]
if not self.flexible and flex_order_func_nb is not None:
raise ValueError("Argument flex_order_func_nb cannot be provided when flexible=False")
if flex_order_func_nb is None:
flex_order_func_nb = nb.def_flex_order_func_nb
return flex_order_func_nb
@cachedproperty
def _pre_chunked(self) -> tp.ChunkedOption:
"""Argument `chunked` before template substitution."""
return self["chunked"]
@cachedproperty
def staticized(self) -> tp.StaticizedOption:
staticized = FOFPreparer.staticized.func(self)
if isinstance(staticized, dict):
if "pre_segment_func_nb" not in staticized:
self.adapt_staticized_to_udf(staticized, self.pre_segment_func_nb, "pre_segment_func_nb")
if "order_func_nb" not in staticized:
self.adapt_staticized_to_udf(staticized, self.order_func_nb, "order_func_nb")
if "flex_order_func_nb" not in staticized:
self.adapt_staticized_to_udf(staticized, self.flex_order_func_nb, "flex_order_func_nb")
return staticized
# ############# Before broadcasting ############# #
@cachedproperty
def _pre_call_seq(self) -> tp.Optional[tp.ArrayLike]:
return BasePFPreparer._pre_call_seq.func(self)
# ############# After broadcasting ############# #
@cachedproperty
def auto_sim_start(self) -> tp.Optional[tp.ArrayLike]:
return FOPreparer.auto_sim_start.func(self)
@cachedproperty
def auto_sim_end(self) -> tp.Optional[tp.ArrayLike]:
return FOPreparer.auto_sim_end.func(self)
# ############# Template substitution ############# #
@cachedproperty
def pre_segment_args(self) -> tp.Args:
"""Argument `pre_segment_args`."""
return (
self.val_price,
self.price,
self.size,
self.size_type,
self.direction,
self.auto_call_seq,
)
@cachedproperty
def _order_args(self) -> tp.Args:
"""Either `order_args` or `flex_order_args`."""
return (
self.size,
self.price,
self.size_type,
self.direction,
self.fees,
self.fixed_fees,
self.slippage,
self.min_size,
self.max_size,
self.size_granularity,
self.leverage,
self.leverage_mode,
self.reject_prob,
self.price_area_vio_mode,
self.allow_partial,
self.raise_reject,
self.log,
)
@cachedproperty
def order_args(self) -> tp.Args:
"""Argument `order_args`."""
if self.flexible:
return self._post_order_args
return self._order_args
@cachedproperty
def flex_order_args(self) -> tp.Args:
"""Argument `flex_order_args`."""
if not self.flexible:
return self._post_flex_order_args
return self._order_args
@cachedproperty
def chunked(self) -> tp.ChunkedOption:
arg_take_spec = dict()
arg_take_spec["pre_segment_args"] = ch.ArgsTaker(*[base_ch.flex_array_gl_slicer] * 5, None)
if self.flexible:
arg_take_spec["flex_order_args"] = ch.ArgsTaker(*[base_ch.flex_array_gl_slicer] * 17)
else:
arg_take_spec["order_args"] = ch.ArgsTaker(*[base_ch.flex_array_gl_slicer] * 17)
return ch.specialize_chunked_option(self._pre_chunked, arg_take_spec=arg_take_spec)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with trade records.
Trade records capture information on trades.
In vectorbt, a trade is a sequence of orders that starts with an opening order and optionally ends
with a closing order. Every pair of opposite orders can be represented by a trade. Each trade has a PnL
info attached to quickly assess its performance. An interesting effect of this representation
is the ability to aggregate trades: if two or more trades are happening one after another in time,
they can be aggregated into a bigger trade. This way, for example, single-order trades can be aggregated
into positions; but also multiple positions can be aggregated into a single blob that reflects the performance
of the entire symbol.
!!! warning
All classes return both closed AND open trades/positions, which may skew your performance results.
To only consider closed trades/positions, you should explicitly query the `status_closed` attribute.
## Trade types
There are three main types of trades.
### Entry trades
An entry trade is created from each order that opens or adds to a position.
For example, if we have a single large buy order and 100 smaller sell orders, we will see
a single trade with the entry information copied from the buy order and the exit information being
a size-weighted average over the exit information of all sell orders. On the other hand,
if we have 100 smaller buy orders and a single sell order, we will see 100 trades,
each with the entry information copied from the buy order and the exit information being
a size-based fraction of the exit information of the sell order.
Use `vectorbtpro.portfolio.trades.EntryTrades.from_orders` to build entry trades from orders.
Also available as `vectorbtpro.portfolio.base.Portfolio.entry_trades`.
### Exit trades
An exit trade is created from each order that closes or removes from a position.
Use `vectorbtpro.portfolio.trades.ExitTrades.from_orders` to build exit trades from orders.
Also available as `vectorbtpro.portfolio.base.Portfolio.exit_trades`.
### Positions
A position is created from a sequence of entry or exit trades.
Use `vectorbtpro.portfolio.trades.Positions.from_trades` to build positions from entry or exit trades.
Also available as `vectorbtpro.portfolio.base.Portfolio.positions`.
## Example
* Increasing position:
```pycon
>>> from vectorbtpro import *
>>> # Entry trades
>>> pf_kwargs = dict(
... close=pd.Series([1., 2., 3., 4., 5.]),
... size=pd.Series([1., 1., 1., 1., -4.]),
... fixed_fees=1.
... )
>>> entry_trades = vbt.Portfolio.from_orders(**pf_kwargs).entry_trades
>>> entry_trades.readable
Entry Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 1.0 0 0 1.0
1 1 0 1.0 1 1 2.0
2 2 0 1.0 2 2 3.0
3 3 0 1.0 3 3 4.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 4 4 5.0 0.25 2.75
1 1.0 4 4 5.0 0.25 1.75
2 1.0 4 4 5.0 0.25 0.75
3 1.0 4 4 5.0 0.25 -0.25
Return Direction Status Position Id
0 2.7500 Long Closed 0
1 0.8750 Long Closed 0
2 0.2500 Long Closed 0
3 -0.0625 Long Closed 0
>>> # Exit trades
>>> exit_trades = vbt.Portfolio.from_orders(**pf_kwargs).exit_trades
>>> exit_trades.readable
Exit Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 4.0 0 0 2.5
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 4.0 4 4 5.0 1.0 5.0
Return Direction Status Position Id
0 0.5 Long Closed 0
>>> # Positions
>>> positions = vbt.Portfolio.from_orders(**pf_kwargs).positions
>>> positions.readable
Position Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 4.0 0 0 2.5
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 4.0 4 4 5.0 1.0 5.0
Return Direction Status
0 0.5 Long Closed
>>> entry_trades.pnl.sum() == exit_trades.pnl.sum() == positions.pnl.sum()
True
```
* Decreasing position:
```pycon
>>> # Entry trades
>>> pf_kwargs = dict(
... close=pd.Series([1., 2., 3., 4., 5.]),
... size=pd.Series([4., -1., -1., -1., -1.]),
... fixed_fees=1.
... )
>>> entry_trades = vbt.Portfolio.from_orders(**pf_kwargs).entry_trades
>>> entry_trades.readable
Entry Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 4.0 0 0 1.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 4 4 3.5 4.0 5.0
Return Direction Status Position Id
0 1.25 Long Closed 0
>>> # Exit trades
>>> exit_trades = vbt.Portfolio.from_orders(**pf_kwargs).exit_trades
>>> exit_trades.readable
Exit Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 1.0 0 0 1.0
1 1 0 1.0 0 0 1.0
2 2 0 1.0 0 0 1.0
3 3 0 1.0 0 0 1.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 0.25 1 1 2.0 1.0 -0.25
1 0.25 2 2 3.0 1.0 0.75
2 0.25 3 3 4.0 1.0 1.75
3 0.25 4 4 5.0 1.0 2.75
Return Direction Status Position Id
0 -0.25 Long Closed 0
1 0.75 Long Closed 0
2 1.75 Long Closed 0
3 2.75 Long Closed 0
>>> # Positions
>>> positions = vbt.Portfolio.from_orders(**pf_kwargs).positions
>>> positions.readable
Position Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 4.0 0 0 1.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 4 4 3.5 4.0 5.0
Return Direction Status
0 1.25 Long Closed
>>> entry_trades.pnl.sum() == exit_trades.pnl.sum() == positions.pnl.sum()
True
```
* Multiple reversing positions:
```pycon
>>> # Entry trades
>>> pf_kwargs = dict(
... close=pd.Series([1., 2., 3., 4., 5.]),
... size=pd.Series([1., -2., 2., -2., 1.]),
... fixed_fees=1.
... )
>>> entry_trades = vbt.Portfolio.from_orders(**pf_kwargs).entry_trades
>>> entry_trades.readable
Entry Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 1.0 0 0 1.0
1 1 0 1.0 1 1 2.0
2 2 0 1.0 2 2 3.0
3 3 0 1.0 3 3 4.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 1 1 2.0 0.5 -0.5
1 0.5 2 2 3.0 0.5 -2.0
2 0.5 3 3 4.0 0.5 0.0
3 0.5 4 4 5.0 1.0 -2.5
Return Direction Status Position Id
0 -0.500 Long Closed 0
1 -1.000 Short Closed 1
2 0.000 Long Closed 2
3 -0.625 Short Closed 3
>>> # Exit trades
>>> exit_trades = vbt.Portfolio.from_orders(**pf_kwargs).exit_trades
>>> exit_trades.readable
Exit Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 1.0 0 0 1.0
1 1 0 1.0 1 1 2.0
2 2 0 1.0 2 2 3.0
3 3 0 1.0 3 3 4.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 1 1 2.0 0.5 -0.5
1 0.5 2 2 3.0 0.5 -2.0
2 0.5 3 3 4.0 0.5 0.0
3 0.5 4 4 5.0 1.0 -2.5
Return Direction Status Position Id
0 -0.500 Long Closed 0
1 -1.000 Short Closed 1
2 0.000 Long Closed 2
3 -0.625 Short Closed 3
>>> # Positions
>>> positions = vbt.Portfolio.from_orders(**pf_kwargs).positions
>>> positions.readable
Position Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 1.0 0 0 1.0
1 1 0 1.0 1 1 2.0
2 2 0 1.0 2 2 3.0
3 3 0 1.0 3 3 4.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 1 1 2.0 0.5 -0.5
1 0.5 2 2 3.0 0.5 -2.0
2 0.5 3 3 4.0 0.5 0.0
3 0.5 4 4 5.0 1.0 -2.5
Return Direction Status
0 -0.500 Long Closed
1 -1.000 Short Closed
2 0.000 Long Closed
3 -0.625 Short Closed
>>> entry_trades.pnl.sum() == exit_trades.pnl.sum() == positions.pnl.sum()
True
```
* Open position:
```pycon
>>> # Entry trades
>>> pf_kwargs = dict(
... close=pd.Series([1., 2., 3., 4., 5.]),
... size=pd.Series([1., 0., 0., 0., 0.]),
... fixed_fees=1.
... )
>>> entry_trades = vbt.Portfolio.from_orders(**pf_kwargs).entry_trades
>>> entry_trades.readable
Entry Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 1.0 0 0 1.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 -1 4 5.0 0.0 3.0
Return Direction Status Position Id
0 3.0 Long Open 0
>>> # Exit trades
>>> exit_trades = vbt.Portfolio.from_orders(**pf_kwargs).exit_trades
>>> exit_trades.readable
Exit Trade Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 1.0 0 0 1.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 -1 4 5.0 0.0 3.0
Return Direction Status Position Id
0 3.0 Long Open 0
>>> # Positions
>>> positions = vbt.Portfolio.from_orders(**pf_kwargs).positions
>>> positions.readable
Position Id Column Size Entry Order Id Entry Index Avg Entry Price \\
0 0 0 1.0 0 0 1.0
Entry Fees Exit Order Id Exit Index Avg Exit Price Exit Fees PnL \\
0 1.0 -1 4 5.0 0.0 3.0
Return Direction Status
0 3.0 Long Open
>>> entry_trades.pnl.sum() == exit_trades.pnl.sum() == positions.pnl.sum()
True
```
Get trade count, trade PnL, and winning trade PnL:
```pycon
>>> price = pd.Series([1., 2., 3., 4., 3., 2., 1.])
>>> size = pd.Series([1., -0.5, -0.5, 2., -0.5, -0.5, -0.5])
>>> trades = vbt.Portfolio.from_orders(price, size).trades
>>> trades.count()
6
>>> trades.pnl.sum()
-3.0
>>> trades.winning.count()
2
>>> trades.winning.pnl.sum()
1.5
```
Get count and PnL of trades with duration of more than 2 days:
```pycon
>>> mask = (trades.records['exit_idx'] - trades.records['entry_idx']) > 2
>>> trades_filtered = trades.apply_mask(mask)
>>> trades_filtered.count()
2
>>> trades_filtered.pnl.sum()
-3.0
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Trades.metrics`.
```pycon
>>> price = vbt.RandomData.pull(
... ['a', 'b'],
... start=datetime(2020, 1, 1),
... end=datetime(2020, 3, 1),
... seed=vbt.symbol_dict(a=42, b=43)
... ).get()
```
[=100% "100%"]{: .candystripe .candystripe-animate }
```pycon
>>> size = pd.DataFrame({
... 'a': np.random.randint(-1, 2, size=len(price.index)),
... 'b': np.random.randint(-1, 2, size=len(price.index)),
... }, index=price.index, columns=price.columns)
>>> pf = vbt.Portfolio.from_orders(price, size, fees=0.01, init_cash="auto")
>>> pf.trades['a'].stats()
Start 2019-12-31 23:00:00+00:00
End 2020-02-29 23:00:00+00:00
Period 61 days 00:00:00
First Trade Start 2019-12-31 23:00:00+00:00
Last Trade End 2020-02-29 23:00:00+00:00
Coverage 60 days 00:00:00
Overlap Coverage 49 days 00:00:00
Total Records 19.0
Total Long Trades 2.0
Total Short Trades 17.0
Total Closed Trades 18.0
Total Open Trades 1.0
Open Trade PnL 16.063
Win Rate [%] 61.111111
Max Win Streak 11.0
Max Loss Streak 7.0
Best Trade [%] 3.526377
Worst Trade [%] -6.543679
Avg Winning Trade [%] 2.225861
Avg Losing Trade [%] -3.601313
Avg Winning Trade Duration 32 days 19:38:10.909090909
Avg Losing Trade Duration 5 days 00:00:00
Profit Factor 1.022425
Expectancy 0.028157
SQN 0.039174
Name: agg_stats, dtype: object
```
Positions share almost identical metrics with trades:
```pycon
>>> pf.positions['a'].stats()
Start 2019-12-31 23:00:00+00:00
End 2020-02-29 23:00:00+00:00
Period 61 days 00:00:00
First Trade Start 2019-12-31 23:00:00+00:00
Last Trade End 2020-02-29 23:00:00+00:00
Coverage 60 days 00:00:00
Overlap Coverage 0 days 00:00:00
Total Records 5.0
Total Long Trades 2.0
Total Short Trades 3.0
Total Closed Trades 4.0
Total Open Trades 1.0
Open Trade PnL 38.356823
Win Rate [%] 0.0
Max Win Streak 0.0
Max Loss Streak 4.0
Best Trade [%] -1.529613
Worst Trade [%] -6.543679
Avg Winning Trade [%] NaN
Avg Losing Trade [%] -3.786739
Avg Winning Trade Duration NaT
Avg Losing Trade Duration 4 days 00:00:00
Profit Factor 0.0
Expectancy -5.446748
SQN -1.794214
Name: agg_stats, dtype: object
```
To also include open trades/positions when calculating metrics such as win rate, pass `incl_open=True`:
```pycon
>>> pf.trades['a'].stats(settings=dict(incl_open=True))
Start 2019-12-31 23:00:00+00:00
End 2020-02-29 23:00:00+00:00
Period 61 days 00:00:00
First Trade Start 2019-12-31 23:00:00+00:00
Last Trade End 2020-02-29 23:00:00+00:00
Coverage 60 days 00:00:00
Overlap Coverage 49 days 00:00:00
Total Records 19.0
Total Long Trades 2.0
Total Short Trades 17.0
Total Closed Trades 18.0
Total Open Trades 1.0
Open Trade PnL 16.063
Win Rate [%] 61.111111
Max Win Streak 12.0
Max Loss Streak 7.0
Best Trade [%] 3.526377
Worst Trade [%] -6.543679
Avg Winning Trade [%] 2.238896
Avg Losing Trade [%] -3.601313
Avg Winning Trade Duration 33 days 18:00:00
Avg Losing Trade Duration 5 days 00:00:00
Profit Factor 1.733143
Expectancy 0.872096
SQN 0.804714
Name: agg_stats, dtype: object
```
`Trades.stats` also supports (re-)grouping:
```pycon
>>> pf.trades.stats(group_by=True)
Start 2019-12-31 23:00:00+00:00
End 2020-02-29 23:00:00+00:00
Period 61 days 00:00:00
First Trade Start 2019-12-31 23:00:00+00:00
Last Trade End 2020-02-29 23:00:00+00:00
Coverage 61 days 00:00:00
Overlap Coverage 61 days 00:00:00
Total Records 37
Total Long Trades 5
Total Short Trades 32
Total Closed Trades 35
Total Open Trades 2
Open Trade PnL 1.336259
Win Rate [%] 37.142857
Max Win Streak 11
Max Loss Streak 10
Best Trade [%] 3.526377
Worst Trade [%] -8.710238
Avg Winning Trade [%] 1.907799
Avg Losing Trade [%] -3.259135
Avg Winning Trade Duration 28 days 14:46:09.230769231
Avg Losing Trade Duration 14 days 00:00:00
Profit Factor 0.340493
Expectancy -1.292596
SQN -2.509223
Name: group, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Trades.subplots`.
`Trades` class has two subplots based on `Trades.plot` and `Trades.plot_pnl`:
```pycon
>>> pf.trades['a'].plots().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.indexes import stack_indexes
from vectorbtpro.base.reshaping import to_1d_array, to_2d_array, to_pd_array, broadcast_to
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic.enums import range_dt
from vectorbtpro.generic.ranges import Ranges
from vectorbtpro.portfolio import nb
from vectorbtpro.portfolio.enums import TradeDirection, TradeStatus, trade_dt
from vectorbtpro.portfolio.orders import Orders
from vectorbtpro.records.decorators import attach_fields, override_field_config, attach_shortcut_properties
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils.array_ import min_rel_rescale, max_rel_rescale
from vectorbtpro.utils.colors import adjust_lightness
from vectorbtpro.utils.config import merge_dicts, Config, ReadonlyConfig, HybridConfig
from vectorbtpro.utils.template import Rep, RepEval, RepFunc
__all__ = [
"Trades",
"EntryTrades",
"ExitTrades",
"Positions",
]
__pdoc__ = {}
# ############# Trades ############# #
trades_field_config = ReadonlyConfig(
dict(
dtype=trade_dt,
settings={
"id": dict(title="Trade Id"),
"idx": dict(name="exit_idx"), # remap field of Records
"start_idx": dict(name="entry_idx"), # remap field of Ranges
"end_idx": dict(name="exit_idx"), # remap field of Ranges
"size": dict(title="Size"),
"entry_order_id": dict(title="Entry Order Id", mapping="ids"),
"entry_idx": dict(title="Entry Index", mapping="index"),
"entry_price": dict(title="Avg Entry Price"),
"entry_fees": dict(title="Entry Fees"),
"exit_order_id": dict(title="Exit Order Id", mapping="ids"),
"exit_idx": dict(title="Exit Index", mapping="index"),
"exit_price": dict(title="Avg Exit Price"),
"exit_fees": dict(title="Exit Fees"),
"pnl": dict(title="PnL"),
"return": dict(title="Return", hovertemplate="$title: %{customdata[$index]:,%}"),
"direction": dict(title="Direction", mapping=TradeDirection),
"status": dict(title="Status", mapping=TradeStatus),
"parent_id": dict(title="Position Id", mapping="ids"),
},
)
)
"""_"""
__pdoc__[
"trades_field_config"
] = f"""Field config for `Trades`.
```python
{trades_field_config.prettify()}
```
"""
trades_attach_field_config = ReadonlyConfig(
{
"return": dict(attach="returns"),
"direction": dict(attach_filters=True),
"status": dict(attach_filters=True, on_conflict="ignore"),
}
)
"""_"""
__pdoc__[
"trades_attach_field_config"
] = f"""Config of fields to be attached to `Trades`.
```python
{trades_attach_field_config.prettify()}
```
"""
trades_shortcut_config = ReadonlyConfig(
dict(
ranges=dict(),
long_view=dict(),
short_view=dict(),
winning=dict(),
losing=dict(),
winning_streak=dict(obj_type="mapped_array"),
losing_streak=dict(obj_type="mapped_array"),
win_rate=dict(obj_type="red_array"),
profit_factor=dict(obj_type="red_array", method_kwargs=dict(use_returns=False)),
rel_profit_factor=dict(
obj_type="red_array",
method_name="get_profit_factor",
method_kwargs=dict(use_returns=True, wrap_kwargs=dict(name_or_index="rel_profit_factor")),
),
expectancy=dict(obj_type="red_array", method_kwargs=dict(use_returns=False)),
rel_expectancy=dict(
obj_type="red_array",
method_name="get_expectancy",
method_kwargs=dict(use_returns=True, wrap_kwargs=dict(name_or_index="rel_expectancy")),
),
sqn=dict(obj_type="red_array", method_kwargs=dict(use_returns=False)),
rel_sqn=dict(
obj_type="red_array",
method_name="get_sqn",
method_kwargs=dict(use_returns=True, wrap_kwargs=dict(name_or_index="rel_sqn")),
),
best_price=dict(obj_type="mapped_array"),
worst_price=dict(obj_type="mapped_array"),
best_price_idx=dict(obj_type="mapped_array"),
worst_price_idx=dict(obj_type="mapped_array"),
expanding_best_price=dict(obj_type="array"),
expanding_worst_price=dict(obj_type="array"),
mfe=dict(obj_type="mapped_array"),
mfe_returns=dict(
obj_type="mapped_array",
method_name="get_mfe",
method_kwargs=dict(use_returns=True),
),
mae=dict(obj_type="mapped_array"),
mae_returns=dict(
obj_type="mapped_array",
method_name="get_mae",
method_kwargs=dict(use_returns=True),
),
expanding_mfe=dict(obj_type="array"),
expanding_mfe_returns=dict(
obj_type="array",
method_name="get_expanding_mfe",
method_kwargs=dict(use_returns=True),
),
expanding_mae=dict(obj_type="array"),
expanding_mae_returns=dict(
obj_type="array",
method_name="get_expanding_mae",
method_kwargs=dict(use_returns=True),
),
edge_ratio=dict(obj_type="red_array"),
running_edge_ratio=dict(obj_type="array"),
)
)
"""_"""
__pdoc__[
"trades_shortcut_config"
] = f"""Config of shortcut properties to be attached to `Trades`.
```python
{trades_shortcut_config.prettify()}
```
"""
TradesT = tp.TypeVar("TradesT", bound="Trades")
@attach_shortcut_properties(trades_shortcut_config)
@attach_fields(trades_attach_field_config)
@override_field_config(trades_field_config)
class Trades(Ranges):
"""Extends `vectorbtpro.generic.ranges.Ranges` for working with trade-like records, such as
entry trades, exit trades, and positions."""
@property
def field_config(self) -> Config:
return self._field_config
def get_ranges(self, **kwargs) -> Ranges:
"""Get records of type `vectorbtpro.generic.ranges.Ranges`."""
new_records_arr = np.empty(self.values.shape, dtype=range_dt)
new_records_arr["id"][:] = self.get_field_arr("id").copy()
new_records_arr["col"][:] = self.get_field_arr("col").copy()
new_records_arr["start_idx"][:] = self.get_field_arr("entry_idx").copy()
new_records_arr["end_idx"][:] = self.get_field_arr("exit_idx").copy()
new_records_arr["status"][:] = self.get_field_arr("status").copy()
return Ranges.from_records(
self.wrapper,
new_records_arr,
open=self._open,
high=self._high,
low=self._low,
close=self._close,
**kwargs,
)
# ############# Views ############# #
def get_long_view(self: TradesT, **kwargs) -> TradesT:
"""Get long view."""
filter_mask = self.get_field_arr("direction") == TradeDirection.Long
return self.apply_mask(filter_mask, **kwargs)
def get_short_view(self: TradesT, **kwargs) -> TradesT:
"""Get short view."""
filter_mask = self.get_field_arr("direction") == TradeDirection.Short
return self.apply_mask(filter_mask, **kwargs)
# ############# Stats ############# #
def get_winning(self: TradesT, **kwargs) -> TradesT:
"""Get winning trades."""
filter_mask = self.get_field_arr("pnl") > 0.0
return self.apply_mask(filter_mask, **kwargs)
def get_losing(self: TradesT, **kwargs) -> TradesT:
"""Get losing trades."""
filter_mask = self.get_field_arr("pnl") < 0.0
return self.apply_mask(filter_mask, **kwargs)
def get_winning_streak(self, **kwargs) -> MappedArray:
"""Get winning streak at each trade in the current column.
See `vectorbtpro.portfolio.nb.records.trade_winning_streak_nb`."""
return self.apply(nb.trade_winning_streak_nb, dtype=int_, **kwargs)
def get_losing_streak(self, **kwargs) -> MappedArray:
"""Get losing streak at each trade in the current column.
See `vectorbtpro.portfolio.nb.records.trade_losing_streak_nb`."""
return self.apply(nb.trade_losing_streak_nb, dtype=int_, **kwargs)
def get_win_rate(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get rate of winning trades."""
wrap_kwargs = merge_dicts(dict(name_or_index="win_rate"), wrap_kwargs)
return self.get_map_field("pnl").reduce(
nb.win_rate_reduce_nb,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
def get_profit_factor(
self,
use_returns: bool = False,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get profit factor."""
wrap_kwargs = merge_dicts(dict(name_or_index="profit_factor"), wrap_kwargs)
if use_returns:
mapped_arr = self.get_map_field("return")
else:
mapped_arr = self.get_map_field("pnl")
return mapped_arr.reduce(
nb.profit_factor_reduce_nb,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
def get_expectancy(
self,
use_returns: bool = False,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get average profitability."""
wrap_kwargs = merge_dicts(dict(name_or_index="expectancy"), wrap_kwargs)
if use_returns:
mapped_arr = self.get_map_field("return")
else:
mapped_arr = self.get_map_field("pnl")
return mapped_arr.reduce(
nb.expectancy_reduce_nb,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
def get_sqn(
self,
ddof: int = 1,
use_returns: bool = False,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Get System Quality Number (SQN)."""
wrap_kwargs = merge_dicts(dict(name_or_index="sqn"), wrap_kwargs)
if use_returns:
mapped_arr = self.get_map_field("return")
else:
mapped_arr = self.get_map_field("pnl")
return mapped_arr.reduce(
nb.sqn_reduce_nb,
ddof,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
def get_best_price(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
**kwargs,
) -> MappedArray:
"""Get best price.
See `vectorbtpro.portfolio.nb.records.best_price_nb`."""
return self.apply(
nb.best_price_nb,
self._open,
self._high,
self._low,
self._close,
entry_price_open,
exit_price_close,
max_duration,
**kwargs,
)
def get_worst_price(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
**kwargs,
) -> MappedArray:
"""Get worst price.
See `vectorbtpro.portfolio.nb.records.worst_price_nb`."""
return self.apply(
nb.worst_price_nb,
self._open,
self._high,
self._low,
self._close,
entry_price_open,
exit_price_close,
max_duration,
**kwargs,
)
def get_best_price_idx(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
relative: bool = True,
**kwargs,
) -> MappedArray:
"""Get (relative) index of best price.
See `vectorbtpro.portfolio.nb.records.best_price_idx_nb`."""
return self.apply(
nb.best_price_idx_nb,
self._open,
self._high,
self._low,
self._close,
entry_price_open,
exit_price_close,
max_duration,
relative,
dtype=int_,
**kwargs,
)
def get_worst_price_idx(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
relative: bool = True,
**kwargs,
) -> MappedArray:
"""Get (relative) index of worst price.
See `vectorbtpro.portfolio.nb.records.worst_price_idx_nb`."""
return self.apply(
nb.worst_price_idx_nb,
self._open,
self._high,
self._low,
self._close,
entry_price_open,
exit_price_close,
max_duration,
relative,
dtype=int_,
**kwargs,
)
def get_expanding_best_price(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
clean_index_kwargs: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get expanding best price.
See `vectorbtpro.portfolio.nb.records.expanding_best_price_nb`."""
func = jit_reg.resolve_option(nb.expanding_best_price_nb, jitted)
out = func(
self.values,
self._open,
self._high,
self._low,
self._close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
)
if clean_index_kwargs is None:
clean_index_kwargs = {}
new_columns = stack_indexes(
(
self.wrapper.columns[self.get_field_arr("col")],
pd.Index(self.get_field_arr("id"), name="id"),
),
**clean_index_kwargs,
)
if wrap_kwargs is None:
wrap_kwargs = {}
return self.wrapper.wrap(
out, group_by=False, index=pd.RangeIndex(stop=len(out)), columns=new_columns, **wrap_kwargs
)
def get_expanding_worst_price(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
clean_index_kwargs: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get expanding worst price.
See `vectorbtpro.portfolio.nb.records.expanding_worst_price_nb`."""
func = jit_reg.resolve_option(nb.expanding_worst_price_nb, jitted)
out = func(
self.values,
self._open,
self._high,
self._low,
self._close,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
)
if clean_index_kwargs is None:
clean_index_kwargs = {}
new_columns = stack_indexes(
(
self.wrapper.columns[self.get_field_arr("col")],
pd.Index(self.get_field_arr("id"), name="id"),
),
**clean_index_kwargs,
)
if wrap_kwargs is None:
wrap_kwargs = {}
return self.wrapper.wrap(
out, group_by=False, index=pd.RangeIndex(stop=len(out)), columns=new_columns, **wrap_kwargs
)
def get_mfe(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
use_returns: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArray:
"""Get MFE.
See `vectorbtpro.portfolio.nb.records.mfe_nb`."""
best_price = self.resolve_shortcut_attr(
"best_price",
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
jitted=jitted,
chunked=chunked,
)
func = jit_reg.resolve_option(nb.mfe_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mfe = func(
self.get_field_arr("size"),
self.get_field_arr("direction"),
self.get_field_arr("entry_price"),
best_price.values,
use_returns=use_returns,
)
return self.map_array(mfe, **kwargs)
def get_mae(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
use_returns: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArray:
"""Get MAE.
See `vectorbtpro.portfolio.nb.records.mae_nb`."""
worst_price = self.resolve_shortcut_attr(
"worst_price",
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
jitted=jitted,
chunked=chunked,
)
func = jit_reg.resolve_option(nb.mae_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mae = func(
self.get_field_arr("size"),
self.get_field_arr("direction"),
self.get_field_arr("entry_price"),
worst_price.values,
use_returns=use_returns,
)
return self.map_array(mae, **kwargs)
def get_expanding_mfe(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
use_returns: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> tp.SeriesFrame:
"""Get expanding MFE.
See `vectorbtpro.portfolio.nb.records.expanding_mfe_nb`."""
expanding_best_price = self.resolve_shortcut_attr(
"expanding_best_price",
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
jitted=jitted,
**kwargs,
)
func = jit_reg.resolve_option(nb.expanding_mfe_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.values,
expanding_best_price.values,
use_returns=use_returns,
)
return ArrayWrapper.from_obj(expanding_best_price).wrap(out)
def get_expanding_mae(
self,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
use_returns: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> tp.SeriesFrame:
"""Get expanding MAE.
See `vectorbtpro.portfolio.nb.records.expanding_mae_nb`."""
expanding_worst_price = self.resolve_shortcut_attr(
"expanding_worst_price",
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
jitted=jitted,
**kwargs,
)
func = jit_reg.resolve_option(nb.expanding_mae_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.values,
expanding_worst_price.values,
use_returns=use_returns,
)
return ArrayWrapper.from_obj(expanding_worst_price).wrap(out)
def get_edge_ratio(
self,
volatility: tp.Optional[tp.ArrayLike] = None,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get edge ratio.
See `vectorbtpro.portfolio.nb.records.edge_ratio_nb`.
If `volatility` is None, calculates the 14-period ATR if both high and low are provided,
otherwise the 14-period rolling standard deviation."""
if self._close is None:
raise ValueError("Must provide close")
if volatility is None:
if self._high is not None and self._low is not None:
from vectorbtpro.indicators.nb import atr_nb
from vectorbtpro.generic.enums import WType
if self._high is None or self._low is None:
raise ValueError("Must provide high and low for ATR calculation")
volatility = atr_nb(
high=to_2d_array(self._high),
low=to_2d_array(self._low),
close=to_2d_array(self._close),
window=14,
wtype=WType.Wilder,
)[1]
else:
from vectorbtpro.indicators.nb import msd_nb
from vectorbtpro.generic.enums import WType
volatility = msd_nb(
close=to_2d_array(self._close),
window=14,
wtype=WType.Wilder,
)
else:
volatility = broadcast_to(volatility, self.wrapper, to_pd=False, keep_flex=True)
col_map = self.col_mapper.get_col_map(group_by=group_by)
func = jit_reg.resolve_option(nb.edge_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.values,
col_map,
self._open,
self._high,
self._low,
self._close,
volatility,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
)
if wrap_kwargs is None:
wrap_kwargs = {}
return self.wrapper.wrap_reduced(out, group_by=group_by, **wrap_kwargs)
def get_running_edge_ratio(
self,
volatility: tp.Optional[tp.ArrayLike] = None,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
incl_shorter: bool = False,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get running edge ratio.
See `vectorbtpro.portfolio.nb.records.running_edge_ratio_nb`.
If `volatility` is None, calculates the 14-period ATR if both high and low are provided,
otherwise the 14-period rolling standard deviation."""
if self._close is None:
raise ValueError("Must provide close")
if volatility is None:
if self._high is not None and self._low is not None:
from vectorbtpro.indicators.nb import atr_nb
from vectorbtpro.generic.enums import WType
if self._high is None or self._low is None:
raise ValueError("Must provide high and low for ATR calculation")
volatility = atr_nb(
high=to_2d_array(self._high),
low=to_2d_array(self._low),
close=to_2d_array(self._close),
window=14,
wtype=WType.Wilder,
)[1]
else:
from vectorbtpro.indicators.nb import msd_nb
from vectorbtpro.generic.enums import WType
volatility = msd_nb(
close=to_2d_array(self._close),
window=14,
wtype=WType.Wilder,
)
else:
volatility = broadcast_to(volatility, self.wrapper, to_pd=False, keep_flex=True)
col_map = self.col_mapper.get_col_map(group_by=group_by)
func = jit_reg.resolve_option(nb.running_edge_ratio_nb, jitted)
out = func(
self.values,
col_map,
self._open,
self._high,
self._low,
self._close,
volatility,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
incl_shorter=incl_shorter,
)
if wrap_kwargs is None:
wrap_kwargs = {}
return self.wrapper.wrap(out, group_by=group_by, index=pd.RangeIndex(stop=len(out)), **wrap_kwargs)
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Trades.stats`.
Merges `vectorbtpro.generic.ranges.Ranges.stats_defaults` and
`stats` from `vectorbtpro._settings.trades`."""
from vectorbtpro._settings import settings
trades_stats_cfg = settings["trades"]["stats"]
return merge_dicts(Ranges.stats_defaults.__get__(self), trades_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
first_trade_start=dict(
title="First Trade Start",
calc_func="entry_idx.nth",
n=0,
wrap_kwargs=dict(to_index=True),
tags=["trades", "index"],
),
last_trade_end=dict(
title="Last Trade End",
calc_func="exit_idx.nth",
n=-1,
wrap_kwargs=dict(to_index=True),
tags=["trades", "index"],
),
coverage=dict(
title="Coverage",
calc_func="coverage",
overlapping=False,
normalize=False,
apply_to_timedelta=True,
tags=["ranges", "coverage"],
),
overlap_coverage=dict(
title="Overlap Coverage",
calc_func="coverage",
overlapping=True,
normalize=False,
apply_to_timedelta=True,
tags=["ranges", "coverage"],
),
total_records=dict(title="Total Records", calc_func="count", tags="records"),
total_long_trades=dict(
title="Total Long Trades", calc_func="direction_long.count", tags=["trades", "long"]
),
total_short_trades=dict(
title="Total Short Trades", calc_func="direction_short.count", tags=["trades", "short"]
),
total_closed_trades=dict(
title="Total Closed Trades", calc_func="status_closed.count", tags=["trades", "closed"]
),
total_open_trades=dict(title="Total Open Trades", calc_func="status_open.count", tags=["trades", "open"]),
open_trade_pnl=dict(title="Open Trade PnL", calc_func="status_open.pnl.sum", tags=["trades", "open"]),
win_rate=dict(
title="Win Rate [%]",
calc_func="status_closed.get_win_rate",
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['trades', *incl_open_tags]"),
),
winning_streak=dict(
title="Max Win Streak",
calc_func=RepEval("'winning_streak.max' if incl_open else 'status_closed.winning_streak.max'"),
wrap_kwargs=dict(dtype=pd.Int64Dtype()),
tags=RepEval("['trades', *incl_open_tags, 'streak']"),
),
losing_streak=dict(
title="Max Loss Streak",
calc_func=RepEval("'losing_streak.max' if incl_open else 'status_closed.losing_streak.max'"),
wrap_kwargs=dict(dtype=pd.Int64Dtype()),
tags=RepEval("['trades', *incl_open_tags, 'streak']"),
),
best_trade=dict(
title="Best Trade [%]",
calc_func=RepEval("'returns.max' if incl_open else 'status_closed.returns.max'"),
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['trades', *incl_open_tags]"),
),
worst_trade=dict(
title="Worst Trade [%]",
calc_func=RepEval("'returns.min' if incl_open else 'status_closed.returns.min'"),
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['trades', *incl_open_tags]"),
),
avg_winning_trade=dict(
title="Avg Winning Trade [%]",
calc_func=RepEval("'winning.returns.mean' if incl_open else 'status_closed.winning.returns.mean'"),
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['trades', *incl_open_tags, 'winning']"),
),
avg_losing_trade=dict(
title="Avg Losing Trade [%]",
calc_func=RepEval("'losing.returns.mean' if incl_open else 'status_closed.losing.returns.mean'"),
post_calc_func=lambda self, out, settings: out * 100,
tags=RepEval("['trades', *incl_open_tags, 'losing']"),
),
avg_winning_trade_duration=dict(
title="Avg Winning Trade Duration",
calc_func=RepEval("'winning.avg_duration' if incl_open else 'status_closed.winning.get_avg_duration'"),
fill_wrap_kwargs=True,
tags=RepEval("['trades', *incl_open_tags, 'winning', 'duration']"),
),
avg_losing_trade_duration=dict(
title="Avg Losing Trade Duration",
calc_func=RepEval("'losing.avg_duration' if incl_open else 'status_closed.losing.get_avg_duration'"),
fill_wrap_kwargs=True,
tags=RepEval("['trades', *incl_open_tags, 'losing', 'duration']"),
),
profit_factor=dict(
title="Profit Factor",
calc_func=RepEval("'profit_factor' if incl_open else 'status_closed.get_profit_factor'"),
tags=RepEval("['trades', *incl_open_tags]"),
),
expectancy=dict(
title="Expectancy",
calc_func=RepEval("'expectancy' if incl_open else 'status_closed.get_expectancy'"),
tags=RepEval("['trades', *incl_open_tags]"),
),
sqn=dict(
title="SQN",
calc_func=RepEval("'sqn' if incl_open else 'status_closed.get_sqn'"),
tags=RepEval("['trades', *incl_open_tags]"),
),
edge_ratio=dict(
title="Edge Ratio",
calc_func=RepEval("'edge_ratio' if incl_open else 'status_closed.get_edge_ratio'"),
tags=RepEval("['trades', *incl_open_tags]"),
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot_pnl(
self,
column: tp.Optional[tp.Label] = None,
group_by: tp.GroupByLike = False,
pct_scale: bool = False,
marker_size_range: tp.Tuple[float, float] = (7, 14),
opacity_range: tp.Tuple[float, float] = (0.75, 0.9),
closed_trace_kwargs: tp.KwargsLike = None,
closed_profit_trace_kwargs: tp.KwargsLike = None,
closed_loss_trace_kwargs: tp.KwargsLike = None,
open_trace_kwargs: tp.KwargsLike = None,
hline_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot trade PnL or returns.
Args:
column (str): Name of the column to plot.
group_by (any): Group columns. See `vectorbtpro.base.grouping.base.Grouper`.
pct_scale (bool): Whether to set y-axis to `Trades.returns`, otherwise to `Trades.pnl`.
marker_size_range (tuple): Range of marker size.
opacity_range (tuple): Range of marker opacity.
closed_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed" markers.
closed_profit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Profit" markers.
closed_loss_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Loss" markers.
open_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Open" markers.
hline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for zeroline.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
xref (str): X coordinate axis.
yref (str): Y coordinate axis.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> index = pd.date_range("2020", periods=7)
>>> price = pd.Series([1., 2., 3., 4., 3., 2., 1.], index=index)
>>> orders = pd.Series([1., -0.5, -0.5, 2., -0.5, -0.5, -0.5], index=index)
>>> pf = vbt.Portfolio.from_orders(price, orders)
>>> pf.trades.plot_pnl().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure, get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=group_by)
if closed_trace_kwargs is None:
closed_trace_kwargs = {}
if closed_profit_trace_kwargs is None:
closed_profit_trace_kwargs = {}
if closed_loss_trace_kwargs is None:
closed_loss_trace_kwargs = {}
if open_trace_kwargs is None:
open_trace_kwargs = {}
if hline_shape_kwargs is None:
hline_shape_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
marker_size_range = tuple(marker_size_range)
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
if fig is None:
fig = make_figure()
def_layout_kwargs = {xaxis: {}, yaxis: {}}
if pct_scale:
def_layout_kwargs[yaxis]["tickformat"] = ".2%"
def_layout_kwargs[yaxis]["title"] = "Return"
else:
def_layout_kwargs[yaxis]["title"] = "PnL"
fig.update_layout(**def_layout_kwargs)
fig.update_layout(**layout_kwargs)
x_domain = get_domain(xref, fig)
y_domain = get_domain(yref, fig)
if self_col.count() > 0:
# Extract information
exit_idx = self_col.get_map_field_to_index("exit_idx")
pnl = self_col.get_field_arr("pnl")
returns = self_col.get_field_arr("return")
status = self_col.get_field_arr("status")
valid_mask = ~np.isnan(returns)
neutral_mask = (pnl == 0) & valid_mask
profit_mask = (pnl > 0) & valid_mask
loss_mask = (pnl < 0) & valid_mask
marker_size = min_rel_rescale(np.abs(returns), marker_size_range)
opacity = max_rel_rescale(np.abs(returns), opacity_range)
open_mask = status == TradeStatus.Open
closed_profit_mask = (~open_mask) & profit_mask
closed_loss_mask = (~open_mask) & loss_mask
open_mask &= ~neutral_mask
def _plot_scatter(mask, name, color, kwargs):
if np.any(mask):
if self_col.get_field_setting("parent_id", "ignore", False):
customdata, hovertemplate = self_col.prepare_customdata(
incl_fields=["id", "exit_idx", "pnl", "return"], mask=mask
)
else:
customdata, hovertemplate = self_col.prepare_customdata(
incl_fields=["id", "parent_id", "exit_idx", "pnl", "return"], mask=mask
)
_kwargs = merge_dicts(
dict(
x=exit_idx[mask],
y=returns[mask] if pct_scale else pnl[mask],
mode="markers",
marker=dict(
symbol="circle",
color=color,
size=marker_size[mask],
opacity=opacity[mask],
line=dict(width=1, color=adjust_lightness(color)),
),
name=name,
customdata=customdata,
hovertemplate=hovertemplate,
),
kwargs,
)
scatter = go.Scatter(**_kwargs)
fig.add_trace(scatter, **add_trace_kwargs)
# Plot Closed - Neutral scatter
_plot_scatter(neutral_mask, "Closed", plotting_cfg["contrast_color_schema"]["gray"], closed_trace_kwargs)
# Plot Closed - Profit scatter
_plot_scatter(
closed_profit_mask,
"Closed - Profit",
plotting_cfg["contrast_color_schema"]["green"],
closed_profit_trace_kwargs,
)
# Plot Closed - Profit scatter
_plot_scatter(
closed_loss_mask,
"Closed - Loss",
plotting_cfg["contrast_color_schema"]["red"],
closed_loss_trace_kwargs,
)
# Plot Open scatter
_plot_scatter(open_mask, "Open", plotting_cfg["contrast_color_schema"]["orange"], open_trace_kwargs)
# Plot zeroline
fig.add_shape(
**merge_dicts(
dict(
type="line",
xref="paper",
yref=yref,
x0=x_domain[0],
y0=0,
x1=x_domain[1],
y1=0,
line=dict(
color="gray",
dash="dash",
),
),
hline_shape_kwargs,
)
)
return fig
def plot_returns(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_pnl` for `Trades.returns`."""
return self.plot_pnl(
*args,
pct_scale=True,
**kwargs,
)
def plot_against_pnl(
self,
field: tp.Union[str, tp.Array1d, MappedArray],
field_label: tp.Optional[str] = None,
column: tp.Optional[tp.Label] = None,
group_by: tp.GroupByLike = False,
pct_scale: bool = False,
field_pct_scale: bool = False,
closed_trace_kwargs: tp.KwargsLike = None,
closed_profit_trace_kwargs: tp.KwargsLike = None,
closed_loss_trace_kwargs: tp.KwargsLike = None,
open_trace_kwargs: tp.KwargsLike = None,
hline_shape_kwargs: tp.KwargsLike = None,
vline_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot a field against PnL or returns.
Args:
field (str, MappedArray, or array_like): Field to be plotted.
Can be also provided as a mapped array or 1-dim array.
field_label (str): Label of the field.
column (str): Name of the column to plot.
group_by (any): Group columns. See `vectorbtpro.base.grouping.base.Grouper`.
pct_scale (bool): Whether to set x-axis to `Trades.returns`, otherwise to `Trades.pnl`.
field_pct_scale (bool): Whether to make y-axis a percentage scale.
closed_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed" markers.
closed_profit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Profit" markers.
closed_loss_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Closed - Loss" markers.
open_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Open" markers.
hline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for horizontal zeroline.
vline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for vertical zeroline.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
xref (str): X coordinate axis.
yref (str): Y coordinate axis.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> index = pd.date_range("2020", periods=10)
>>> price = pd.Series([1., 2., 3., 4., 5., 6., 5., 3., 2., 1.], index=index)
>>> orders = pd.Series([1., -0.5, 0., -0.5, 2., 0., -0.5, -0.5, 0., -0.5], index=index)
>>> pf = vbt.Portfolio.from_orders(price, orders)
>>> trades = pf.trades
>>> trades.plot_against_pnl("MFE").show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure, get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=group_by)
if closed_trace_kwargs is None:
closed_trace_kwargs = {}
if closed_profit_trace_kwargs is None:
closed_profit_trace_kwargs = {}
if closed_loss_trace_kwargs is None:
closed_loss_trace_kwargs = {}
if open_trace_kwargs is None:
open_trace_kwargs = {}
if hline_shape_kwargs is None:
hline_shape_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
if isinstance(field, str):
if field_label is None:
field_label = field
field = getattr(self_col, field.lower())
if isinstance(field, MappedArray):
field = field.values
if field_label is None:
field_label = "Field"
if fig is None:
fig = make_figure()
def_layout_kwargs = {xaxis: {}, yaxis: {}}
if pct_scale:
def_layout_kwargs[xaxis]["tickformat"] = ".2%"
def_layout_kwargs[xaxis]["title"] = "Return"
else:
def_layout_kwargs[xaxis]["title"] = "PnL"
if field_pct_scale:
def_layout_kwargs[yaxis]["tickformat"] = ".2%"
def_layout_kwargs[yaxis]["title"] = field_label
fig.update_layout(**def_layout_kwargs)
fig.update_layout(**layout_kwargs)
x_domain = get_domain(xref, fig)
y_domain = get_domain(yref, fig)
if self_col.count() > 0:
# Extract information
pnl = self_col.get_field_arr("pnl")
returns = self_col.get_field_arr("return")
status = self_col.get_field_arr("status")
valid_mask = ~np.isnan(returns)
neutral_mask = (pnl == 0) & valid_mask
profit_mask = (pnl > 0) & valid_mask
loss_mask = (pnl < 0) & valid_mask
open_mask = status == TradeStatus.Open
closed_profit_mask = (~open_mask) & profit_mask
closed_loss_mask = (~open_mask) & loss_mask
open_mask &= ~neutral_mask
def _plot_scatter(mask, name, color, kwargs):
if np.any(mask):
if self_col.get_field_setting("parent_id", "ignore", False):
customdata, hovertemplate = self_col.prepare_customdata(
incl_fields=["id", "exit_idx", "pnl", "return"], mask=mask
)
else:
customdata, hovertemplate = self_col.prepare_customdata(
incl_fields=["id", "parent_id", "exit_idx", "pnl", "return"], mask=mask
)
_kwargs = merge_dicts(
dict(
x=returns[mask] if pct_scale else pnl[mask],
y=field[mask],
mode="markers",
marker=dict(
symbol="circle",
color=color,
size=7,
line=dict(width=1, color=adjust_lightness(color)),
),
name=name,
customdata=customdata,
hovertemplate=hovertemplate,
),
kwargs,
)
scatter = go.Scatter(**_kwargs)
fig.add_trace(scatter, **add_trace_kwargs)
# Plot Closed - Neutral scatter
_plot_scatter(neutral_mask, "Closed", plotting_cfg["contrast_color_schema"]["gray"], closed_trace_kwargs)
# Plot Closed - Profit scatter
_plot_scatter(
closed_profit_mask,
"Closed - Profit",
plotting_cfg["contrast_color_schema"]["green"],
closed_profit_trace_kwargs,
)
# Plot Closed - Profit scatter
_plot_scatter(
closed_loss_mask,
"Closed - Loss",
plotting_cfg["contrast_color_schema"]["red"],
closed_loss_trace_kwargs,
)
# Plot Open scatter
_plot_scatter(open_mask, "Open", plotting_cfg["contrast_color_schema"]["orange"], open_trace_kwargs)
# Plot zerolines
fig.add_shape(
**merge_dicts(
dict(
type="line",
xref="paper",
yref=yref,
x0=x_domain[0],
y0=0,
x1=x_domain[1],
y1=0,
line=dict(
color="gray",
dash="dash",
),
),
hline_shape_kwargs,
)
)
fig.add_shape(
**merge_dicts(
dict(
type="line",
xref=xref,
yref="paper",
x0=0,
y0=y_domain[0],
x1=0,
y1=y_domain[1],
line=dict(
color="gray",
dash="dash",
),
),
vline_shape_kwargs,
)
)
return fig
def plot_mfe(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_against_pnl` for `Trades.mfe`."""
return self.plot_against_pnl(
*args,
field="mfe",
field_label="MFE",
**kwargs,
)
def plot_mfe_returns(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_against_pnl` for `Trades.mfe_returns`."""
return self.plot_against_pnl(
*args,
field="mfe_returns",
field_label="MFE Return",
pct_scale=True,
field_pct_scale=True,
**kwargs,
)
def plot_mae(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_against_pnl` for `Trades.mae`."""
return self.plot_against_pnl(
*args,
field="mae",
field_label="MAE",
**kwargs,
)
def plot_mae_returns(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_against_pnl` for `Trades.mae_returns`."""
return self.plot_against_pnl(
*args,
field="mae_returns",
field_label="MAE Return",
pct_scale=True,
field_pct_scale=True,
**kwargs,
)
def plot_expanding(
self,
field: tp.Union[str, tp.Array1d, MappedArray],
field_label: tp.Optional[str] = None,
column: tp.Optional[tp.Label] = None,
group_by: tp.GroupByLike = False,
plot_bands: bool = False,
colorize: tp.Union[bool, str, tp.Callable] = "last",
field_pct_scale: bool = False,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot projections of an expanding field.
Args:
field (str or array_like): Field to be plotted.
Can be also provided as a 2-dim array.
field_label (str): Label of the field.
column (str): Name of the column to plot. Optional.
group_by (any): Group columns. See `vectorbtpro.base.grouping.base.Grouper`.
plot_bands (bool): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
colorize (bool, str or callable): See `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
field_pct_scale (bool): Whether to make y-axis a percentage scale.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**kwargs: Keyword arguments passed to `vectorbtpro.generic.accessors.GenericDFAccessor.plot_projections`.
Usage:
```pycon
>>> index = pd.date_range("2020", periods=10)
>>> price = pd.Series([1., 2., 3., 2., 4., 5., 6., 5., 6., 7.], index=index)
>>> orders = pd.Series([1., 0., 0., -2., 0., 0., 2., 0., 0., -1.], index=index)
>>> pf = vbt.Portfolio.from_orders(price, orders)
>>> pf.trades.plot_expanding("MFE").show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
if column is not None:
self_col = self.select_col(column=column, group_by=group_by)
else:
self_col = self
if isinstance(field, str):
if field_label is None:
field_label = field
if not field.lower().startswith("expanding_"):
field = "expanding_" + field
field = getattr(self_col, field.lower())
if isinstance(field, MappedArray):
field = field.values
if field_label is None:
field_label = "Field"
field = to_pd_array(field)
fig = field.vbt.plot_projections(
add_trace_kwargs=add_trace_kwargs,
fig=fig,
plot_bands=plot_bands,
colorize=colorize,
**kwargs,
)
xref = fig.data[-1]["xaxis"] if fig.data[-1]["xaxis"] is not None else "x"
yref = fig.data[-1]["yaxis"] if fig.data[-1]["yaxis"] is not None else "y"
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
if field_label is not None and "title" not in kwargs.get(yaxis, {}):
fig.update_layout(**{yaxis: dict(title=field_label)})
if field_pct_scale and "tickformat" not in kwargs.get(yaxis, {}):
fig.update_layout(**{yaxis: dict(tickformat=".2%")})
return fig
def plot_expanding_mfe(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_expanding` for `Trades.expanding_mfe`."""
return self.plot_expanding(
*args,
field="expanding_mfe",
field_label="MFE",
**kwargs,
)
def plot_expanding_mfe_returns(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_expanding` for `Trades.expanding_mfe_returns`."""
return self.plot_expanding(
*args,
field="expanding_mfe_returns",
field_label="MFE Return",
field_pct_scale=True,
**kwargs,
)
def plot_expanding_mae(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_expanding` for `Trades.expanding_mae`."""
return self.plot_expanding(
*args,
field="expanding_mae",
field_label="MAE",
**kwargs,
)
def plot_expanding_mae_returns(self, *args, **kwargs) -> tp.BaseFigure:
"""`Trades.plot_expanding` for `Trades.expanding_mae_returns`."""
return self.plot_expanding(
*args,
field="expanding_mae_returns",
field_label="MAE Return",
field_pct_scale=True,
**kwargs,
)
def plot_running_edge_ratio(
self,
column: tp.Optional[tp.Label] = None,
volatility: tp.Optional[tp.ArrayLike] = None,
entry_price_open: bool = False,
exit_price_close: bool = False,
max_duration: tp.Optional[int] = None,
incl_shorter: bool = False,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
xref: str = "x",
yref: str = "y",
hline_shape_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
"""Plot one column/group of edge ratio.
`**kwargs` are passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot_against`."""
from vectorbtpro.utils.figure import get_domain
running_edge_ratio = self.resolve_shortcut_attr(
"running_edge_ratio",
volatility=volatility,
entry_price_open=entry_price_open,
exit_price_close=exit_price_close,
max_duration=max_duration,
incl_shorter=incl_shorter,
group_by=group_by,
jitted=jitted,
)
running_edge_ratio = self.select_col_from_obj(
running_edge_ratio, column, wrapper=self.wrapper.regroup(group_by)
)
kwargs = merge_dicts(
dict(
trace_kwargs=dict(name="Edge Ratio"),
other_trace_kwargs="hidden",
),
kwargs,
)
fig = running_edge_ratio.vbt.plot_against(1, **kwargs)
x_domain = get_domain(xref, fig)
fig.add_shape(
**merge_dicts(
dict(
type="line",
line=dict(
color="gray",
dash="dash",
),
xref="paper",
yref=yref,
x0=x_domain[0],
y0=1.0,
x1=x_domain[1],
y1=1.0,
),
hline_shape_kwargs,
)
)
return fig
def plot(
self,
column: tp.Optional[tp.Label] = None,
plot_ohlc: bool = True,
plot_close: bool = True,
plot_markers: bool = True,
plot_zones: bool = True,
plot_by_type: bool = True,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
close_trace_kwargs: tp.KwargsLike = None,
entry_trace_kwargs: tp.KwargsLike = None,
exit_trace_kwargs: tp.KwargsLike = None,
exit_profit_trace_kwargs: tp.KwargsLike = None,
exit_loss_trace_kwargs: tp.KwargsLike = None,
active_trace_kwargs: tp.KwargsLike = None,
profit_shape_kwargs: tp.KwargsLike = None,
loss_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot trades.
Args:
column (str): Name of the column to plot.
plot_ohlc (bool): Whether to plot OHLC.
plot_close (bool): Whether to plot close.
plot_markers (bool): Whether to plot markers.
plot_zones (bool): Whether to plot zones.
plot_by_type (bool): Whether to plot exit trades by type.
Otherwise, the appearance will be controlled using `exit_trace_kwargs`.
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `Trades.close`.
entry_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Entry" markers.
exit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Exit" markers.
exit_profit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Exit - Profit" markers.
exit_loss_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Exit - Loss" markers.
active_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Active" markers.
profit_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for profit zones.
loss_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for loss zones.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
xref (str): X coordinate axis.
yref (str): Y coordinate axis.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> index = pd.date_range("2020", periods=7)
>>> price = pd.Series([1., 2., 3., 4., 3., 2., 1.], index=index)
>>> size = pd.Series([1., -0.5, -0.5, 2., -0.5, -0.5, -0.5], index=index)
>>> pf = vbt.Portfolio.from_orders(price, size)
>>> pf.trades.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
if ohlc_trace_kwargs is None:
ohlc_trace_kwargs = {}
if close_trace_kwargs is None:
close_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"),
close_trace_kwargs,
)
if entry_trace_kwargs is None:
entry_trace_kwargs = {}
if exit_trace_kwargs is None:
exit_trace_kwargs = {}
if exit_profit_trace_kwargs is None:
exit_profit_trace_kwargs = {}
if exit_loss_trace_kwargs is None:
exit_loss_trace_kwargs = {}
if active_trace_kwargs is None:
active_trace_kwargs = {}
if profit_shape_kwargs is None:
profit_shape_kwargs = {}
if loss_shape_kwargs is None:
loss_shape_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
# Plot close
if (
plot_ohlc
and self_col._open is not None
and self_col._high is not None
and self_col._low is not None
and self_col._close is not None
):
ohlc_df = pd.DataFrame(
{
"open": self_col.open,
"high": self_col.high,
"low": self_col.low,
"close": self_col.close,
}
)
if "opacity" not in ohlc_trace_kwargs:
ohlc_trace_kwargs["opacity"] = 0.5
fig = ohlc_df.vbt.ohlcv.plot(
ohlc_type=ohlc_type,
plot_volume=False,
ohlc_trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
elif plot_close and self_col._close is not None:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if self_col.count() > 0:
# Extract information
entry_idx = self_col.get_map_field_to_index("entry_idx", minus_one_to_zero=True)
entry_price = self_col.get_field_arr("entry_price")
exit_idx = self_col.get_map_field_to_index("exit_idx")
exit_price = self_col.get_field_arr("exit_price")
pnl = self_col.get_field_arr("pnl")
status = self_col.get_field_arr("status")
duration = to_1d_array(
self_col.wrapper.arr_to_timedelta(
self_col.duration.values,
to_pd=True,
silence_warnings=True,
).astype(str)
)
if plot_markers:
# Plot Entry markers
if self_col.get_field_setting("parent_id", "ignore", False):
entry_customdata, entry_hovertemplate = self_col.prepare_customdata(
incl_fields=[
"id",
"entry_order_id",
"entry_idx",
"size",
"entry_price",
"entry_fees",
"direction",
]
)
else:
entry_customdata, entry_hovertemplate = self_col.prepare_customdata(
incl_fields=[
"id",
"entry_order_id",
"parent_id",
"entry_idx",
"size",
"entry_price",
"entry_fees",
"direction",
]
)
_entry_trace_kwargs = merge_dicts(
dict(
x=entry_idx,
y=entry_price,
mode="markers",
marker=dict(
symbol="square",
color=plotting_cfg["contrast_color_schema"]["blue"],
size=7,
line=dict(width=1, color=adjust_lightness(plotting_cfg["contrast_color_schema"]["blue"])),
),
name="Entry",
customdata=entry_customdata,
hovertemplate=entry_hovertemplate,
),
entry_trace_kwargs,
)
entry_scatter = go.Scatter(**_entry_trace_kwargs)
fig.add_trace(entry_scatter, **add_trace_kwargs)
# Plot end markers
def _plot_end_markers(mask, name, color, kwargs, incl_status=False) -> None:
if np.any(mask):
if self_col.get_field_setting("parent_id", "ignore", False):
exit_customdata, exit_hovertemplate = self_col.prepare_customdata(
incl_fields=[
"id",
"exit_order_id",
"exit_idx",
"size",
"exit_price",
"exit_fees",
"pnl",
"return",
"direction",
*(("status",) if incl_status else ()),
],
append_info=[(duration, "Duration")],
mask=mask,
)
else:
exit_customdata, exit_hovertemplate = self_col.prepare_customdata(
incl_fields=[
"id",
"exit_order_id",
"parent_id",
"exit_idx",
"size",
"exit_price",
"exit_fees",
"pnl",
"return",
"direction",
*(("status",) if incl_status else ()),
],
append_info=[(duration, "Duration")],
mask=mask,
)
_kwargs = merge_dicts(
dict(
x=exit_idx[mask],
y=exit_price[mask],
mode="markers",
marker=dict(
symbol="square",
color=color,
size=7,
line=dict(width=1, color=adjust_lightness(color)),
),
name=name,
customdata=exit_customdata,
hovertemplate=exit_hovertemplate,
),
kwargs,
)
scatter = go.Scatter(**_kwargs)
fig.add_trace(scatter, **add_trace_kwargs)
if plot_by_type:
# Plot Exit markers
_plot_end_markers(
(status == TradeStatus.Closed) & (pnl == 0.0),
"Exit",
plotting_cfg["contrast_color_schema"]["gray"],
exit_trace_kwargs,
)
# Plot Exit - Profit markers
_plot_end_markers(
(status == TradeStatus.Closed) & (pnl > 0.0),
"Exit - Profit",
plotting_cfg["contrast_color_schema"]["green"],
exit_profit_trace_kwargs,
)
# Plot Exit - Loss markers
_plot_end_markers(
(status == TradeStatus.Closed) & (pnl < 0.0),
"Exit - Loss",
plotting_cfg["contrast_color_schema"]["red"],
exit_loss_trace_kwargs,
)
# Plot Active markers
_plot_end_markers(
status == TradeStatus.Open,
"Active",
plotting_cfg["contrast_color_schema"]["orange"],
active_trace_kwargs,
)
else:
# Plot Exit markers
_plot_end_markers(
np.full(len(status), True),
"Exit",
plotting_cfg["contrast_color_schema"]["pink"],
exit_trace_kwargs,
incl_status=True,
)
if plot_zones:
# Plot profit zones
self_col.winning.plot_shapes(
plot_ohlc=False,
plot_close=False,
shape_kwargs=merge_dicts(
dict(
yref=Rep("yref"),
y0=RepFunc(lambda record: record["entry_price"]),
y1=RepFunc(lambda record: record["exit_price"]),
fillcolor=plotting_cfg["contrast_color_schema"]["green"],
),
profit_shape_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
# Plot loss zones
self_col.losing.plot_shapes(
plot_ohlc=False,
plot_close=False,
shape_kwargs=merge_dicts(
dict(
yref=Rep("yref"),
y0=RepFunc(lambda record: record["entry_price"]),
y1=RepFunc(lambda record: record["exit_price"]),
fillcolor=plotting_cfg["contrast_color_schema"]["red"],
),
loss_shape_kwargs,
),
add_trace_kwargs=add_trace_kwargs,
xref=xref,
yref=yref,
fig=fig,
)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Trades.plots`.
Merges `vectorbtpro.generic.ranges.Ranges.plots_defaults` and
`plots` from `vectorbtpro._settings.trades`."""
from vectorbtpro._settings import settings
trades_plots_cfg = settings["trades"]["plots"]
return merge_dicts(Ranges.plots_defaults.__get__(self), trades_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot=dict(
title="Trades",
yaxis_kwargs=dict(title="Price"),
check_is_not_grouped=True,
plot_func="plot",
tags="trades",
),
plot_pnl=dict(
title="Trade PnL",
yaxis_kwargs=dict(title="Trade PnL"),
check_is_not_grouped=True,
plot_func="plot_pnl",
tags="trades",
),
)
)
@property
def subplots(self) -> Config:
return self._subplots
Trades.override_field_config_doc(__pdoc__)
Trades.override_metrics_doc(__pdoc__)
Trades.override_subplots_doc(__pdoc__)
# ############# EntryTrades ############# #
entry_trades_field_config = ReadonlyConfig(
dict(settings={"id": dict(title="Entry Trade Id"), "idx": dict(name="entry_idx")}) # remap field of Records,
)
"""_"""
__pdoc__[
"entry_trades_field_config"
] = f"""Field config for `EntryTrades`.
```python
{entry_trades_field_config.prettify()}
```
"""
EntryTradesT = tp.TypeVar("EntryTradesT", bound="EntryTrades")
@override_field_config(entry_trades_field_config)
class EntryTrades(Trades):
"""Extends `Trades` for working with entry trade records."""
@property
def field_config(self) -> Config:
return self._field_config
@classmethod
def from_orders(
cls: tp.Type[EntryTradesT],
orders: Orders,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
close: tp.Optional[tp.ArrayLike] = None,
init_position: tp.ArrayLike = 0.0,
init_price: tp.ArrayLike = np.nan,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> EntryTradesT:
"""Build `EntryTrades` from `vectorbtpro.portfolio.orders.Orders`."""
if open is None:
open = orders._open
if high is None:
high = orders._high
if low is None:
low = orders._low
if close is None:
close = orders._close
func = jit_reg.resolve_option(nb.get_entry_trades_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
trade_records_arr = func(
orders.values,
to_2d_array(orders.wrapper.wrap(close, group_by=False)),
orders.col_mapper.col_map,
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
sim_start=None if sim_start is None else to_1d_array(sim_start),
sim_end=None if sim_end is None else to_1d_array(sim_end),
)
return cls.from_records(
orders.wrapper,
trade_records_arr,
open=open,
high=high,
low=low,
close=close,
**kwargs,
)
def plot_signals(
self,
column: tp.Optional[tp.Label] = None,
plot_ohlc: bool = True,
plot_close: bool = True,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
close_trace_kwargs: tp.KwargsLike = None,
long_entry_trace_kwargs: tp.KwargsLike = None,
short_entry_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot entry trade signals.
Args:
column (str): Name of the column to plot.
plot_ohlc (bool): Whether to plot OHLC.
plot_close (bool): Whether to plot close.
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `EntryTrades.close`.
long_entry_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Long Entry" markers.
short_entry_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Short Entry" markers.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> index = pd.date_range("2020", periods=7)
>>> price = pd.Series([1, 2, 3, 2, 3, 4, 3], index=index)
>>> orders = pd.Series([1, 0, -1, 0, -1, 2, -2], index=index)
>>> pf = vbt.Portfolio.from_orders(price, orders)
>>> pf.entry_trades.plot_signals().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
if ohlc_trace_kwargs is None:
ohlc_trace_kwargs = {}
if close_trace_kwargs is None:
close_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"),
close_trace_kwargs,
)
if long_entry_trace_kwargs is None:
long_entry_trace_kwargs = {}
if short_entry_trace_kwargs is None:
short_entry_trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
# Plot close
if (
plot_ohlc
and self_col._open is not None
and self_col._high is not None
and self_col._low is not None
and self_col._close is not None
):
ohlc_df = pd.DataFrame(
{
"open": self_col.open,
"high": self_col.high,
"low": self_col.low,
"close": self_col.close,
}
)
if "opacity" not in ohlc_trace_kwargs:
ohlc_trace_kwargs["opacity"] = 0.5
fig = ohlc_df.vbt.ohlcv.plot(
ohlc_type=ohlc_type,
plot_volume=False,
ohlc_trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
elif plot_close and self_col._close is not None:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if self_col.count() > 0:
# Extract information
entry_idx = self_col.get_map_field_to_index("entry_idx", minus_one_to_zero=True)
entry_price = self_col.get_field_arr("entry_price")
direction = self_col.get_field_arr("direction")
def _plot_entry_markers(mask, name, color, kwargs):
if np.any(mask):
entry_customdata, entry_hovertemplate = self_col.prepare_customdata(
incl_fields=[
"id",
"entry_order_id",
"parent_id",
"entry_idx",
"size",
"entry_price",
"entry_fees",
"pnl",
"return",
"status",
],
mask=mask,
)
_kwargs = merge_dicts(
dict(
x=entry_idx[mask],
y=entry_price[mask],
mode="markers",
marker=dict(
symbol="circle",
color="rgba(0, 0, 0, 0)",
size=15,
line=dict(
color=color,
width=2,
),
),
name=name,
customdata=entry_customdata,
hovertemplate=entry_hovertemplate,
),
kwargs,
)
scatter = go.Scatter(**_kwargs)
fig.add_trace(scatter, **add_trace_kwargs)
# Plot Long Entry markers
_plot_entry_markers(
direction == TradeDirection.Long,
"Long Entry",
plotting_cfg["contrast_color_schema"]["green"],
long_entry_trace_kwargs,
)
# Plot Short Entry markers
_plot_entry_markers(
direction == TradeDirection.Short,
"Short Entry",
plotting_cfg["contrast_color_schema"]["red"],
short_entry_trace_kwargs,
)
return fig
EntryTrades.override_field_config_doc(__pdoc__)
# ############# ExitTrades ############# #
exit_trades_field_config = ReadonlyConfig(dict(settings={"id": dict(title="Exit Trade Id")}))
"""_"""
__pdoc__[
"exit_trades_field_config"
] = f"""Field config for `ExitTrades`.
```python
{exit_trades_field_config.prettify()}
```
"""
ExitTradesT = tp.TypeVar("ExitTradesT", bound="ExitTrades")
@override_field_config(exit_trades_field_config)
class ExitTrades(Trades):
"""Extends `Trades` for working with exit trade records."""
@property
def field_config(self) -> Config:
return self._field_config
@classmethod
def from_orders(
cls: tp.Type[ExitTradesT],
orders: Orders,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
close: tp.Optional[tp.ArrayLike] = None,
init_position: tp.ArrayLike = 0.0,
init_price: tp.ArrayLike = np.nan,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> ExitTradesT:
"""Build `ExitTrades` from `vectorbtpro.portfolio.orders.Orders`."""
if open is None:
open = orders._open
if high is None:
high = orders._high
if low is None:
low = orders._low
if close is None:
close = orders._close
func = jit_reg.resolve_option(nb.get_exit_trades_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
trade_records_arr = func(
orders.values,
to_2d_array(orders.wrapper.wrap(close, group_by=False)),
orders.col_mapper.col_map,
init_position=to_1d_array(init_position),
init_price=to_1d_array(init_price),
sim_start=None if sim_start is None else to_1d_array(sim_start),
sim_end=None if sim_end is None else to_1d_array(sim_end),
)
return cls.from_records(
orders.wrapper,
trade_records_arr,
open=open,
high=high,
low=low,
close=close,
**kwargs,
)
def plot_signals(
self,
column: tp.Optional[tp.Label] = None,
plot_ohlc: bool = True,
plot_close: bool = True,
ohlc_type: tp.Union[None, str, tp.BaseTraceType] = None,
ohlc_trace_kwargs: tp.KwargsLike = None,
close_trace_kwargs: tp.KwargsLike = None,
long_exit_trace_kwargs: tp.KwargsLike = None,
short_exit_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot exit trade signals.
Args:
column (str): Name of the column to plot.
plot_ohlc (bool): Whether to plot OHLC.
plot_close (bool): Whether to plot close.
ohlc_type: Either 'OHLC', 'Candlestick' or Plotly trace.
Pass None to use the default.
ohlc_trace_kwargs (dict): Keyword arguments passed to `ohlc_type`.
close_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for `ExitTrades.close`.
long_exit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Long Exit" markers.
short_exit_trace_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Scatter` for "Short Exit" markers.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> index = pd.date_range("2020", periods=7)
>>> price = pd.Series([1, 2, 3, 2, 3, 4, 3], index=index)
>>> orders = pd.Series([1, 0, -1, 0, -1, 2, -2], index=index)
>>> pf = vbt.Portfolio.from_orders(price, orders)
>>> pf.exit_trades.plot_signals().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import plotly.graph_objects as go
from vectorbtpro.utils.figure import make_figure
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
self_col = self.select_col(column=column, group_by=False)
if ohlc_trace_kwargs is None:
ohlc_trace_kwargs = {}
if close_trace_kwargs is None:
close_trace_kwargs = {}
close_trace_kwargs = merge_dicts(
dict(line=dict(color=plotting_cfg["color_schema"]["blue"]), name="Close"),
close_trace_kwargs,
)
if long_exit_trace_kwargs is None:
long_exit_trace_kwargs = {}
if short_exit_trace_kwargs is None:
short_exit_trace_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
if fig is None:
fig = make_figure()
fig.update_layout(**layout_kwargs)
# Plot close
if (
plot_ohlc
and self_col._open is not None
and self_col._high is not None
and self_col._low is not None
and self_col._close is not None
):
ohlc_df = pd.DataFrame(
{
"open": self_col.open,
"high": self_col.high,
"low": self_col.low,
"close": self_col.close,
}
)
if "opacity" not in ohlc_trace_kwargs:
ohlc_trace_kwargs["opacity"] = 0.5
fig = ohlc_df.vbt.ohlcv.plot(
ohlc_type=ohlc_type,
plot_volume=False,
ohlc_trace_kwargs=ohlc_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
elif plot_close and self_col._close is not None:
fig = self_col.close.vbt.lineplot(
trace_kwargs=close_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
)
if self_col.count() > 0:
# Extract information
exit_idx = self_col.get_map_field_to_index("exit_idx", minus_one_to_zero=True)
exit_price = self_col.get_field_arr("exit_price")
direction = self_col.get_field_arr("direction")
def _plot_exit_markers(mask, name, color, kwargs):
if np.any(mask):
exit_customdata, exit_hovertemplate = self_col.prepare_customdata(
incl_fields=[
"id",
"exit_order_id",
"parent_id",
"exit_idx",
"size",
"exit_price",
"exit_fees",
"pnl",
"return",
"status",
],
mask=mask,
)
_kwargs = merge_dicts(
dict(
x=exit_idx[mask],
y=exit_price[mask],
mode="markers",
marker=dict(
symbol="circle",
color=color,
size=8,
),
name=name,
customdata=exit_customdata,
hovertemplate=exit_hovertemplate,
),
kwargs,
)
scatter = go.Scatter(**_kwargs)
fig.add_trace(scatter, **add_trace_kwargs)
# Plot Long Exit markers
_plot_exit_markers(
direction == TradeDirection.Long,
"Long Exit",
plotting_cfg["contrast_color_schema"]["green"],
long_exit_trace_kwargs,
)
# Plot Short Exit markers
_plot_exit_markers(
direction == TradeDirection.Short,
"Short Exit",
plotting_cfg["contrast_color_schema"]["red"],
short_exit_trace_kwargs,
)
return fig
ExitTrades.override_field_config_doc(__pdoc__)
# ############# Positions ############# #
positions_field_config = ReadonlyConfig(
dict(settings={"id": dict(title="Position Id"), "parent_id": dict(title="Parent Id", ignore=True)}),
)
"""_"""
__pdoc__[
"positions_field_config"
] = f"""Field config for `Positions`.
```python
{positions_field_config.prettify()}
```
"""
PositionsT = tp.TypeVar("PositionsT", bound="Positions")
@override_field_config(positions_field_config)
class Positions(Trades):
"""Extends `Trades` for working with position records."""
@property
def field_config(self) -> Config:
return self._field_config
@classmethod
def from_trades(
cls: tp.Type[PositionsT],
trades: Trades,
open: tp.Optional[tp.ArrayLike] = None,
high: tp.Optional[tp.ArrayLike] = None,
low: tp.Optional[tp.ArrayLike] = None,
close: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> PositionsT:
"""Build `Positions` from `Trades`."""
if open is None:
open = trades._open
if high is None:
high = trades._high
if low is None:
low = trades._low
if close is None:
close = trades._close
func = jit_reg.resolve_option(nb.get_positions_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
position_records_arr = func(trades.values, trades.col_mapper.col_map)
return cls.from_records(
trades.wrapper,
position_records_arr,
open=open,
high=high,
low=low,
close=close,
**kwargs,
)
Positions.override_field_config_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for plotting with Plotly Express."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.px.accessors import *
from vectorbtpro.px.decorators import *
__import_if_installed__ = dict()
__import_if_installed__["accessors"] = "plotly"
__import_if_installed__["decorators"] = "plotly"
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Pandas accessors for Plotly Express.
!!! note
Accessors do not utilize caching."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.base.accessors import BaseAccessor, BaseDFAccessor, BaseSRAccessor
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.accessors import register_vbt_accessor, register_df_vbt_accessor, register_sr_vbt_accessor
from vectorbtpro.px.decorators import attach_px_methods
__all__ = [
"PXAccessor",
"PXSRAccessor",
"PXDFAccessor",
]
@register_vbt_accessor("px")
@attach_px_methods
class PXAccessor(BaseAccessor):
"""Accessor for running Plotly Express functions.
Accessible via `pd.Series.vbt.px` and `pd.DataFrame.vbt.px`.
Usage:
```pycon
>>> from vectorbtpro import *
>>> pd.Series([1, 2, 3]).vbt.px.bar().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
**kwargs,
) -> None:
BaseAccessor.__init__(self, wrapper, obj=obj, **kwargs)
@register_sr_vbt_accessor("px")
class PXSRAccessor(PXAccessor, BaseSRAccessor):
"""Accessor for running Plotly Express functions. For Series only.
Accessible via `pd.Series.vbt.px`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
BaseSRAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs)
if _full_init:
PXAccessor.__init__(self, wrapper, obj=obj, **kwargs)
@register_df_vbt_accessor("px")
class PXDFAccessor(PXAccessor, BaseDFAccessor):
"""Accessor for running Plotly Express functions. For DataFrames only.
Accessible via `pd.DataFrame.vbt.px`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
BaseDFAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs)
if _full_init:
PXAccessor.__init__(self, wrapper, obj=obj, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class decorators for Plotly Express accessors."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("plotly")
from inspect import getmembers, isfunction
import pandas as pd
import plotly.express as px
from vectorbtpro import _typing as tp
from vectorbtpro.base.reshaping import to_2d_array
from vectorbtpro.generic.plotting import clean_labels
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.figure import make_figure
def attach_px_methods(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
"""Class decorator to attach Plotly Express methods."""
for px_func_name, px_func in getmembers(px, isfunction):
if checks.func_accepts_arg(px_func, "data_frame") or px_func_name == "imshow":
def plot_method(
self,
*args,
_px_func_name: str = px_func_name,
_px_func: tp.Callable = px_func,
layout: tp.KwargsLike = None,
**kwargs,
) -> tp.BaseFigure:
from vectorbtpro._settings import settings
layout_cfg = settings["plotting"]["layout"]
layout_kwargs = dict(
template=kwargs.pop("template", layout_cfg["template"]),
width=kwargs.pop("width", layout_cfg["width"]),
height=kwargs.pop("height", layout_cfg["height"]),
)
layout = merge_dicts(layout_kwargs, layout)
# Fix category_orders
if "color" in kwargs:
if isinstance(kwargs["color"], str):
if isinstance(self.obj, pd.DataFrame):
if kwargs["color"] in self.obj.columns:
category_orders = dict()
category_orders[kwargs["color"]] = sorted(self.obj[kwargs["color"]].unique())
kwargs = merge_dicts(dict(category_orders=category_orders), kwargs)
# Fix Series name
obj = self.obj.copy(deep=False)
if isinstance(obj, pd.Series):
if obj.name is not None:
obj = obj.rename(str(obj.name))
else:
obj.columns = clean_labels(obj.columns)
obj.index = clean_labels(obj.index)
if _px_func_name == "imshow":
return make_figure(_px_func(to_2d_array(obj), *args, **layout_kwargs, **kwargs), layout=layout)
return make_figure(_px_func(obj, *args, **layout_kwargs, **kwargs), layout=layout)
plot_method.__name__ = px_func_name
plot_method.__module__ = cls.__module__
plot_method.__qualname__ = f"{cls.__name__}.{plot_method.__name__}"
plot_method.__doc__ = f"""Plot using `{px_func.__module__ + '.' + px_func.__name__}`."""
setattr(cls, px_func_name, plot_method)
return cls
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for working with records.
Records are the second form of data representation in vectorbtpro. They allow storing sparse event data
such as drawdowns, orders, trades, and positions, without converting them back to the matrix form and
occupying the user's memory."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.records.base import *
from vectorbtpro.records.chunking import *
from vectorbtpro.records.col_mapper import *
from vectorbtpro.records.mapped_array import *
from vectorbtpro.records.nb import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with records.
vectorbt works with two different representations of data: matrices and records.
A matrix, in this context, is just an array of one-dimensional arrays, each corresponding
to a separate feature. The matrix itself holds only one kind of information (one attribute).
For example, one can create a matrix for entry signals, with columns being different strategy
configurations. But what if the matrix is huge and sparse? What if there is more
information we would like to represent by each element? Creating multiple matrices would be
a waste of memory.
Records make possible representing complex, sparse information in a dense format. They are just
an array of one-dimensional arrays of a fixed schema, where each element holds a different
kind of information. You can imagine records being a DataFrame, where each row represents a record
and each column represents a specific attribute. Read more on structured arrays
[here](https://numpy.org/doc/stable/user/basics.rec.html).
For example, let's represent two DataFrames as a single record array:
```plaintext
a b
0 1.0 5.0
attr1 = 1 2.0 NaN
2 NaN 7.0
3 4.0 8.0
a b
0 9.0 13.0
attr2 = 1 10.0 NaN
2 NaN 15.0
3 12.0 16.0
|
v
id col idx attr1 attr2
0 0 0 0 1 9
1 1 0 1 2 10
2 2 0 3 4 12
3 0 1 0 5 13
4 1 1 2 7 15
5 2 1 3 8 16
```
Another advantage of records is that they are not constrained by size. Multiple records can map
to a single element in a matrix. For example, one can define multiple orders at the same timestamp,
which is impossible to represent in a matrix form without duplicating index entries or using complex data types.
Consider the following example:
```pycon
>>> from vectorbtpro import *
>>> example_dt = np.dtype([
... ('id', int_),
... ('col', int_),
... ('idx', int_),
... ('some_field', float_)
... ])
>>> records_arr = np.array([
... (0, 0, 0, 10.),
... (1, 0, 1, 11.),
... (2, 0, 2, 12.),
... (0, 1, 0, 13.),
... (1, 1, 1, 14.),
... (2, 1, 2, 15.),
... (0, 2, 0, 16.),
... (1, 2, 1, 17.),
... (2, 2, 2, 18.)
... ], dtype=example_dt)
>>> wrapper = vbt.ArrayWrapper(index=['x', 'y', 'z'],
... columns=['a', 'b', 'c'], ndim=2, freq='1 day')
>>> records = vbt.Records(wrapper, records_arr)
```
## Printing
There are two ways to print records:
* Raw dataframe that preserves field names and data types:
```pycon
>>> records.records
id col idx some_field
0 0 0 0 10.0
1 1 0 1 11.0
2 2 0 2 12.0
3 0 1 0 13.0
4 1 1 1 14.0
5 2 1 2 15.0
6 0 2 0 16.0
7 1 2 1 17.0
8 2 2 2 18.0
```
* Readable dataframe that takes into consideration `Records.field_config`:
```pycon
>>> records.readable
Id Column Timestamp some_field
0 0 a x 10.0
1 1 a y 11.0
2 2 a z 12.0
3 0 b x 13.0
4 1 b y 14.0
5 2 b z 15.0
6 0 c x 16.0
7 1 c y 17.0
8 2 c z 18.0
```
## Mapping
`Records` are just [structured arrays](https://numpy.org/doc/stable/user/basics.rec.html) with a bunch
of methods and properties for processing them. Their main feature is to map the records array and
to reduce it by column (similar to the MapReduce paradigm). The main advantage is that it all happens
without conversion to the matrix form and wasting memory resources.
`Records` can be mapped to `vectorbtpro.records.mapped_array.MappedArray` in several ways:
* Use `Records.map_field` to map a record field:
```pycon
>>> records.map_field('some_field')
>>> records.map_field('some_field').values
array([10., 11., 12., 13., 14., 15., 16., 17., 18.])
```
* Use `Records.map` to map records using a custom function.
```pycon
>>> @njit
... def power_map_nb(record, pow):
... return record.some_field ** pow
>>> records.map(power_map_nb, 2)
>>> records.map(power_map_nb, 2).values
array([100., 121., 144., 169., 196., 225., 256., 289., 324.])
>>> # Map using a meta function
>>> @njit
... def power_map_meta_nb(ridx, records, pow):
... return records[ridx].some_field ** pow
>>> vbt.Records.map(power_map_meta_nb, records.values, 2, col_mapper=records.col_mapper).values
array([100., 121., 144., 169., 196., 225., 256., 289., 324.])
```
* Use `Records.map_array` to convert an array to `vectorbtpro.records.mapped_array.MappedArray`.
```pycon
>>> records.map_array(records_arr['some_field'] ** 2)
>>> records.map_array(records_arr['some_field'] ** 2).values
array([100., 121., 144., 169., 196., 225., 256., 289., 324.])
```
* Use `Records.apply` to apply a function on each column/group:
```pycon
>>> @njit
... def cumsum_apply_nb(records):
... return np.cumsum(records.some_field)
>>> records.apply(cumsum_apply_nb)
>>> records.apply(cumsum_apply_nb).values
array([10., 21., 33., 13., 27., 42., 16., 33., 51.])
>>> group_by = np.array(['first', 'first', 'second'])
>>> records.apply(cumsum_apply_nb, group_by=group_by, apply_per_group=True).values
array([10., 21., 33., 46., 60., 75., 16., 33., 51.])
>>> # Apply using a meta function
>>> @njit
... def cumsum_apply_meta_nb(idxs, col, records):
... return np.cumsum(records[idxs].some_field)
>>> vbt.Records.apply(cumsum_apply_meta_nb, records.values, col_mapper=records.col_mapper).values
array([10., 21., 33., 13., 27., 42., 16., 33., 51.])
```
Notice how cumsum resets at each column in the first example and at each group in the second example.
## Filtering
Use `Records.apply_mask` to filter elements per column/group:
```pycon
>>> mask = [True, False, True, False, True, False, True, False, True]
>>> filtered_records = records.apply_mask(mask)
>>> filtered_records.records
id col idx some_field
0 0 0 0 10.0
1 2 0 2 12.0
2 1 1 1 14.0
3 0 2 0 16.0
4 2 2 2 18.0
```
## Grouping
One of the key features of `Records` is that you can perform reducing operations on a group
of columns as if they were a single column. Groups can be specified by `group_by`, which
can be anything from positions or names of column levels, to a NumPy array with actual groups.
There are multiple ways of define grouping:
* When creating `Records`, pass `group_by` to `vectorbtpro.base.wrapping.ArrayWrapper`:
```pycon
>>> group_by = np.array(['first', 'first', 'second'])
>>> grouped_wrapper = wrapper.replace(group_by=group_by)
>>> grouped_records = vbt.Records(grouped_wrapper, records_arr)
>>> grouped_records.map_field('some_field').mean()
first 12.5
second 17.0
dtype: float64
```
* Regroup an existing `Records`:
```pycon
>>> records.regroup(group_by).map_field('some_field').mean()
first 12.5
second 17.0
dtype: float64
```
* Pass `group_by` directly to the mapping method:
```pycon
>>> records.map_field('some_field', group_by=group_by).mean()
first 12.5
second 17.0
dtype: float64
```
* Pass `group_by` directly to the reducing method:
```pycon
>>> records.map_field('some_field').mean(group_by=group_by)
a 11.0
b 14.0
c 17.0
dtype: float64
```
!!! note
Grouping applies only to reducing operations, there is no change to the arrays.
## Indexing
Like any other class subclassing `vectorbtpro.base.wrapping.Wrapping`, we can do pandas indexing
on a `Records` instance, which forwards indexing operation to each object with columns:
```pycon
>>> records['a'].records
id col idx some_field
0 0 0 0 10.0
1 1 0 1 11.0
2 2 0 2 12.0
>>> grouped_records['first'].records
id col idx some_field
0 0 0 0 10.0
1 1 0 1 11.0
2 2 0 2 12.0
3 0 1 0 13.0
4 1 1 1 14.0
5 2 1 2 15.0
```
!!! note
Changing index (time axis) is not supported. The object should be treated as a Series
rather than a DataFrame; for example, use `some_field.iloc[0]` instead of `some_field.iloc[:, 0]`
to get the first column.
Indexing behavior depends solely upon `vectorbtpro.base.wrapping.ArrayWrapper`.
For example, if `group_select` is enabled indexing will be performed on groups when grouped,
otherwise on single columns.
## Caching
`Records` supports caching. If a method or a property requires heavy computation, it's wrapped
with `vectorbtpro.utils.decorators.cached_method` and `vectorbtpro.utils.decorators.cached_property`
respectively. Caching can be disabled globally via `vectorbtpro._settings.caching`.
!!! note
Because of caching, class is meant to be immutable and all properties are read-only.
To change any attribute, use the `Records.replace` method and pass changes as keyword arguments.
## Saving and loading
Like any other class subclassing `vectorbtpro.utils.pickling.Pickleable`, we can save a `Records`
instance to the disk with `Records.save` and load it with `Records.load`.
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `Records.metrics`.
```pycon
>>> records.stats(column='a')
Start x
End z
Period 3 days 00:00:00
Total Records 3
Name: a, dtype: object
```
`Records.stats` also supports (re-)grouping:
```pycon
>>> grouped_records.stats(column='first')
Start x
End z
Period 3 days 00:00:00
Total Records 6
Name: first, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `Records.subplots`.
This class is too generic to have any subplots, but feel free to add custom subplots to your subclass.
## Extending
`Records` class can be extended by subclassing.
In case some of our fields have the same meaning but different naming (such as the base field `idx`)
or other properties, we can override `field_config` using `vectorbtpro.records.decorators.override_field_config`.
It will look for configs of all base classes and merge our config on top of them. This preserves
any base class property that is not explicitly listed in our config.
```pycon
>>> from vectorbtpro.records.decorators import override_field_config
>>> my_dt = np.dtype([
... ('my_id', int_),
... ('my_col', int_),
... ('my_idx', int_)
... ])
>>> my_fields_config = dict(
... dtype=my_dt,
... settings=dict(
... id=dict(name='my_id'),
... col=dict(name='my_col'),
... idx=dict(name='my_idx')
... )
... )
>>> @override_field_config(my_fields_config)
... class MyRecords(vbt.Records):
... pass
>>> records_arr = np.array([
... (0, 0, 0),
... (1, 0, 1),
... (0, 1, 0),
... (1, 1, 1)
... ], dtype=my_dt)
>>> wrapper = vbt.ArrayWrapper(index=['x', 'y'],
... columns=['a', 'b'], ndim=2, freq='1 day')
>>> my_records = MyRecords(wrapper, records_arr)
>>> my_records.id_arr
array([0, 1, 0, 1])
>>> my_records.col_arr
array([0, 0, 1, 1])
>>> my_records.idx_arr
array([0, 1, 0, 1])
```
Alternatively, we can override the `_field_config` class attribute.
```pycon
>>> @override_field_config
... class MyRecords(vbt.Records):
... _field_config = dict(
... dtype=my_dt,
... settings=dict(
... id=dict(name='my_id'),
... idx=dict(name='my_idx'),
... col=dict(name='my_col')
... )
... )
```
!!! note
Don't forget to decorate the class with `@override_field_config` to inherit configs from base classes.
You can stop inheritance by not decorating or passing `merge_configs=False` to the decorator.
"""
import inspect
import string
from collections import defaultdict
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.reshaping import to_1d_array, index_to_series, index_to_frame
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.records import nb
from vectorbtpro.records.col_mapper import ColumnMapper
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import get_dict_attr
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig
from vectorbtpro.utils.decorators import cached_method, hybrid_method
from vectorbtpro.utils.random_ import set_seed_nb
from vectorbtpro.utils.template import Sub
__all__ = [
"Records",
]
__pdoc__ = {}
RecordsT = tp.TypeVar("RecordsT", bound="Records")
class MetaRecords(type(Analyzable)):
"""Metaclass for `Records`."""
@property
def field_config(cls) -> Config:
"""Field config."""
return cls._field_config
class Records(Analyzable, metaclass=MetaRecords):
"""Wraps the actual records array (such as trades) and exposes methods for mapping
it to some array of values (such as PnL of each trade).
Args:
wrapper (ArrayWrapper): Array wrapper.
See `vectorbtpro.base.wrapping.ArrayWrapper`.
records_arr (array_like): A structured NumPy array of records.
Must have the fields `id` (record index) and `col` (column index).
col_mapper (ColumnMapper): Column mapper if already known.
!!! note
It depends on `records_arr`, so make sure to invalidate `col_mapper` upon creating
a `Records` instance with a modified `records_arr`.
`Records.replace` does it automatically.
**kwargs: Custom keyword arguments passed to the config.
Useful if any subclass wants to extend the config.
"""
_writeable_attrs: tp.WriteableAttrs = {"_field_config"}
_field_config: tp.ClassVar[Config] = HybridConfig(
dict(
dtype=None,
settings=dict(
id=dict(name="id", title="Id", mapping="ids"),
col=dict(name="col", title="Column", mapping="columns", as_customdata=False),
idx=dict(name="idx", title="Index", mapping="index"),
),
)
)
@property
def field_config(self) -> Config:
"""Field config of `${cls_name}`.
```python
${field_config}
```
Returns `${cls_name}._field_config`, which gets (hybrid-) copied upon creation of each instance.
Thus, changing this config won't affect the class.
To change fields, you can either change the config in-place, override this property,
or overwrite the instance variable `${cls_name}._field_config`.
"""
return self._field_config
@classmethod
def row_stack_records_arrs(cls, *objs: tp.MaybeTuple[tp.RecordArray], **kwargs) -> tp.RecordArray:
"""Stack multiple record arrays along rows."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
records_arrs = []
for col in range(kwargs["wrapper"].shape_2d[1]):
n_rows_sum = 0
from_id = defaultdict(int)
for obj in objs:
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
col_records = None
set_columns = False
if col > 0 and obj.wrapper.shape_2d[1] == 1:
col_records = obj.records_arr[col_idxs]
set_columns = True
elif col_lens[col] > 0:
col_end_idxs = np.cumsum(col_lens)
col_start_idxs = col_end_idxs - col_lens
col_records = obj.records_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]]
if col_records is not None:
col_records = col_records.copy()
for field in obj.values.dtype.names:
field_mapping = cls.field_config.get("settings", {}).get(field, {}).get("mapping", None)
if isinstance(field_mapping, str) and field_mapping == "columns" and set_columns:
col_records[field][:] = col
elif isinstance(field_mapping, str) and field_mapping == "index":
col_records[field][:] += n_rows_sum
elif isinstance(field_mapping, str) and field_mapping == "ids":
col_records[field][:] += from_id[field]
from_id[field] = col_records[field].max() + 1
records_arrs.append(col_records)
n_rows_sum += obj.wrapper.shape_2d[0]
if len(records_arrs) == 0:
return np.array([], dtype=objs[0].values.dtype)
return np.concatenate(records_arrs)
@classmethod
def get_row_stack_record_indices(cls, *objs: tp.MaybeTuple[tp.RecordArray], **kwargs) -> tp.Array1d:
"""Get the indices that map concatenated record arrays into the row-stacked record array."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
record_indices = []
cum_n_rows_sum = []
for i in range(len(objs)):
if i == 0:
cum_n_rows_sum.append(0)
else:
cum_n_rows_sum.append(cum_n_rows_sum[-1] + len(objs[i - 1].values))
for col in range(kwargs["wrapper"].shape_2d[1]):
for i, obj in enumerate(objs):
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
if col > 0 and obj.wrapper.shape_2d[1] == 1:
_record_indices = col_idxs + cum_n_rows_sum[i]
record_indices.append(_record_indices)
elif col_lens[col] > 0:
col_end_idxs = np.cumsum(col_lens)
col_start_idxs = col_end_idxs - col_lens
_record_indices = col_idxs[col_start_idxs[col] : col_end_idxs[col]] + cum_n_rows_sum[i]
record_indices.append(_record_indices)
if len(record_indices) == 0:
return np.array([], dtype=int_)
return np.concatenate(record_indices)
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[RecordsT],
*objs: tp.MaybeTuple[RecordsT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> RecordsT:
"""Stack multiple `Records` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers
and `Records.row_stack_records_arrs` to stack the record arrays.
!!! note
Will produce a column-sorted array."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Records):
raise TypeError("Each object to be merged must be an instance of Records")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
if "col_mapper" not in kwargs:
kwargs["col_mapper"] = ColumnMapper.row_stack(
*[obj.col_mapper for obj in objs],
wrapper=kwargs["wrapper"],
)
if "records_arr" not in kwargs:
kwargs["records_arr"] = cls.row_stack_records_arrs(*objs, **kwargs)
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
@classmethod
def column_stack_records_arrs(
cls,
*objs: tp.MaybeTuple[tp.RecordArray],
get_indexer_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.RecordArray:
"""Stack multiple record arrays along columns."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
if get_indexer_kwargs is None:
get_indexer_kwargs = {}
records_arrs = []
col_sum = 0
for obj in objs:
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
col_end_idxs = np.cumsum(col_lens)
col_start_idxs = col_end_idxs - col_lens
for col in range(len(col_lens)):
if col_lens[col] > 0:
col_records = obj.records_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]]
col_records = col_records.copy()
for field in obj.values.dtype.names:
field_mapping = cls.field_config.get("settings", {}).get(field, {}).get("mapping", None)
if isinstance(field_mapping, str) and field_mapping == "columns":
col_records[field][:] += col_sum
elif isinstance(field_mapping, str) and field_mapping == "index":
old_idxs = col_records[field]
if not obj.wrapper.index.equals(kwargs["wrapper"].index):
new_idxs = kwargs["wrapper"].index.get_indexer(
obj.wrapper.index[old_idxs],
**get_indexer_kwargs,
)
else:
new_idxs = old_idxs
col_records[field][:] = new_idxs
records_arrs.append(col_records)
col_sum += obj.wrapper.shape_2d[1]
if len(records_arrs) == 0:
return np.array([], dtype=objs[0].values.dtype)
return np.concatenate(records_arrs)
@classmethod
def get_column_stack_record_indices(cls, *objs: tp.MaybeTuple[tp.RecordArray], **kwargs) -> tp.Array1d:
"""Get the indices that map concatenated record arrays into the column-stacked record array."""
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
record_indices = []
cum_n_rows_sum = []
for i in range(len(objs)):
if i == 0:
cum_n_rows_sum.append(0)
else:
cum_n_rows_sum.append(cum_n_rows_sum[-1] + len(objs[i - 1].values))
for i, obj in enumerate(objs):
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
col_end_idxs = np.cumsum(col_lens)
col_start_idxs = col_end_idxs - col_lens
for col in range(len(col_lens)):
if col_lens[col] > 0:
_record_indices = col_idxs[col_start_idxs[col] : col_end_idxs[col]] + cum_n_rows_sum[i]
record_indices.append(_record_indices)
if len(record_indices) == 0:
return np.array([], dtype=int_)
return np.concatenate(record_indices)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[RecordsT],
*objs: tp.MaybeTuple[RecordsT],
wrapper_kwargs: tp.KwargsLike = None,
get_indexer_kwargs: tp.KwargsLike = None,
**kwargs,
) -> RecordsT:
"""Stack multiple `Records` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers
and `Records.column_stack_records_arrs` to stack the record arrays.
`get_indexer_kwargs` are passed to
[pandas.Index.get_indexer](https://pandas.pydata.org/docs/reference/api/pandas.Index.get_indexer.html)
to translate old indices to new ones after the reindexing operation.
!!! note
Will produce a column-sorted array."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, Records):
raise TypeError("Each object to be merged must be an instance of Records")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in objs],
**wrapper_kwargs,
)
if "col_mapper" not in kwargs:
kwargs["col_mapper"] = ColumnMapper.column_stack(
*[obj.col_mapper for obj in objs],
wrapper=kwargs["wrapper"],
)
if "records_arr" not in kwargs:
kwargs["records_arr"] = cls.column_stack_records_arrs(
*objs,
get_indexer_kwargs=get_indexer_kwargs,
**kwargs,
)
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
def __init__(
self,
wrapper: ArrayWrapper,
records_arr: tp.RecordArray,
col_mapper: tp.Optional[ColumnMapper] = None,
**kwargs,
) -> None:
# Check fields
records_arr = np.asarray(records_arr)
checks.assert_not_none(records_arr.dtype.fields)
field_names = {dct.get("name", field_name) for field_name, dct in self.field_config.get("settings", {}).items()}
dtype = self.field_config.get("dtype", None)
if dtype is not None:
for field in dtype.names:
if field not in records_arr.dtype.names:
if field not in field_names:
raise TypeError(f"Field '{field}' from {dtype} cannot be found in records or config")
if col_mapper is None:
col_mapper = ColumnMapper(wrapper, records_arr[self.get_field_name("col")])
Analyzable.__init__(self, wrapper, records_arr=records_arr, col_mapper=col_mapper, **kwargs)
self._records_arr = records_arr
self._col_mapper = col_mapper
# Only slices of rows can be selected
self._range_only_select = True
# Copy writeable attrs
self._field_config = type(self)._field_config.copy()
def replace(self: RecordsT, **kwargs) -> RecordsT:
"""See `vectorbtpro.utils.config.Configured.replace`.
Also, makes sure that `Records.col_mapper` is not passed to the new instance."""
if self.config.get("col_mapper", None) is not None:
if "wrapper" in kwargs:
if self.wrapper is not kwargs.get("wrapper"):
kwargs["col_mapper"] = None
if "records_arr" in kwargs:
if self.records_arr is not kwargs.get("records_arr"):
kwargs["col_mapper"] = None
return Analyzable.replace(self, **kwargs)
def select_cols(
self,
col_idxs: tp.MaybeIndexArray,
jitted: tp.JittedOption = None,
) -> tp.Tuple[tp.Array1d, tp.RecordArray]:
"""Select columns.
Returns indices and new record array. Automatically decides whether to use column lengths or column map."""
if len(self.values) == 0:
return np.arange(len(self.values)), self.values
if isinstance(col_idxs, slice):
if col_idxs.start is None and col_idxs.stop is None:
return np.arange(len(self.values)), self.values
col_idxs = np.arange(col_idxs.start, col_idxs.stop)
if self.col_mapper.is_sorted():
func = jit_reg.resolve_option(nb.record_col_lens_select_nb, jitted)
new_indices, new_records_arr = func(self.values, self.col_mapper.col_lens, to_1d_array(col_idxs)) # faster
else:
func = jit_reg.resolve_option(nb.record_col_map_select_nb, jitted)
new_indices, new_records_arr = func(
self.values, self.col_mapper.col_map, to_1d_array(col_idxs)
) # more flexible
return new_indices, new_records_arr
def indexing_func_meta(self, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict:
"""Perform indexing on `Records` and return metadata.
By default, all fields that are mapped to index are indexed.
To avoid indexing on some fields, set their setting `noindex` to True."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(
*args,
column_only_select=self.column_only_select,
range_only_select=self.range_only_select,
group_select=self.group_select,
**kwargs,
)
if self.get_field_setting("col", "group_indexing", False):
new_indices, new_records_arr = self.select_cols(wrapper_meta["group_idxs"])
else:
new_indices, new_records_arr = self.select_cols(wrapper_meta["col_idxs"])
if wrapper_meta["rows_changed"]:
row_idxs = wrapper_meta["row_idxs"]
index_fields = []
all_index_fields = []
for field in new_records_arr.dtype.names:
field_mapping = self.get_field_mapping(field)
noindex = self.get_field_setting(field, "noindex", False)
if isinstance(field_mapping, str) and field_mapping == "index":
all_index_fields.append(field)
if not noindex:
index_fields.append(field)
if len(index_fields) > 0:
masks = []
for field in index_fields:
field_arr = new_records_arr[field]
masks.append((field_arr >= row_idxs.start) & (field_arr < row_idxs.stop))
mask = np.array(masks).all(axis=0)
new_indices = new_indices[mask]
new_records_arr = new_records_arr[mask]
for field in all_index_fields:
new_records_arr[field] = new_records_arr[field] - row_idxs.start
return dict(
wrapper_meta=wrapper_meta,
new_indices=new_indices,
new_records_arr=new_records_arr,
)
def indexing_func(self: RecordsT, *args, records_meta: tp.DictLike = None, **kwargs) -> RecordsT:
"""Perform indexing on `Records`."""
if records_meta is None:
records_meta = self.indexing_func_meta(*args, **kwargs)
return self.replace(
wrapper=records_meta["wrapper_meta"]["new_wrapper"],
records_arr=records_meta["new_records_arr"],
)
def resample_records_arr(self, resampler: tp.Union[Resampler, tp.PandasResampler]) -> tp.RecordArray:
"""Perform resampling on the record array."""
if isinstance(resampler, Resampler):
_resampler = resampler
else:
_resampler = Resampler.from_pd_resampler(resampler)
new_records_arr = self.records_arr.copy()
for field_name in self.values.dtype.names:
field_mapping = self.get_field_mapping(field_name)
if isinstance(field_mapping, str) and field_mapping == "index":
index_map = _resampler.map_to_target_index(return_index=False)
new_records_arr[field_name] = index_map[new_records_arr[field_name]]
return new_records_arr
def resample_meta(self: RecordsT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict:
"""Perform resampling on `Records` and return metadata."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.resample_meta(*args, **kwargs)
new_records_arr = self.resample_records_arr(wrapper_meta["resampler"])
return dict(wrapper_meta=wrapper_meta, new_records_arr=new_records_arr)
def resample(self: RecordsT, *args, records_meta: tp.DictLike = None, **kwargs) -> RecordsT:
"""Perform resampling on `Records`."""
if records_meta is None:
records_meta = self.resample_meta(*args, **kwargs)
return self.replace(
wrapper=records_meta["wrapper_meta"]["new_wrapper"],
records_arr=records_meta["new_records_arr"],
)
@property
def records_arr(self) -> tp.RecordArray:
"""Records array."""
return self._records_arr
@property
def values(self) -> tp.RecordArray:
"""Records array."""
return self.records_arr
def __len__(self) -> int:
return len(self.values)
@property
def records(self) -> tp.Frame:
"""Records."""
return pd.DataFrame.from_records(self.values)
@property
def recarray(self) -> tp.RecArray:
"""Records with field access using attributes."""
return self.values.view(np.recarray)
@property
def col_mapper(self) -> ColumnMapper:
"""Column mapper.
See `vectorbtpro.records.col_mapper.ColumnMapper`."""
return self._col_mapper
@property
def field_names(self) -> tp.List[str]:
"""Field names."""
return list(self.values.dtype.fields.keys())
def to_readable(self, expand_columns: bool = False) -> tp.Frame:
"""Get records in a human-readable format."""
new_columns = list()
field_settings = self.field_config.get("settings", {})
for field_name in self.field_names:
if field_name in field_settings:
dct = field_settings[field_name]
if dct.get("ignore", False):
continue
field_name = dct.get("name", field_name)
if "title" in dct:
title = dct["title"]
else:
title = field_name
if "mapping" in dct:
if isinstance(dct["mapping"], str) and dct["mapping"] == "index":
new_columns.append(pd.Series(self.get_map_field_to_index(field_name), name=title))
elif isinstance(dct["mapping"], str) and dct["mapping"] == "columns":
column_index = self.get_map_field_to_columns(field_name)
if expand_columns and isinstance(column_index, pd.MultiIndex):
column_frame = index_to_frame(column_index, reset_index=True)
new_columns.append(column_frame.add_prefix(f"{title}: "))
else:
column_sr = index_to_series(column_index, reset_index=True)
if expand_columns and self.wrapper.ndim == 2 and column_sr.name is not None:
new_columns.append(column_sr.rename(f"{title}: {column_sr.name}"))
else:
new_columns.append(column_sr.rename(title))
else:
new_columns.append(pd.Series(self.get_apply_mapping_arr(field_name), name=title))
else:
new_columns.append(pd.Series(self.values[field_name], name=title))
else:
new_columns.append(pd.Series(self.values[field_name], name=field_name))
records_readable = pd.concat(new_columns, axis=1)
if all([isinstance(col, tuple) for col in records_readable.columns]):
records_readable.columns = pd.MultiIndex.from_tuples(records_readable.columns)
return records_readable
@property
def records_readable(self) -> tp.Frame:
"""`Records.to_readable` with default arguments."""
return self.to_readable()
readable = records_readable
def get_field_setting(self, field: str, setting: str, default: tp.Any = None) -> tp.Any:
"""Get any setting of the field. Uses `Records.field_config`."""
return self.field_config.get("settings", {}).get(field, {}).get(setting, default)
def get_field_name(self, field: str) -> str:
"""Get the name of the field. Uses `Records.field_config`.."""
return self.get_field_setting(field, "name", field)
def get_field_title(self, field: str) -> str:
"""Get the title of the field. Uses `Records.field_config`."""
return self.get_field_setting(field, "title", field)
def get_field_mapping(self, field: str) -> tp.Optional[tp.MappingLike]:
"""Get the mapping of the field. Uses `Records.field_config`."""
return self.get_field_setting(field, "mapping", None)
def get_field_arr(self, field: str, copy: bool = False) -> tp.Array1d:
"""Get the array of the field. Uses `Records.field_config`."""
out = self.values[self.get_field_name(field)]
if copy:
out = out.copy()
return out
def get_map_field(self, field: str, **kwargs) -> MappedArray:
"""Get the mapped array of the field. Uses `Records.field_config`."""
mapping = self.get_field_mapping(field)
if isinstance(mapping, str) and mapping == "ids":
mapping = None
return self.map_field(self.get_field_name(field), mapping=mapping, **kwargs)
def get_map_field_to_index(self, field: str, minus_one_to_zero: bool = False, **kwargs) -> tp.Index:
"""Get the mapped array on the field, with index applied. Uses `Records.field_config`."""
return self.get_map_field(field, **kwargs).to_index(minus_one_to_zero=minus_one_to_zero)
def get_map_field_to_columns(self, field: str, **kwargs) -> tp.Index:
"""Get the mapped array on the field, with columns applied. Uses `Records.field_config`."""
return self.get_map_field(field, **kwargs).to_columns()
def get_apply_mapping_arr(self, field: str, mapping_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Array1d:
"""Get the mapped array on the field, with mapping applied. Uses `Records.field_config`."""
mapping = self.get_field_mapping(field)
if isinstance(mapping, str) and mapping == "index":
return self.get_map_field_to_index(field, **kwargs).values
if isinstance(mapping, str) and mapping == "columns":
return self.get_map_field_to_columns(field, **kwargs).values
return self.get_map_field(field, **kwargs).apply_mapping(mapping_kwargs=mapping_kwargs).values
def get_apply_mapping_str_arr(self, field: str, mapping_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Array1d:
"""Get the mapped array on the field, with mapping applied and stringified. Uses `Records.field_config`."""
mapping = self.get_field_mapping(field)
if isinstance(mapping, str) and mapping == "index":
return self.get_map_field_to_index(field, **kwargs).astype(str).values
if isinstance(mapping, str) and mapping == "columns":
return self.get_map_field_to_columns(field, **kwargs).astype(str).values
return self.get_map_field(field, **kwargs).apply_mapping(mapping_kwargs=mapping_kwargs).values.astype(str)
@property
def id_arr(self) -> tp.Array1d:
"""Get id array."""
return self.values[self.get_field_name("id")]
@property
def col_arr(self) -> tp.Array1d:
"""Get column array."""
return self.values[self.get_field_name("col")]
@property
def idx_arr(self) -> tp.Optional[tp.Array1d]:
"""Get index array."""
idx_field_name = self.get_field_name("idx")
if idx_field_name is None:
return None
return self.values[idx_field_name]
# ############# Sorting ############# #
@cached_method
def is_sorted(self, incl_id: bool = False, jitted: tp.JittedOption = None) -> bool:
"""Check whether records are sorted."""
if incl_id:
func = jit_reg.resolve_option(nb.is_col_id_sorted_nb, jitted)
return func(self.col_arr, self.id_arr)
func = jit_reg.resolve_option(nb.is_col_sorted_nb, jitted)
return func(self.col_arr)
def sort(self: RecordsT, incl_id: bool = False, group_by: tp.GroupByLike = None, **kwargs) -> RecordsT:
"""Sort records by columns (primary) and ids (secondary, optional).
!!! note
Sorting is expensive. A better approach is to append records already in the correct order."""
if self.is_sorted(incl_id=incl_id):
return self.replace(**kwargs).regroup(group_by)
if incl_id:
ind = np.lexsort((self.id_arr, self.col_arr)) # expensive!
else:
ind = np.argsort(self.col_arr)
return self.replace(records_arr=self.values[ind], **kwargs).regroup(group_by)
# ############# Filtering ############# #
def apply_mask(self: RecordsT, mask: tp.Array1d, group_by: tp.GroupByLike = None, **kwargs) -> RecordsT:
"""Return a new class instance, filtered by mask."""
mask_indices = np.flatnonzero(mask)
return self.replace(records_arr=np.take(self.values, mask_indices), **kwargs).regroup(group_by)
def first_n(
self: RecordsT,
n: int,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> RecordsT:
"""Return the first N records in each column."""
col_map = self.col_mapper.get_col_map(group_by=False)
func = jit_reg.resolve_option(nb.first_n_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
return self.apply_mask(func(col_map, n), **kwargs)
def last_n(
self: RecordsT,
n: int,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> RecordsT:
"""Return the last N records in each column."""
col_map = self.col_mapper.get_col_map(group_by=False)
func = jit_reg.resolve_option(nb.last_n_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
return self.apply_mask(func(col_map, n), **kwargs)
def random_n(
self: RecordsT,
n: int,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> RecordsT:
"""Return random N records in each column."""
if seed is not None:
set_seed_nb(seed)
col_map = self.col_mapper.get_col_map(group_by=False)
func = jit_reg.resolve_option(nb.random_n_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
return self.apply_mask(func(col_map, n), **kwargs)
# ############# Mapping ############# #
def map_array(
self,
a: tp.ArrayLike,
idx_arr: tp.Union[None, str, tp.Array1d] = None,
mapping: tp.Optional[tp.MappingLike] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> MappedArray:
"""Convert array to mapped array.
The length of the array must match that of the records."""
if not isinstance(a, np.ndarray):
a = np.asarray(a)
checks.assert_shape_equal(a, self.values)
if idx_arr is None:
idx_arr = self.idx_arr
elif isinstance(idx_arr, str):
idx_arr = self.get_field_arr(idx_arr)
return MappedArray(
self.wrapper,
a,
self.col_arr,
id_arr=self.id_arr,
idx_arr=idx_arr,
mapping=mapping,
col_mapper=self.col_mapper,
**kwargs,
).regroup(group_by)
def map_field(self, field: str, **kwargs) -> MappedArray:
"""Convert field to mapped array.
`**kwargs` are passed to `Records.map_array`."""
mapped_arr = self.values[field]
return self.map_array(mapped_arr, **kwargs)
@hybrid_method
def map(
cls_or_self,
map_func_nb: tp.Union[tp.RecordsMapFunc, tp.RecordsMapMetaFunc],
*args,
dtype: tp.Optional[tp.DTypeLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
col_mapper: tp.Optional[ColumnMapper] = None,
**kwargs,
) -> MappedArray:
"""Map each record to a scalar value. Returns mapped array.
See `vectorbtpro.records.nb.map_records_nb`.
For details on the meta version, see `vectorbtpro.records.nb.map_records_meta_nb`.
`**kwargs` are passed to `Records.map_array`."""
if isinstance(cls_or_self, type):
checks.assert_not_none(col_mapper, arg_name="col_mapper")
func = jit_reg.resolve_option(nb.map_records_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mapped_arr = func(len(col_mapper.col_arr), map_func_nb, *args)
mapped_arr = np.asarray(mapped_arr, dtype=dtype)
return MappedArray(col_mapper.wrapper, mapped_arr, col_mapper.col_arr, col_mapper=col_mapper, **kwargs)
else:
func = jit_reg.resolve_option(nb.map_records_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mapped_arr = func(cls_or_self.values, map_func_nb, *args)
mapped_arr = np.asarray(mapped_arr, dtype=dtype)
return cls_or_self.map_array(mapped_arr, **kwargs)
@hybrid_method
def apply(
cls_or_self,
apply_func_nb: tp.Union[tp.ApplyFunc, tp.ApplyMetaFunc],
*args,
group_by: tp.GroupByLike = None,
apply_per_group: bool = False,
dtype: tp.Optional[tp.DTypeLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
col_mapper: tp.Optional[ColumnMapper] = None,
**kwargs,
) -> MappedArray:
"""Apply function on records per column/group. Returns mapped array.
Applies per group if `apply_per_group` is True.
See `vectorbtpro.records.nb.apply_nb`.
For details on the meta version, see `vectorbtpro.records.nb.apply_meta_nb`.
`**kwargs` are passed to `Records.map_array`."""
if isinstance(cls_or_self, type):
checks.assert_not_none(col_mapper, arg_name="col_mapper")
col_map = col_mapper.get_col_map(group_by=group_by if apply_per_group else False)
func = jit_reg.resolve_option(nb.apply_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mapped_arr = func(len(col_mapper.col_arr), col_map, apply_func_nb, *args)
mapped_arr = np.asarray(mapped_arr, dtype=dtype)
return MappedArray(col_mapper.wrapper, mapped_arr, col_mapper.col_arr, col_mapper=col_mapper, **kwargs)
else:
col_map = cls_or_self.col_mapper.get_col_map(group_by=group_by if apply_per_group else False)
func = jit_reg.resolve_option(nb.apply_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mapped_arr = func(cls_or_self.values, col_map, apply_func_nb, *args)
mapped_arr = np.asarray(mapped_arr, dtype=dtype)
return cls_or_self.map_array(mapped_arr, group_by=group_by, **kwargs)
# ############# Masking ############# #
def get_pd_mask(
self,
idx_arr: tp.Union[None, str, tp.Array1d] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get mask in form of a Series/DataFrame from row and column indices."""
if idx_arr is None:
if self.idx_arr is None:
raise ValueError("Must pass idx_arr")
idx_arr = self.idx_arr
elif isinstance(idx_arr, str):
idx_arr = self.get_field_arr(idx_arr)
col_arr = self.col_mapper.get_col_arr(group_by=group_by)
target_shape = self.wrapper.get_shape_2d(group_by=group_by)
out_arr = np.full(target_shape, False)
out_arr[idx_arr, col_arr] = True
return self.wrapper.wrap(out_arr, group_by=group_by, **resolve_dict(wrap_kwargs))
@property
def pd_mask(self) -> tp.SeriesFrame:
"""`MappedArray.get_pd_mask` with default arguments."""
return self.get_pd_mask()
# ############# Reducing ############# #
@cached_method
def count(self, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries:
"""Get count by column."""
wrap_kwargs = merge_dicts(dict(name_or_index="count"), wrap_kwargs)
return self.wrapper.wrap_reduced(
self.col_mapper.get_col_map(group_by=group_by)[1],
group_by=group_by,
**wrap_kwargs,
)
# ############# Conflicts ############# #
@cached_method
def has_conflicts(self, **kwargs) -> bool:
"""See `vectorbtpro.records.mapped_array.MappedArray.has_conflicts`."""
return self.get_map_field("col").has_conflicts(**kwargs)
def coverage_map(self, **kwargs) -> tp.SeriesFrame:
"""See `vectorbtpro.records.mapped_array.MappedArray.coverage_map`."""
return self.get_map_field("col").coverage_map(**kwargs)
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `Records.stats`.
Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
`stats` from `vectorbtpro._settings.records`."""
from vectorbtpro._settings import settings
records_stats_cfg = settings["records"]["stats"]
return merge_dicts(Analyzable.stats_defaults.__get__(self), records_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
count=dict(title="Count", calc_func="count", tags="records"),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def prepare_customdata(
self,
incl_fields: tp.Optional[tp.Sequence[str]] = None,
excl_fields: tp.Optional[tp.Sequence[str]] = None,
append_info: tp.Optional[tp.Sequence[tp.Tuple]] = None,
mask: tp.Optional[tp.Array1d] = None,
) -> tp.Tuple[tp.Array2d, str]:
"""Prepare customdata and hoverinfo for Plotly.
Will display all fields in the data type or only those in `incl_fields`, unless any of them has
the field config setting `as_customdata` disabled, or it's listed in `excl_fields`.
Additionally, you can define `hovertemplate` in the field config such as by using
`vectorbtpro.utils.template.Sub` where `title` is substituted by the title and `index` is
substituted by (final) index in the customdata. If provided as a string, will be wrapped with
`vectorbtpro.utils.template.Sub`. Defaults to "$title: %{{customdata[$index]}}". Mapped fields
will be stringified automatically.
To append one or more custom arrays, provide `append_info` as a list of tuples, each consisting
of a 1-dim NumPy array, title, and optionally hoverinfo. If the array's data type is `object`,
will treat it as strings, otherwise as numbers."""
customdata_info = []
if incl_fields is not None:
iterate_over_names = incl_fields
else:
iterate_over_names = self.field_config.get("dtype").names
for field in iterate_over_names:
if excl_fields is not None and field in excl_fields:
continue
field_as_customdata = self.get_field_setting(field, "as_customdata", True)
if field_as_customdata:
numeric_customdata = self.get_field_setting(field, "mapping", None)
if numeric_customdata is not None:
field_arr = self.get_apply_mapping_str_arr(field)
field_hovertemplate = self.get_field_setting(
field,
"hovertemplate",
"$title: %{customdata[$index]}",
)
else:
field_arr = self.get_apply_mapping_arr(field)
field_hovertemplate = self.get_field_setting(
field,
"hovertemplate",
"$title: %{customdata[$index]:,}",
)
if isinstance(field_hovertemplate, str):
field_hovertemplate = Sub(field_hovertemplate)
field_title = self.get_field_title(field)
customdata_info.append((field_arr, field_title, field_hovertemplate))
if append_info is not None:
for info in append_info:
checks.assert_instance_of(info, tuple)
if len(info) == 2:
if info[0].dtype == object:
info += ("$title: %{customdata[$index]}",)
else:
info += ("$title: %{customdata[$index]:,}",)
if isinstance(info[2], str):
info = (info[0], info[1], Sub(info[2]))
customdata_info.append(info)
customdata = []
hovertemplate = []
for i in range(len(customdata_info)):
if mask is not None:
customdata.append(customdata_info[i][0][mask])
else:
customdata.append(customdata_info[i][0])
_hovertemplate = customdata_info[i][2].substitute(dict(title=customdata_info[i][1], index=i))
if not _hovertemplate.startswith(" "):
_hovertemplate = " " + _hovertemplate
hovertemplate.append(_hovertemplate)
return np.stack(customdata, axis=1), "\n".join(hovertemplate)
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `Records.plots`.
Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
`plots` from `vectorbtpro._settings.records`."""
from vectorbtpro._settings import settings
records_plots_cfg = settings["records"]["plots"]
return merge_dicts(Analyzable.plots_defaults.__get__(self), records_plots_cfg)
@property
def subplots(self) -> Config:
return self._subplots
# ############# Docs ############# #
@classmethod
def build_field_config_doc(cls, source_cls: tp.Optional[type] = None) -> str:
"""Build field config documentation."""
if source_cls is None:
source_cls = Records
return string.Template(inspect.cleandoc(get_dict_attr(source_cls, "field_config").__doc__)).substitute(
{"field_config": cls.field_config.prettify(), "cls_name": cls.__name__},
)
@classmethod
def override_field_config_doc(cls, __pdoc__: dict, source_cls: tp.Optional[type] = None) -> None:
"""Call this method on each subclass that overrides `Records.field_config`."""
__pdoc__[cls.__name__ + ".field_config"] = cls.build_field_config_doc(source_cls=source_cls)
Records.override_field_config_doc(__pdoc__)
Records.override_metrics_doc(__pdoc__)
Records.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Extensions for chunking records and mapped arrays."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.base.chunking import GroupLensMapper, GroupIdxsMapper
from vectorbtpro.utils.chunking import ChunkMeta, ChunkMapper
from vectorbtpro.utils.parsing import Regex
__all__ = []
col_lens_mapper = GroupLensMapper(arg_query=Regex(r"(col_lens|col_map)"))
"""Default instance of `vectorbtpro.base.chunking.GroupLensMapper` for per-column lengths."""
col_idxs_mapper = GroupIdxsMapper(arg_query="col_map")
"""Default instance of `vectorbtpro.base.chunking.GroupIdxsMapper` for per-column indices."""
def fix_field_in_records(
record_arrays: tp.List[tp.RecordArray],
chunk_meta: tp.Iterable[ChunkMeta],
ann_args: tp.Optional[tp.AnnArgs] = None,
mapper: tp.Optional[ChunkMapper] = None,
field: str = "col",
) -> None:
"""Fix a field of the record array in each chunk."""
for _chunk_meta in chunk_meta:
if mapper is None:
record_arrays[_chunk_meta.idx][field] += _chunk_meta.start
else:
_chunk_meta_mapped = mapper.map(_chunk_meta, ann_args=ann_args)
record_arrays[_chunk_meta.idx][field] += _chunk_meta_mapped.start
def merge_records(
results: tp.List[tp.RecordArray],
chunk_meta: tp.Iterable[ChunkMeta],
ann_args: tp.Optional[tp.AnnArgs] = None,
mapper: tp.Optional[ChunkMapper] = None,
) -> tp.RecordArray:
"""Merge chunks of record arrays.
Mapper is only applied on the column field."""
if "col" in results[0].dtype.fields:
fix_field_in_records(results, chunk_meta, ann_args=ann_args, mapper=mapper, field="col")
if "group" in results[0].dtype.fields:
fix_field_in_records(results, chunk_meta, field="group")
return np.concatenate(results)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class for mapping column arrays."""
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.base.grouping import nb as grouping_nb
from vectorbtpro.base.reshaping import to_1d_array
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.records import nb
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils.decorators import hybrid_method, cached_property, cached_method
__all__ = [
"ColumnMapper",
]
ColumnMapperT = tp.TypeVar("ColumnMapperT", bound="ColumnMapper")
class ColumnMapper(Wrapping):
"""Used by `vectorbtpro.records.base.Records` and `vectorbtpro.records.mapped_array.MappedArray`
classes to make use of column and group metadata."""
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[ColumnMapperT],
*objs: tp.MaybeTuple[ColumnMapperT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> ColumnMapperT:
"""Stack multiple `ColumnMapper` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.
!!! note
Will produce a column-sorted array."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, ColumnMapper):
raise TypeError("Each object to be merged must be an instance of ColumnMapper")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
if "col_arr" not in kwargs:
col_arrs = []
for col in range(kwargs["wrapper"].shape_2d[1]):
for obj in objs:
col_idxs, col_lens = obj.col_map
if len(col_idxs) > 0:
if col > 0 and obj.wrapper.shape_2d[1] == 1:
col_arrs.append(np.full(col_lens[0], col))
elif col_lens[col] > 0:
col_arrs.append(np.full(col_lens[col], col))
kwargs["col_arr"] = np.concatenate(col_arrs)
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[ColumnMapperT],
*objs: tp.MaybeTuple[ColumnMapperT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> ColumnMapperT:
"""Stack multiple `ColumnMapper` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.
!!! note
Will produce a column-sorted array."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, ColumnMapper):
raise TypeError("Each object to be merged must be an instance of ColumnMapper")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in objs],
**wrapper_kwargs,
)
if "col_arr" not in kwargs:
col_arrs = []
col_sum = 0
for obj in objs:
col_idxs, col_lens = obj.col_map
if len(col_idxs) > 0:
col_arrs.append(obj.col_arr[col_idxs] + col_sum)
col_sum += obj.wrapper.shape_2d[1]
kwargs["col_arr"] = np.concatenate(col_arrs)
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
def __init__(self, wrapper: ArrayWrapper, col_arr: tp.Array1d, **kwargs) -> None:
Wrapping.__init__(self, wrapper, col_arr=col_arr, **kwargs)
self._col_arr = col_arr
# Cannot select rows
self._column_only_select = True
def select_cols(
self,
col_idxs: tp.MaybeIndexArray,
jitted: tp.JittedOption = None,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Select columns.
Returns indices and new column array. Automatically decides whether to use column lengths or column map."""
if len(self.col_arr) == 0:
return np.arange(len(self.col_arr)), self.col_arr
if isinstance(col_idxs, slice):
if col_idxs.start is None and col_idxs.stop is None:
return np.arange(len(self.col_arr)), self.col_arr
col_idxs = np.arange(col_idxs.start, col_idxs.stop)
if self.is_sorted():
func = jit_reg.resolve_option(grouping_nb.group_lens_select_nb, jitted)
new_indices, new_col_arr = func(self.col_lens, to_1d_array(col_idxs)) # faster
else:
func = jit_reg.resolve_option(grouping_nb.group_map_select_nb, jitted)
new_indices, new_col_arr = func(self.col_map, to_1d_array(col_idxs)) # more flexible
return new_indices, new_col_arr
def indexing_func_meta(self, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict:
"""Perform indexing on `ColumnMapper` and return metadata."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(
*args,
column_only_select=self.column_only_select,
group_select=self.group_select,
**kwargs,
)
new_indices, new_col_arr = self.select_cols(wrapper_meta["col_idxs"])
return dict(
wrapper_meta=wrapper_meta,
new_indices=new_indices,
new_col_arr=new_col_arr,
)
def indexing_func(self: ColumnMapperT, *args, col_mapper_meta: tp.DictLike = None, **kwargs) -> ColumnMapperT:
"""Perform indexing on `ColumnMapper`."""
if col_mapper_meta is None:
col_mapper_meta = self.indexing_func_meta(*args, **kwargs)
return self.replace(
wrapper=col_mapper_meta["wrapper_meta"]["new_wrapper"],
col_arr=col_mapper_meta["new_col_arr"],
)
@property
def col_arr(self) -> tp.Array1d:
"""Column array."""
return self._col_arr
@cached_method(whitelist=True)
def get_col_arr(self, group_by: tp.GroupByLike = None) -> tp.Array1d:
"""Get group-aware column array."""
group_arr = self.wrapper.grouper.get_groups(group_by=group_by)
if group_arr is not None:
col_arr = group_arr[self.col_arr]
else:
col_arr = self.col_arr
return col_arr
@cached_property(whitelist=True)
def col_lens(self) -> tp.GroupLens:
"""Column lengths.
Faster than `ColumnMapper.col_map` but only compatible with sorted columns."""
func = jit_reg.resolve_option(nb.col_lens_nb, None)
return func(self.col_arr, len(self.wrapper.columns))
@cached_method(whitelist=True)
def get_col_lens(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None) -> tp.GroupLens:
"""Get group-aware column lengths."""
if not self.wrapper.grouper.is_grouped(group_by=group_by):
return self.col_lens
col_arr = self.get_col_arr(group_by=group_by)
columns = self.wrapper.get_columns(group_by=group_by)
func = jit_reg.resolve_option(nb.col_lens_nb, jitted)
return func(col_arr, len(columns))
@cached_property(whitelist=True)
def col_map(self) -> tp.GroupMap:
"""Column map.
More flexible than `ColumnMapper.col_lens`.
More suited for mapped arrays."""
func = jit_reg.resolve_option(nb.col_map_nb, None)
return func(self.col_arr, len(self.wrapper.columns))
@cached_method(whitelist=True)
def get_col_map(self, group_by: tp.GroupByLike = None, jitted: tp.JittedOption = None) -> tp.GroupMap:
"""Get group-aware column map."""
if not self.wrapper.grouper.is_grouped(group_by=group_by):
return self.col_map
col_arr = self.get_col_arr(group_by=group_by)
columns = self.wrapper.get_columns(group_by=group_by)
func = jit_reg.resolve_option(nb.col_map_nb, jitted)
return func(col_arr, len(columns))
@cached_method(whitelist=True)
def is_sorted(self, jitted: tp.JittedOption = None) -> bool:
"""Check whether column array is sorted."""
func = jit_reg.resolve_option(nb.is_col_sorted_nb, jitted)
return func(self.col_arr)
@cached_property(whitelist=True)
def new_id_arr(self) -> tp.Array1d:
"""Generate a new id array."""
func = jit_reg.resolve_option(nb.generate_ids_nb, None)
return func(self.col_arr, self.wrapper.shape_2d[1])
@cached_method(whitelist=True)
def get_new_id_arr(self, group_by: tp.GroupByLike = None) -> tp.Array1d:
"""Generate a new group-aware id array."""
group_arr = self.wrapper.grouper.get_groups(group_by=group_by)
if group_arr is not None:
col_arr = group_arr[self.col_arr]
else:
col_arr = self.col_arr
columns = self.wrapper.get_columns(group_by=group_by)
func = jit_reg.resolve_option(nb.generate_ids_nb, None)
return func(col_arr, len(columns))
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Class decorators for records."""
import keyword
from functools import partial
from vectorbtpro import _typing as tp
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig
from vectorbtpro.utils.decorators import cacheable_property, cached_property
from vectorbtpro.utils.mapping import to_value_mapping
__all__ = []
def override_field_config(config: Config, merge_configs: bool = True) -> tp.ClassWrapper:
"""Class decorator to override the field config of a class subclassing
`vectorbtpro.records.base.Records`.
Instead of overriding `_field_config` class attribute, you can pass `config` directly to this decorator.
Disable `merge_configs` to not merge, which will effectively disable field inheritance."""
def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
checks.assert_subclass_of(cls, "Records")
if merge_configs:
new_config = merge_dicts(cls.field_config, config)
else:
new_config = config
if not isinstance(new_config, Config):
new_config = HybridConfig(new_config)
setattr(cls, "_field_config", new_config)
return cls
return wrapper
def attach_fields(*args, on_conflict: str = "raise") -> tp.FlexClassWrapper:
"""Class decorator to attach field properties in a `vectorbtpro.records.base.Records` class.
Will extract `dtype` and other relevant information from `vectorbtpro.records.base.Records.field_config`
and map its fields as properties. This behavior can be changed by using `config`.
!!! note
Make sure to run `attach_fields` after `override_field_config`.
`config` must contain fields (keys) and dictionaries (values) with the following keys:
* `attach`: Whether to attach the field property. Can be provided as a string to be used
as a target attribute name. Defaults to True.
* `defaults`: Dictionary with default keyword arguments for `vectorbtpro.records.base.Records.map_field`.
Defaults to an empty dict.
* `attach_filters`: Whether to attach filters based on the field's values. Can be provided as a dict
to be used instead of the mapping (filter value -> target filter name). Defaults to False.
If True, defaults to `mapping` in `vectorbtpro.records.base.Records.field_config`.
* `filter_defaults`: Dictionary with default keyword arguments for `vectorbtpro.records.base.Records.apply_mask`.
Can be provided by target filter name. Defaults to an empty dict.
* `on_conflict`: Overrides global `on_conflict` for both field and filter properties.
Any potential attribute name is prepared by placing underscores between capital letters and
converting to the lower case.
If an attribute with the same name already exists in the class but the name is not listed in the field config:
* it will be overridden if `on_conflict` is 'override'
* it will be ignored if `on_conflict` is 'ignore'
* an error will be raised if `on_conflict` is 'raise'
"""
def wrapper(cls: tp.Type[tp.T], config: tp.DictLike = None) -> tp.Type[tp.T]:
checks.assert_subclass_of(cls, "Records")
dtype = cls.field_config.get("dtype", None)
checks.assert_not_none(dtype.fields)
if config is None:
config = {}
def _prepare_attr_name(attr_name: str) -> str:
checks.assert_instance_of(attr_name, str)
attr_name = attr_name.replace("NaN", "Nan")
startswith_ = attr_name.startswith("_")
new_attr_name = ""
for i in range(len(attr_name)):
if attr_name[i].isupper():
if i > 0 and attr_name[i - 1].islower():
new_attr_name += "_"
new_attr_name += attr_name[i]
attr_name = new_attr_name
if not startswith_ and attr_name.startswith("_"):
attr_name = attr_name[1:]
attr_name = attr_name.lower()
if keyword.iskeyword(attr_name):
attr_name += "_"
return attr_name.replace("__", "_")
def _check_attr_name(attr_name, _on_conflict: str = on_conflict) -> None:
if attr_name not in cls.field_config.get("settings", {}):
# Consider only attributes that are not listed in the field config
if hasattr(cls, attr_name):
if _on_conflict.lower() == "raise":
raise ValueError(f"An attribute with the name '{attr_name}' already exists in {cls}")
if _on_conflict.lower() == "ignore":
return
if _on_conflict.lower() == "override":
return
raise ValueError(f"Value '{_on_conflict}' is invalid for on_conflict")
if keyword.iskeyword(attr_name):
raise ValueError(f"Name '{attr_name}' is a keyword and cannot be used as an attribute name")
if dtype is not None:
for field_name in dtype.names:
settings = config.get(field_name, {})
attach = settings.get("attach", True)
if not isinstance(attach, bool):
target_name = attach
attach = True
else:
target_name = field_name
defaults = settings.get("defaults", None)
if defaults is None:
defaults = {}
attach_filters = settings.get("attach_filters", False)
filter_defaults = settings.get("filter_defaults", None)
if filter_defaults is None:
filter_defaults = {}
_on_conflict = settings.get("on_conflict", on_conflict)
if attach:
target_name = _prepare_attr_name(target_name)
_check_attr_name(target_name, _on_conflict)
def new_prop(
self,
_field_name: str = field_name,
_defaults: tp.KwargsLike = defaults,
) -> MappedArray:
return self.get_map_field(_field_name, **_defaults)
new_prop.__name__ = target_name
new_prop.__module__ = cls.__module__
new_prop.__qualname__ = f"{cls.__name__}.{new_prop.__name__}"
new_prop.__doc__ = f"Mapped array of the field `{field_name}`."
setattr(cls, target_name, cached_property(new_prop))
if attach_filters:
if isinstance(attach_filters, bool):
if not attach_filters:
continue
mapping = cls.field_config.get("settings", {}).get(field_name, {}).get("mapping", None)
else:
mapping = attach_filters
if mapping is None:
raise ValueError(f"Field '{field_name}': Mapping is required to attach filters")
mapping = to_value_mapping(mapping)
for filter_value, target_filter_name in mapping.items():
if target_filter_name is None:
continue
if isinstance(attach_filters, bool):
target_filter_name = field_name + "_" + target_filter_name
target_filter_name = _prepare_attr_name(target_filter_name)
_check_attr_name(target_filter_name, _on_conflict)
if target_filter_name in filter_defaults:
__filter_defaults = filter_defaults[target_filter_name]
else:
__filter_defaults = filter_defaults
def new_filter_prop(
self,
_field_name: str = field_name,
_filter_value: tp.Any = filter_value,
_filter_defaults: tp.KwargsLike = __filter_defaults,
) -> MappedArray:
filter_mask = self.get_field_arr(_field_name) == _filter_value
return self.apply_mask(filter_mask, **_filter_defaults)
new_filter_prop.__name__ = target_filter_name
new_filter_prop.__module__ = cls.__module__
new_filter_prop.__qualname__ = f"{cls.__name__}.{new_filter_prop.__name__}"
new_filter_prop.__doc__ = f"Records filtered by `{field_name} == {filter_value}`."
setattr(cls, target_filter_name, cached_property(new_filter_prop))
return cls
if len(args) == 0:
return wrapper
elif len(args) == 1:
if isinstance(args[0], type):
return wrapper(args[0])
return partial(wrapper, config=args[0])
elif len(args) == 2:
return wrapper(args[0], config=args[1])
raise ValueError("Either class, config, class and config, or keyword arguments must be passed")
def attach_shortcut_properties(config: Config) -> tp.ClassWrapper:
"""Class decorator to attach shortcut properties.
`config` must contain target property names (keys) and settings (values) with the following keys:
* `method_name`: Name of the source method. Defaults to the target name prepended with the prefix `get_`.
* `obj_type`: Type of the returned object. Can be 'array' for 2-dim arrays, 'red_array' for 1-dim arrays,
'records' for record arrays, and 'mapped_array' for mapped arrays. Defaults to 'records'.
* `group_aware`: Whether the returned object is aligned based on the current grouping.
Defaults to True.
* `method_kwargs`: Keyword arguments passed to the source method. Defaults to None.
* `decorator`: Defaults to `vectorbtpro.utils.decorators.cached_property` for object types
'records' and 'red_array'. Otherwise, to `vectorbtpro.utils.decorators.cacheable_property`.
* `decorator_kwargs`: Keyword arguments passed to the decorator. By default,
includes options `obj_type` and `group_aware`.
* `docstring`: Method docstring.
The class must be a subclass of `vectorbtpro.records.base.Records`."""
def wrapper(cls: tp.Type[tp.T]) -> tp.Type[tp.T]:
checks.assert_subclass_of(cls, "Records")
for target_name, settings in config.items():
if target_name.startswith("get_"):
raise ValueError(f"Property names cannot have prefix 'get_' ('{target_name}')")
method_name = settings.get("method_name", "get_" + target_name)
obj_type = settings.get("obj_type", "records")
group_by_aware = settings.get("group_by_aware", True)
method_kwargs = settings.get("method_kwargs", None)
method_kwargs = resolve_dict(method_kwargs)
decorator = settings.get("decorator", None)
if decorator is None:
if obj_type in ("red_array", "records"):
decorator = cached_property
else:
decorator = cacheable_property
decorator_kwargs = merge_dicts(
dict(obj_type=obj_type, group_by_aware=group_by_aware),
settings.get("decorator_kwargs", None),
)
docstring = settings.get("docstring", None)
if docstring is None:
if len(method_kwargs) == 0:
docstring = f"`{cls.__name__}.{method_name}` with default arguments."
else:
docstring = f"`{cls.__name__}.{method_name}` with arguments `{method_kwargs}`."
def new_prop(self, _method_name: str = method_name, _method_kwargs: tp.Kwargs = method_kwargs) -> tp.Any:
return getattr(self, _method_name)(**_method_kwargs)
new_prop.__name__ = target_name
new_prop.__module__ = cls.__module__
new_prop.__qualname__ = f"{cls.__name__}.{new_prop.__name__}"
new_prop.__doc__ = docstring
setattr(cls, new_prop.__name__, decorator(new_prop, **decorator_kwargs))
return cls
return wrapper
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base class for working with mapped arrays.
This class takes the mapped array and the corresponding column and (optionally) index arrays,
and offers features to directly process the mapped array without converting it to pandas;
for example, to compute various statistics by column, such as standard deviation.
Consider the following example:
```pycon
>>> from vectorbtpro import *
>>> a = np.array([10., 11., 12., 13., 14., 15., 16., 17., 18.])
>>> col_arr = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
>>> idx_arr = np.array([0, 1, 2, 0, 1, 2, 0, 1, 2])
>>> wrapper = vbt.ArrayWrapper(index=['x', 'y', 'z'],
... columns=['a', 'b', 'c'], ndim=2, freq='1 day')
>>> ma = vbt.MappedArray(wrapper, a, col_arr, idx_arr=idx_arr)
```
## Reducing
Using `MappedArray`, we can then reduce by column as follows:
* Use already provided reducers such as `MappedArray.mean`:
```pycon
>>> ma.mean()
a 11.0
b 14.0
c 17.0
dtype: float64
```
* Use `MappedArray.to_pd` to map to pandas and then reduce manually (expensive):
```pycon
>>> ma.to_pd().mean()
a 11.0
b 14.0
c 17.0
dtype: float64
```
* Use `MappedArray.reduce` to reduce using a custom function:
```pycon
>>> # Reduce to a scalar
>>> @njit
... def pow_mean_reduce_nb(a, pow):
... return np.mean(a ** pow)
>>> ma.reduce(pow_mean_reduce_nb, 2)
a 121.666667
b 196.666667
c 289.666667
dtype: float64
>>> # Reduce to an array
>>> @njit
... def min_max_reduce_nb(a):
... return np.array([np.min(a), np.max(a)])
>>> ma.reduce(min_max_reduce_nb, returns_array=True,
... wrap_kwargs=dict(name_or_index=['min', 'max']))
a b c
min 10.0 13.0 16.0
max 12.0 15.0 18.0
>>> # Reduce to an array of indices
>>> @njit
... def idxmin_idxmax_reduce_nb(a):
... return np.array([np.argmin(a), np.argmax(a)])
>>> ma.reduce(idxmin_idxmax_reduce_nb, returns_array=True,
... returns_idx=True, wrap_kwargs=dict(name_or_index=['idxmin', 'idxmax']))
a b c
idxmin x x x
idxmax z z z
>>> # Reduce using a meta function to combine multiple mapped arrays
>>> @njit
... def mean_ratio_reduce_meta_nb(idxs, col, a, b):
... return np.mean(a[idxs]) / np.mean(b[idxs])
>>> vbt.MappedArray.reduce(mean_ratio_reduce_meta_nb,
... ma.values - 1, ma.values + 1, col_mapper=ma.col_mapper)
a 0.833333
b 0.866667
c 0.888889
Name: reduce, dtype: float64
```
## Mapping
Use `MappedArray.apply` to apply a function on each column/group:
```pycon
>>> @njit
... def cumsum_apply_nb(a):
... return np.cumsum(a)
>>> ma.apply(cumsum_apply_nb)
>>> ma.apply(cumsum_apply_nb).values
array([10., 21., 33., 13., 27., 42., 16., 33., 51.])
>>> group_by = np.array(['first', 'first', 'second'])
>>> ma.apply(cumsum_apply_nb, group_by=group_by, apply_per_group=True).values
array([10., 21., 33., 46., 60., 75., 16., 33., 51.])
>>> # Apply using a meta function
>>> @njit
... def cumsum_apply_meta_nb(ridxs, col, a):
... return np.cumsum(a[ridxs])
>>> vbt.MappedArray.apply(cumsum_apply_meta_nb, ma.values, col_mapper=ma.col_mapper).values
array([10., 21., 33., 13., 27., 42., 16., 33., 51.])
```
Notice how cumsum resets at each column in the first example and at each group in the second example.
## Conversion
We can unstack any `MappedArray` instance to pandas:
* Given `idx_arr` was provided:
```pycon
>>> ma.to_pd()
a b c
x 10.0 13.0 16.0
y 11.0 14.0 17.0
z 12.0 15.0 18.0
```
!!! note
Will throw a warning if there are multiple values pointing to the same position.
* In case `group_by` was provided, index can be ignored, or there are position conflicts:
```pycon
>>> ma.to_pd(group_by=np.array(['first', 'first', 'second']), ignore_index=True)
first second
0 10.0 16.0
1 11.0 17.0
2 12.0 18.0
3 13.0 NaN
4 14.0 NaN
5 15.0 NaN
```
## Resolving conflicts
Sometimes, we may encounter multiple values for each index and column combination.
In such case, we can use `MappedArray.reduce_segments` to aggregate "duplicate" elements.
For example, let's sum up duplicate values per each index and column combination:
```pycon
>>> ma_conf = ma.replace(idx_arr=np.array([0, 0, 0, 1, 1, 1, 2, 2, 2]))
>>> ma_conf.to_pd()
UserWarning: Multiple values are pointing to the same position. Only the latest value is used.
a b c
x 12.0 NaN NaN
y NaN 15.0 NaN
z NaN NaN 18.0
>>> @njit
... def sum_reduce_nb(a):
... return np.sum(a)
>>> ma_no_conf = ma_conf.reduce_segments(
... (ma_conf.idx_arr, ma_conf.col_arr),
... sum_reduce_nb
... )
>>> ma_no_conf.to_pd()
a b c
x 33.0 NaN NaN
y NaN 42.0 NaN
z NaN NaN 51.0
```
## Filtering
Use `MappedArray.apply_mask` to filter elements per column/group:
```pycon
>>> mask = [True, False, True, False, True, False, True, False, True]
>>> filtered_ma = ma.apply_mask(mask)
>>> filtered_ma.count()
a 2
b 1
c 2
dtype: int64
>>> filtered_ma.id_arr
array([0, 2, 4, 6, 8])
```
## Grouping
One of the key features of `MappedArray` is that we can perform reducing operations on a group
of columns as if they were a single column. Groups can be specified by `group_by`, which
can be anything from positions or names of column levels, to a NumPy array with actual groups.
There are multiple ways of define grouping:
* When creating `MappedArray`, pass `group_by` to `vectorbtpro.base.wrapping.ArrayWrapper`:
```pycon
>>> group_by = np.array(['first', 'first', 'second'])
>>> grouped_wrapper = wrapper.replace(group_by=group_by)
>>> grouped_ma = vbt.MappedArray(grouped_wrapper, a, col_arr, idx_arr=idx_arr)
>>> grouped_ma.mean()
first 12.5
second 17.0
dtype: float64
```
* Regroup an existing `MappedArray`:
```pycon
>>> ma.regroup(group_by).mean()
first 12.5
second 17.0
dtype: float64
```
* Pass `group_by` directly to the reducing method:
```pycon
>>> ma.mean(group_by=group_by)
first 12.5
second 17.0
dtype: float64
```
By the same way we can disable or modify any existing grouping:
```pycon
>>> grouped_ma.mean(group_by=False)
a 11.0
b 14.0
c 17.0
dtype: float64
```
!!! note
Grouping applies only to reducing operations, there is no change to the arrays.
## Operators
`MappedArray` implements arithmetic, comparison, and logical operators. We can perform basic
operations (such as addition) on mapped arrays as if they were NumPy arrays.
```pycon
>>> ma ** 2
>>> ma * np.array([1, 2, 3, 4, 5, 6])
>>> ma + ma
```
!!! note
Ensure that your `MappedArray` operand is on the left if the other operand is an array.
If two `MappedArray` operands have different metadata, will copy metadata from the first one,
but at least their `id_arr` and `col_arr` must match.
## Indexing
Like any other class subclassing `vectorbtpro.base.wrapping.Wrapping`, we can do pandas indexing
on a `MappedArray` instance, which forwards indexing operation to each object with columns:
```pycon
>>> ma['a'].values
array([10., 11., 12.])
>>> grouped_ma['first'].values
array([10., 11., 12., 13., 14., 15.])
```
!!! note
Changing index (time axis) is not supported. The object should be treated as a Series
rather than a DataFrame; for example, use `some_field.iloc[0]` instead of `some_field.iloc[:, 0]`
to get the first column.
Indexing behavior depends solely upon `vectorbtpro.base.wrapping.ArrayWrapper`.
For example, if `group_select` is enabled indexing will be performed on groups,
otherwise on single columns.
## Caching
`MappedArray` supports caching. If a method or a property requires heavy computation, it's wrapped
with `vectorbtpro.utils.decorators.cached_method` and `vectorbtpro.utils.decorators.cached_property`
respectively. Caching can be disabled globally in `vectorbtpro._settings.caching`.
!!! note
Because of caching, class is meant to be immutable and all properties are read-only.
To change any attribute, use the `MappedArray.replace` method and pass changes as keyword arguments.
## Saving and loading
Like any other class subclassing `vectorbtpro.utils.pickling.Pickleable`, we can save a `MappedArray`
instance to the disk with `MappedArray.save` and load it with `MappedArray.load`.
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `MappedArray.metrics`.
Metric for mapped arrays are similar to that for `vectorbtpro.generic.accessors.GenericAccessor`.
```pycon
>>> ma.stats(column='a')
Start x
End z
Period 3 days 00:00:00
Count 3
Mean 11.0
Std 1.0
Min 10.0
Median 11.0
Max 12.0
Min Index x
Max Index z
Name: a, dtype: object
```
The main difference unfolds once the mapped array has a mapping:
values are then considered as categorical and usual statistics are meaningless to compute.
For this case, `MappedArray.stats` returns the value counts:
```pycon
>>> mapping = {v: "test_" + str(v) for v in np.unique(ma.values)}
>>> ma.stats(column='a', settings=dict(mapping=mapping))
Start x
End z
Period 3 days 00:00:00
Count 3
Value Counts: test_10.0 1
Value Counts: test_11.0 1
Value Counts: test_12.0 1
Value Counts: test_13.0 0
Value Counts: test_14.0 0
Value Counts: test_15.0 0
Value Counts: test_16.0 0
Value Counts: test_17.0 0
Value Counts: test_18.0 0
Name: a, dtype: object
```
`MappedArray.stats` also supports (re-)grouping:
```pycon
>>> grouped_ma.stats(column='first')
Start x
End z
Period 3 days 00:00:00
Count 6
Mean 12.5
Std 1.870829
Min 10.0
Median 12.5
Max 15.0
Min Index x
Max Index z
Name: first, dtype: object
```
## Plots
We can build histograms and boxplots of `MappedArray` directly:
```pycon
>>> ma.boxplot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
To use scatterplots or any other plots that require index, convert to pandas first:
```pycon
>>> ma.to_pd().vbt.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `MappedArray.subplots`.
`MappedArray` class has a single subplot based on `MappedArray.to_pd` and
`vectorbtpro.generic.accessors.GenericAccessor.plot`.
"""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base.merging import concat_arrays, column_stack_arrays
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.base.reshaping import to_1d_array, to_dict, index_to_series, index_to_frame
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.analyzable import Analyzable
from vectorbtpro.records import nb
from vectorbtpro.records.col_mapper import ColumnMapper
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.utils import checks
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.array_ import index_repeating_rows_nb
from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig
from vectorbtpro.utils.decorators import hybrid_method, cached_method
from vectorbtpro.utils.magic_decorators import attach_binary_magic_methods, attach_unary_magic_methods
from vectorbtpro.utils.mapping import to_value_mapping, apply_mapping
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"MappedArray",
]
MappedArrayT = tp.TypeVar("MappedArrayT", bound="MappedArray")
def combine_mapped_with_other(
self: MappedArrayT,
other: tp.Union["MappedArray", tp.ArrayLike],
np_func: tp.Callable[[tp.ArrayLike, tp.ArrayLike], tp.Array1d],
) -> MappedArrayT:
"""Combine `MappedArray` with other compatible object.
If other object is also `MappedArray`, their `id_arr` and `col_arr` must match."""
if isinstance(other, MappedArray):
checks.assert_array_equal(self.id_arr, other.id_arr)
checks.assert_array_equal(self.col_arr, other.col_arr)
other = other.values
return self.replace(mapped_arr=np_func(self.values, other))
@attach_binary_magic_methods(combine_mapped_with_other)
@attach_unary_magic_methods(lambda self, np_func: self.replace(mapped_arr=np_func(self.values)))
class MappedArray(Analyzable):
"""Exposes methods for reducing, converting, and plotting arrays mapped by
`vectorbtpro.records.base.Records` class.
Args:
wrapper (ArrayWrapper): Array wrapper.
See `vectorbtpro.base.wrapping.ArrayWrapper`.
mapped_arr (array_like): A one-dimensional array of mapped record values.
col_arr (array_like): A one-dimensional column array.
Must be of the same size as `mapped_arr`.
id_arr (array_like): A one-dimensional id array. Defaults to simple range.
Must be of the same size as `mapped_arr`.
idx_arr (array_like): A one-dimensional index array. Optional.
Must be of the same size as `mapped_arr`.
mapping (namedtuple, dict or callable): Mapping.
col_mapper (ColumnMapper): Column mapper if already known.
!!! note
It depends upon `wrapper` and `col_arr`, so make sure to invalidate `col_mapper` upon creating
a `MappedArray` instance with a modified `wrapper` or `col_arr.
`MappedArray.replace` does it automatically.
**kwargs: Custom keyword arguments passed to the config.
Useful if any subclass wants to extend the config.
"""
@hybrid_method
def row_stack(
cls_or_self: tp.MaybeType[MappedArrayT],
*objs: tp.MaybeTuple[MappedArrayT],
wrapper_kwargs: tp.KwargsLike = None,
**kwargs,
) -> MappedArrayT:
"""Stack multiple `MappedArray` instances along rows.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.row_stack` to stack the wrappers.
!!! note
Will produce a column-sorted array."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, MappedArray):
raise TypeError("Each object to be merged must be an instance of MappedArray")
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.row_stack(*[obj.wrapper for obj in objs], **wrapper_kwargs)
if "col_mapper" not in kwargs:
kwargs["col_mapper"] = ColumnMapper.row_stack(
*[obj.col_mapper for obj in objs],
wrapper=kwargs["wrapper"],
)
if "mapped_arr" not in kwargs:
mapped_arrs = []
for col in range(kwargs["wrapper"].shape_2d[1]):
for obj in objs:
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
if col > 0 and obj.wrapper.shape_2d[1] == 1:
mapped_arrs.append(obj.mapped_arr[col_idxs])
elif col_lens[col] > 0:
col_end_idxs = np.cumsum(col_lens)
col_start_idxs = col_end_idxs - col_lens
mapped_arrs.append(obj.mapped_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]])
kwargs["mapped_arr"] = concat_arrays(mapped_arrs)
if "col_arr" not in kwargs:
kwargs["col_arr"] = kwargs["col_mapper"].col_arr
if "idx_arr" not in kwargs:
stack_idx_arrs = True
for obj in objs:
if obj.idx_arr is None:
stack_idx_arrs = False
break
if stack_idx_arrs:
idx_arrs = []
for col in range(kwargs["wrapper"].shape_2d[1]):
n_rows_sum = 0
for obj in objs:
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
if col > 0 and obj.wrapper.shape_2d[1] == 1:
idx_arrs.append(obj.idx_arr[col_idxs] + n_rows_sum)
elif col_lens[col] > 0:
col_end_idxs = np.cumsum(col_lens)
col_start_idxs = col_end_idxs - col_lens
col_idx_arr = obj.idx_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]]
idx_arrs.append(col_idx_arr + n_rows_sum)
n_rows_sum += obj.wrapper.shape_2d[0]
kwargs["idx_arr"] = concat_arrays(idx_arrs)
if "id_arr" not in kwargs:
id_arrs = []
for col in range(kwargs["wrapper"].shape_2d[1]):
from_id = 0
for obj in objs:
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
if col > 0 and obj.wrapper.shape_2d[1] == 1:
id_arrs.append(obj.id_arr[col_idxs] + from_id)
elif col_lens[col] > 0:
col_end_idxs = np.cumsum(col_lens)
col_start_idxs = col_end_idxs - col_lens
id_arrs.append(obj.id_arr[col_idxs[col_start_idxs[col] : col_end_idxs[col]]] + from_id)
if len(id_arrs) > 0 and len(id_arrs[-1]) > 0:
from_id = id_arrs[-1].max() + 1
kwargs["id_arr"] = concat_arrays(id_arrs)
kwargs = cls.resolve_row_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
@hybrid_method
def column_stack(
cls_or_self: tp.MaybeType[MappedArrayT],
*objs: tp.MaybeTuple[MappedArrayT],
wrapper_kwargs: tp.KwargsLike = None,
get_indexer_kwargs: tp.KwargsLike = None,
**kwargs,
) -> MappedArrayT:
"""Stack multiple `MappedArray` instances along columns.
Uses `vectorbtpro.base.wrapping.ArrayWrapper.column_stack` to stack the wrappers.
`get_indexer_kwargs` are passed to
[pandas.Index.get_indexer](https://pandas.pydata.org/docs/reference/api/pandas.Index.get_indexer.html)
to translate old indices to new ones after the reindexing operation.
!!! note
Will produce a column-sorted array."""
if not isinstance(cls_or_self, type):
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, MappedArray):
raise TypeError("Each object to be merged must be an instance of MappedArray")
if get_indexer_kwargs is None:
get_indexer_kwargs = {}
if "wrapper" not in kwargs:
if wrapper_kwargs is None:
wrapper_kwargs = {}
kwargs["wrapper"] = ArrayWrapper.column_stack(
*[obj.wrapper for obj in objs],
**wrapper_kwargs,
)
if "col_mapper" not in kwargs:
kwargs["col_mapper"] = ColumnMapper.column_stack(
*[obj.col_mapper for obj in objs],
wrapper=kwargs["wrapper"],
)
if "mapped_arr" not in kwargs:
mapped_arrs = []
for obj in objs:
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
mapped_arrs.append(obj.mapped_arr[col_idxs])
kwargs["mapped_arr"] = concat_arrays(mapped_arrs)
if "col_arr" not in kwargs:
kwargs["col_arr"] = kwargs["col_mapper"].col_arr
if "idx_arr" not in kwargs:
stack_idx_arrs = True
for obj in objs:
if obj.idx_arr is None:
stack_idx_arrs = False
break
if stack_idx_arrs:
idx_arrs = []
for obj in objs:
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
old_idxs = obj.idx_arr[col_idxs]
if not obj.wrapper.index.equals(kwargs["wrapper"].index):
new_idxs = kwargs["wrapper"].index.get_indexer(
obj.wrapper.index[old_idxs],
**get_indexer_kwargs,
)
else:
new_idxs = old_idxs
idx_arrs.append(new_idxs)
kwargs["idx_arr"] = concat_arrays(idx_arrs)
if "id_arr" not in kwargs:
id_arrs = []
for obj in objs:
col_idxs, col_lens = obj.col_mapper.col_map
if len(col_idxs) > 0:
id_arrs.append(obj.id_arr[col_idxs])
kwargs["id_arr"] = concat_arrays(id_arrs)
kwargs = cls.resolve_column_stack_kwargs(*objs, **kwargs)
kwargs = cls.resolve_stack_kwargs(*objs, **kwargs)
return cls(**kwargs)
def __init__(
self,
wrapper: ArrayWrapper,
mapped_arr: tp.ArrayLike,
col_arr: tp.ArrayLike,
idx_arr: tp.Optional[tp.ArrayLike] = None,
id_arr: tp.Optional[tp.ArrayLike] = None,
mapping: tp.Optional[tp.MappingLike] = None,
col_mapper: tp.Optional[ColumnMapper] = None,
**kwargs,
) -> None:
mapped_arr = np.asarray(mapped_arr)
col_arr = np.asarray(col_arr)
checks.assert_shape_equal(mapped_arr, col_arr, axis=0)
if idx_arr is not None:
idx_arr = np.asarray(idx_arr)
checks.assert_shape_equal(mapped_arr, idx_arr, axis=0)
if col_mapper is None:
col_mapper = ColumnMapper(wrapper, col_arr)
if id_arr is None:
id_arr = col_mapper.new_id_arr
else:
id_arr = np.asarray(id_arr)
checks.assert_shape_equal(mapped_arr, id_arr, axis=0)
Analyzable.__init__(
self,
wrapper,
mapped_arr=mapped_arr,
col_arr=col_arr,
idx_arr=idx_arr,
id_arr=id_arr,
mapping=mapping,
col_mapper=col_mapper,
**kwargs,
)
self._mapped_arr = mapped_arr
self._col_arr = col_arr
self._idx_arr = idx_arr
self._id_arr = id_arr
self._mapping = mapping
self._col_mapper = col_mapper
# Only slices of rows can be selected
self._range_only_select = True
def replace(self: MappedArrayT, **kwargs) -> MappedArrayT:
"""See `vectorbtpro.utils.config.Configured.replace`.
Also, makes sure that `MappedArray.col_mapper` is not passed to the new instance."""
if self.config.get("col_mapper", None) is not None:
if "wrapper" in kwargs:
if self.wrapper is not kwargs.get("wrapper"):
kwargs["col_mapper"] = None
if "col_arr" in kwargs:
if self.col_arr is not kwargs.get("col_arr"):
kwargs["col_mapper"] = None
return Analyzable.replace(self, **kwargs)
def indexing_func_meta(self, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict:
"""Perform indexing on `MappedArray` and return metadata."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(
*args,
column_only_select=self.column_only_select,
range_only_select=self.range_only_select,
group_select=self.group_select,
**kwargs,
)
new_indices, new_col_arr = self.col_mapper.select_cols(wrapper_meta["col_idxs"])
new_mapped_arr = self.values[new_indices]
if self.idx_arr is not None:
new_idx_arr = self.idx_arr[new_indices]
else:
new_idx_arr = None
new_id_arr = self.id_arr[new_indices]
if wrapper_meta["rows_changed"] and new_idx_arr is not None:
row_idxs = wrapper_meta["row_idxs"]
mask = (new_idx_arr >= row_idxs.start) & (new_idx_arr < row_idxs.stop)
new_indices = new_indices[mask]
new_mapped_arr = new_mapped_arr[mask]
new_col_arr = new_col_arr[mask]
if new_idx_arr is not None:
new_idx_arr = new_idx_arr[mask] - row_idxs.start
new_id_arr = new_id_arr[mask]
return dict(
wrapper_meta=wrapper_meta,
new_indices=new_indices,
new_mapped_arr=new_mapped_arr,
new_col_arr=new_col_arr,
new_idx_arr=new_idx_arr,
new_id_arr=new_id_arr,
)
def indexing_func(self: MappedArrayT, *args, mapped_meta: tp.DictLike = None, **kwargs) -> MappedArrayT:
"""Perform indexing on `MappedArray`."""
if mapped_meta is None:
mapped_meta = self.indexing_func_meta(*args, **kwargs)
return self.replace(
wrapper=mapped_meta["wrapper_meta"]["new_wrapper"],
mapped_arr=mapped_meta["new_mapped_arr"],
col_arr=mapped_meta["new_col_arr"],
id_arr=mapped_meta["new_id_arr"],
idx_arr=mapped_meta["new_idx_arr"],
)
def resample_meta(self: MappedArrayT, *args, wrapper_meta: tp.DictLike = None, **kwargs) -> dict:
"""Perform resampling on `MappedArray` and return metadata."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.resample_meta(*args, **kwargs)
if isinstance(wrapper_meta["resampler"], Resampler):
_resampler = wrapper_meta["resampler"]
else:
_resampler = Resampler.from_pd_resampler(wrapper_meta["resampler"])
if self.idx_arr is not None:
index_map = _resampler.map_to_target_index(return_index=False)
new_idx_arr = index_map[self.idx_arr]
else:
new_idx_arr = None
return dict(wrapper_meta=wrapper_meta, new_idx_arr=new_idx_arr)
def resample(self: MappedArrayT, *args, mapped_meta: tp.DictLike = None, **kwargs) -> MappedArrayT:
"""Perform resampling on `MappedArray`."""
if mapped_meta is None:
mapped_meta = self.resample_meta(*args, **kwargs)
return self.replace(
wrapper=mapped_meta["wrapper_meta"]["new_wrapper"],
idx_arr=mapped_meta["new_idx_arr"],
)
@property
def mapped_arr(self) -> tp.Array1d:
"""Mapped array."""
return self._mapped_arr
@property
def values(self) -> tp.Array1d:
"""Mapped array."""
return self.mapped_arr
def to_readable(
self,
title: str = "Value",
only_values: bool = False,
expand_columns: bool = False,
**kwargs,
) -> tp.SeriesFrame:
"""Get values in a human-readable format."""
values = pd.Series(self.apply_mapping(**kwargs).values, name=title)
if only_values:
return pd.Series(values, name=title)
new_columns = list()
new_columns.append(pd.Series(self.id_arr, name="Id"))
column_index = self.wrapper.columns[self.col_arr]
if expand_columns and isinstance(column_index, pd.MultiIndex):
column_frame = index_to_frame(column_index, reset_index=True)
new_columns.append(column_frame.add_prefix("Column: "))
else:
column_sr = index_to_series(column_index, reset_index=True)
if expand_columns and self.wrapper.ndim == 2 and column_sr.name is not None:
new_columns.append(column_sr.rename(f"Column: {column_sr.name}"))
else:
new_columns.append(column_sr.rename("Column"))
if self.idx_arr is not None:
new_columns.append(pd.Series(self.wrapper.index[self.idx_arr], name="Index"))
new_columns.append(values)
return pd.concat(new_columns, axis=1)
@property
def mapped_readable(self) -> tp.SeriesFrame:
"""`MappedArray.to_readable` with default arguments."""
return self.to_readable()
readable = mapped_readable
def __len__(self) -> int:
return len(self.values)
@property
def col_arr(self) -> tp.Array1d:
"""Column array."""
return self._col_arr
@property
def col_mapper(self) -> ColumnMapper:
"""Column mapper.
See `vectorbtpro.records.col_mapper.ColumnMapper`."""
return self._col_mapper
@property
def idx_arr(self) -> tp.Optional[tp.Array1d]:
"""Index array."""
return self._idx_arr
@property
def id_arr(self) -> tp.Array1d:
"""Id array."""
return self._id_arr
@property
def mapping(self) -> tp.Optional[tp.MappingLike]:
"""Mapping."""
return self._mapping
# ############# Sorting ############# #
@cached_method
def is_sorted(self, incl_id: bool = False, jitted: tp.JittedOption = None) -> bool:
"""Check whether mapped array is sorted."""
if incl_id:
func = jit_reg.resolve_option(nb.is_col_id_sorted_nb, jitted)
return func(self.col_arr, self.id_arr)
func = jit_reg.resolve_option(nb.is_col_sorted_nb, jitted)
return func(self.col_arr)
def sort(
self: MappedArrayT,
incl_id: bool = False,
idx_arr: tp.Optional[tp.Array1d] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> MappedArrayT:
"""Sort mapped array by column array (primary) and id array (secondary, optional).
`**kwargs` are passed to `MappedArray.replace`."""
if idx_arr is None:
idx_arr = self.idx_arr
if self.is_sorted(incl_id=incl_id):
return self.replace(idx_arr=idx_arr, **kwargs).regroup(group_by)
if incl_id:
ind = np.lexsort((self.id_arr, self.col_arr)) # expensive!
else:
ind = np.argsort(self.col_arr)
return self.replace(
mapped_arr=self.values[ind],
col_arr=self.col_arr[ind],
id_arr=self.id_arr[ind],
idx_arr=idx_arr[ind] if idx_arr is not None else None,
**kwargs,
).regroup(group_by)
# ############# Filtering ############# #
def apply_mask(
self: MappedArrayT,
mask: tp.Array1d,
idx_arr: tp.Optional[tp.Array1d] = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> MappedArrayT:
"""Return a new class instance, filtered by mask.
`**kwargs` are passed to `MappedArray.replace`."""
if idx_arr is None:
idx_arr = self.idx_arr
mask_indices = np.flatnonzero(mask)
return self.replace(
mapped_arr=np.take(self.values, mask_indices),
col_arr=np.take(self.col_arr, mask_indices),
id_arr=np.take(self.id_arr, mask_indices),
idx_arr=np.take(idx_arr, mask_indices) if idx_arr is not None else None,
**kwargs,
).regroup(group_by)
def top_n_mask(
self,
n: int,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
) -> tp.Array1d:
"""Return mask of top N elements in each column/group."""
col_map = self.col_mapper.get_col_map(group_by=group_by)
func = jit_reg.resolve_option(nb.top_n_mapped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
return func(self.values, col_map, n)
def bottom_n_mask(
self,
n: int,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
) -> tp.Array1d:
"""Return mask of bottom N elements in each column/group."""
col_map = self.col_mapper.get_col_map(group_by=group_by)
func = jit_reg.resolve_option(nb.bottom_n_mapped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
return func(self.values, col_map, n)
def top_n(
self: MappedArrayT,
n: int,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArrayT:
"""Filter top N elements from each column/group."""
return self.apply_mask(self.top_n_mask(n, group_by=group_by, jitted=jitted, chunked=chunked), **kwargs)
def bottom_n(
self: MappedArrayT,
n: int,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArrayT:
"""Filter bottom N elements from each column/group."""
return self.apply_mask(self.bottom_n_mask(n, group_by=group_by, jitted=jitted, chunked=chunked), **kwargs)
# ############# Mapping ############# #
def resolve_mapping(self, mapping: tp.Union[None, bool, tp.MappingLike] = None) -> tp.Optional[tp.Mapping]:
"""Resolve mapping.
Set `mapping` to False to disable mapping completely."""
if mapping is None or mapping is True:
mapping = self.mapping
if isinstance(mapping, bool):
if not mapping:
return None
if isinstance(mapping, str):
if mapping.lower() == "index":
mapping = self.wrapper.index
elif mapping.lower() == "columns":
mapping = self.wrapper.columns
elif mapping.lower() == "groups":
mapping = self.wrapper.get_columns()
mapping = to_value_mapping(mapping)
return mapping
def apply_mapping(
self: MappedArrayT,
mapping: tp.Union[None, bool, tp.MappingLike] = None,
mapping_kwargs: tp.KwargsLike = None,
**kwargs,
) -> MappedArrayT:
"""Apply mapping on each element."""
mapping = self.resolve_mapping(mapping)
new_mapped_arr = apply_mapping(self.values, mapping, **resolve_dict(mapping_kwargs))
return self.replace(mapped_arr=new_mapped_arr, **kwargs)
def to_index(self, minus_one_to_zero: bool = False) -> tp.Index:
"""Convert to index.
If `minus_one_to_zero` is True, index -1 will automatically become 0.
Otherwise, will throw an error."""
if np.isin(-1, self.values):
nan_mask = self.values == -1
values = self.values.copy()
values[nan_mask] = 0
if minus_one_to_zero:
return self.wrapper.index[values]
if pd.api.types.is_integer_dtype(self.wrapper.index):
new_values = self.wrapper.index.values[values]
new_values[nan_mask] = -1
return pd.Index(new_values, name=self.wrapper.index.name)
if isinstance(self.wrapper.index, pd.DatetimeIndex):
new_values = self.wrapper.index.values[values]
new_values[nan_mask] = np.datetime64("NaT")
return pd.Index(new_values, name=self.wrapper.index.name)
new_values = self.wrapper.index.values[values]
new_values[nan_mask] = np.nan
return pd.Index(new_values, name=self.wrapper.index.name)
return self.wrapper.index[self.values]
def to_columns(self) -> tp.Index:
"""Convert to columns."""
if np.isin(-1, self.values):
raise ValueError("Cannot get index at position -1")
return self.wrapper.columns[self.values]
@hybrid_method
def apply(
cls_or_self: tp.MaybeType[MappedArrayT],
apply_func_nb: tp.Union[tp.ApplyFunc, tp.ApplyMetaFunc],
*args,
group_by: tp.GroupByLike = None,
apply_per_group: bool = False,
dtype: tp.Optional[tp.DTypeLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
col_mapper: tp.Optional[ColumnMapper] = None,
**kwargs,
) -> MappedArrayT:
"""Apply function on mapped array per column/group. Returns a new mapped array.
Applies per group of columns if `apply_per_group` is True.
See `vectorbtpro.records.nb.apply_nb`.
For details on the meta version, see `vectorbtpro.records.nb.apply_meta_nb`.
`**kwargs` are passed to `MappedArray.replace`."""
if isinstance(cls_or_self, type):
checks.assert_not_none(col_mapper, arg_name="col_mapper")
col_map = col_mapper.get_col_map(group_by=group_by if apply_per_group else False)
func = jit_reg.resolve_option(nb.apply_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mapped_arr = func(len(col_mapper.col_arr), col_map, apply_func_nb, *args)
mapped_arr = np.asarray(mapped_arr, dtype=dtype)
return MappedArray(col_mapper.wrapper, mapped_arr, col_mapper.col_arr, col_mapper=col_mapper, **kwargs)
else:
col_map = cls_or_self.col_mapper.get_col_map(group_by=group_by if apply_per_group else False)
func = jit_reg.resolve_option(nb.apply_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mapped_arr = func(cls_or_self.values, col_map, apply_func_nb, *args)
mapped_arr = np.asarray(mapped_arr, dtype=dtype)
return cls_or_self.replace(mapped_arr=mapped_arr, **kwargs).regroup(group_by)
# ############# Reducing ############# #
def reduce_segments(
self: MappedArrayT,
segment_arr: tp.Union[str, tp.MaybeTuple[tp.Array1d]],
reduce_func_nb: tp.Union[str, tp.ReduceFunc],
*args,
idx_arr: tp.Optional[tp.Array1d] = None,
group_by: tp.GroupByLike = None,
apply_per_group: bool = False,
dtype: tp.Optional[tp.DTypeLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> MappedArrayT:
"""Reduce each segment of values in mapped array. Returns a new mapped array.
`segment_arr` must be an array of integers increasing per column, each indicating a segment.
It must have the same length as the mapped array. You can also pass a list of such arrays.
In this case, each unique combination of values will be considered a single segment.
Can also pass the string "idx" to use the index array.
`reduce_func_nb` can be a string denoting the suffix of a reducing function
from `vectorbtpro.generic.nb`. For example, "sum" will refer to "sum_reduce_nb".
!!! warning
Each segment or combination of segments in `segment_arr` is assumed to be coherent and non-repeating.
That is, `np.array([0, 1, 0])` for a single column annotates three different segments, not two.
See `vectorbtpro.utils.array_.index_repeating_rows_nb`.
!!! hint
Use `MappedArray.sort` to bring the mapped array to the desired order, if required.
Applies per group of columns if `apply_per_group` is True.
See `vectorbtpro.records.nb.reduce_mapped_segments_nb`.
`**kwargs` are passed to `MappedArray.replace`."""
if idx_arr is None:
if self.idx_arr is None:
raise ValueError("Must pass idx_arr")
idx_arr = self.idx_arr
col_map = self.col_mapper.get_col_map(group_by=group_by if apply_per_group else False)
if isinstance(segment_arr, str):
if segment_arr.lower() == "idx":
segment_arr = idx_arr
else:
raise ValueError(f"Invalid segment_arr: '{segment_arr}'")
if isinstance(segment_arr, tuple):
stacked_segment_arr = column_stack_arrays(segment_arr)
segment_arr = index_repeating_rows_nb(stacked_segment_arr)
if isinstance(reduce_func_nb, str):
reduce_func_nb = getattr(generic_nb, reduce_func_nb + "_reduce_nb")
func = jit_reg.resolve_option(nb.reduce_mapped_segments_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
new_mapped_arr, new_col_arr, new_idx_arr, new_id_arr = func(
self.values,
idx_arr,
self.id_arr,
col_map,
segment_arr,
reduce_func_nb,
*args,
)
new_mapped_arr = np.asarray(new_mapped_arr, dtype=dtype)
return self.replace(
mapped_arr=new_mapped_arr,
col_arr=new_col_arr,
idx_arr=new_idx_arr,
id_arr=new_id_arr,
**kwargs,
).regroup(group_by)
@hybrid_method
def reduce(
cls_or_self,
reduce_func_nb: tp.Union[
tp.ReduceFunc, tp.MappedReduceMetaFunc, tp.ReduceToArrayFunc, tp.MappedReduceToArrayMetaFunc
],
*args,
idx_arr: tp.Optional[tp.Array1d] = None,
returns_array: bool = False,
returns_idx: bool = False,
to_index: bool = True,
fill_value: tp.Scalar = np.nan,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
col_mapper: tp.Optional[ColumnMapper] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeriesFrame:
"""Reduce mapped array by column/group.
Set `returns_array` to True if `reduce_func_nb` returns an array.
Set `returns_idx` to True if `reduce_func_nb` returns row index/position. Must pass `idx_arr`.
Set `to_index` to True to return labels instead of positions.
Use `fill_value` to set the default value.
For implementation details, see
* `vectorbtpro.records.nb.reduce_mapped_nb` if `returns_array` is False and `returns_idx` is False
* `vectorbtpro.records.nb.reduce_mapped_to_idx_nb` if `returns_array` is False and `returns_idx` is True
* `vectorbtpro.records.nb.reduce_mapped_to_array_nb` if `returns_array` is True and `returns_idx` is False
* `vectorbtpro.records.nb.reduce_mapped_to_idx_array_nb` if `returns_array` is True and `returns_idx` is True
For implementation details on the meta versions, see
* `vectorbtpro.records.nb.reduce_mapped_meta_nb` if `returns_array` is False and `returns_idx` is False
* `vectorbtpro.records.nb.reduce_mapped_to_idx_meta_nb` if `returns_array` is False and `returns_idx` is True
* `vectorbtpro.records.nb.reduce_mapped_to_array_meta_nb` if `returns_array` is True and `returns_idx` is False
* `vectorbtpro.records.nb.reduce_mapped_to_idx_array_meta_nb` if `returns_array` is True and `returns_idx` is True
"""
if isinstance(cls_or_self, type):
checks.assert_not_none(col_mapper, arg_name="col_mapper")
col_map = col_mapper.get_col_map(group_by=group_by)
if not returns_array:
if not returns_idx:
func = jit_reg.resolve_option(nb.reduce_mapped_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(col_map, fill_value, reduce_func_nb, *args)
else:
checks.assert_not_none(idx_arr, arg_name="idx_arr")
func = jit_reg.resolve_option(nb.reduce_mapped_to_idx_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(col_map, idx_arr, fill_value, reduce_func_nb, *args)
else:
if not returns_idx:
func = jit_reg.resolve_option(nb.reduce_mapped_to_array_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(col_map, fill_value, reduce_func_nb, *args)
else:
checks.assert_not_none(idx_arr, arg_name="idx_arr")
func = jit_reg.resolve_option(nb.reduce_mapped_to_idx_array_meta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(col_map, idx_arr, fill_value, reduce_func_nb, *args)
wrapper = col_mapper.wrapper
else:
if idx_arr is None:
if cls_or_self.idx_arr is None:
if returns_idx:
raise ValueError("Must pass idx_arr")
idx_arr = cls_or_self.idx_arr
col_map = cls_or_self.col_mapper.get_col_map(group_by=group_by)
if not returns_array:
if not returns_idx:
func = jit_reg.resolve_option(nb.reduce_mapped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.values, col_map, fill_value, reduce_func_nb, *args)
else:
checks.assert_not_none(idx_arr, arg_name="idx_arr")
func = jit_reg.resolve_option(nb.reduce_mapped_to_idx_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.values, col_map, idx_arr, fill_value, reduce_func_nb, *args)
else:
if not returns_idx:
func = jit_reg.resolve_option(nb.reduce_mapped_to_array_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.values, col_map, fill_value, reduce_func_nb, *args)
else:
checks.assert_not_none(idx_arr, arg_name="idx_arr")
func = jit_reg.resolve_option(nb.reduce_mapped_to_idx_array_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(cls_or_self.values, col_map, idx_arr, fill_value, reduce_func_nb, *args)
wrapper = cls_or_self.wrapper
wrap_kwargs = merge_dicts(
dict(
name_or_index="reduce" if not returns_array else None,
to_index=returns_idx and to_index,
fillna=-1 if returns_idx else None,
dtype=int_ if returns_idx else None,
),
wrap_kwargs,
)
return wrapper.wrap_reduced(out, group_by=group_by, **wrap_kwargs)
@cached_method
def nth(
self,
n: int,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return n-th element of each column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="nth"), wrap_kwargs)
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
args=ch.ArgsTaker(
None,
)
),
)
return self.reduce(
jit_reg.resolve_option(generic_nb.nth_reduce_nb, jitted),
n,
returns_array=False,
returns_idx=False,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def nth_index(
self,
n: int,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return index of n-th element of each column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="nth_index"), wrap_kwargs)
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
args=ch.ArgsTaker(
None,
)
),
)
return self.reduce(
jit_reg.resolve_option(generic_nb.nth_index_reduce_nb, jitted),
n,
returns_array=False,
returns_idx=True,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def min(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return min by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="min"), wrap_kwargs)
return self.reduce(
jit_reg.resolve_option(generic_nb.min_reduce_nb, jitted),
returns_array=False,
returns_idx=False,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def max(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return max by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="max"), wrap_kwargs)
return self.reduce(
jit_reg.resolve_option(generic_nb.max_reduce_nb, jitted),
returns_array=False,
returns_idx=False,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def mean(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return mean by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="mean"), wrap_kwargs)
return self.reduce(
jit_reg.resolve_option(generic_nb.mean_reduce_nb, jitted),
returns_array=False,
returns_idx=False,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def median(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return median by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="median"), wrap_kwargs)
return self.reduce(
jit_reg.resolve_option(generic_nb.median_reduce_nb, jitted),
returns_array=False,
returns_idx=False,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def std(
self,
ddof: int = 1,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return std by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="std"), wrap_kwargs)
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
args=ch.ArgsTaker(
None,
)
),
)
return self.reduce(
jit_reg.resolve_option(generic_nb.std_reduce_nb, jitted),
ddof,
returns_array=False,
returns_idx=False,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def sum(
self,
fill_value: tp.Scalar = 0.0,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return sum by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="sum"), wrap_kwargs)
return self.reduce(
jit_reg.resolve_option(generic_nb.sum_reduce_nb, jitted),
fill_value=fill_value,
returns_array=False,
returns_idx=False,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def idxmin(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return index of min by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="idxmin"), wrap_kwargs)
return self.reduce(
jit_reg.resolve_option(generic_nb.argmin_reduce_nb, jitted),
returns_array=False,
returns_idx=True,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def idxmax(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Return index of max by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="idxmax"), wrap_kwargs)
return self.reduce(
jit_reg.resolve_option(generic_nb.argmax_reduce_nb, jitted),
returns_array=False,
returns_idx=True,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
@cached_method
def describe(
self,
percentiles: tp.Optional[tp.ArrayLike] = None,
ddof: int = 1,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Return statistics by column/group."""
if percentiles is not None:
percentiles = to_1d_array(percentiles)
else:
percentiles = np.array([0.25, 0.5, 0.75])
percentiles = percentiles.tolist()
if 0.5 not in percentiles:
percentiles.append(0.5)
percentiles = np.unique(percentiles)
perc_formatted = pd.io.formats.format.format_percentiles(percentiles)
index = pd.Index(["count", "mean", "std", "min", *perc_formatted, "max"])
wrap_kwargs = merge_dicts(dict(name_or_index=index), wrap_kwargs)
chunked = ch.specialize_chunked_option(chunked, arg_take_spec=dict(args=ch.ArgsTaker(None, None)))
out = self.reduce(
jit_reg.resolve_option(generic_nb.describe_reduce_nb, jitted),
percentiles,
ddof,
returns_array=True,
returns_idx=False,
group_by=group_by,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
if wrap_kwargs.get("to_timedelta", False):
out.drop("count", axis=0, inplace=True)
else:
if isinstance(out, pd.DataFrame):
out.loc["count", np.isnan(out.loc["count"])] = 0.0
else:
if np.isnan(out.loc["count"]):
out.loc["count"] = 0.0
return out
@cached_method
def count(self, group_by: tp.GroupByLike = None, wrap_kwargs: tp.KwargsLike = None) -> tp.MaybeSeries:
"""Return number of values by column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="count"), wrap_kwargs)
return self.wrapper.wrap_reduced(
self.col_mapper.get_col_map(group_by=group_by)[1],
group_by=group_by,
**wrap_kwargs,
)
# ############# Value counts ############# #
@cached_method
def value_counts(
self,
axis: int = 1,
idx_arr: tp.Optional[tp.Array1d] = None,
normalize: bool = False,
sort_uniques: bool = True,
sort: bool = False,
ascending: bool = False,
dropna: bool = False,
group_by: tp.GroupByLike = None,
mapping: tp.Union[None, bool, tp.MappingLike] = None,
incl_all_keys: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""See `vectorbtpro.generic.accessors.GenericAccessor.value_counts`."""
checks.assert_in(axis, (-1, 0, 1))
mapping = self.resolve_mapping(mapping)
mapped_codes, mapped_uniques = pd.factorize(self.values, sort=False, use_na_sentinel=False)
if axis == 0:
if idx_arr is None:
idx_arr = self.idx_arr
checks.assert_not_none(idx_arr, arg_name="idx_arr")
func = jit_reg.resolve_option(nb.mapped_value_counts_per_row_nb, jitted)
value_counts = func(mapped_codes, len(mapped_uniques), idx_arr, self.wrapper.shape[0])
elif axis == 1:
col_map = self.col_mapper.get_col_map(group_by=group_by)
func = jit_reg.resolve_option(nb.mapped_value_counts_per_col_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
value_counts = func(mapped_codes, len(mapped_uniques), col_map)
else:
func = jit_reg.resolve_option(nb.mapped_value_counts_nb, jitted)
value_counts = func(mapped_codes, len(mapped_uniques))
if incl_all_keys and mapping is not None:
missing_keys = []
for x in mapping:
if pd.isnull(x) and pd.isnull(mapped_uniques).any():
continue
if x not in mapped_uniques:
missing_keys.append(x)
if axis == 0 or axis == 1:
value_counts = np.vstack((value_counts, np.full((len(missing_keys), value_counts.shape[1]), 0)))
else:
value_counts = concat_arrays((value_counts, np.full(len(missing_keys), 0)))
mapped_uniques = concat_arrays((mapped_uniques, np.array(missing_keys)))
nan_mask = np.isnan(mapped_uniques)
if dropna:
value_counts = value_counts[~nan_mask]
mapped_uniques = mapped_uniques[~nan_mask]
if sort_uniques:
new_indices = mapped_uniques.argsort()
value_counts = value_counts[new_indices]
mapped_uniques = mapped_uniques[new_indices]
if axis == 0 or axis == 1:
value_counts_sum = value_counts.sum(axis=1)
else:
value_counts_sum = value_counts
if normalize:
value_counts = value_counts / value_counts_sum.sum()
if sort:
if ascending:
new_indices = value_counts_sum.argsort()
else:
new_indices = (-value_counts_sum).argsort()
value_counts = value_counts[new_indices]
mapped_uniques = mapped_uniques[new_indices]
if axis == 0:
wrapper = ArrayWrapper.from_obj(value_counts)
value_counts_pd = wrapper.wrap(
value_counts,
index=mapped_uniques,
columns=self.wrapper.index,
**resolve_dict(wrap_kwargs),
)
elif axis == 1:
value_counts_pd = self.wrapper.wrap(
value_counts,
index=mapped_uniques,
group_by=group_by,
**resolve_dict(wrap_kwargs),
)
else:
wrapper = ArrayWrapper.from_obj(value_counts)
value_counts_pd = wrapper.wrap(
value_counts,
index=mapped_uniques,
**merge_dicts(dict(columns=["value_counts"]), wrap_kwargs),
)
if mapping is not None:
value_counts_pd.index = apply_mapping(value_counts_pd.index, mapping, **kwargs)
return value_counts_pd
# ############# Conflicts ############# #
@cached_method
def has_conflicts(
self,
idx_arr: tp.Optional[tp.Array1d] = None,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
) -> bool:
"""See `vectorbtpro.records.nb.mapped_has_conflicts_nb`."""
if idx_arr is None:
if self.idx_arr is None:
raise ValueError("Must pass idx_arr")
idx_arr = self.idx_arr
col_arr = self.col_mapper.get_col_arr(group_by=group_by)
target_shape = self.wrapper.get_shape_2d(group_by=group_by)
func = jit_reg.resolve_option(nb.mapped_has_conflicts_nb, jitted)
return func(col_arr, idx_arr, target_shape)
def coverage_map(
self,
idx_arr: tp.Optional[tp.Array1d] = None,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.records.nb.mapped_coverage_map_nb`."""
if idx_arr is None:
if self.idx_arr is None:
raise ValueError("Must pass idx_arr")
idx_arr = self.idx_arr
col_arr = self.col_mapper.get_col_arr(group_by=group_by)
target_shape = self.wrapper.get_shape_2d(group_by=group_by)
func = jit_reg.resolve_option(nb.mapped_coverage_map_nb, jitted)
out = func(col_arr, idx_arr, target_shape)
return self.wrapper.wrap(out, group_by=group_by, **resolve_dict(wrap_kwargs))
# ############# Unstacking ############# #
def to_pd(
self,
idx_arr: tp.Optional[tp.Array1d] = None,
reduce_func_nb: tp.Union[None, str, tp.ReduceFunc] = None,
reduce_args: tp.ArgsLike = None,
dtype: tp.Optional[tp.DTypeLike] = None,
ignore_index: bool = False,
repeat_index: bool = False,
fill_value: float = np.nan,
mapping: tp.Union[None, bool, tp.MappingLike] = False,
mapping_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
silence_warnings: bool = False,
) -> tp.SeriesFrame:
"""Unstack mapped array to a Series/DataFrame.
If `reduce_func_nb` is not None, will use it to reduce conflicting index segments
using `MappedArray.reduce_segments`.
* If `ignore_index`, will ignore the index and place values on top of each other in every column/group.
See `vectorbtpro.records.nb.ignore_unstack_mapped_nb`.
* If `repeat_index`, will repeat any index pointed from multiple values.
Otherwise, in case of positional conflicts, will throw a warning and use the latest value.
See `vectorbtpro.records.nb.repeat_unstack_mapped_nb`.
* Otherwise, see `vectorbtpro.records.nb.unstack_mapped_nb`.
!!! note
Will raise an error if there are multiple values pointing to the same position.
Set `ignore_index` to True in this case.
!!! warning
Mapped arrays represent information in the most memory-friendly format.
Mapping back to pandas may occupy lots of memory if records are sparse."""
if ignore_index:
if self.wrapper.ndim == 1:
return self.wrapper.wrap(
self.values,
index=np.arange(len(self.values)),
group_by=group_by,
**resolve_dict(wrap_kwargs),
)
col_map = self.col_mapper.get_col_map(group_by=group_by)
func = jit_reg.resolve_option(nb.ignore_unstack_mapped_nb, jitted)
out = func(self.values, col_map, fill_value)
mapping = self.resolve_mapping(mapping)
out = apply_mapping(out, mapping, **resolve_dict(mapping_kwargs))
return self.wrapper.wrap(out, index=np.arange(out.shape[0]), group_by=group_by, **resolve_dict(wrap_kwargs))
if idx_arr is None:
if self.idx_arr is None:
raise ValueError("Must pass idx_arr")
idx_arr = self.idx_arr
has_conflicts = self.has_conflicts(idx_arr=idx_arr, group_by=group_by)
if has_conflicts and repeat_index:
col_arr = self.col_mapper.get_col_arr(group_by=group_by)
target_shape = self.wrapper.get_shape_2d(group_by=group_by)
func = jit_reg.resolve_option(nb.mapped_coverage_map_nb, jitted)
coverage_map = func(col_arr, idx_arr, target_shape)
repeat_cnt_arr = np.max(coverage_map, axis=1)
func = jit_reg.resolve_option(nb.unstack_index_nb, jitted)
unstacked_index = self.wrapper.index[func(repeat_cnt_arr)]
func = jit_reg.resolve_option(nb.repeat_unstack_mapped_nb, jitted)
out = func(self.values, col_arr, idx_arr, repeat_cnt_arr, target_shape[1], fill_value)
mapping = self.resolve_mapping(mapping)
out = apply_mapping(out, mapping, **resolve_dict(mapping_kwargs))
wrap_kwargs = merge_dicts(dict(index=unstacked_index), wrap_kwargs)
return self.wrapper.wrap(out, group_by=group_by, **wrap_kwargs)
else:
if has_conflicts:
if reduce_func_nb is not None:
if reduce_args is None:
reduce_args = ()
self_ = self.reduce_segments(
"idx",
reduce_func_nb,
*reduce_args,
idx_arr=idx_arr,
group_by=group_by,
dtype=dtype,
jitted=jitted,
chunked=chunked,
)
idx_arr = self_.idx_arr
else:
if not silence_warnings:
warn("Multiple values are pointing to the same position. Only the latest value is used.")
self_ = self
else:
self_ = self
col_arr = self_.col_mapper.get_col_arr(group_by=group_by)
target_shape = self_.wrapper.get_shape_2d(group_by=group_by)
func = jit_reg.resolve_option(nb.unstack_mapped_nb, jitted)
out = func(self_.values, col_arr, idx_arr, target_shape, fill_value)
mapping = self_.resolve_mapping(mapping)
out = apply_mapping(out, mapping, **resolve_dict(mapping_kwargs))
return self_.wrapper.wrap(out, group_by=group_by, **resolve_dict(wrap_kwargs))
# ############# Masking ############# #
def get_pd_mask(
self,
idx_arr: tp.Optional[tp.Array1d] = None,
group_by: tp.GroupByLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Get mask in form of a Series/DataFrame from row and column indices."""
if idx_arr is None:
if self.idx_arr is None:
raise ValueError("Must pass idx_arr")
idx_arr = self.idx_arr
col_arr = self.col_mapper.get_col_arr(group_by=group_by)
target_shape = self.wrapper.get_shape_2d(group_by=group_by)
out_arr = np.full(target_shape, False)
out_arr[idx_arr, col_arr] = True
return self.wrapper.wrap(out_arr, group_by=group_by, **resolve_dict(wrap_kwargs))
@property
def pd_mask(self) -> tp.SeriesFrame:
"""`MappedArray.get_pd_mask` with default arguments."""
return self.get_pd_mask()
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `MappedArray.stats`.
Merges `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats_defaults` and
`stats` from `vectorbtpro._settings.mapped_array`."""
from vectorbtpro._settings import settings
mapped_array_stats_cfg = settings["mapped_array"]["stats"]
return merge_dicts(Analyzable.stats_defaults.__get__(self), mapped_array_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
count=dict(title="Count", calc_func="count", tags="mapped_array"),
mean=dict(title="Mean", calc_func="mean", inv_check_has_mapping=True, tags=["mapped_array", "describe"]),
std=dict(title="Std", calc_func="std", inv_check_has_mapping=True, tags=["mapped_array", "describe"]),
min=dict(title="Min", calc_func="min", inv_check_has_mapping=True, tags=["mapped_array", "describe"]),
median=dict(
title="Median",
calc_func="median",
inv_check_has_mapping=True,
tags=["mapped_array", "describe"],
),
max=dict(title="Max", calc_func="max", inv_check_has_mapping=True, tags=["mapped_array", "describe"]),
idx_min=dict(
title="Min Index",
calc_func="idxmin",
inv_check_has_mapping=True,
agg_func=None,
tags=["mapped_array", "index"],
),
idx_max=dict(
title="Max Index",
calc_func="idxmax",
inv_check_has_mapping=True,
agg_func=None,
tags=["mapped_array", "index"],
),
value_counts=dict(
title="Value Counts",
calc_func=lambda value_counts: to_dict(value_counts, orient="index_series"),
resolve_value_counts=True,
check_has_mapping=True,
tags=["mapped_array", "value_counts"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def histplot(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.BaseFigure:
"""Plot histogram by column/group."""
return self.to_pd(group_by=group_by, ignore_index=True).vbt.histplot(**kwargs)
def boxplot(self, group_by: tp.GroupByLike = None, **kwargs) -> tp.BaseFigure:
"""Plot box plot by column/group."""
return self.to_pd(group_by=group_by, ignore_index=True).vbt.boxplot(**kwargs)
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `MappedArray.plots`.
Merges `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots_defaults` and
`plots` from `vectorbtpro._settings.mapped_array`."""
from vectorbtpro._settings import settings
mapped_array_plots_cfg = settings["mapped_array"]["plots"]
return merge_dicts(Analyzable.plots_defaults.__get__(self), mapped_array_plots_cfg)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
to_pd_plot=dict(
check_is_not_grouped=True,
plot_func="to_pd.vbt.plot",
pass_trace_names=False,
tags="mapped_array",
)
)
)
@property
def subplots(self) -> Config:
return self._subplots
__pdoc__ = dict()
MappedArray.override_metrics_doc(__pdoc__)
MappedArray.override_subplots_doc(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for records and mapped arrays.
Provides an arsenal of Numba-compiled functions for records and mapped arrays.
These only accept NumPy arrays and other Numba-compatible types.
!!! note
All functions passed as argument must be Numba-compiled.
Records must retain the order they were created in."""
import numpy as np
from numba import prange
from numba.extending import overload
from numba.np.numpy_support import as_dtype
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.records import chunking as records_ch
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.utils import chunking as ch
__all__ = []
# ############# Generation ############# #
@register_jitted(cache=True)
def generate_ids_nb(col_arr: tp.Array1d, n_cols: int) -> tp.Array1d:
"""Generate the monotonically increasing id array based on the column index array."""
col_idxs = np.full(n_cols, 0, dtype=int_)
out = np.empty_like(col_arr)
for c in range(len(col_arr)):
out[c] = col_idxs[col_arr[c]]
col_idxs[col_arr[c]] += 1
return out
# ############# Indexing ############# #
@register_jitted(cache=True)
def col_lens_nb(col_arr: tp.Array1d, n_cols: int) -> tp.GroupLens:
"""Get column lengths from sorted column array.
!!! note
Requires `col_arr` to be in ascending order. This can be done by sorting."""
col_lens = np.full(n_cols, 0, dtype=int_)
last_col = -1
for c in range(col_arr.shape[0]):
col = col_arr[c]
if col < last_col:
raise ValueError("col_arr must come in ascending order")
last_col = col
col_lens[col] += 1
return col_lens
@register_jitted(cache=True)
def record_col_lens_select_nb(
records: tp.RecordArray,
col_lens: tp.GroupLens,
new_cols: tp.Array1d,
) -> tp.Tuple[tp.Array1d, tp.RecordArray]:
"""Perform indexing on sorted records using column lengths.
Returns new records."""
col_end_idxs = np.cumsum(col_lens)
col_start_idxs = col_end_idxs - col_lens
n_values = np.sum(col_lens[new_cols])
indices_out = np.empty(n_values, dtype=int_)
records_arr_out = np.empty(n_values, dtype=records.dtype)
j = 0
for c in range(new_cols.shape[0]):
from_r = col_start_idxs[new_cols[c]]
to_r = col_end_idxs[new_cols[c]]
if from_r == to_r:
continue
col_records = np.copy(records[from_r:to_r])
col_records["col"][:] = c # don't forget to assign new column indices
rang = np.arange(from_r, to_r)
indices_out[j : j + rang.shape[0]] = rang
records_arr_out[j : j + rang.shape[0]] = col_records
j += col_records.shape[0]
return indices_out, records_arr_out
@register_jitted(cache=True)
def col_map_nb(col_arr: tp.Array1d, n_cols: int) -> tp.GroupMap:
"""Build a map between columns and value indices.
Returns an array with indices segmented by column and an array with column lengths.
Works well for unsorted column arrays."""
col_lens_out = np.full(n_cols, 0, dtype=int_)
for c in range(col_arr.shape[0]):
col = col_arr[c]
col_lens_out[col] += 1
col_start_idxs = np.cumsum(col_lens_out) - col_lens_out
col_idxs_out = np.empty((col_arr.shape[0],), dtype=int_)
col_i = np.full(n_cols, 0, dtype=int_)
for c in range(col_arr.shape[0]):
col = col_arr[c]
col_idxs_out[col_start_idxs[col] + col_i[col]] = c
col_i[col] += 1
return col_idxs_out, col_lens_out
@register_jitted(cache=True)
def record_col_map_select_nb(
records: tp.RecordArray,
col_map: tp.GroupMap,
new_cols: tp.Array1d,
) -> tp.Tuple[tp.Array1d, tp.RecordArray]:
"""Same as `record_col_lens_select_nb` but using column map `col_map`."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
total_count = np.sum(col_lens[new_cols])
indices_out = np.empty(total_count, dtype=int_)
records_arr_out = np.empty(total_count, dtype=records.dtype)
j = 0
for new_col_i in range(len(new_cols)):
new_col = new_cols[new_col_i]
col_len = col_lens[new_col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[new_col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
col_records = np.copy(records[idxs])
col_records["col"][:] = new_col_i
indices_out[j : j + col_len] = idxs
records_arr_out[j : j + col_len] = col_records
j += col_len
return indices_out, records_arr_out
# ############# Sorting ############# #
@register_jitted(cache=True)
def is_col_sorted_nb(col_arr: tp.Array1d) -> bool:
"""Check whether the column array is sorted."""
for i in range(len(col_arr) - 1):
if col_arr[i + 1] < col_arr[i]:
return False
return True
@register_jitted(cache=True)
def is_col_id_sorted_nb(col_arr: tp.Array1d, id_arr: tp.Array1d) -> bool:
"""Check whether the column and id arrays are sorted."""
for i in range(len(col_arr) - 1):
if col_arr[i + 1] < col_arr[i]:
return False
if col_arr[i + 1] == col_arr[i] and id_arr[i + 1] < id_arr[i]:
return False
return True
# ############# Filtering ############# #
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
col_map=base_ch.GroupMapSlicer(),
n=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def first_n_nb(col_map: tp.GroupMap, n: int) -> tp.Array1d:
"""Returns the mask of the first N elements."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(col_idxs.shape[0], False, dtype=np.bool_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[idxs[:n]] = True
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
col_map=base_ch.GroupMapSlicer(),
n=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def last_n_nb(col_map: tp.GroupMap, n: int) -> tp.Array1d:
"""Returns the mask of the last N elements."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(col_idxs.shape[0], False, dtype=np.bool_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[idxs[-n:]] = True
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
col_map=base_ch.GroupMapSlicer(),
n=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def random_n_nb(col_map: tp.GroupMap, n: int) -> tp.Array1d:
"""Returns the mask of random N elements."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(col_idxs.shape[0], False, dtype=np.bool_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[np.random.choice(idxs, n, replace=False)] = True
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
n=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def top_n_mapped_nb(mapped_arr: tp.Array1d, col_map: tp.GroupMap, n: int) -> tp.Array1d:
"""Returns the mask of the top N mapped elements."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(mapped_arr.shape[0], False, dtype=np.bool_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[idxs[np.argsort(mapped_arr[idxs])[-n:]]] = True
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
n=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def bottom_n_mapped_nb(mapped_arr: tp.Array1d, col_map: tp.GroupMap, n: int) -> tp.Array1d:
"""Returns the mask of the bottom N mapped elements."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(mapped_arr.shape[0], False, dtype=np.bool_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[idxs[np.argsort(mapped_arr[idxs])[:n]]] = True
return out
# ############# Mapping ############# #
@register_chunkable(
size=ch.ArraySizer(arg_query="records", axis=0),
arg_take_spec=dict(records=ch.ArraySlicer(axis=0), map_func_nb=None, args=ch.ArgsTaker()),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def map_records_nb(records: tp.RecordArray, map_func_nb: tp.RecordsMapFunc, *args) -> tp.Array1d:
"""Map each record to a single value.
`map_func_nb` must accept a single record and `*args`. Must return a single value."""
out = np.empty(records.shape[0], dtype=float_)
for ridx in prange(records.shape[0]):
out[ridx] = map_func_nb(records[ridx], *args)
return out
@register_chunkable(
size=ch.ArgSizer(arg_query="n_values"),
arg_take_spec=dict(n_values=ch.CountAdapter(), map_func_nb=None, args=ch.ArgsTaker()),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def map_records_meta_nb(n_values: int, map_func_nb: tp.MappedReduceMetaFunc, *args) -> tp.Array1d:
"""Meta version of `map_records_nb`.
`map_func_nb` must accept the record index and `*args`. Must return a single value."""
out = np.empty(n_values, dtype=float_)
for ridx in prange(n_values):
out[ridx] = map_func_nb(ridx, *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
apply_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def apply_nb(arr: tp.Array1d, col_map: tp.GroupMap, apply_func_nb: tp.ApplyFunc, *args) -> tp.Array1d:
"""Apply function on mapped array or records per column.
Returns the same shape as `arr`.
`apply_func_nb` must accept the values of the column and `*args`. Must return an array."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.empty(arr.shape[0], dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[idxs] = apply_func_nb(arr[idxs], *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
n_values=ch.CountAdapter(mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
apply_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def apply_meta_nb(n_values: int, col_map: tp.GroupMap, apply_func_nb: tp.ApplyMetaFunc, *args) -> tp.Array1d:
"""Meta version of `apply_nb`.
`apply_func_nb` must accept the indices, the column index, and `*args`. Must return an array."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.empty(n_values, dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[idxs] = apply_func_nb(idxs, col, *args)
return out
# ############# Reducing ############# #
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
id_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
segment_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted
def reduce_mapped_segments_nb(
mapped_arr: tp.Array1d,
idx_arr: tp.Array1d,
id_arr: tp.Array1d,
col_map: tp.GroupMap,
segment_arr: tp.Array1d,
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]:
"""Reduce each segment of values in mapped array.
Uses the last column, index, and id of each segment for the new value.
`reduce_func_nb` must accept the values in the segment and `*args`. Must return a single value.
!!! note
Groups must come in ascending order per column, and `idx_arr` and `id_arr`
must come in ascending order per segment of values."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.empty(len(mapped_arr), dtype=mapped_arr.dtype)
col_arr_out = np.empty(len(mapped_arr), dtype=int_)
idx_arr_out = np.empty(len(mapped_arr), dtype=int_)
id_arr_out = np.empty(len(mapped_arr), dtype=int_)
k = 0
for col in range(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
segment_start_i = 0
for i in range(len(idxs)):
r = idxs[i]
if i == 0:
prev_r = -1
else:
prev_r = idxs[i - 1]
if i < len(idxs) - 1:
next_r = idxs[i + 1]
else:
next_r = -1
if prev_r != -1:
if segment_arr[r] < segment_arr[prev_r]:
raise ValueError("segment_arr must come in ascending order per column")
elif segment_arr[r] == segment_arr[prev_r]:
if idx_arr[r] < idx_arr[prev_r]:
raise ValueError("idx_arr must come in ascending order per segment")
if id_arr[r] < id_arr[prev_r]:
raise ValueError("id_arr must come in ascending order per segment")
else:
segment_start_i = i
if next_r == -1 or segment_arr[r] != segment_arr[next_r]:
n_values = i - segment_start_i + 1
if n_values > 1:
out[k] = reduce_func_nb(mapped_arr[idxs[segment_start_i : i + 1]], *args)
else:
out[k] = mapped_arr[r]
col_arr_out[k] = col
idx_arr_out[k] = idx_arr[r]
id_arr_out[k] = id_arr[r]
k += 1
return out[:k], col_arr_out[:k], idx_arr_out[:k], id_arr_out[:k]
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
fill_value=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_mapped_nb(
mapped_arr: tp.Array1d,
col_map: tp.GroupMap,
fill_value: float,
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array1d:
"""Reduce mapped array by column to a single value.
Faster than `unstack_mapped_nb` and `vbt.*` used together, and also
requires less memory. But does not take advantage of caching.
`reduce_func_nb` must accept the mapped array and `*args`.
Must return a single value."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(col_lens.shape[0], fill_value, dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[col] = reduce_func_nb(mapped_arr[idxs], *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(col_map=base_ch.GroupMapSlicer(), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker()),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_mapped_meta_nb(
col_map: tp.GroupMap,
fill_value: float,
reduce_func_nb: tp.MappedReduceMetaFunc,
*args,
) -> tp.Array1d:
"""Meta version of `reduce_mapped_nb`.
`reduce_func_nb` must accept the mapped indices, the column index, and `*args`.
Must return a single value."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(col_lens.shape[0], fill_value, dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[col] = reduce_func_nb(idxs, col, *args)
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
fill_value=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_mapped_to_idx_nb(
mapped_arr: tp.Array1d,
col_map: tp.GroupMap,
idx_arr: tp.Array1d,
fill_value: float,
reduce_func_nb: tp.ReduceFunc,
*args,
) -> tp.Array1d:
"""Reduce mapped array by column to an index.
Same as `reduce_mapped_nb` except `idx_arr` must be passed.
!!! note
Must return integers or raise an exception."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(col_lens.shape[0], fill_value, dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
col_out = reduce_func_nb(mapped_arr[idxs], *args)
out[col] = idx_arr[idxs][col_out]
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
col_map=base_ch.GroupMapSlicer(),
idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
fill_value=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="concat",
)
@register_jitted(tags={"can_parallel"})
def reduce_mapped_to_idx_meta_nb(
col_map: tp.GroupMap,
idx_arr: tp.Array1d,
fill_value: float,
reduce_func_nb: tp.MappedReduceMetaFunc,
*args,
) -> tp.Array1d:
"""Meta version of `reduce_mapped_to_idx_nb`.
`reduce_func_nb` is the same as in `reduce_mapped_meta_nb`."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full(col_lens.shape[0], fill_value, dtype=float_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
col_out = reduce_func_nb(idxs, col, *args)
out[col] = idx_arr[idxs][col_out]
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
fill_value=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_mapped_to_array_nb(
mapped_arr: tp.Array1d,
col_map: tp.GroupMap,
fill_value: float,
reduce_func_nb: tp.ReduceToArrayFunc,
*args,
) -> tp.Array2d:
"""Reduce mapped array by column to an array.
`reduce_func_nb` same as for `reduce_mapped_nb` but must return an array."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
for col in range(col_lens.shape[0]):
col_len = col_lens[col]
if col_len > 0:
col_start_idx = col_start_idxs[col]
col0, midxs0 = col, col_idxs[col_start_idx : col_start_idx + col_len]
break
col_0_out = reduce_func_nb(mapped_arr[midxs0], *args)
out = np.full((col_0_out.shape[0], col_lens.shape[0]), fill_value, dtype=float_)
for i in range(col_0_out.shape[0]):
out[i, col0] = col_0_out[i]
for col in prange(col0 + 1, col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
col_out = reduce_func_nb(mapped_arr[idxs], *args)
for i in range(col_out.shape[0]):
out[i, col] = col_out[i]
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(col_map=base_ch.GroupMapSlicer(), fill_value=None, reduce_func_nb=None, args=ch.ArgsTaker()),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_mapped_to_array_meta_nb(
col_map: tp.GroupMap,
fill_value: float,
reduce_func_nb: tp.MappedReduceToArrayMetaFunc,
*args,
) -> tp.Array2d:
"""Meta version of `reduce_mapped_to_array_nb`.
`reduce_func_nb` is the same as in `reduce_mapped_meta_nb`."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
for col in range(col_lens.shape[0]):
col_len = col_lens[col]
if col_len > 0:
col_start_idx = col_start_idxs[col]
col0, midxs0 = col, col_idxs[col_start_idx : col_start_idx + col_len]
break
col_0_out = reduce_func_nb(midxs0, col0, *args)
out = np.full((col_0_out.shape[0], col_lens.shape[0]), fill_value, dtype=float_)
for i in range(col_0_out.shape[0]):
out[i, col0] = col_0_out[i]
for col in prange(col0 + 1, col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
col_out = reduce_func_nb(idxs, col, *args)
for i in range(col_out.shape[0]):
out[i, col] = col_out[i]
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
mapped_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
col_map=base_ch.GroupMapSlicer(),
idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
fill_value=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_mapped_to_idx_array_nb(
mapped_arr: tp.Array1d,
col_map: tp.GroupMap,
idx_arr: tp.Array1d,
fill_value: float,
reduce_func_nb: tp.ReduceToArrayFunc,
*args,
) -> tp.Array2d:
"""Reduce mapped array by column to an index array.
Same as `reduce_mapped_to_array_nb` except `idx_arr` must be passed.
!!! note
Must return integers or raise an exception."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
for col in range(col_lens.shape[0]):
col_len = col_lens[col]
if col_len > 0:
col_start_idx = col_start_idxs[col]
col0, midxs0 = col, col_idxs[col_start_idx : col_start_idx + col_len]
break
col_0_out = reduce_func_nb(mapped_arr[midxs0], *args)
out = np.full((col_0_out.shape[0], col_lens.shape[0]), fill_value, dtype=float_)
for i in range(col_0_out.shape[0]):
out[i, col0] = idx_arr[midxs0[col_0_out[i]]]
for col in prange(col0 + 1, col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
col_out = reduce_func_nb(mapped_arr[idxs], *args)
for i in range(col_0_out.shape[0]):
out[i, col] = idx_arr[idxs[col_out[i]]]
return out
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
col_map=base_ch.GroupMapSlicer(),
idx_arr=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
fill_value=None,
reduce_func_nb=None,
args=ch.ArgsTaker(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def reduce_mapped_to_idx_array_meta_nb(
col_map: tp.GroupMap,
idx_arr: tp.Array1d,
fill_value: float,
reduce_func_nb: tp.MappedReduceToArrayMetaFunc,
*args,
) -> tp.Array2d:
"""Meta version of `reduce_mapped_to_idx_array_nb`.
`reduce_func_nb` is the same as in `reduce_mapped_meta_nb`."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
for col in range(col_lens.shape[0]):
col_len = col_lens[col]
if col_len > 0:
col_start_idx = col_start_idxs[col]
col0, midxs0 = col, col_idxs[col_start_idx : col_start_idx + col_len]
break
col_0_out = reduce_func_nb(midxs0, col0, *args)
out = np.full((col_0_out.shape[0], col_lens.shape[0]), fill_value, dtype=float_)
for i in range(col_0_out.shape[0]):
out[i, col0] = idx_arr[midxs0[col_0_out[i]]]
for col in prange(col0 + 1, col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
col_out = reduce_func_nb(idxs, col, *args)
for i in range(col_0_out.shape[0]):
out[i, col] = idx_arr[idxs[col_out[i]]]
return out
# ############# Value counts ############# #
@register_chunkable(
size=base_ch.GroupLensSizer(arg_query="col_map"),
arg_take_spec=dict(
codes=ch.ArraySlicer(axis=0, mapper=records_ch.col_idxs_mapper),
n_uniques=None,
col_map=base_ch.GroupMapSlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def mapped_value_counts_per_col_nb(codes: tp.Array1d, n_uniques: int, col_map: tp.GroupMap) -> tp.Array2d:
"""Get value counts per column/group of an already factorized mapped array."""
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full((n_uniques, col_lens.shape[0]), 0, dtype=int_)
for col in prange(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[:, col] = generic_nb.value_counts_1d_nb(codes[idxs], n_uniques)
return out
@register_jitted(cache=True)
def mapped_value_counts_per_row_nb(
mapped_arr: tp.Array1d,
n_uniques: int,
idx_arr: tp.Array1d,
n_rows: int,
) -> tp.Array2d:
"""Get value counts per row of an already factorized mapped array."""
out = np.full((n_uniques, n_rows), 0, dtype=int_)
for c in range(mapped_arr.shape[0]):
out[mapped_arr[c], idx_arr[c]] += 1
return out
@register_jitted(cache=True)
def mapped_value_counts_nb(mapped_arr: tp.Array1d, n_uniques: int) -> tp.Array2d:
"""Get value counts globally of an already factorized mapped array."""
out = np.full(n_uniques, 0, dtype=int_)
for c in range(mapped_arr.shape[0]):
out[mapped_arr[c]] += 1
return out
# ############# Coverage ############# #
@register_jitted(cache=True)
def mapped_has_conflicts_nb(col_arr: tp.Array1d, idx_arr: tp.Array1d, target_shape: tp.Shape) -> bool:
"""Check whether mapped array has positional conflicts."""
temp = np.zeros(target_shape)
for i in range(len(col_arr)):
if temp[idx_arr[i], col_arr[i]] > 0:
return True
temp[idx_arr[i], col_arr[i]] = 1
return False
@register_jitted(cache=True)
def mapped_coverage_map_nb(col_arr: tp.Array1d, idx_arr: tp.Array1d, target_shape: tp.Shape) -> tp.Array2d:
"""Get the coverage map of a mapped array.
Each element corresponds to the number of times it was referenced (= duplicates of `col_arr` and `idx_arr`).
More than one depicts a positional conflict."""
out = np.zeros(target_shape, dtype=int_)
for i in range(len(col_arr)):
out[idx_arr[i], col_arr[i]] += 1
return out
# ############# Unstacking ############# #
def _unstack_mapped_nb(
mapped_arr,
col_arr,
idx_arr,
target_shape,
fill_value,
):
nb_enabled = not isinstance(mapped_arr, np.ndarray)
if nb_enabled:
mapped_arr_dtype = as_dtype(mapped_arr.dtype)
fill_value_dtype = as_dtype(fill_value)
else:
mapped_arr_dtype = mapped_arr.dtype
fill_value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(mapped_arr_dtype, fill_value_dtype)
def impl(mapped_arr, col_arr, idx_arr, target_shape, fill_value):
out = np.full(target_shape, fill_value, dtype=dtype)
for r in range(mapped_arr.shape[0]):
out[idx_arr[r], col_arr[r]] = mapped_arr[r]
return out
if not nb_enabled:
return impl(mapped_arr, col_arr, idx_arr, target_shape, fill_value)
return impl
overload(_unstack_mapped_nb)(_unstack_mapped_nb)
@register_jitted(cache=True)
def unstack_mapped_nb(
mapped_arr: tp.Array1d,
col_arr: tp.Array1d,
idx_arr: tp.Array1d,
target_shape: tp.Shape,
fill_value: float,
) -> tp.Array2d:
"""Unstack mapped array using index data."""
return _unstack_mapped_nb(mapped_arr, col_arr, idx_arr, target_shape, fill_value)
def _ignore_unstack_mapped_nb(mapped_arr, col_map, fill_value):
nb_enabled = not isinstance(mapped_arr, np.ndarray)
if nb_enabled:
mapped_arr_dtype = as_dtype(mapped_arr.dtype)
fill_value_dtype = as_dtype(fill_value)
else:
mapped_arr_dtype = mapped_arr.dtype
fill_value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(mapped_arr_dtype, fill_value_dtype)
def impl(mapped_arr, col_map, fill_value):
col_idxs, col_lens = col_map
col_start_idxs = np.cumsum(col_lens) - col_lens
out = np.full((np.max(col_lens), col_lens.shape[0]), fill_value, dtype=dtype)
for col in range(col_lens.shape[0]):
col_len = col_lens[col]
if col_len == 0:
continue
col_start_idx = col_start_idxs[col]
idxs = col_idxs[col_start_idx : col_start_idx + col_len]
out[:col_len, col] = mapped_arr[idxs]
return out
if not nb_enabled:
return impl(mapped_arr, col_map, fill_value)
return impl
overload(_ignore_unstack_mapped_nb)(_ignore_unstack_mapped_nb)
@register_jitted(cache=True)
def ignore_unstack_mapped_nb(mapped_arr: tp.Array1d, col_map: tp.GroupMap, fill_value: float) -> tp.Array2d:
"""Unstack mapped array by ignoring index data."""
return _ignore_unstack_mapped_nb(mapped_arr, col_map, fill_value)
@register_jitted(cache=True)
def unstack_index_nb(repeat_cnt_arr: tp.Array1d) -> tp.Array1d:
"""Unstack index using the number of times each element must repeat.
`repeat_cnt_arr` can be created from the coverage map."""
out = np.empty(np.sum(repeat_cnt_arr), dtype=int_)
k = 0
for i in range(len(repeat_cnt_arr)):
out[k : k + repeat_cnt_arr[i]] = i
k += repeat_cnt_arr[i]
return out
def _repeat_unstack_mapped_nb(
mapped_arr,
col_arr,
idx_arr,
repeat_cnt_arr,
n_cols,
fill_value,
):
nb_enabled = not isinstance(mapped_arr, np.ndarray)
if nb_enabled:
mapped_arr_dtype = as_dtype(mapped_arr.dtype)
fill_value_dtype = as_dtype(fill_value)
else:
mapped_arr_dtype = mapped_arr.dtype
fill_value_dtype = np.array(fill_value).dtype
dtype = np.promote_types(mapped_arr_dtype, fill_value_dtype)
def impl(mapped_arr, col_arr, idx_arr, repeat_cnt_arr, n_cols, fill_value):
index_start_arr = np.cumsum(repeat_cnt_arr) - repeat_cnt_arr
out = np.full((np.sum(repeat_cnt_arr), n_cols), fill_value, dtype=dtype)
temp = np.zeros((len(repeat_cnt_arr), n_cols), dtype=int_)
for i in range(len(col_arr)):
out[index_start_arr[idx_arr[i]] + temp[idx_arr[i], col_arr[i]], col_arr[i]] = mapped_arr[i]
temp[idx_arr[i], col_arr[i]] += 1
return out
if not nb_enabled:
return impl(mapped_arr, col_arr, idx_arr, repeat_cnt_arr, n_cols, fill_value)
return impl
overload(_repeat_unstack_mapped_nb)(_repeat_unstack_mapped_nb)
@register_jitted(cache=True)
def repeat_unstack_mapped_nb(
mapped_arr: tp.Array1d,
col_arr: tp.Array1d,
idx_arr: tp.Array1d,
repeat_cnt_arr: tp.Array1d,
n_cols: int,
fill_value: float,
) -> tp.Array2d:
"""Unstack mapped array using repeated index data."""
return _repeat_unstack_mapped_nb(mapped_arr, col_arr, idx_arr, repeat_cnt_arr, n_cols, fill_value)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules that register objects across vectorbtpro."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.registries.ca_registry import *
from vectorbtpro.registries.ch_registry import *
from vectorbtpro.registries.jit_registry import *
from vectorbtpro.registries.pbar_registry import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Global registry for cacheables.
Caching in vectorbt is achieved through a combination of decorators and the registry.
Cacheable decorators such as `vectorbtpro.utils.decorators.cacheable` take a function and wrap
it with another function that behaves like the wrapped function but also takes care of all
caching modalities.
But unlike other implementations such as that of `functools.lru_cache`, the actual caching procedure
doesn't happen nor are the results stored inside the decorators themselves: decorators just register a
so-called "setup" for the wrapped function at the registry (see `CARunSetup`).
## Runnable setups
The actual magic happens within a runnable setup: it takes the function that should be called
and the arguments that should be passed to this function, looks whether the result should be cached,
runs the function, stores the result in the cache, updates the metrics, etc. It then returns the
resulting object to the wrapping function, which in turn returns it to the user. Each setup is stateful
- it stores the cache, the number of hits and misses, and other metadata. Thus, there can be only one
registered setup per each cacheable function globally at a time. To avoid creating new setups for the same
function over and over again, each setup can be uniquely identified by its function through hashing:
```pycon
>>> from vectorbtpro import *
>>> my_func = lambda: np.random.uniform(size=1000000)
>>> # Decorator returns a wrapper
>>> my_ca_func = vbt.cached(my_func)
>>> # Wrapper registers a new setup
>>> my_ca_func.get_ca_setup()
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable= at 0x7fe14e94cae8>, instance=None, max_size=None, ignore_args=None, cache={})
>>> # Another call won't register a new setup but return the existing one
>>> my_ca_func.get_ca_setup()
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable= at 0x7fe14e94cae8>, instance=None, max_size=None, ignore_args=None, cache={})
>>> # Only one CARunSetup object per wrapper and optionally the instance the wrapper is bound to
>>> hash(my_ca_func.get_ca_setup()) == hash((my_ca_func, None))
True
```
When we call `my_ca_func`, it takes the setup from the registry and calls `CARunSetup.run`.
The caching happens by the setup itself and isn't in any way visible to `my_ca_func`.
To access the cache or any metric of interest, we can ask the setup:
```pycon
>>> my_setup = my_ca_func.get_ca_setup()
>>> # Cache is empty
>>> my_setup.get_stats()
{
'hash': 4792160544297109364,
'string': '>',
'use_cache': True,
'whitelist': False,
'caching_enabled': True,
'hits': 0,
'misses': 0,
'total_size': '0 Bytes',
'total_elapsed': None,
'total_saved': None,
'first_run_time': None,
'last_run_time': None,
'first_hit_time': None,
'last_hit_time': None,
'creation_time': 'now',
'last_update_time': None
}
>>> # The result is cached
>>> my_ca_func()
>>> my_setup.get_stats()
{
'hash': 4792160544297109364,
'string': '>',
'use_cache': True,
'whitelist': False,
'caching_enabled': True,
'hits': 0,
'misses': 1,
'total_size': '8.0 MB',
'total_elapsed': '11.33 milliseconds',
'total_saved': '0 milliseconds',
'first_run_time': 'now',
'last_run_time': 'now',
'first_hit_time': None,
'last_hit_time': None,
'creation_time': 'now',
'last_update_time': None
}
>>> # The cached result is retrieved
>>> my_ca_func()
>>> my_setup.get_stats()
{
'hash': 4792160544297109364,
'string': '>',
'use_cache': True,
'whitelist': False,
'caching_enabled': True,
'hits': 1,
'misses': 1,
'total_size': '8.0 MB',
'total_elapsed': '11.33 milliseconds',
'total_saved': '11.33 milliseconds',
'first_run_time': 'now',
'last_run_time': 'now',
'first_hit_time': 'now',
'last_hit_time': 'now',
'creation_time': 'now',
'last_update_time': None
}
```
## Enabling/disabling caching
To enable or disable caching, we can invoke `CARunSetup.enable_caching` and `CARunSetup.disable_caching`
respectively. This will set `CARunSetup.use_cache` flag to True or False. Even though we expressed
our disire to change caching rules, the final decision also depends on the global settings and whether
the setup is whitelisted in case caching is disabled globally. This decision is available via
`CARunSetup.caching_enabled`:
```pycon
>>> my_setup.disable_caching()
>>> my_setup.caching_enabled
False
>>> my_setup.enable_caching()
>>> my_setup.caching_enabled
True
>>> vbt.settings.caching['disable'] = True
>>> my_setup.caching_enabled
False
>>> my_setup.enable_caching()
UserWarning: This operation has no effect: caching is disabled globally and this setup is not whitelisted
>>> my_setup.enable_caching(force=True)
>>> my_setup.caching_enabled
True
>>> vbt.settings.caching['disable_whitelist'] = True
>>> my_setup.caching_enabled
False
>>> my_setup.enable_caching(force=True)
UserWarning: This operation has no effect: caching and whitelisting are disabled globally
```
To disable registration of new setups completely, use `disable_machinery`:
```pycon
>>> vbt.settings.caching['disable_machinery'] = True
```
## Setup hierarchy
But what if we wanted to change caching rules for an entire instance or class at once?
Even if we changed the setup of every cacheable function declared in the class, how do we
make sure that each future subclass or instance inherits the changes that we applied?
To account for this, vectorbt provides us with a set of setups that both are stateful
and can delegate various operations to their child setups, all the way down to `CARunSetup`.
The setup hierarchy follows the inheritance hierarchy in OOP:
{: loading=lazy style="width:700px;" }
For example, calling `B.get_ca_setup().disable_caching()` would disable caching for each current
and future subclass and instance of `B`, but it won't disable caching for `A` or any other superclass of `B`.
In turn, each instance of `B` would then disable caching for each cacheable property and method in
that instance. As we see, the propagation of this operation is happening from top to bottom.
The reason why unbound setups are stretching outside of their classes in the diagram is
because there is no easy way to derive the class when calling a cacheable decorator,
thus their functions are considered to be living on their own. When calling
`B.f.get_ca_setup().disable_caching()`, we are disabling caching for the function `B.f`
for each current and future subclass and instance of `B`, while all other functions remain untouched.
But what happens when we enable caching for the class `B` and disable caching for the unbound
function `B.f`? Would the future method `b2.f` be cached or not? Quite easy: it would then
inherit the state from the setup that has been updated more recently.
Here is another illustration of how operations are propagated from parents to children:
{: loading=lazy style="width:800px;" }
The diagram above depicts the following setup hierarchy:
```pycon
>>> # Populate setups at init
>>> vbt.settings.caching.reset()
>>> vbt.settings.caching['register_lazily'] = False
>>> class A(vbt.Cacheable):
... @vbt.cached_property
... def f1(self): pass
>>> class B(A):
... def f2(self): pass
>>> class C(A):
... @vbt.cached_method
... def f2(self): pass
>>> b1 = B()
>>> c1 = C()
>>> c2 = C()
>>> print(vbt.prettify(A.get_ca_setup().get_setup_hierarchy()))
[
{
"parent": "",
"children": [
{
"parent": "",
"children": [
""
]
}
]
},
{
"parent": "",
"children": [
{
"parent": "",
"children": [
"",
""
]
},
{
"parent": "",
"children": [
"",
""
]
}
]
}
]
>>> print(vbt.prettify(A.f1.get_ca_setup().get_setup_hierarchy()))
[
"",
"",
""
]
>>> print(vbt.prettify(C.f2.get_ca_setup().get_setup_hierarchy()))
[
"",
""
]
```
Let's disable caching for the entire `A` class:
```pycon
>>> A.get_ca_setup().disable_caching()
>>> A.get_ca_setup().use_cache
False
>>> B.get_ca_setup().use_cache
False
>>> C.get_ca_setup().use_cache
False
```
This disabled caching for `A`, subclasses `B` and `C`, their instances, and any instance function.
But it didn't touch unbound functions such as `C.f1` and `C.f2`:
```pycon
>>> C.f1.get_ca_setup().use_cache
True
>>> C.f2.get_ca_setup().use_cache
True
```
This is because unbound functions are not children of the classes they are declared in!
Still, any future instance method of `C` won't be cached because it looks which parent
has been updated more recently: the class or the unbound function. In our case,
the class had a more recent update.
```pycon
>>> c3 = C()
>>> C.f2.get_ca_setup(c3).use_cache
False
```
In fact, if we want to disable an entire class but leave one function untouched,
we need to perform two operations in a particular order: 1) disable caching on the class
and 2) enable caching on the unbound function.
```pycon
>>> A.get_ca_setup().disable_caching()
>>> C.f2.get_ca_setup().enable_caching()
>>> c4 = C()
>>> C.f2.get_ca_setup(c4).use_cache
True
```
## Getting statistics
The main advantage of having a central registry of setups is that we can easily find any setup
registered in any part of vectorbt that matches some condition using `CacheableRegistry.match_setups`.
!!! note
By default, all setups are registered lazily - no setup is registered until it's run
or explicitly called. To change this behavior, set `register_lazily` in the global
settings to False.
For example, let's look which setups have been registered so far:
```pycon
>>> vbt.ca_reg.match_setups(kind=None)
{
CAClassSetup(registry=, use_cache=None, whitelist=None, cls=),
CAClassSetup(registry=, use_cache=None, whitelist=None, cls=),
CAInstanceSetup(registry=, use_cache=None, whitelist=None, instance=),
CAInstanceSetup(registry=, use_cache=None, whitelist=None, instance=),
CAInstanceSetup(registry=, use_cache=None, whitelist=None, instance=),
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable= at 0x7fe14e94cae8>, instance=None, max_size=None, ignore_args=None, cache={}),
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}),
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}),
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}),
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}),
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}),
CAUnboundSetup(registry=, use_cache=True, whitelist=False, cacheable=),
CAUnboundSetup(registry=, use_cache=True, whitelist=False, cacheable=)
}
```
Let's get the runnable setup of any property and method called `f2`:
```pycon
>>> vbt.ca_reg.match_setups('f2', kind='runnable')
{
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={}),
CARunSetup(registry=, use_cache=True, whitelist=False, cacheable=, instance=, max_size=None, ignore_args=None, cache={})
}
```
But there is a better way to get the stats: `CAQueryDelegator.get_stats`.
It returns a DataFrame with setup stats as rows:
```pycon
>>> vbt.CAQueryDelegator('f2', kind='runnable').get_stats()
string use_cache whitelist \\
hash
3506416602224216137 True False
-4747092115268118855 True False
-4748466030718995055 True False
caching_enabled hits misses total_size total_elapsed \\
hash
3506416602224216137 True 0 0 0 Bytes None
-4747092115268118855 True 0 0 0 Bytes None
-4748466030718995055 True 0 0 0 Bytes None
total_saved first_run_time last_run_time first_hit_time \\
hash
3506416602224216137 None None None None
-4747092115268118855 None None None None
-4748466030718995055 None None None None
last_hit_time creation_time last_update_time
hash
3506416602224216137 None 9 minutes ago 9 minutes ago
-4747092115268118855 None 9 minutes ago 9 minutes ago
-4748466030718995055 None 9 minutes ago 9 minutes ago
```
## Clearing up
Instance and runnable setups hold only weak references to their instances such that
deleting those instances won't keep them in memory and will automatically remove the setups.
To clear all caches:
```pycon
>>> vbt.CAQueryDelegator().clear_cache()
```
## Resetting
To reset global caching flags:
```pycon
>>> vbt.settings.caching.reset()
```
To remove all setups:
```pycon
>>> vbt.CAQueryDelegator(kind=None).deregister()
```
"""
import inspect
import sys
from collections.abc import ValuesView
from datetime import datetime, timezone, timedelta
from functools import wraps
from weakref import ref, ReferenceType
import attr
import humanize
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks, datetime_ as dt
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.caching import Cacheable
from vectorbtpro.utils.decorators import cacheableT, cacheable_property
from vectorbtpro.utils.formatting import ptable
from vectorbtpro.utils.parsing import Regex, hash_args, UnhashableArgsError, get_func_arg_names
from vectorbtpro.utils.profiling import Timer
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"CacheableRegistry",
"ca_reg",
"CAQuery",
"CARule",
"CAQueryDelegator",
"get_cache_stats",
"print_cache_stats",
"clear_cache",
"collect_garbage",
"flush",
"disable_caching",
"enable_caching",
"CachingDisabled",
"with_caching_disabled",
"CachingEnabled",
"with_caching_enabled",
]
__pdoc__ = {}
class _GARBAGE:
"""Sentinel class for garbage."""
def is_cacheable_function(cacheable: tp.Any) -> bool:
"""Check if `cacheable` is a cacheable function."""
return (
callable(cacheable)
and hasattr(cacheable, "is_method")
and not cacheable.is_method
and hasattr(cacheable, "is_cacheable")
and cacheable.is_cacheable
)
def is_cacheable_property(cacheable: tp.Any) -> bool:
"""Check if `cacheable` is a cacheable property."""
return isinstance(cacheable, cacheable_property)
def is_cacheable_method(cacheable: tp.Any) -> bool:
"""Check if `cacheable` is a cacheable method."""
return (
callable(cacheable)
and hasattr(cacheable, "is_method")
and cacheable.is_method
and hasattr(cacheable, "is_cacheable")
and cacheable.is_cacheable
)
def is_bindable_cacheable(cacheable: tp.Any) -> bool:
"""Check if `cacheable` is a cacheable that can be bound to an instance."""
return is_cacheable_property(cacheable) or is_cacheable_method(cacheable)
def is_cacheable(cacheable: tp.Any) -> bool:
"""Check if `cacheable` is a cacheable."""
return is_cacheable_function(cacheable) or is_bindable_cacheable(cacheable)
def get_obj_id(instance: object) -> tp.Tuple[type, int]:
"""Get id of an instance."""
return type(instance), id(instance)
CAQueryT = tp.TypeVar("CAQueryT", bound="CAQuery")
InstanceT = tp.Optional[tp.Union[Cacheable, ReferenceType]]
def _instance_converter(instance: InstanceT) -> InstanceT:
"""Make the reference to the instance weak."""
if instance is not None and instance is not _GARBAGE and not isinstance(instance, ReferenceType):
return ref(instance)
return instance
@define
class CAQuery(DefineMixin):
"""Data class that represents a query for matching and ranking setups."""
cacheable: tp.Optional[tp.Union[tp.Callable, cacheableT, str, Regex]] = define.field(default=None)
"""Cacheable object or its name (case-sensitive)."""
instance: InstanceT = define.field(default=None, converter=_instance_converter)
"""Weak reference to the instance `CAQuery.cacheable` is bound to."""
cls: tp.Optional[tp.TypeLike] = define.field(default=None)
"""Class of the instance or its name (case-sensitive) `CAQuery.cacheable` is bound to."""
base_cls: tp.Optional[tp.TypeLike] = define.field(default=None)
"""Base class of the instance or its name (case-sensitive) `CAQuery.cacheable` is bound to."""
options: tp.Optional[dict] = define.field(default=None)
"""Options to match."""
@classmethod
def parse(cls: tp.Type[CAQueryT], query_like: tp.Any, use_base_cls: bool = True) -> CAQueryT:
"""Parse a query-like object.
!!! note
Not all attribute combinations can be safely parsed by this function.
For example, you cannot combine cacheable together with options.
Usage:
```pycon
>>> vbt.CAQuery.parse(lambda x: x)
CAQuery(cacheable= at 0x7fd4766c7730>, instance=None, cls=None, base_cls=None, options=None)
>>> vbt.CAQuery.parse("a")
CAQuery(cacheable='a', instance=None, cls=None, base_cls=None, options=None)
>>> vbt.CAQuery.parse("A.a")
CAQuery(cacheable='a', instance=None, cls=None, base_cls='A', options=None)
>>> vbt.CAQuery.parse("A")
CAQuery(cacheable=None, instance=None, cls=None, base_cls='A', options=None)
>>> vbt.CAQuery.parse("A", use_base_cls=False)
CAQuery(cacheable=None, instance=None, cls='A', base_cls=None, options=None)
>>> vbt.CAQuery.parse(vbt.Regex("[A-B]"))
CAQuery(cacheable=None, instance=None, cls=None, base_cls=Regex(pattern='[A-B]', flags=0), options=None)
>>> vbt.CAQuery.parse(dict(my_option=100))
CAQuery(cacheable=None, instance=None, cls=None, base_cls=None, options={'my_option': 100})
```
"""
if query_like is None:
return CAQuery()
if isinstance(query_like, CAQuery):
return query_like
if isinstance(query_like, CABaseSetup):
return query_like.query
if isinstance(query_like, cacheable_property):
return cls(cacheable=query_like)
if isinstance(query_like, str) and query_like[0].islower():
return cls(cacheable=query_like)
if isinstance(query_like, str) and query_like[0].isupper() and "." in query_like:
if use_base_cls:
return cls(cacheable=query_like.split(".")[1], base_cls=query_like.split(".")[0])
return cls(cacheable=query_like.split(".")[1], cls=query_like.split(".")[0])
if isinstance(query_like, str) and query_like[0].isupper():
if use_base_cls:
return cls(base_cls=query_like)
return cls(cls=query_like)
if isinstance(query_like, Regex):
if use_base_cls:
return cls(base_cls=query_like)
return cls(cls=query_like)
if isinstance(query_like, type):
if use_base_cls:
return cls(base_cls=query_like)
return cls(cls=query_like)
if isinstance(query_like, tuple):
if use_base_cls:
return cls(base_cls=query_like)
return cls(cls=query_like)
if isinstance(query_like, dict):
return cls(options=query_like)
if callable(query_like):
return cls(cacheable=query_like)
return cls(instance=query_like)
@property
def instance_obj(self) -> tp.Optional[tp.Union[Cacheable, object]]:
"""Instance object."""
if self.instance is _GARBAGE:
return _GARBAGE
if self.instance is not None and self.instance() is None:
return _GARBAGE
return self.instance() if self.instance is not None else None
def matches_setup(self, setup: "CABaseSetup") -> bool:
"""Return whether the setup matches this query.
Usage:
Let's evaluate various queries:
```pycon
>>> class A(vbt.Cacheable):
... @vbt.cached_method(my_option=True)
... def f(self):
... return None
>>> class B(A):
... pass
>>> @vbt.cached(my_option=False)
... def f():
... return None
>>> a = A()
>>> b = B()
>>> def match_query(query):
... matched = []
... if query.matches_setup(A.f.get_ca_setup()): # unbound method
... matched.append('A.f')
... if query.matches_setup(A.get_ca_setup()): # class
... matched.append('A')
... if query.matches_setup(a.get_ca_setup()): # instance
... matched.append('a')
... if query.matches_setup(A.f.get_ca_setup(a)): # instance method
... matched.append('a.f')
... if query.matches_setup(B.f.get_ca_setup()): # unbound method
... matched.append('B.f')
... if query.matches_setup(B.get_ca_setup()): # class
... matched.append('B')
... if query.matches_setup(b.get_ca_setup()): # instance
... matched.append('b')
... if query.matches_setup(B.f.get_ca_setup(b)): # instance method
... matched.append('b.f')
... if query.matches_setup(f.get_ca_setup()): # function
... matched.append('f')
... return matched
>>> match_query(vbt.CAQuery())
['A.f', 'A', 'a', 'a.f', 'B.f', 'B', 'b', 'b.f', 'f']
>>> match_query(vbt.CAQuery(cacheable=A.f))
['A.f', 'a.f', 'B.f', 'b.f']
>>> match_query(vbt.CAQuery(cacheable=B.f))
['A.f', 'a.f', 'B.f', 'b.f']
>>> match_query(vbt.CAQuery(cls=A))
['A', 'a', 'a.f']
>>> match_query(vbt.CAQuery(cls=B))
['B', 'b', 'b.f']
>>> match_query(vbt.CAQuery(cls=vbt.Regex('[A-B]')))
['A', 'a', 'a.f', 'B', 'b', 'b.f']
>>> match_query(vbt.CAQuery(base_cls=A))
['A', 'a', 'a.f', 'B', 'b', 'b.f']
>>> match_query(vbt.CAQuery(base_cls=B))
['B', 'b', 'b.f']
>>> match_query(vbt.CAQuery(instance=a))
['a', 'a.f']
>>> match_query(vbt.CAQuery(instance=b))
['b', 'b.f']
>>> match_query(vbt.CAQuery(instance=a, cacheable='f'))
['a.f']
>>> match_query(vbt.CAQuery(instance=b, cacheable='f'))
['b.f']
>>> match_query(vbt.CAQuery(options=dict(my_option=True)))
['A.f', 'a.f', 'B.f', 'b.f']
>>> match_query(vbt.CAQuery(options=dict(my_option=False)))
['f']
```
"""
if self.cacheable is not None:
if not isinstance(setup, (CARunSetup, CAUnboundSetup)):
return False
if is_cacheable(self.cacheable):
if setup.cacheable is not self.cacheable and setup.cacheable.func is not self.cacheable.func:
return False
elif callable(self.cacheable):
if setup.cacheable.func is not self.cacheable:
return False
elif isinstance(self.cacheable, str):
if setup.cacheable.name != self.cacheable:
return False
elif isinstance(self.cacheable, Regex):
if not self.cacheable.matches(setup.cacheable.name):
return False
else:
return False
if self.instance_obj is not None:
if not isinstance(setup, (CARunSetup, CAInstanceSetup)):
return False
if setup.instance_obj is not self.instance_obj:
return False
if self.cls is not None:
if not isinstance(setup, (CARunSetup, CAInstanceSetup, CAClassSetup)):
return False
if isinstance(setup, (CARunSetup, CAInstanceSetup)) and setup.instance_obj is _GARBAGE:
return False
if isinstance(setup, (CARunSetup, CAInstanceSetup)) and not checks.is_class(
type(setup.instance_obj),
self.cls,
):
return False
if isinstance(setup, CAClassSetup) and not checks.is_class(setup.cls, self.cls):
return False
if self.base_cls is not None:
if not isinstance(setup, (CARunSetup, CAInstanceSetup, CAClassSetup)):
return False
if isinstance(setup, (CARunSetup, CAInstanceSetup)) and setup.instance_obj is _GARBAGE:
return False
if isinstance(setup, (CARunSetup, CAInstanceSetup)) and not checks.is_subclass_of(
type(setup.instance_obj),
self.base_cls,
):
return False
if isinstance(setup, CAClassSetup) and not checks.is_subclass_of(setup.cls, self.base_cls):
return False
if self.options is not None and len(self.options) > 0:
if not isinstance(setup, (CARunSetup, CAUnboundSetup)):
return False
for k, v in self.options.items():
if k not in setup.cacheable.options or setup.cacheable.options[k] != v:
return False
return True
@property
def hash_key(self) -> tuple:
return (
self.cacheable,
get_obj_id(self.instance_obj) if self.instance_obj is not None else None,
self.cls,
self.base_cls,
tuple(self.options.items()) if self.options is not None else None,
)
@define
class CARule(DefineMixin):
"""Data class that represents a rule that should be enforced on setups that match a query."""
query: CAQuery = define.field()
"""`CAQuery` used in matching."""
enforce_func: tp.Optional[tp.Callable] = define.field()
"""Function to run on the setup if it has been matched."""
kind: tp.Optional[tp.MaybeIterable[str]] = define.field(default=None)
"""Kind of a setup to match."""
exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = define.field(default=None)
"""One or multiple setups to exclude."""
filter_func: tp.Optional[tp.Callable] = define.field(default=None)
"""Function to filter out a setup."""
def matches_setup(self, setup: "CABaseSetup") -> bool:
"""Return whether the setup matches the rule."""
if not self.query.matches_setup(setup):
return False
if self.kind is not None:
kind = self.kind
if isinstance(kind, str):
kind = {kind}
else:
kind = set(kind)
if isinstance(setup, CAClassSetup):
setup_kind = "class"
elif isinstance(setup, CAInstanceSetup):
setup_kind = "instance"
elif isinstance(setup, CAUnboundSetup):
setup_kind = "unbound"
else:
setup_kind = "runnable"
if setup_kind not in kind:
return False
if self.exclude is not None:
exclude = self.exclude
if exclude is None:
exclude = set()
if isinstance(exclude, CABaseSetup):
exclude = {exclude}
else:
exclude = set(exclude)
if setup in exclude:
return False
if self.filter_func is not None:
if not self.filter_func(setup):
return False
return True
def enforce(self, setup: "CABaseSetup") -> None:
"""Run `CARule.enforce_func` on the setup if it has been matched."""
if self.matches_setup(setup):
self.enforce_func(setup)
@property
def hash_key(self) -> tuple:
return (
self.query,
self.enforce_func,
self.kind,
self.exclude if isinstance(self.exclude, CABaseSetup) else tuple(self.exclude),
self.filter_func,
)
class CacheableRegistry(Base):
"""Class for registering setups of cacheables."""
def __init__(self) -> None:
self._class_setups = dict()
self._instance_setups = dict()
self._unbound_setups = dict()
self._run_setups = dict()
self._rules = []
@property
def class_setups(self) -> tp.Dict[int, "CAClassSetup"]:
"""Dict of registered `CAClassSetup` instances by their hash."""
return self._class_setups
@property
def instance_setups(self) -> tp.Dict[int, "CAInstanceSetup"]:
"""Dict of registered `CAInstanceSetup` instances by their hash."""
return self._instance_setups
@property
def unbound_setups(self) -> tp.Dict[int, "CAUnboundSetup"]:
"""Dict of registered `CAUnboundSetup` instances by their hash."""
return self._unbound_setups
@property
def run_setups(self) -> tp.Dict[int, "CARunSetup"]:
"""Dict of registered `CARunSetup` instances by their hash."""
return self._run_setups
@property
def setups(self) -> tp.Dict[int, "CABaseSetup"]:
"""Dict of registered `CABaseSetup` instances by their hash."""
return {**self.class_setups, **self.instance_setups, **self.unbound_setups, **self.run_setups}
@property
def rules(self) -> tp.List[CARule]:
"""List of registered `CARule` instances."""
return self._rules
def get_setup_by_hash(self, hash_: int) -> tp.Optional["CABaseSetup"]:
"""Get the setup by its hash."""
if hash_ in self.class_setups:
return self.class_setups[hash_]
if hash_ in self.instance_setups:
return self.instance_setups[hash_]
if hash_ in self.unbound_setups:
return self.unbound_setups[hash_]
if hash_ in self.run_setups:
return self.run_setups[hash_]
return None
def setup_registered(self, setup: "CABaseSetup") -> bool:
"""Return whether the setup is registered."""
return self.get_setup_by_hash(hash(setup)) is not None
def register_setup(self, setup: "CABaseSetup") -> None:
"""Register a new setup of type `CABaseSetup`."""
if isinstance(setup, CARunSetup):
setups = self.run_setups
elif isinstance(setup, CAUnboundSetup):
setups = self.unbound_setups
elif isinstance(setup, CAInstanceSetup):
setups = self.instance_setups
elif isinstance(setup, CAClassSetup):
setups = self.class_setups
else:
raise TypeError(str(type(setup)))
setups[hash(setup)] = setup
def deregister_setup(self, setup: "CABaseSetup") -> None:
"""Deregister a new setup of type `CABaseSetup`.
Removes the setup from its respective collection.
To also deregister its children, call the `CASetupDelegatorMixin.deregister` method."""
if isinstance(setup, CARunSetup):
setups = self.run_setups
elif isinstance(setup, CAUnboundSetup):
setups = self.unbound_setups
elif isinstance(setup, CAInstanceSetup):
setups = self.instance_setups
elif isinstance(setup, CAClassSetup):
setups = self.class_setups
else:
raise TypeError(str(type(setup)))
if hash(setup) in setups:
del setups[hash(setup)]
def register_rule(self, rule: CARule) -> None:
"""Register a new rule of type `CARule`."""
self.rules.append(rule)
def deregister_rule(self, rule: CARule) -> None:
"""Deregister a rule of type `CARule`."""
self.rules.remove(rule)
def get_run_setup(
self,
cacheable: cacheableT,
instance: tp.Optional[Cacheable] = None,
) -> tp.Optional["CARunSetup"]:
"""Get a setup of type `CARunSetup` with this cacheable and instance, or return None."""
run_setup = self.run_setups.get(CARunSetup.get_hash(cacheable, instance=instance), None)
if run_setup is not None and run_setup.instance_obj is _GARBAGE:
self.deregister_setup(run_setup)
return None
return run_setup
def get_unbound_setup(self, cacheable: cacheableT) -> tp.Optional["CAUnboundSetup"]:
"""Get a setup of type `CAUnboundSetup` with this cacheable or return None."""
return self.unbound_setups.get(CAUnboundSetup.get_hash(cacheable), None)
def get_instance_setup(self, instance: Cacheable) -> tp.Optional["CAInstanceSetup"]:
"""Get a setup of type `CAInstanceSetup` with this instance or return None."""
instance_setup = self.instance_setups.get(CAInstanceSetup.get_hash(instance), None)
if instance_setup is not None and instance_setup.instance_obj is _GARBAGE:
self.deregister_setup(instance_setup)
return None
return instance_setup
def get_class_setup(self, cls: tp.Type[Cacheable]) -> tp.Optional["CAClassSetup"]:
"""Get a setup of type `CAInstanceSetup` with this class or return None."""
return self.class_setups.get(CAClassSetup.get_hash(cls), None)
def match_setups(
self,
query_like: tp.MaybeIterable[tp.Any] = None,
collapse: bool = False,
kind: tp.Optional[tp.MaybeIterable[str]] = None,
exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = None,
exclude_children: bool = True,
filter_func: tp.Optional[tp.Callable] = None,
) -> tp.Set["CABaseSetup"]:
"""Match all setups registered in this registry against `query_like`.
`query_like` can be one or more query-like objects that will be parsed using `CAQuery.parse`.
Set `collapse` to True to remove child setups that belong to any matched parent setup.
`kind` can be one or multiple of the following:
* 'class' to only return class setups (instances of `CAClassSetup`)
* 'instance' to only return instance setups (instances of `CAInstanceSetup`)
* 'unbound' to only return unbound setups (instances of `CAUnboundSetup`)
* 'runnable' to only return runnable setups (instances of `CARunSetup`)
Set `exclude` to one or multiple setups to exclude. To not exclude their children,
set `exclude_children` to False.
!!! note
`exclude_children` is applied only when `collapse` is True.
`filter_func` can be used to filter out setups. For example, `lambda setup: setup.caching_enabled`
includes only those setups that have caching enabled. It must take a setup and return a boolean
of whether to include this setup in the final results."""
if not checks.is_iterable(query_like) or isinstance(query_like, (str, tuple)):
query_like = [query_like]
query_like = list(map(CAQuery.parse, query_like))
if kind is None:
kind = {"class", "instance", "unbound", "runnable"}
if exclude is None:
exclude = set()
if isinstance(exclude, CABaseSetup):
exclude = {exclude}
else:
exclude = set(exclude)
matches = set()
if not collapse:
if isinstance(kind, str):
if kind.lower() == "class":
setups = set(self.class_setups.values())
elif kind.lower() == "instance":
setups = set(self.instance_setups.values())
elif kind.lower() == "unbound":
setups = set(self.unbound_setups.values())
elif kind.lower() == "runnable":
setups = set(self.run_setups.values())
else:
raise ValueError(f"kind '{kind}' is not supported")
for setup in setups:
if setup not in exclude:
for q in query_like:
if q.matches_setup(setup):
if filter_func is None or filter_func(setup):
matches.add(setup)
break
elif checks.is_iterable(kind):
matches = set.union(
*[
self.match_setups(
query_like,
kind=k,
collapse=collapse,
exclude=exclude,
exclude_children=exclude_children,
filter_func=filter_func,
)
for k in kind
]
)
else:
raise TypeError(f"kind must be either a string or a sequence of strings, not {type(kind)}")
else:
if isinstance(kind, str):
kind = {kind}
else:
kind = set(kind)
collapse_setups = set()
if "class" in kind:
class_matches = set()
for class_setup in self.class_setups.values():
for q in query_like:
if q.matches_setup(class_setup):
if filter_func is None or filter_func(class_setup):
if class_setup not in exclude:
class_matches.add(class_setup)
if class_setup not in exclude or exclude_children:
collapse_setups |= class_setup.child_setups
break
for class_setup in class_matches:
if class_setup not in collapse_setups:
matches.add(class_setup)
if "instance" in kind:
for instance_setup in self.instance_setups.values():
if instance_setup not in collapse_setups:
for q in query_like:
if q.matches_setup(instance_setup):
if filter_func is None or filter_func(instance_setup):
if instance_setup not in exclude:
matches.add(instance_setup)
if instance_setup not in exclude or exclude_children:
collapse_setups |= instance_setup.child_setups
break
if "unbound" in kind:
for unbound_setup in self.unbound_setups.values():
if unbound_setup not in collapse_setups:
for q in query_like:
if q.matches_setup(unbound_setup):
if filter_func is None or filter_func(unbound_setup):
if unbound_setup not in exclude:
matches.add(unbound_setup)
if unbound_setup not in exclude or exclude_children:
collapse_setups |= unbound_setup.child_setups
break
if "runnable" in kind:
for run_setup in self.run_setups.values():
if run_setup not in collapse_setups:
for q in query_like:
if q.matches_setup(run_setup):
if filter_func is None or filter_func(run_setup):
if run_setup not in exclude:
matches.add(run_setup)
break
return matches
ca_reg = CacheableRegistry()
"""Default registry of type `CacheableRegistry`."""
class CAMetrics(Base):
"""Abstract class that exposes various metrics related to caching."""
@property
def hits(self) -> int:
"""Number of hits."""
raise NotImplementedError
@property
def misses(self) -> int:
"""Number of misses."""
raise NotImplementedError
@property
def total_size(self) -> int:
"""Total size of cached objects."""
raise NotImplementedError
@property
def total_elapsed(self) -> tp.Optional[timedelta]:
"""Total number of seconds elapsed during running the function."""
raise NotImplementedError
@property
def total_saved(self) -> tp.Optional[timedelta]:
"""Total number of seconds saved by using the cache."""
raise NotImplementedError
@property
def first_run_time(self) -> tp.Optional[datetime]:
"""Time of the first run."""
raise NotImplementedError
@property
def last_run_time(self) -> tp.Optional[datetime]:
"""Time of the last run."""
raise NotImplementedError
@property
def first_hit_time(self) -> tp.Optional[datetime]:
"""Time of the first hit."""
raise NotImplementedError
@property
def last_hit_time(self) -> tp.Optional[datetime]:
"""Time of the last hit."""
raise NotImplementedError
@property
def metrics(self) -> dict:
"""Dict with all metrics."""
return dict(
hits=self.hits,
misses=self.misses,
total_size=self.total_size,
total_elapsed=self.total_elapsed,
total_saved=self.total_saved,
first_run_time=self.first_run_time,
last_run_time=self.last_run_time,
first_hit_time=self.first_hit_time,
last_hit_time=self.last_hit_time,
)
@define
class CABaseSetup(CAMetrics, DefineMixin):
"""Base class that exposes properties and methods for cache management."""
registry: CacheableRegistry = define.field(default=ca_reg)
"""Registry of type `CacheableRegistry`."""
use_cache: tp.Optional[bool] = define.field(default=None)
"""Whether caching is enabled."""
whitelist: tp.Optional[bool] = define.field(default=None)
"""Whether to cache even if caching was disabled globally."""
active: bool = define.field(default=True)
"""Whether to register and/or return setup when requested."""
def __attrs_post_init__(self) -> None:
object.__setattr__(self, "_creation_time", datetime.now(timezone.utc))
object.__setattr__(self, "_use_cache_lut", None)
object.__setattr__(self, "_whitelist_lut", None)
@property
def query(self) -> CAQuery:
"""Query to match this setup."""
raise NotImplementedError
@property
def caching_enabled(self) -> tp.Optional[bool]:
"""Whether caching is enabled in this setup.
Caching is disabled when any of the following apply:
* `CARunSetup.use_cache` is False
* Caching is disabled globally and `CARunSetup.whitelist` is False
* Caching and whitelisting are disabled globally
Returns None if `CABaseSetup.use_cache` or `CABaseSetup.whitelist` is None."""
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
if self.use_cache is None:
return None
if self.use_cache:
if not caching_cfg["disable"]:
return True
if not caching_cfg["disable_whitelist"]:
if self.whitelist is None:
return None
if self.whitelist:
return True
return False
def register(self) -> None:
"""Register setup using `CacheableRegistry.register_setup`."""
self.registry.register_setup(self)
def deregister(self) -> None:
"""Register setup using `CacheableRegistry.deregister_setup`."""
self.registry.deregister_setup(self)
@property
def registered(self) -> bool:
"""Return whether setup is registered."""
return self.registry.setup_registered(self)
def enforce_rules(self) -> None:
"""Enforce registry rules."""
for rule in self.registry.rules:
rule.enforce(self)
def activate(self) -> None:
"""Activate."""
object.__setattr__(self, "active", True)
def deactivate(self) -> None:
"""Deactivate."""
object.__setattr__(self, "active", False)
def enable_whitelist(self) -> None:
"""Enable whitelisting."""
object.__setattr__(self, "whitelist", True)
object.__setattr__(self, "_whitelist_lut", datetime.now(timezone.utc))
def disable_whitelist(self) -> None:
"""Disable whitelisting."""
object.__setattr__(self, "whitelist", False)
object.__setattr__(self, "_whitelist_lut", datetime.now(timezone.utc))
def enable_caching(self, force: bool = False, silence_warnings: tp.Optional[bool] = None) -> None:
"""Enable caching.
Set `force` to True to whitelist this setup."""
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
if silence_warnings is None:
silence_warnings = caching_cfg["silence_warnings"]
object.__setattr__(self, "use_cache", True)
if force:
object.__setattr__(self, "whitelist", True)
else:
if caching_cfg["disable"] and not caching_cfg["disable_whitelist"] and not silence_warnings:
warn("This operation has no effect: caching is disabled globally and this setup is not whitelisted")
if caching_cfg["disable"] and caching_cfg["disable_whitelist"] and not silence_warnings:
warn("This operation has no effect: caching and whitelisting are disabled globally")
object.__setattr__(self, "_use_cache_lut", datetime.now(timezone.utc))
def disable_caching(self, clear_cache: bool = True) -> None:
"""Disable caching.
Set `clear_cache` to True to also clear the cache."""
object.__setattr__(self, "use_cache", False)
if clear_cache:
self.clear_cache()
object.__setattr__(self, "_use_cache_lut", datetime.now(timezone.utc))
@property
def creation_time(self) -> tp.datetime:
"""Time when this setup was created."""
return object.__getattribute__(self, "_creation_time")
@property
def use_cache_lut(self) -> tp.Optional[datetime]:
"""Last time `CABaseSetup.use_cache` was updated."""
return object.__getattribute__(self, "_use_cache_lut")
@property
def whitelist_lut(self) -> tp.Optional[datetime]:
"""Last time `CABaseSetup.whitelist` was updated."""
return object.__getattribute__(self, "_whitelist_lut")
@property
def last_update_time(self) -> tp.Optional[datetime]:
"""Last time any of `CABaseSetup.use_cache` and `CABaseSetup.whitelist` were updated."""
if self.use_cache_lut is None:
return self.whitelist_lut
elif self.whitelist_lut is None:
return self.use_cache_lut
elif self.use_cache_lut is None and self.whitelist_lut is None:
return None
return max(self.use_cache_lut, self.whitelist_lut)
def clear_cache(self) -> None:
"""Clear the cache."""
raise NotImplementedError
@property
def same_type_setups(self) -> ValuesView:
"""Setups of the same type."""
raise NotImplementedError
@property
def short_str(self) -> str:
"""Convert this setup into a short string."""
raise NotImplementedError
@property
def readable_name(self) -> str:
"""Get a readable name of the object the setup is bound to."""
raise NotImplementedError
@property
def position_among_similar(self) -> tp.Optional[int]:
"""Get position among all similar setups.
Ordered by creation time."""
i = 0
for setup in self.same_type_setups:
if self is setup:
return i
if setup.readable_name == self.readable_name:
i += 1
return None
@property
def readable_str(self) -> str:
"""Convert this setup into a readable string."""
return f"{self.readable_name}:{self.position_among_similar}"
def get_stats(self, readable: bool = True, short_str: bool = False) -> dict:
"""Get stats of the setup as a dict with metrics."""
if short_str:
string = self.short_str
else:
string = str(self)
total_size = self.total_size
total_elapsed = self.total_elapsed
total_saved = self.total_saved
first_run_time = self.first_run_time
last_run_time = self.last_run_time
first_hit_time = self.first_hit_time
last_hit_time = self.last_hit_time
creation_time = self.creation_time
last_update_time = self.last_update_time
if readable:
string = self.readable_str
total_size = humanize.naturalsize(total_size)
if total_elapsed is not None:
minimum_unit = "seconds" if total_elapsed.total_seconds() >= 1 else "milliseconds"
total_elapsed = humanize.precisedelta(total_elapsed, minimum_unit)
if total_saved is not None:
minimum_unit = "seconds" if total_saved.total_seconds() >= 1 else "milliseconds"
total_saved = humanize.precisedelta(total_saved, minimum_unit)
if first_run_time is not None:
first_run_time = humanize.naturaltime(
dt.to_naive_datetime(first_run_time),
when=dt.to_naive_datetime(datetime.now(timezone.utc)),
)
if last_run_time is not None:
last_run_time = humanize.naturaltime(
dt.to_naive_datetime(last_run_time),
when=dt.to_naive_datetime(datetime.now(timezone.utc)),
)
if first_hit_time is not None:
first_hit_time = humanize.naturaltime(
dt.to_naive_datetime(first_hit_time),
when=dt.to_naive_datetime(datetime.now(timezone.utc)),
)
if last_hit_time is not None:
last_hit_time = humanize.naturaltime(
dt.to_naive_datetime(last_hit_time),
when=dt.to_naive_datetime(datetime.now(timezone.utc)),
)
if creation_time is not None:
creation_time = humanize.naturaltime(
dt.to_naive_datetime(creation_time),
when=dt.to_naive_datetime(datetime.now(timezone.utc)),
)
if last_update_time is not None:
last_update_time = humanize.naturaltime(
dt.to_naive_datetime(last_update_time),
when=dt.to_naive_datetime(datetime.now(timezone.utc)),
)
return dict(
hash=hash(self),
string=string,
use_cache=self.use_cache,
whitelist=self.whitelist,
caching_enabled=self.caching_enabled,
hits=self.hits,
misses=self.misses,
total_size=total_size,
total_elapsed=total_elapsed,
total_saved=total_saved,
first_run_time=first_run_time,
last_run_time=last_run_time,
first_hit_time=first_hit_time,
last_hit_time=last_hit_time,
creation_time=creation_time,
last_update_time=last_update_time,
)
class CASetupDelegatorMixin(CAMetrics):
"""Mixin class that delegates cache management to child setups."""
@property
def child_setups(self) -> tp.Set[CABaseSetup]:
"""Child setups."""
raise NotImplementedError
def get_setup_hierarchy(self, readable: bool = True, short_str: bool = False) -> tp.List[dict]:
"""Get the setup hierarchy by recursively traversing the child setups."""
results = []
for setup in self.child_setups:
if readable:
setup_obj = setup.readable_str
elif short_str:
setup_obj = setup.short_str
else:
setup_obj = setup
if isinstance(setup, CASetupDelegatorMixin):
results.append(dict(parent=setup_obj, children=setup.get_setup_hierarchy(readable=readable)))
else:
results.append(setup_obj)
return results
def delegate(
self,
func: tp.Callable,
exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = None,
**kwargs,
) -> None:
"""Delegate a function to all child setups.
`func` must take the setup and return nothing. If the setup is an instance of
`CASetupDelegatorMixin`, it must additionally accept `exclude`."""
if exclude is None:
exclude = set()
if isinstance(exclude, CABaseSetup):
exclude = {exclude}
else:
exclude = set(exclude)
for setup in self.child_setups:
if setup not in exclude:
if isinstance(setup, CASetupDelegatorMixin):
func(setup, exclude=exclude, **kwargs)
else:
func(setup, **kwargs)
def deregister(self, **kwargs) -> None:
"""Calls `CABaseSetup.deregister` on each child setup."""
self.delegate(lambda setup, **_kwargs: setup.deregister(**_kwargs), **kwargs)
def enable_whitelist(self, **kwargs) -> None:
"""Calls `CABaseSetup.enable_whitelist` on each child setup."""
self.delegate(lambda setup, **_kwargs: setup.enable_whitelist(**_kwargs), **kwargs)
def disable_whitelist(self, **kwargs) -> None:
"""Calls `CABaseSetup.disable_whitelist` on each child setup."""
self.delegate(lambda setup, **_kwargs: setup.disable_whitelist(**_kwargs), **kwargs)
def enable_caching(self, **kwargs) -> None:
"""Calls `CABaseSetup.enable_caching` on each child setup."""
self.delegate(lambda setup, **_kwargs: setup.enable_caching(**_kwargs), **kwargs)
def disable_caching(self, **kwargs) -> None:
"""Calls `CABaseSetup.disable_caching` on each child setup."""
self.delegate(lambda setup, **_kwargs: setup.disable_caching(**_kwargs), **kwargs)
def clear_cache(self, **kwargs) -> None:
"""Calls `CABaseSetup.clear_cache` on each child setup."""
self.delegate(lambda setup, **_kwargs: setup.clear_cache(**_kwargs), **kwargs)
@property
def hits(self) -> int:
return sum([setup.hits for setup in self.child_setups])
@property
def misses(self) -> int:
return sum([setup.misses for setup in self.child_setups])
@property
def total_size(self) -> int:
return sum([setup.total_size for setup in self.child_setups])
@property
def total_elapsed(self) -> tp.Optional[timedelta]:
total_elapsed = None
for setup in self.child_setups:
elapsed = setup.total_elapsed
if elapsed is not None:
if total_elapsed is None:
total_elapsed = elapsed
else:
total_elapsed += elapsed
return total_elapsed
@property
def total_saved(self) -> tp.Optional[timedelta]:
total_saved = None
for setup in self.child_setups:
saved = setup.total_saved
if saved is not None:
if total_saved is None:
total_saved = saved
else:
total_saved += saved
return total_saved
@property
def first_run_time(self) -> tp.Optional[datetime]:
first_run_times = []
for setup in self.child_setups:
first_run_time = setup.first_run_time
if first_run_time is not None:
first_run_times.append(first_run_time)
if len(first_run_times) == 0:
return None
return list(sorted(first_run_times))[0]
@property
def last_run_time(self) -> tp.Optional[datetime]:
last_run_times = []
for setup in self.child_setups:
last_run_time = setup.last_run_time
if last_run_time is not None:
last_run_times.append(last_run_time)
if len(last_run_times) == 0:
return None
return list(sorted(last_run_times))[-1]
@property
def first_hit_time(self) -> tp.Optional[datetime]:
first_hit_times = []
for setup in self.child_setups:
first_hit_time = setup.first_hit_time
if first_hit_time is not None:
first_hit_times.append(first_hit_time)
if len(first_hit_times) == 0:
return None
return list(sorted(first_hit_times))[0]
@property
def last_hit_time(self) -> tp.Optional[datetime]:
last_hit_times = []
for setup in self.child_setups:
last_hit_time = setup.last_hit_time
if last_hit_time is not None:
last_hit_times.append(last_hit_time)
if len(last_hit_times) == 0:
return None
return list(sorted(last_hit_times))[-1]
def get_stats(
self,
readable: bool = True,
short_str: bool = False,
index_by_hash: bool = False,
filter_func: tp.Optional[tp.Callable] = None,
include: tp.Optional[tp.MaybeSequence[str]] = None,
exclude: tp.Optional[tp.MaybeSequence[str]] = None,
) -> tp.Optional[tp.Frame]:
"""Get a DataFrame out of stats dicts of child setups."""
if len(self.child_setups) == 0:
return None
df = pd.DataFrame(
[
setup.get_stats(readable=readable, short_str=short_str)
for setup in self.child_setups
if filter_func is None or filter_func(setup)
]
)
if index_by_hash:
df.set_index("hash", inplace=True)
df.index.name = "hash"
else:
df.set_index("string", inplace=True)
df.index.name = "object"
if include is not None:
if isinstance(include, str):
include = [include]
columns = include
else:
columns = df.columns
if exclude is not None:
if isinstance(exclude, str):
exclude = [exclude]
columns = [c for c in columns if c not in exclude]
if len(columns) == 0:
return None
return df[columns].sort_index()
class CABaseDelegatorSetup(CABaseSetup, CASetupDelegatorMixin):
"""Base class acting as a stateful setup that delegates cache management to child setups.
First delegates the work and only then changes its own state."""
@property
def child_setups(self) -> tp.Set[CABaseSetup]:
"""Get child setups that match `CABaseDelegatorSetup.query`."""
return self.registry.match_setups(self.query, collapse=True)
def deregister(self, **kwargs) -> None:
CASetupDelegatorMixin.deregister(self, **kwargs)
CABaseSetup.deregister(self)
def enable_whitelist(self, **kwargs) -> None:
CASetupDelegatorMixin.enable_whitelist(self, **kwargs)
CABaseSetup.enable_whitelist(self)
def disable_whitelist(self, **kwargs) -> None:
CASetupDelegatorMixin.disable_whitelist(self, **kwargs)
CABaseSetup.disable_whitelist(self)
def enable_caching(self, force: bool = False, silence_warnings: tp.Optional[bool] = None, **kwargs) -> None:
CASetupDelegatorMixin.enable_caching(self, force=force, silence_warnings=silence_warnings, **kwargs)
CABaseSetup.enable_caching(self, force=force, silence_warnings=silence_warnings)
def disable_caching(self, clear_cache: bool = True, **kwargs) -> None:
CASetupDelegatorMixin.disable_caching(self, clear_cache=clear_cache, **kwargs)
CABaseSetup.disable_caching(self, clear_cache=False)
def clear_cache(self, **kwargs) -> None:
CASetupDelegatorMixin.clear_cache(self, **kwargs)
def _assert_value_not_none(instance: object, attribute: attr.Attribute, value: tp.Any) -> None:
"""Assert that value is not None."""
if value is None:
raise ValueError("Please provide {}".format(attribute.name))
CAClassSetupT = tp.TypeVar("CAClassSetupT", bound="CAClassSetup")
@define
class CAClassSetup(CABaseDelegatorSetup, DefineMixin):
"""Class that represents a setup of a cacheable class.
The provided class must subclass `vectorbtpro.utils.caching.Cacheable`.
Delegates cache management to its child subclass setups of type `CAClassSetup` and
child instance setups of type `CAInstanceSetup`.
If `use_cash` or `whitelist` are None, inherits a non-empty value from its superclass setups
using the method resolution order (MRO).
!!! note
Unbound setups are not children of class setups. See notes on `CAUnboundSetup`."""
cls: tp.Type[Cacheable] = define.field(default=None, validator=_assert_value_not_none)
"""Cacheable class."""
@staticmethod
def get_hash(cls: tp.Type[Cacheable]) -> int:
return hash((cls,))
@staticmethod
def get_cacheable_superclasses(cls: tp.Type[Cacheable]) -> tp.List[tp.Type[Cacheable]]:
"""Get an ordered list of the cacheable superclasses of a class."""
superclasses = []
for super_cls in inspect.getmro(cls):
if issubclass(super_cls, Cacheable):
if super_cls is not cls:
superclasses.append(super_cls)
return superclasses
@staticmethod
def get_superclass_setups(registry: CacheableRegistry, cls: tp.Type[Cacheable]) -> tp.List["CAClassSetup"]:
"""Setups of type `CAClassSetup` of each in `CAClassSetup.get_cacheable_superclasses`."""
setups = []
for super_cls in CAClassSetup.get_cacheable_superclasses(cls):
if registry.get_class_setup(super_cls) is not None:
setups.append(super_cls.get_ca_setup())
return setups
@staticmethod
def get_cacheable_subclasses(cls: tp.Type[Cacheable]) -> tp.List[tp.Type[Cacheable]]:
"""Get an ordered list of the cacheable subclasses of a class."""
subclasses = []
for sub_cls in cls.__subclasses__():
if issubclass(sub_cls, Cacheable):
if sub_cls is not cls:
subclasses.append(sub_cls)
subclasses.extend(CAClassSetup.get_cacheable_subclasses(sub_cls))
return subclasses
@staticmethod
def get_subclass_setups(registry: CacheableRegistry, cls: tp.Type[Cacheable]) -> tp.List["CAClassSetup"]:
"""Setups of type `CAClassSetup` of each in `CAClassSetup.get_cacheable_subclasses`."""
setups = []
for super_cls in CAClassSetup.get_cacheable_subclasses(cls):
if registry.get_class_setup(super_cls) is not None:
setups.append(super_cls.get_ca_setup())
return setups
@staticmethod
def get_unbound_cacheables(cls: tp.Type[Cacheable]) -> tp.Set[cacheableT]:
"""Get a set of the unbound cacheables of a class."""
members = inspect.getmembers(cls, is_bindable_cacheable)
return {attr for attr_name, attr in members}
@staticmethod
def get_unbound_setups(registry: CacheableRegistry, cls: tp.Type[Cacheable]) -> tp.Set["CAUnboundSetup"]:
"""Setups of type `CAUnboundSetup` of each in `CAClassSetup.get_unbound_cacheables`."""
setups = set()
for cacheable in CAClassSetup.get_unbound_cacheables(cls):
if registry.get_unbound_setup(cacheable) is not None:
setups.add(cacheable.get_ca_setup())
return setups
@classmethod
def get(
cls: tp.Type[CAClassSetupT],
cls_: tp.Type[Cacheable],
registry: CacheableRegistry = ca_reg,
**kwargs,
) -> tp.Optional[CAClassSetupT]:
"""Get setup from `CacheableRegistry` or register a new one.
`**kwargs` are passed to `CAClassSetup.__init__`."""
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
if caching_cfg["disable_machinery"]:
return None
setup = registry.get_class_setup(cls_)
if setup is not None:
if not setup.active:
return None
return setup
instance = cls(cls=cls_, registry=registry, **kwargs)
instance.enforce_rules()
if instance.active:
instance.register()
return instance
def __attrs_post_init__(self) -> None:
CABaseSetup.__attrs_post_init__(self)
checks.assert_subclass_of(self.cls, Cacheable)
use_cache = self.use_cache
whitelist = self.whitelist
if use_cache is None or whitelist is None:
superclass_setups = self.superclass_setups[::-1]
for setup in superclass_setups:
if use_cache is None:
if setup.use_cache is not None:
object.__setattr__(self, "use_cache", setup.use_cache)
if whitelist is None:
if setup.whitelist is not None:
object.__setattr__(self, "whitelist", setup.whitelist)
@property
def query(self) -> CAQuery:
return CAQuery(base_cls=self.cls)
@property
def superclass_setups(self) -> tp.List["CAClassSetup"]:
"""See `CAClassSetup.get_superclass_setups`."""
return self.get_superclass_setups(self.registry, self.cls)
@property
def subclass_setups(self) -> tp.List["CAClassSetup"]:
"""See `CAClassSetup.get_subclass_setups`."""
return self.get_subclass_setups(self.registry, self.cls)
@property
def unbound_setups(self) -> tp.Set["CAUnboundSetup"]:
"""See `CAClassSetup.get_unbound_setups`."""
return self.get_unbound_setups(self.registry, self.cls)
@property
def instance_setups(self) -> tp.Set["CAInstanceSetup"]:
"""Setups of type `CAInstanceSetup` of instances of the class."""
matches = set()
for instance_setup in self.registry.instance_setups.values():
if instance_setup.class_setup is self:
matches.add(instance_setup)
return matches
@property
def any_use_cache_lut(self) -> tp.Optional[datetime]:
"""Last time `CABaseSetup.use_cache` was updated in this class or any of its superclasses."""
max_use_cache_lut = self.use_cache_lut
for setup in self.superclass_setups:
if setup.use_cache_lut is not None:
if max_use_cache_lut is None or setup.use_cache_lut > max_use_cache_lut:
max_use_cache_lut = setup.use_cache_lut
return max_use_cache_lut
@property
def any_whitelist_lut(self) -> tp.Optional[datetime]:
"""Last time `CABaseSetup.whitelist` was updated in this class or any of its superclasses."""
max_whitelist_lut = self.whitelist_lut
for setup in self.superclass_setups:
if setup.whitelist_lut is not None:
if max_whitelist_lut is None or setup.whitelist_lut > max_whitelist_lut:
max_whitelist_lut = setup.whitelist_lut
return max_whitelist_lut
@property
def child_setups(self) -> tp.Set[tp.Union["CAClassSetup", "CAInstanceSetup"]]:
return set(self.subclass_setups) | self.instance_setups
@property
def same_type_setups(self) -> ValuesView:
return self.registry.class_setups.values()
@property
def short_str(self) -> str:
return f""
@property
def readable_name(self) -> str:
return self.cls.__name__
@property
def hash_key(self) -> tuple:
return (self.cls,)
CAInstanceSetupT = tp.TypeVar("CAInstanceSetupT", bound="CAInstanceSetup")
@define
class CAInstanceSetup(CABaseDelegatorSetup, DefineMixin):
"""Class that represents a setup of an instance that has cacheables bound to it.
The provided instance must be of `vectorbtpro.utils.caching.Cacheable`.
Delegates cache management to its child setups of type `CARunSetup`.
If `use_cash` or `whitelist` are None, inherits a non-empty value from its parent class setup."""
instance: tp.Union[Cacheable, ReferenceType] = define.field(default=None, validator=_assert_value_not_none)
"""Cacheable instance."""
@staticmethod
def get_hash(instance: Cacheable) -> int:
return hash((get_obj_id(instance),))
@classmethod
def get(
cls: tp.Type[CAInstanceSetupT],
instance: Cacheable,
registry: CacheableRegistry = ca_reg,
**kwargs,
) -> tp.Optional[CAInstanceSetupT]:
"""Get setup from `CacheableRegistry` or register a new one.
`**kwargs` are passed to `CAInstanceSetup.__init__`."""
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
if caching_cfg["disable_machinery"]:
return None
setup = registry.get_instance_setup(instance)
if setup is not None:
if not setup.active:
return None
return setup
instance = cls(instance=instance, registry=registry, **kwargs)
instance.enforce_rules()
if instance.active:
instance.register()
return instance
def __attrs_post_init__(self) -> None:
CABaseSetup.__attrs_post_init__(self)
if not isinstance(self.instance, ReferenceType):
checks.assert_instance_of(self.instance, Cacheable)
instance_ref = ref(self.instance, lambda ref: self.registry.deregister_setup(self))
object.__setattr__(self, "instance", instance_ref)
if self.use_cache is None or self.whitelist is None:
class_setup = self.class_setup
if self.use_cache is None:
if class_setup.use_cache is not None:
object.__setattr__(self, "use_cache", class_setup.use_cache)
if self.whitelist is None:
if class_setup.whitelist is not None:
object.__setattr__(self, "whitelist", class_setup.whitelist)
@property
def query(self) -> CAQuery:
return CAQuery(instance=self.instance_obj)
@property
def instance_obj(self) -> tp.Union[Cacheable, object]:
"""Instance object."""
if self.instance() is None:
return _GARBAGE
return self.instance()
@property
def contains_garbage(self) -> bool:
"""Whether instance was destroyed."""
return self.instance_obj is _GARBAGE
@property
def class_setup(self) -> tp.Optional[CAClassSetup]:
"""Setup of type `CAClassSetup` of the cacheable class of the instance."""
if self.contains_garbage:
return None
return CAClassSetup.get(type(self.instance_obj), self.registry)
@property
def unbound_setups(self) -> tp.Set["CAUnboundSetup"]:
"""Setups of type `CAUnboundSetup` of unbound cacheables declared in the class of the instance."""
if self.contains_garbage:
return set()
return self.class_setup.unbound_setups
@property
def run_setups(self) -> tp.Set["CARunSetup"]:
"""Setups of type `CARunSetup` of cacheables bound to the instance."""
if self.contains_garbage:
return set()
matches = set()
for run_setup in self.registry.run_setups.values():
if run_setup.instance_setup is self:
matches.add(run_setup)
return matches
@property
def child_setups(self) -> tp.Set["CARunSetup"]:
return self.run_setups
@property
def same_type_setups(self) -> ValuesView:
return self.registry.instance_setups.values()
@property
def short_str(self) -> str:
if self.contains_garbage:
return ""
return f""
@property
def readable_name(self) -> str:
if self.contains_garbage:
return "_GARBAGE"
return type(self.instance_obj).__name__.lower()
@property
def hash_key(self) -> tuple:
return (get_obj_id(self.instance_obj),)
CAUnboundSetupT = tp.TypeVar("CAUnboundSetupT", bound="CAUnboundSetup")
@define
class CAUnboundSetup(CABaseDelegatorSetup, DefineMixin):
"""Class that represents a setup of an unbound cacheable property or method.
An unbound callable is a callable that was declared in a class but is not bound
to any instance (just yet).
!!! note
Unbound callables are just regular functions - they have no parent setups. Even though they
are formally declared in a class, there is no easy way to get a reference to the class
from the decorator itself. Thus, searching for child setups of a specific class won't return
unbound setups.
Delegates cache management to its child setups of type `CARunSetup`.
One unbound cacheable property or method can be bound to multiple instances, thus there is
one-to-many relationship between `CAUnboundSetup` and `CARunSetup` instances.
!!! hint
Use class attributes instead of instance attributes to access unbound callables."""
cacheable: cacheableT = define.field(default=None, validator=_assert_value_not_none)
"""Cacheable object."""
@staticmethod
def get_hash(cacheable: cacheableT) -> int:
return hash((cacheable,))
@classmethod
def get(
cls: tp.Type[CAUnboundSetupT],
cacheable: cacheableT,
registry: CacheableRegistry = ca_reg,
**kwargs,
) -> tp.Optional[CAUnboundSetupT]:
"""Get setup from `CacheableRegistry` or register a new one.
`**kwargs` are passed to `CAUnboundSetup.__init__`."""
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
if caching_cfg["disable_machinery"]:
return None
setup = registry.get_unbound_setup(cacheable)
if setup is not None:
if not setup.active:
return None
return setup
instance = cls(cacheable=cacheable, registry=registry, **kwargs)
instance.enforce_rules()
if instance.active:
instance.register()
return instance
def __attrs_post_init__(self) -> None:
CABaseSetup.__attrs_post_init__(self)
if not is_bindable_cacheable(self.cacheable):
raise TypeError("cacheable must be either cacheable_property or cacheable_method")
@property
def query(self) -> CAQuery:
return CAQuery(cacheable=self.cacheable)
@property
def run_setups(self) -> tp.Set["CARunSetup"]:
"""Setups of type `CARunSetup` of bound cacheables."""
matches = set()
for run_setup in self.registry.run_setups.values():
if run_setup.unbound_setup is self:
matches.add(run_setup)
return matches
@property
def child_setups(self) -> tp.Set["CARunSetup"]:
return self.run_setups
@property
def same_type_setups(self) -> ValuesView:
return self.registry.unbound_setups.values()
@property
def short_str(self) -> str:
if is_cacheable_property(self.cacheable):
return f""
return f""
@property
def readable_name(self) -> str:
if is_cacheable_property(self.cacheable):
return f"{self.cacheable.func.__name__}"
return f"{self.cacheable.func.__name__}()"
@property
def hash_key(self) -> tuple:
return (self.cacheable,)
CARunSetupT = tp.TypeVar("CARunSetupT", bound="CARunSetup")
@define
class CARunResult(DefineMixin):
"""Class that represents a cached result of a run.
!!! note
Hashed solely by the hash of the arguments `args_hash`."""
args_hash: int = define.field()
"""Hash of the arguments."""
result: tp.Any = define.field()
"""Result of the run."""
timer: Timer = define.field()
"""Timer used to measure the execution time."""
def __attrs_post_init__(self) -> None:
object.__setattr__(self, "_run_time", datetime.now(timezone.utc))
object.__setattr__(self, "_hits", 0)
object.__setattr__(self, "_first_hit_time", None)
object.__setattr__(self, "_last_hit_time", None)
@staticmethod
def get_hash(args_hash: int) -> int:
return hash((args_hash,))
@property
def result_size(self) -> int:
"""Get size of the result in memory."""
return sys.getsizeof(self.result)
@property
def run_time(self) -> datetime:
"""Time of the run."""
return object.__getattribute__(self, "_run_time")
@property
def hits(self) -> int:
"""Number of hits."""
return object.__getattribute__(self, "_hits")
@property
def first_hit_time(self) -> tp.Optional[datetime]:
"""Time of the first hit."""
return object.__getattribute__(self, "_first_hit_time")
@property
def last_hit_time(self) -> tp.Optional[datetime]:
"""Time of the last hit."""
return object.__getattribute__(self, "_last_hit_time")
def hit(self) -> tp.Any:
"""Hit the result."""
hit_time = datetime.now(timezone.utc)
if self.first_hit_time is None:
object.__setattr__(self, "_first_hit_time", hit_time)
object.__setattr__(self, "_last_hit_time", hit_time)
object.__setattr__(self, "_hits", self.hits + 1)
return self.result
@property
def hash_key(self) -> tuple:
return (self.args_hash,)
@define
class CARunSetup(CABaseSetup, DefineMixin):
"""Class that represents a runnable cacheable setup.
Takes care of running functions and caching the results using `CARunSetup.run`.
Accepts as `cacheable` either `vectorbtpro.utils.decorators.cacheable_property`,
`vectorbtpro.utils.decorators.cacheable_method`, or `vectorbtpro.utils.decorators.cacheable`.
Hashed by the callable and optionally the id of the instance its bound to.
This way, it can be uniquely identified among all setups.
!!! note
Cacheable properties and methods must provide an instance.
Only one instance per each unique combination of `cacheable` and `instance` can exist at a time.
If `use_cash` or `whitelist` are None, inherits a non-empty value either from its parent instance setup
or its parent unbound setup. If both setups have non-empty values, takes the one that has been
updated more recently.
!!! note
Use `CARunSetup.get` class method instead of `CARunSetup.__init__` to create a setup. The class method
first checks whether a setup with the same hash has already been registered, and if so, returns it.
Otherwise, creates and registers a new one. Using `CARunSetup.__init__` will throw an error if there
is a setup with the same hash."""
cacheable: cacheableT = define.field(default=None, validator=_assert_value_not_none)
"""Cacheable object."""
instance: tp.Union[Cacheable, ReferenceType] = define.field(default=None)
"""Cacheable instance."""
max_size: tp.Optional[int] = define.field(default=None)
"""Maximum number of entries in `CARunSetup.cache`."""
ignore_args: tp.Optional[tp.Iterable[tp.AnnArgQuery]] = define.field(default=None)
"""Arguments to ignore when hashing."""
cache: tp.Dict[int, CARunResult] = define.field(factory=dict)
"""Dict of cached `CARunResult` instances by their hash."""
@staticmethod
def get_hash(cacheable: cacheableT, instance: tp.Optional[Cacheable] = None) -> int:
return hash((cacheable, get_obj_id(instance) if instance is not None else None))
@classmethod
def get(
cls: tp.Type[CARunSetupT],
cacheable: cacheableT,
instance: tp.Optional[Cacheable] = None,
registry: CacheableRegistry = ca_reg,
**kwargs,
) -> tp.Optional[CARunSetupT]:
"""Get setup from `CacheableRegistry` or register a new one.
`**kwargs` are passed to `CARunSetup.__init__`."""
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
if caching_cfg["disable_machinery"]:
return None
setup = registry.get_run_setup(cacheable, instance=instance)
if setup is not None:
if not setup.active:
return None
return setup
instance = cls(cacheable=cacheable, instance=instance, registry=registry, **kwargs)
instance.enforce_rules()
if instance.active:
instance.register()
return instance
def __attrs_post_init__(self) -> None:
CABaseSetup.__attrs_post_init__(self)
if not is_cacheable(self.cacheable):
raise TypeError("cacheable must be either cacheable_property, cacheable_method, or cacheable")
if self.instance is None:
if is_cacheable_property(self.cacheable):
raise ValueError("CARunSetup requires an instance for cacheable_property")
elif is_cacheable_method(self.cacheable):
raise ValueError("CARunSetup requires an instance for cacheable_method")
else:
checks.assert_instance_of(self.instance, Cacheable)
if is_cacheable_function(self.cacheable):
raise ValueError("Cacheable functions can't have an instance")
if self.instance is not None and not isinstance(self.instance, ReferenceType):
checks.assert_instance_of(self.instance, Cacheable)
instance_ref = ref(self.instance, lambda ref: self.registry.deregister_setup(self))
object.__setattr__(self, "instance", instance_ref)
if self.use_cache is None or self.whitelist is None:
instance_setup = self.instance_setup
unbound_setup = self.unbound_setup
if self.use_cache is None:
if (
instance_setup is not None
and unbound_setup is not None
and instance_setup.use_cache is not None
and unbound_setup.use_cache is not None
):
if unbound_setup.use_cache_lut is not None and (
instance_setup.class_setup.any_use_cache_lut is None
or unbound_setup.use_cache_lut > instance_setup.class_setup.any_use_cache_lut
):
# Unbound setup was updated more recently than any superclass setup
object.__setattr__(self, "use_cache", unbound_setup.use_cache)
else:
object.__setattr__(self, "use_cache", instance_setup.use_cache)
elif instance_setup is not None and instance_setup.use_cache is not None:
object.__setattr__(self, "use_cache", instance_setup.use_cache)
elif unbound_setup is not None and unbound_setup.use_cache is not None:
object.__setattr__(self, "use_cache", unbound_setup.use_cache)
if self.whitelist is None:
if (
instance_setup is not None
and unbound_setup is not None
and instance_setup.whitelist is not None
and unbound_setup.whitelist is not None
):
if unbound_setup.whitelist_lut is not None and (
instance_setup.class_setup.any_whitelist_lut is None
or unbound_setup.whitelist_lut > instance_setup.class_setup.any_whitelist_lut
):
# Unbound setup was updated more recently than any superclass setup
object.__setattr__(self, "whitelist", unbound_setup.whitelist)
else:
object.__setattr__(self, "whitelist", instance_setup.whitelist)
elif instance_setup is not None and instance_setup.whitelist is not None:
object.__setattr__(self, "whitelist", instance_setup.whitelist)
elif unbound_setup is not None and unbound_setup.whitelist is not None:
object.__setattr__(self, "whitelist", unbound_setup.whitelist)
@property
def query(self) -> CAQuery:
return CAQuery(cacheable=self.cacheable, instance=self.instance_obj)
@property
def instance_obj(self) -> tp.Union[Cacheable, object]:
"""Instance object."""
if self.instance is not None and self.instance() is None:
return _GARBAGE
return self.instance() if self.instance is not None else None
@property
def contains_garbage(self) -> bool:
"""Whether instance was destroyed."""
return self.instance_obj is _GARBAGE
@property
def instance_setup(self) -> tp.Optional[CAInstanceSetup]:
"""Setup of type `CAInstanceSetup` of the instance this cacheable is bound to."""
if self.instance_obj is None or self.contains_garbage:
return None
return CAInstanceSetup.get(self.instance_obj, self.registry)
@property
def unbound_setup(self) -> tp.Optional[CAUnboundSetup]:
"""Setup of type `CAUnboundSetup` of the unbound cacheable."""
return self.registry.get_unbound_setup(self.cacheable)
@property
def hits(self) -> int:
return sum([run_result.hits for run_result in self.cache.values()])
@property
def misses(self) -> int:
return len(self.cache)
@property
def total_size(self) -> int:
return sum([run_result.result_size for run_result in self.cache.values()])
@property
def total_elapsed(self) -> tp.Optional[timedelta]:
total_elapsed = None
for run_result in self.cache.values():
elapsed = run_result.timer.elapsed(readable=False)
if total_elapsed is None:
total_elapsed = elapsed
else:
total_elapsed += elapsed
return total_elapsed
@property
def total_saved(self) -> tp.Optional[timedelta]:
total_saved = None
for run_result in self.cache.values():
saved = run_result.timer.elapsed(readable=False) * run_result.hits
if total_saved is None:
total_saved = saved
else:
total_saved += saved
return total_saved
@property
def first_run_time(self) -> tp.Optional[datetime]:
if len(self.cache) == 0:
return None
return list(self.cache.values())[0].run_time
@property
def last_run_time(self) -> tp.Optional[datetime]:
if len(self.cache) == 0:
return None
return list(self.cache.values())[-1].run_time
@property
def first_hit_time(self) -> tp.Optional[datetime]:
first_hit_times = []
for run_result in self.cache.values():
if run_result.first_hit_time is not None:
first_hit_times.append(run_result.first_hit_time)
if len(first_hit_times) == 0:
return None
return list(sorted(first_hit_times))[0]
@property
def last_hit_time(self) -> tp.Optional[datetime]:
last_hit_times = []
for run_result in self.cache.values():
if run_result.last_hit_time is not None:
last_hit_times.append(run_result.last_hit_time)
if len(last_hit_times) == 0:
return None
return list(sorted(last_hit_times))[-1]
def run_func(self, *args, **kwargs) -> tp.Any:
"""Run the setup's function without caching."""
if self.instance_obj is not None:
return self.cacheable.func(self.instance_obj, *args, **kwargs)
return self.cacheable.func(*args, **kwargs)
def get_args_hash(self, *args, **kwargs) -> tp.Optional[int]:
"""Get the hash of the passed arguments.
`CARunSetup.ignore_args` gets extended with `ignore_args` under `vectorbtpro._settings.caching`.
If no arguments were passed, hashes None."""
if len(args) == 0 and len(kwargs) == 0:
return hash(None)
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
ignore_args = list(caching_cfg["ignore_args"])
if self.ignore_args is not None:
ignore_args.extend(list(self.ignore_args))
return hash_args(
self.cacheable.func,
args if self.instance_obj is None else (get_obj_id(self.instance_obj), *args),
kwargs,
ignore_args=ignore_args,
)
def run_func_and_cache(self, *args, **kwargs) -> tp.Any:
"""Run the setup's function and cache the result.
Hashes the arguments using `CARunSetup.get_args_hash`, runs the function using
`CARunSetup.run_func`, wraps the result using `CARunResult`, and uses the hash
as a key to store the instance of `CARunResult` into `CARunSetup.cache` for later retrieval."""
args_hash = self.get_args_hash(*args, **kwargs)
run_result_hash = CARunResult.get_hash(args_hash)
if run_result_hash in self.cache:
return self.cache[run_result_hash].hit()
if self.max_size is not None and self.max_size <= len(self.cache):
del self.cache[list(self.cache.keys())[0]]
with Timer() as timer:
result = self.run_func(*args, **kwargs)
run_result = CARunResult(args_hash, result, timer=timer)
self.cache[run_result_hash] = run_result
return result
def run(self, *args, **kwargs) -> tp.Any:
"""Run the setup and cache it depending on a range of conditions.
Runs `CARunSetup.run_func` if caching is disabled or arguments are not hashable,
and `CARunSetup.run_func_and_cache` otherwise."""
if self.caching_enabled:
try:
return self.run_func_and_cache(*args, **kwargs)
except UnhashableArgsError:
pass
return self.run_func(*args, **kwargs)
def clear_cache(self) -> None:
"""Clear the cache."""
self.cache.clear()
@property
def same_type_setups(self) -> ValuesView:
return self.registry.run_setups.values()
@property
def short_str(self) -> str:
if self.contains_garbage:
return ""
if is_cacheable_property(self.cacheable):
return (
f""
)
if is_cacheable_method(self.cacheable):
return (
f""
)
return f""
@property
def readable_name(self) -> str:
if self.contains_garbage:
return "_GARBAGE"
if is_cacheable_property(self.cacheable):
return f"{type(self.instance_obj).__name__.lower()}.{self.cacheable.func.__name__}"
if is_cacheable_method(self.cacheable):
return f"{type(self.instance_obj).__name__.lower()}.{self.cacheable.func.__name__}()"
return f"{self.cacheable.__name__}()"
@property
def readable_str(self) -> str:
if self.contains_garbage:
return f"_GARBAGE:{self.position_among_similar}"
if is_cacheable_property(self.cacheable):
return (
f"{type(self.instance_obj).__name__.lower()}:"
f"{self.instance_setup.position_among_similar}."
f"{self.cacheable.func.__name__}"
)
if is_cacheable_method(self.cacheable):
return (
f"{type(self.instance_obj).__name__.lower()}:"
f"{self.instance_setup.position_among_similar}."
f"{self.cacheable.func.__name__}()"
)
return f"{self.cacheable.__name__}():{self.position_among_similar}"
@property
def hash_key(self) -> tuple:
return self.cacheable, get_obj_id(self.instance_obj) if self.instance_obj is not None else None
class CAQueryDelegator(CASetupDelegatorMixin):
"""Class that delegates any setups that match a query.
`*args`, `collapse`, and `**kwargs` are passed to `CacheableRegistry.match_setups`."""
def __init__(self, *args, registry: CacheableRegistry = ca_reg, collapse: bool = True, **kwargs) -> None:
self._args = args
kwargs["collapse"] = collapse
self._kwargs = kwargs
self._registry = registry
@property
def args(self) -> tp.Args:
"""Arguments."""
return self._args
@property
def kwargs(self) -> tp.Kwargs:
"""Keyword arguments."""
return self._kwargs
@property
def registry(self) -> CacheableRegistry:
"""Registry of type `CacheableRegistry`."""
return self._registry
@property
def child_setups(self) -> tp.Set[CABaseSetup]:
"""Get child setups by matching them using `CacheableRegistry.match_setups`."""
return self.registry.match_setups(*self.args, **self.kwargs)
def get_cache_stats(*args, **kwargs) -> tp.Optional[tp.Frame]:
"""Get cache stats globally or of an object."""
delegator_kwargs = {}
stats_kwargs = {}
if len(kwargs) > 0:
overview_arg_names = get_func_arg_names(CAQueryDelegator.get_stats)
for k in list(kwargs.keys()):
if k in overview_arg_names:
stats_kwargs[k] = kwargs.pop(k)
else:
delegator_kwargs[k] = kwargs.pop(k)
else:
delegator_kwargs = kwargs
return CAQueryDelegator(*args, **delegator_kwargs).get_stats(**stats_kwargs)
def print_cache_stats(*args, **kwargs) -> None:
"""Print cache stats globally or of an object."""
ptable(get_cache_stats(*args, **kwargs))
def clear_cache(*args, **kwargs) -> None:
"""Clear cache globally or of an object."""
return CAQueryDelegator(*args, **kwargs).clear_cache()
def collect_garbage() -> None:
"""Collect garbage."""
import gc
gc.collect()
def flush() -> None:
"""Clear cache and collect garbage."""
clear_cache()
collect_garbage()
def disable_caching(clear_cache: bool = True) -> None:
"""Disable caching globally."""
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
caching_cfg["disable"] = True
caching_cfg["disable_whitelist"] = True
caching_cfg["disable_machinery"] = True
if clear_cache:
CAQueryDelegator().clear_cache()
def enable_caching() -> None:
"""Enable caching globally."""
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
caching_cfg["disable"] = False
caching_cfg["disable_whitelist"] = False
caching_cfg["disable_machinery"] = False
class CachingDisabled(Base):
"""Context manager to disable caching."""
def __init__(
self,
query_like: tp.Optional[tp.Any] = None,
use_base_cls: bool = True,
kind: tp.Optional[tp.MaybeIterable[str]] = None,
exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = None,
filter_func: tp.Optional[tp.Callable] = None,
registry: CacheableRegistry = ca_reg,
disable_whitelist: bool = True,
disable_machinery: bool = True,
clear_cache: bool = True,
silence_warnings: bool = False,
) -> None:
self._query_like = query_like
self._use_base_cls = use_base_cls
self._kind = kind
self._exclude = exclude
self._filter_func = filter_func
self._registry = registry
self._disable_whitelist = disable_whitelist
self._disable_machinery = disable_machinery
self._clear_cache = clear_cache
self._silence_warnings = silence_warnings
self._rule = None
self._init_settings = None
self._init_setup_settings = None
@property
def query_like(self) -> tp.Optional[tp.Any]:
"""See `CAQuery.parse`."""
return self._query_like
@property
def use_base_cls(self) -> bool:
"""See `CAQuery.parse`."""
return self._use_base_cls
@property
def kind(self) -> tp.Optional[tp.MaybeIterable[str]]:
"""See `CARule.kind`."""
return self._kind
@property
def exclude(self) -> tp.Optional[tp.MaybeIterable["CABaseSetup"]]:
"""See `CARule.exclude`."""
return self._exclude
@property
def filter_func(self) -> tp.Optional[tp.Callable]:
"""See `CARule.filter_func`."""
return self._filter_func
@property
def registry(self) -> CacheableRegistry:
"""Registry of type `CacheableRegistry`."""
return self._registry
@property
def disable_whitelist(self) -> bool:
"""Whether to disable whitelist."""
return self._disable_whitelist
@property
def disable_machinery(self) -> bool:
"""Whether to disable machinery."""
return self._disable_machinery
@property
def clear_cache(self) -> bool:
"""Whether to clear global cache when entering or local cache when disabling caching."""
return self._clear_cache
@property
def silence_warnings(self) -> bool:
"""Whether to silence warnings."""
return self._silence_warnings
@property
def rule(self) -> tp.Optional[CARule]:
"""Rule."""
return self._rule
@property
def init_settings(self) -> tp.Kwargs:
"""Initial caching settings."""
return self._init_settings
@property
def init_setup_settings(self) -> tp.Dict[int, dict]:
"""Initial setup settings."""
return self._init_setup_settings
def __enter__(self) -> tp.Self:
if self.query_like is None:
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
self._init_settings = dict(
disable=caching_cfg["disable"],
disable_whitelist=caching_cfg["disable_whitelist"],
disable_machinery=caching_cfg["disable_machinery"],
)
caching_cfg["disable"] = True
caching_cfg["disable_whitelist"] = self.disable_whitelist
caching_cfg["disable_machinery"] = self.disable_machinery
if self.clear_cache:
clear_cache()
else:
def _enforce_func(setup):
if self.disable_machinery:
setup.deactivate()
if self.disable_whitelist:
setup.disable_whitelist()
setup.disable_caching(clear_cache=self.clear_cache)
query = CAQuery.parse(self.query_like, use_base_cls=self.use_base_cls)
rule = CARule(
query,
_enforce_func,
kind=self.kind,
exclude=self.exclude,
filter_func=self.filter_func,
)
self._rule = rule
self.registry.register_rule(rule)
init_setup_settings = dict()
for setup_hash, setup in self.registry.setups.items():
init_setup_settings[setup_hash] = dict(
active=setup.active,
whitelist=setup.whitelist,
use_cache=setup.use_cache,
)
rule.enforce(setup)
self._init_setup_settings = init_setup_settings
return self
def __exit__(self, *args) -> None:
if self.query_like is None:
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
caching_cfg["disable"] = self.init_settings["disable"]
caching_cfg["disable_whitelist"] = self.init_settings["disable_whitelist"]
caching_cfg["disable_machinery"] = self.init_settings["disable_machinery"]
else:
self.registry.deregister_rule(self.rule)
for setup_hash, setup_settings in self.init_setup_settings.items():
if setup_hash in self.registry.setups:
setup = self.registry.setups[setup_hash]
if self.disable_machinery and setup_settings["active"]:
setup.activate()
if self.disable_whitelist and setup_settings["whitelist"]:
setup.enable_whitelist()
if setup_settings["use_cache"]:
setup.enable_caching(silence_warnings=self.silence_warnings)
def with_caching_disabled(*args, **caching_disabled_kwargs) -> tp.Callable:
"""Decorator to run a function with `CachingDisabled`."""
def decorator(func: tp.Callable) -> tp.Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> tp.Any:
with CachingDisabled(**caching_disabled_kwargs):
return func(*args, **kwargs)
return wrapper
if len(args) == 0:
return decorator
elif len(args) == 1:
return decorator(args[0])
raise ValueError("Either function or keyword arguments must be passed")
class CachingEnabled(Base):
"""Context manager to enable caching."""
def __init__(
self,
query_like: tp.Optional[tp.Any] = None,
use_base_cls: bool = True,
kind: tp.Optional[tp.MaybeIterable[str]] = None,
exclude: tp.Optional[tp.MaybeIterable["CABaseSetup"]] = None,
filter_func: tp.Optional[tp.Callable] = None,
registry: CacheableRegistry = ca_reg,
enable_whitelist: bool = True,
enable_machinery: bool = True,
clear_cache: bool = True,
silence_warnings: bool = False,
) -> None:
self._query_like = query_like
self._use_base_cls = use_base_cls
self._kind = kind
self._exclude = exclude
self._filter_func = filter_func
self._registry = registry
self._enable_whitelist = enable_whitelist
self._enable_machinery = enable_machinery
self._clear_cache = clear_cache
self._silence_warnings = silence_warnings
self._rule = None
self._init_settings = None
self._init_setup_settings = None
@property
def query_like(self) -> tp.Optional[tp.Any]:
"""See `CAQuery.parse`."""
return self._query_like
@property
def use_base_cls(self) -> bool:
"""See `CAQuery.parse`."""
return self._use_base_cls
@property
def kind(self) -> tp.Optional[tp.MaybeIterable[str]]:
"""See `CARule.kind`."""
return self._kind
@property
def exclude(self) -> tp.Optional[tp.MaybeIterable["CABaseSetup"]]:
"""See `CARule.exclude`."""
return self._exclude
@property
def filter_func(self) -> tp.Optional[tp.Callable]:
"""See `CARule.filter_func`."""
return self._filter_func
@property
def registry(self) -> CacheableRegistry:
"""Registry of type `CacheableRegistry`."""
return self._registry
@property
def enable_whitelist(self) -> bool:
"""Whether to enable whitelist."""
return self._enable_whitelist
@property
def enable_machinery(self) -> bool:
"""Whether to enable machinery."""
return self._enable_machinery
@property
def clear_cache(self) -> bool:
"""Whether to clear global cache when exiting or local cache when disabling caching."""
return self._clear_cache
@property
def silence_warnings(self) -> bool:
"""Whether to silence warnings."""
return self._silence_warnings
@property
def rule(self) -> tp.Optional[CARule]:
"""Rule."""
return self._rule
@property
def init_settings(self) -> tp.Kwargs:
"""Initial caching settings."""
return self._init_settings
@property
def init_setup_settings(self) -> tp.Dict[int, dict]:
"""Initial setup settings."""
return self._init_setup_settings
def __enter__(self) -> tp.Self:
if self.query_like is None:
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
self._init_settings = dict(
disable=caching_cfg["disable"],
disable_whitelist=caching_cfg["disable_whitelist"],
disable_machinery=caching_cfg["disable_machinery"],
)
caching_cfg["disable"] = False
caching_cfg["disable_whitelist"] = not self.enable_whitelist
caching_cfg["disable_machinery"] = not self.enable_machinery
else:
def _enforce_func(setup):
if self.enable_machinery:
setup.activate()
if self.enable_whitelist:
setup.enable_whitelist()
setup.enable_caching(silence_warnings=self.silence_warnings)
query = CAQuery.parse(self.query_like, use_base_cls=self.use_base_cls)
rule = CARule(
query,
_enforce_func,
kind=self.kind,
exclude=self.exclude,
filter_func=self.filter_func,
)
self._rule = rule
self.registry.register_rule(rule)
init_setup_settings = dict()
for setup_hash, setup in self.registry.setups.items():
init_setup_settings[setup_hash] = dict(
active=setup.active,
whitelist=setup.whitelist,
use_cache=setup.use_cache,
)
rule.enforce(setup)
self._init_setup_settings = init_setup_settings
return self
def __exit__(self, *args) -> None:
if self.query_like is None:
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
caching_cfg["disable"] = self.init_settings["disable"]
caching_cfg["disable_whitelist"] = self.init_settings["disable_whitelist"]
caching_cfg["disable_machinery"] = self.init_settings["disable_machinery"]
if self.clear_cache:
clear_cache()
else:
self.registry.deregister_rule(self.rule)
for setup_hash, setup_settings in self.init_setup_settings.items():
if setup_hash in self.registry.setups:
setup = self.registry.setups[setup_hash]
if self.enable_machinery and not setup_settings["active"]:
setup.deactivate()
if self.enable_whitelist and not setup_settings["whitelist"]:
setup.disable_whitelist()
if not setup_settings["use_cache"]:
setup.disable_caching(clear_cache=self.clear_cache)
def with_caching_enabled(*args, **caching_enabled_kwargs) -> tp.Callable:
"""Decorator to run a function with `CachingEnabled`."""
def decorator(func: tp.Callable) -> tp.Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> tp.Any:
with CachingEnabled(**caching_enabled_kwargs):
return func(*args, **kwargs)
return wrapper
if len(args) == 0:
return decorator
elif len(args) == 1:
return decorator(args[0])
raise ValueError("Either function or keyword arguments must be passed")
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Global registry for chunkables."""
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.chunking import chunked, resolve_chunked, resolve_chunked_option
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.template import RepEval
__all__ = [
"ChunkableRegistry",
"ch_reg",
"register_chunkable",
]
@define
class ChunkedSetup(DefineMixin):
"""Class that represents a chunkable setup.
!!! note
Hashed solely by `setup_id`."""
setup_id: tp.Hashable = define.field()
"""Setup id."""
func: tp.Callable = define.field()
"""Chunkable function."""
options: tp.DictLike = define.field(default=None)
"""Options dictionary."""
tags: tp.SetLike = define.field(default=None)
"""Set of tags."""
@staticmethod
def get_hash(setup_id: tp.Hashable) -> int:
return hash((setup_id,))
@property
def hash_key(self) -> tuple:
return (self.setup_id,)
class ChunkableRegistry(Base):
"""Class for registering chunkable functions."""
def __init__(self) -> None:
self._setups = {}
@property
def setups(self) -> tp.Dict[tp.Hashable, ChunkedSetup]:
"""Dict of registered `ChunkedSetup` instances by `ChunkedSetup.setup_id`."""
return self._setups
def register(
self,
func: tp.Callable,
setup_id: tp.Optional[tp.Hashable] = None,
options: tp.DictLike = None,
tags: tp.SetLike = None,
) -> None:
"""Register a new setup."""
if setup_id is None:
setup_id = func.__module__ + "." + func.__name__
setup = ChunkedSetup(setup_id=setup_id, func=func, options=options, tags=tags)
self.setups[setup_id] = setup
def match_setups(self, expression: tp.Optional[str] = None, context: tp.KwargsLike = None) -> tp.Set[ChunkedSetup]:
"""Match setups against an expression with each setup being a context."""
matched_setups = set()
for setup in self.setups.values():
if expression is None:
result = True
else:
result = RepEval(expression).substitute(context=merge_dicts(setup.asdict(), context))
checks.assert_instance_of(result, bool)
if result:
matched_setups.add(setup)
return matched_setups
def get_setup(self, setup_id_or_func: tp.Union[tp.Hashable, tp.Callable]) -> tp.Optional[ChunkedSetup]:
"""Get setup by its id or function.
`setup_id_or_func` can be an identifier or a function.
If it's a function, will build the identifier using its module and name."""
if hasattr(setup_id_or_func, "py_func"):
nb_setup_id = setup_id_or_func.__module__ + "." + setup_id_or_func.__name__
if nb_setup_id in self.setups:
setup_id = nb_setup_id
else:
setup_id = setup_id_or_func.py_func.__module__ + "." + setup_id_or_func.py_func.__name__
elif callable(setup_id_or_func):
setup_id = setup_id_or_func.__module__ + "." + setup_id_or_func.__name__
else:
setup_id = setup_id_or_func
if setup_id not in self.setups:
return None
return self.setups[setup_id]
def decorate(
self,
setup_id_or_func: tp.Union[tp.Hashable, tp.Callable],
target_func: tp.Optional[tp.Callable] = None,
**kwargs,
) -> tp.Callable:
"""Decorate the setup's function using the `vectorbtpro.utils.chunking.chunked` decorator.
Finds setup using `ChunkableRegistry.get_setup`.
Merges setup's options with `options`.
Specify `target_func` to apply the found setup on another function."""
setup = self.get_setup(setup_id_or_func)
if setup is None:
raise KeyError(f"Setup for {setup_id_or_func} not registered")
if target_func is not None:
func = target_func
elif callable(setup_id_or_func):
func = setup_id_or_func
else:
func = setup.func
return chunked(func, **merge_dicts(setup.options, kwargs))
def resolve_option(
self,
setup_id_or_func: tp.Union[tp.Hashable, tp.Callable],
option: tp.ChunkedOption,
target_func: tp.Optional[tp.Callable] = None,
**kwargs,
) -> tp.Callable:
"""Same as `ChunkableRegistry.decorate` but using `vectorbtpro.utils.chunking.resolve_chunked`."""
setup = self.get_setup(setup_id_or_func)
if setup is None:
if callable(setup_id_or_func):
option = resolve_chunked_option(option=option)
if option is None:
return setup_id_or_func
raise KeyError(f"Setup for {setup_id_or_func} not registered")
if target_func is not None:
func = target_func
elif callable(setup_id_or_func):
func = setup_id_or_func
else:
func = setup.func
return resolve_chunked(func, option=option, **merge_dicts(setup.options, kwargs))
ch_reg = ChunkableRegistry()
"""Default registry of type `ChunkableRegistry`."""
def register_chunkable(
func: tp.Optional[tp.Callable] = None,
setup_id: tp.Optional[tp.Hashable] = None,
registry: ChunkableRegistry = ch_reg,
tags: tp.SetLike = None,
return_wrapped: bool = False,
**options,
) -> tp.Callable:
"""Register a new chunkable function.
If `return_wrapped` is True, wraps with the `vectorbtpro.utils.chunking.chunked` decorator.
Otherwise, leaves the function as-is (preferred).
Options are merged in the following order:
* `options` in `vectorbtpro._settings.chunking`
* `setup_options.{setup_id}` in `vectorbtpro._settings.chunking`
* `options`
* `override_options` in `vectorbtpro._settings.chunking`
* `override_setup_options.{setup_id}` in `vectorbtpro._settings.chunking`
!!! note
Calling the `register_chunkable` decorator before (or below) the `vectorbtpro.registries.jit_registry.register_jitted`
decorator with `return_wrapped` set to True won't work. Doing the same after (or above)
`vectorbtpro.registries.jit_registry.register_jitted` will work for calling the function from Python but not from Numba.
Generally, avoid wrapping right away and use `ChunkableRegistry.decorate` to perform decoration."""
def decorator(_func: tp.Callable) -> tp.Callable:
nonlocal setup_id, options
from vectorbtpro._settings import settings
chunking_cfg = settings["chunking"]
if setup_id is None:
setup_id = _func.__module__ + "." + _func.__name__
options = merge_dicts(
chunking_cfg.get("options", None),
chunking_cfg.get("setup_options", {}).get(setup_id, None),
options,
chunking_cfg.get("override_options", None),
chunking_cfg.get("override_setup_options", {}).get(setup_id, None),
)
registry.register(func=_func, setup_id=setup_id, options=options, tags=tags)
if return_wrapped:
return chunked(_func, **options)
return _func
if func is None:
return decorator
return decorator(func)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Global registry for jittables.
Jitting is a process of just-in-time compiling functions to make their execution faster.
A jitter is a decorator that wraps a regular Python function and returns the decorated function.
Depending upon a jitter, this decorated function has the same or at least a similar signature
to the function that has been decorated. Jitters take various jitter-specific options
to change the behavior of execution; that is, a single regular Python function can be
decorated by multiple jitter instances (for example, one jitter for decorating a function
with `numba.jit` and another jitter for doing the same with `parallel=True` flag).
In addition to jitters, vectorbt introduces the concept of tasks. One task can be
executed by multiple jitter types (such as NumPy, Numba, and JAX). For example, one
can create a task that converts price into returns and implements it using NumPy and Numba.
Those implementations are registered by `JITRegistry` as `JitableSetup` instances, are stored
in `JITRegistry.jitable_setups`, and can be uniquely identified by the task id and jitter type.
Note that `JitableSetup` instances contain only information on how to decorate a function.
The decorated function itself and the jitter that has been used are registered as a `JittedSetup`
instance and stored in `JITRegistry.jitted_setups`. It acts as a cache to quickly retrieve an
already decorated function and to avoid recompilation.
Let's implement a task that takes a sum over an array using both NumPy and Numba:
```pycon
>>> from vectorbtpro import *
>>> @vbt.register_jitted(task_id_or_func='sum')
... def sum_np(a):
... return a.sum()
>>> @vbt.register_jitted(task_id_or_func='sum')
... def sum_nb(a):
... out = 0.
... for i in range(a.shape[0]):
... out += a[i]
... return out
```
We can see that two new jitable setups were registered:
```pycon
>>> vbt.jit_reg.jitable_setups['sum']
{'np': JitableSetup(task_id='sum', jitter_id='np', py_func=, jitter_kwargs={}, tags=None),
'nb': JitableSetup(task_id='sum', jitter_id='nb', py_func=, jitter_kwargs={}, tags=None)}
```
Moreover, two jitted setups were registered for our decorated functions:
```pycon
>>> from vectorbtpro.registries.jit_registry import JitableSetup
>>> hash_np = JitableSetup.get_hash('sum', 'np')
>>> vbt.jit_reg.jitted_setups[hash_np]
{3527539: JittedSetup(jitter=, jitted_func=)}
>>> hash_nb = JitableSetup.get_hash('sum', 'nb')
>>> vbt.jit_reg.jitted_setups[hash_nb]
{6326224984503844995: JittedSetup(jitter=, jitted_func=CPUDispatcher())}
```
These setups contain decorated functions with the options passed during the registration.
When we call `JITRegistry.resolve` without any additional keyword arguments,
`JITRegistry` returns exactly these functions:
```pycon
>>> jitted_func = vbt.jit_reg.resolve('sum', jitter='nb')
>>> jitted_func
CPUDispatcher()
>>> jitted_func.targetoptions
{'nopython': True, 'nogil': True, 'parallel': False, 'boundscheck': False}
```
Once we pass any other option, the Python function will be redecorated, and another `JittedOption`
instance will be registered:
```pycon
>>> jitted_func = vbt.jit_reg.resolve('sum', jitter='nb', nopython=False)
>>> jitted_func
CPUDispatcher()
>>> jitted_func.targetoptions
{'nopython': False, 'nogil': True, 'parallel': False, 'boundscheck': False}
>>> vbt.jit_reg.jitted_setups[hash_nb]
{6326224984503844995: JittedSetup(jitter=, jitted_func=CPUDispatcher()),
-2979374923679407948: JittedSetup(jitter=, jitted_func=CPUDispatcher())}
```
## Templates
Templates can be used to, based on the current context, dynamically select the jitter or
keyword arguments for jitting. For example, let's pick the NumPy jitter over any other
jitter if there are more than two of them for a given task:
```pycon
>>> vbt.jit_reg.resolve('sum', jitter=vbt.RepEval("'nb' if 'nb' in task_setups else None"))
CPUDispatcher()
```
## Disabling
In the case we want to disable jitting, we can simply pass `disable=True` to `JITRegistry.resolve`:
```pycon
>>> py_func = vbt.jit_reg.resolve('sum', jitter='nb', disable=True)
>>> py_func
```
We can also disable jitting globally:
```pycon
>>> vbt.settings.jitting['disable'] = True
>>> vbt.jit_reg.resolve('sum', jitter='nb')
```
!!! hint
If we don't plan to use any additional options and we have only one jitter registered per task,
we can also disable resolution to increase performance.
!!! warning
Disabling jitting globally only applies to functions resolved using `JITRegistry.resolve`.
Any decorated function that is being called directly will be executed as usual.
## Jitted option
Since most functions that call other jitted functions in vectorbt have a `jitted` argument,
you can pass `jitted` as a dictionary with options, as a string denoting the jitter, or False
to disable jitting (see `vectorbtpro.utils.jitting.resolve_jitted_option`):
```pycon
>>> def sum_arr(arr, jitted=None):
... func = vbt.jit_reg.resolve_option('sum', jitted)
... return func(arr)
>>> arr = np.random.uniform(size=1000000)
>>> %timeit sum_arr(arr, jitted='np')
319 µs ± 3.35 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
>>> %timeit sum_arr(arr, jitted='nb')
1.09 ms ± 4.13 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
>>> %timeit sum_arr(arr, jitted=dict(jitter='nb', disable=True))
133 ms ± 2.32 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
```
!!! hint
A good rule of thumb is: whenever a caller function accepts a `jitted` argument,
the jitted functions it calls are most probably resolved using `JITRegistry.resolve_option`.
## Changing options upon registration
Options are usually specified upon registration using `register_jitted`:
```pycon
>>> from numba import prange
>>> @vbt.register_jitted(parallel=True, tags={'can_parallel'})
... def sum_parallel_nb(a):
... out = np.empty(a.shape[1])
... for col in prange(a.shape[1]):
... total = 0.
... for i in range(a.shape[0]):
... total += a[i, col]
... out[col] = total
... return out
>>> sum_parallel_nb.targetoptions
{'nopython': True, 'nogil': True, 'parallel': True, 'boundscheck': False}
```
But what if we wanted to change the registration options of vectorbt's own jitable functions,
such as `vectorbtpro.generic.nb.base.diff_nb`? For example, let's disable caching for all Numba functions.
```pycon
>>> vbt.settings.jitting.jitters['nb']['override_options'] = dict(cache=False)
```
Since all functions have already been registered, the above statement has no effect:
```pycon
>>> vbt.jit_reg.jitable_setups['vectorbtpro.generic.nb.base.diff_nb']['nb'].jitter_kwargs
{'cache': True}
```
In order for them to be applied, we need to save the settings to a file and
load them before all functions are imported:
```pycon
>>> vbt.settings.save('my_settings')
```
Let's restart the runtime and instruct vectorbt to load the file with settings before anything else:
```pycon
>>> import os
>>> os.environ['VBT_SETTINGS_PATH'] = "my_settings"
>>> from vectorbtpro import *
>>> vbt.jit_reg.jitable_setups['vectorbtpro.generic.nb.base.diff_nb']['nb'].jitter_kwargs
{'cache': False}
```
We can also change the registration options for some specific tasks, and even replace Python functions.
For example, we can change the implementation in the deepest places of the core.
Let's change the default `ddof` from 0 to 1 in `vectorbtpro.generic.nb.base.nanstd_1d_nb` and disable caching with Numba:
```pycon
>>> vbt.nb.nanstd_1d_nb(np.array([1, 2, 3]))
0.816496580927726
>>> def new_nanstd_1d_nb(arr, ddof=1):
... return np.sqrt(vbt.nb.nanvar_1d_nb(arr, ddof=ddof))
>>> vbt.settings.jitting.jitters['nb']['tasks']['vectorbtpro.generic.nb.base.nanstd_1d_nb'] = dict(
... replace_py_func=new_nanstd_1d_nb,
... override_options=dict(
... cache=False
... )
... )
>>> vbt.settings.save('my_settings')
```
After restarting the runtime:
```pycon
>>> import os
>>> os.environ['VBT_SETTINGS_PATH'] = "my_settings"
>>> vbt.nb.nanstd_1d_nb(np.array([1, 2, 3]))
1.0
```
!!! note
All of the above examples require saving the setting to a file, restarting the runtime,
setting the path to the file to an environment variable, and only then importing vectorbtpro.
## Changing options upon resolution
Another approach but without the need to restart the runtime is by changing the options
upon resolution using `JITRegistry.resolve_option`:
```pycon
>>> # On specific Numba function
>>> vbt.settings.jitting.jitters['nb']['tasks']['vectorbtpro.generic.nb.base.diff_nb'] = dict(
... resolve_kwargs=dict(
... nogil=False
... )
... )
>>> # disabled
>>> vbt.jit_reg.resolve('vectorbtpro.generic.nb.base.diff_nb', jitter='nb').targetoptions
{'nopython': True, 'nogil': False, 'parallel': False, 'boundscheck': False}
>>> # still enabled
>>> vbt.jit_reg.resolve('sum', jitter='nb').targetoptions
{'nopython': True, 'nogil': True, 'parallel': False, 'boundscheck': False}
>>> # On each Numba function
>>> vbt.settings.jitting.jitters['nb']['resolve_kwargs'] = dict(nogil=False)
>>> # disabled
>>> vbt.jit_reg.resolve('vectorbtpro.generic.nb.base.diff_nb', jitter='nb').targetoptions
{'nopython': True, 'nogil': False, 'parallel': False, 'boundscheck': False}
>>> # disabled
>>> vbt.jit_reg.resolve('sum', jitter='nb').targetoptions
{'nopython': True, 'nogil': False, 'parallel': False, 'boundscheck': False}
```
## Building custom jitters
Let's build a custom jitter on top of `vectorbtpro.utils.jitting.NumbaJitter` that converts
any argument that contains a Pandas object to a 2-dimensional NumPy array prior to decoration:
```pycon
>>> from functools import wraps
>>> from vectorbtpro.utils.jitting import NumbaJitter
>>> class SafeNumbaJitter(NumbaJitter):
... def decorate(self, py_func, tags=None):
... if self.wrapping_disabled:
... return py_func
...
... @wraps(py_func)
... def wrapper(*args, **kwargs):
... new_args = ()
... for arg in args:
... if isinstance(arg, pd.Series):
... arg = np.expand_dims(arg.values, 1)
... elif isinstance(arg, pd.DataFrame):
... arg = arg.values
... new_args += (arg,)
... new_kwargs = dict()
... for k, v in kwargs.items():
... if isinstance(v, pd.Series):
... v = np.expand_dims(v.values, 1)
... elif isinstance(v, pd.DataFrame):
... v = v.values
... new_kwargs[k] = v
... return NumbaJitter.decorate(self, py_func, tags=tags)(*new_args, **new_kwargs)
... return wrapper
```
After we have defined our jitter class, we need to register it globally:
```pycon
>>> vbt.settings.jitting.jitters['safe_nb'] = dict(cls=SafeNumbaJitter)
```
Finally, we can execute any Numba function by specifying our new jitter:
```pycon
>>> func = vbt.jit_reg.resolve(
... task_id_or_func=vbt.generic.nb.diff_nb,
... jitter='safe_nb',
... allow_new=True
... )
>>> func(pd.DataFrame([[1, 2], [3, 4]]))
array([[nan, nan],
[ 2., 2.]])
```
Whereas executing the same func using the vanilla Numba jitter causes an error:
```pycon
>>> func = vbt.jit_reg.resolve(task_id_or_func=vbt.generic.nb.diff_nb)
>>> func(pd.DataFrame([[1, 2], [3, 4]]))
Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
```
!!! note
Make sure to pass a function as `task_id_or_func` if the jitted function hasn't been registered yet.
This jitter cannot be used for decorating Numba functions that should be called
from other Numba functions since the convertion operation is done using Python.
"""
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import merge_dicts, atomic_dict
from vectorbtpro.utils.jitting import (
Jitter,
resolve_jitted_kwargs,
resolve_jitter_type,
resolve_jitter,
get_id_of_jitter_type,
get_func_suffix,
)
from vectorbtpro.utils.template import RepEval, substitute_templates, CustomTemplate
__all__ = [
"JITRegistry",
"jit_reg",
"register_jitted",
]
def get_func_full_name(func: tp.Callable) -> str:
"""Get full name of the func to be used as task id."""
return func.__module__ + "." + func.__name__
@define
class JitableSetup(DefineMixin):
"""Class that represents a jitable setup.
!!! note
Hashed solely by `task_id` and `jitter_id`."""
task_id: tp.Hashable = define.field()
"""Task id."""
jitter_id: tp.Hashable = define.field()
"""Jitter id."""
py_func: tp.Callable = define.field()
"""Python function to be jitted."""
jitter_kwargs: tp.KwargsLike = define.field(default=None)
"""Keyword arguments passed to `vectorbtpro.utils.jitting.resolve_jitter`."""
tags: tp.SetLike = define.field(default=None)
"""Set of tags."""
@staticmethod
def get_hash(task_id: tp.Hashable, jitter_id: tp.Hashable) -> int:
return hash((task_id, jitter_id))
@property
def hash_key(self) -> tuple:
return (self.task_id, self.jitter_id)
@define
class JittedSetup(DefineMixin):
"""Class that represents a jitted setup.
!!! note
Hashed solely by sorted config of `jitter`. That is, two jitters with the same config
will yield the same hash and the function won't be re-decorated."""
jitter: Jitter = define.field()
"""Jitter that decorated the function."""
jitted_func: tp.Callable = define.field()
"""Decorated function."""
@staticmethod
def get_hash(jitter: Jitter) -> int:
return hash(tuple(sorted(jitter.config.items())))
@property
def hash_key(self) -> tuple:
return tuple(sorted(self.jitter.config.items()))
class JITRegistry(Base):
"""Class for registering jitted functions."""
def __init__(self) -> None:
self._jitable_setups = {}
self._jitted_setups = {}
@property
def jitable_setups(self) -> tp.Dict[tp.Hashable, tp.Dict[tp.Hashable, JitableSetup]]:
"""Dict of registered `JitableSetup` instances by `task_id` and `jitter_id`."""
return self._jitable_setups
@property
def jitted_setups(self) -> tp.Dict[int, tp.Dict[int, JittedSetup]]:
"""Nested dict of registered `JittedSetup` instances by hash of their `JitableSetup` instance."""
return self._jitted_setups
def register_jitable_setup(
self,
task_id: tp.Hashable,
jitter_id: tp.Hashable,
py_func: tp.Callable,
jitter_kwargs: tp.KwargsLike = None,
tags: tp.Optional[set] = None,
) -> JitableSetup:
"""Register a jitable setup."""
jitable_setup = JitableSetup(
task_id=task_id,
jitter_id=jitter_id,
py_func=py_func,
jitter_kwargs=jitter_kwargs,
tags=tags,
)
if task_id not in self.jitable_setups:
self.jitable_setups[task_id] = dict()
if jitter_id not in self.jitable_setups[task_id]:
self.jitable_setups[task_id][jitter_id] = jitable_setup
return jitable_setup
def register_jitted_setup(
self,
jitable_setup: JitableSetup,
jitter: Jitter,
jitted_func: tp.Callable,
) -> JittedSetup:
"""Register a jitted setup."""
jitable_setup_hash = hash(jitable_setup)
jitted_setup = JittedSetup(jitter=jitter, jitted_func=jitted_func)
jitted_setup_hash = hash(jitted_setup)
if jitable_setup_hash not in self.jitted_setups:
self.jitted_setups[jitable_setup_hash] = dict()
if jitted_setup_hash not in self.jitted_setups[jitable_setup_hash]:
self.jitted_setups[jitable_setup_hash][jitted_setup_hash] = jitted_setup
return jitted_setup
def decorate_and_register(
self,
task_id: tp.Hashable,
py_func: tp.Callable,
jitter: tp.Optional[tp.JitterLike] = None,
jitter_kwargs: tp.KwargsLike = None,
tags: tp.Optional[set] = None,
):
"""Decorate a jitable function and register both jitable and jitted setups."""
if jitter_kwargs is None:
jitter_kwargs = {}
jitter = resolve_jitter(jitter=jitter, py_func=py_func, **jitter_kwargs)
jitter_id = get_id_of_jitter_type(type(jitter))
if jitter_id is None:
raise ValueError("Jitter id cannot be None: is jitter registered globally?")
jitable_setup = self.register_jitable_setup(task_id, jitter_id, py_func, jitter_kwargs=jitter_kwargs, tags=tags)
jitted_func = jitter.decorate(py_func, tags=tags)
self.register_jitted_setup(jitable_setup, jitter, jitted_func)
return jitted_func
def match_jitable_setups(
self,
expression: tp.Optional[str] = None,
context: tp.KwargsLike = None,
) -> tp.Set[JitableSetup]:
"""Match jitable setups against an expression with each setup being a context."""
matched_setups = set()
for setups_by_jitter_id in self.jitable_setups.values():
for setup in setups_by_jitter_id.values():
if expression is None:
result = True
else:
result = RepEval(expression).substitute(context=merge_dicts(setup.asdict(), context))
checks.assert_instance_of(result, bool)
if result:
matched_setups.add(setup)
return matched_setups
def match_jitted_setups(
self,
jitable_setup: JitableSetup,
expression: tp.Optional[str] = None,
context: tp.KwargsLike = None,
) -> tp.Set[JittedSetup]:
"""Match jitted setups of a jitable setup against an expression with each setup a context."""
matched_setups = set()
for setup in self.jitted_setups[hash(jitable_setup)].values():
if expression is None:
result = True
else:
result = RepEval(expression).substitute(context=merge_dicts(setup.asdict(), context))
checks.assert_instance_of(result, bool)
if result:
matched_setups.add(setup)
return matched_setups
def resolve(
self,
task_id_or_func: tp.Union[tp.Hashable, tp.Callable],
jitter: tp.Optional[tp.Union[tp.JitterLike, CustomTemplate]] = None,
disable: tp.Optional[tp.Union[bool, CustomTemplate]] = None,
disable_resolution: tp.Optional[bool] = None,
allow_new: tp.Optional[bool] = None,
register_new: tp.Optional[bool] = None,
return_missing_task: bool = False,
template_context: tp.KwargsLike = None,
tags: tp.Optional[set] = None,
**jitter_kwargs,
) -> tp.Union[tp.Hashable, tp.Callable]:
"""Resolve jitted function for the given task id.
For details on the format of `task_id_or_func`, see `register_jitted`.
Jitter keyword arguments are merged in the following order:
* `jitable_setup.jitter_kwargs`
* `jitter.your_jitter.resolve_kwargs` in `vectorbtpro._settings.jitting`
* `jitter.your_jitter.tasks.your_task.resolve_kwargs` in `vectorbtpro._settings.jitting`
* `jitter_kwargs`
Templates are substituted in `jitter`, `disable`, and `jitter_kwargs`.
Set `disable` to True to return the Python function without decoration.
If `disable_resolution` is enabled globally, `task_id_or_func` is returned unchanged.
!!! note
`disable` is only being used by `JITRegistry`, not `vectorbtpro.utils.jitting`.
!!! note
If there are more than one jitted setups registered for a single task id,
make sure to provide a jitter.
If no jitted setup of type `JittedSetup` was found and `allow_new` is True,
decorates and returns the function supplied as `task_id_or_func` (otherwise throws an error).
Set `return_missing_task` to True to return `task_id_or_func` if it cannot be found
in `JITRegistry.jitable_setups`.
"""
from vectorbtpro._settings import settings
jitting_cfg = settings["jitting"]
if disable_resolution is None:
disable_resolution = jitting_cfg["disable_resolution"]
if disable_resolution:
return task_id_or_func
if allow_new is None:
allow_new = jitting_cfg["allow_new"]
if register_new is None:
register_new = jitting_cfg["register_new"]
if hasattr(task_id_or_func, "py_func"):
py_func = task_id_or_func.py_func
task_id = get_func_full_name(py_func)
elif callable(task_id_or_func):
py_func = task_id_or_func
task_id = get_func_full_name(py_func)
else:
py_func = None
task_id = task_id_or_func
if task_id not in self.jitable_setups:
if not allow_new:
if return_missing_task:
return task_id_or_func
raise KeyError(f"Task id '{task_id}' not registered")
task_setups = self.jitable_setups.get(task_id, dict())
template_context = merge_dicts(
jitting_cfg["template_context"],
template_context,
dict(task_id=task_id, py_func=py_func, task_setups=atomic_dict(task_setups)),
)
jitter = substitute_templates(jitter, template_context, eval_id="jitter")
if jitter is None and py_func is not None:
jitter = get_func_suffix(py_func)
if jitter is None:
if len(task_setups) > 1:
raise ValueError(
f"There are multiple registered setups for task id '{task_id}'. Please specify the jitter."
)
elif len(task_setups) == 0:
raise ValueError(f"There are no registered setups for task id '{task_id}'")
jitable_setup = list(task_setups.values())[0]
jitter = jitable_setup.jitter_id
jitter_id = jitable_setup.jitter_id
else:
jitter_type = resolve_jitter_type(jitter=jitter)
jitter_id = get_id_of_jitter_type(jitter_type)
if jitter_id not in task_setups:
if not allow_new:
raise KeyError(f"Jitable setup with task id '{task_id}' and jitter id '{jitter_id}' not registered")
jitable_setup = None
else:
jitable_setup = task_setups[jitter_id]
if jitter_id is None:
raise ValueError("Jitter id cannot be None: is jitter registered globally?")
if jitable_setup is None and py_func is None:
raise ValueError(f"Unable to find Python function for task id '{task_id}' and jitter id '{jitter_id}'")
template_context = merge_dicts(
template_context,
dict(jitter_id=jitter_id, jitter=jitter, jitable_setup=jitable_setup),
)
disable = substitute_templates(disable, template_context, eval_id="disable")
if disable is None:
disable = jitting_cfg["disable"]
if disable:
if jitable_setup is None:
return py_func
return jitable_setup.py_func
if not isinstance(jitter, Jitter):
jitter_cfg = jitting_cfg["jitters"].get(jitter_id, {})
setup_cfg = jitter_cfg.get("tasks", {}).get(task_id, {})
jitter_kwargs = merge_dicts(
jitable_setup.jitter_kwargs if jitable_setup is not None else None,
jitter_cfg.get("resolve_kwargs", None),
setup_cfg.get("resolve_kwargs", None),
jitter_kwargs,
)
jitter_kwargs = substitute_templates(jitter_kwargs, template_context, eval_id="jitter_kwargs")
jitter = resolve_jitter(jitter=jitter, **jitter_kwargs)
if jitable_setup is not None:
jitable_hash = hash(jitable_setup)
jitted_hash = JittedSetup.get_hash(jitter)
if jitable_hash in self.jitted_setups and jitted_hash in self.jitted_setups[jitable_hash]:
return self.jitted_setups[jitable_hash][jitted_hash].jitted_func
else:
if register_new:
return self.decorate_and_register(
task_id=task_id,
py_func=py_func,
jitter=jitter,
jitter_kwargs=jitter_kwargs,
tags=tags,
)
return jitter.decorate(py_func, tags=tags)
jitted_func = jitter.decorate(jitable_setup.py_func, tags=jitable_setup.tags)
self.register_jitted_setup(jitable_setup, jitter, jitted_func)
return jitted_func
def resolve_option(
self,
task_id: tp.Union[tp.Hashable, tp.Callable],
option: tp.JittedOption,
**kwargs,
) -> tp.Union[tp.Hashable, tp.Callable]:
"""Resolve `option` using `vectorbtpro.utils.jitting.resolve_jitted_option` and call `JITRegistry.resolve`."""
kwargs = resolve_jitted_kwargs(option=option, **kwargs)
if kwargs is None:
kwargs = dict(disable=True)
return self.resolve(task_id, **kwargs)
jit_reg = JITRegistry()
"""Default registry of type `JITRegistry`."""
def register_jitted(
py_func: tp.Optional[tp.Callable] = None,
task_id_or_func: tp.Optional[tp.Union[tp.Hashable, tp.Callable]] = None,
registry: JITRegistry = jit_reg,
tags: tp.Optional[set] = None,
**options,
) -> tp.Callable:
"""Decorate and register a jitable function using `JITRegistry.decorate_and_register`.
If `task_id_or_func` is a callable, gets replaced by the callable's module name and function name.
Additionally, the function name may contain a suffix pointing at the jitter (such as `_nb`).
Options are merged in the following order:
* `jitters.{jitter_id}.options` in `vectorbtpro._settings.jitting`
* `jitters.{jitter_id}.tasks.{task_id}.options` in `vectorbtpro._settings.jitting`
* `options`
* `jitters.{jitter_id}.override_options` in `vectorbtpro._settings.jitting`
* `jitters.{jitter_id}.tasks.{task_id}.override_options` in `vectorbtpro._settings.jitting`
`py_func` can also be overridden using `jitters.your_jitter.tasks.your_task.replace_py_func`
in `vectorbtpro._settings.jitting`."""
def decorator(_py_func: tp.Callable) -> tp.Callable:
nonlocal options
from vectorbtpro._settings import settings
jitting_cfg = settings["jitting"]
if task_id_or_func is None:
task_id = get_func_full_name(_py_func)
elif hasattr(task_id_or_func, "py_func"):
task_id = get_func_full_name(task_id_or_func.py_func)
elif callable(task_id_or_func):
task_id = get_func_full_name(task_id_or_func)
else:
task_id = task_id_or_func
jitter = options.pop("jitter", None)
jitter_type = resolve_jitter_type(jitter=jitter, py_func=_py_func)
jitter_id = get_id_of_jitter_type(jitter_type)
jitter_cfg = jitting_cfg["jitters"].get(jitter_id, {})
setup_cfg = jitter_cfg.get("tasks", {}).get(task_id, {})
options = merge_dicts(
jitter_cfg.get("options", None),
setup_cfg.get("options", None),
options,
jitter_cfg.get("override_options", None),
setup_cfg.get("override_options", None),
)
if setup_cfg.get("replace_py_func", None) is not None:
_py_func = setup_cfg["replace_py_func"]
if task_id_or_func is None:
task_id = get_func_full_name(_py_func)
return registry.decorate_and_register(
task_id=task_id,
py_func=_py_func,
jitter=jitter,
jitter_kwargs=options,
tags=tags,
)
if py_func is None:
return decorator
return decorator(py_func)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Global registry for progress bars."""
import uuid
from vectorbtpro import _typing as tp
from vectorbtpro.utils.base import Base
if tp.TYPE_CHECKING:
from vectorbtpro.utils.pbar import ProgressBar as ProgressBarT
else:
ProgressBarT = "ProgressBar"
__all__ = [
"PBarRegistry",
"pbar_reg",
]
class PBarRegistry(Base):
"""Class for registering `vectorbtpro.utils.pbar.ProgressBar` instances."""
@classmethod
def generate_bar_id(cls) -> tp.Hashable:
"""Generate a unique bar id."""
return str(uuid.uuid4())
def __init__(self):
self._instances = {}
@property
def instances(self) -> tp.Dict[tp.Hashable, ProgressBarT]:
"""Dict of registered instances by their bar id."""
return self._instances
def register_instance(self, instance: ProgressBarT) -> None:
"""Register an instance."""
self.instances[instance.bar_id] = instance
def deregister_instance(self, instance: ProgressBarT) -> None:
"""Deregister an instance."""
if instance.bar_id in self.instances:
del self.instances[instance.bar_id]
def has_conflict(self, instance: ProgressBarT) -> bool:
"""Return whether there is an (active) instance with the same bar id."""
if instance.bar_id is None:
return False
for inst in self.instances.values():
if inst is not instance and inst.bar_id == instance.bar_id and inst.active:
return True
return False
def get_last_active_instance(self) -> tp.Optional[ProgressBarT]:
"""Get the last active instance."""
max_open_time = None
last_active_instance = None
for inst in self.instances.values():
if inst.active:
if max_open_time is None or inst.open_time > max_open_time:
max_open_time = inst.open_time
last_active_instance = inst
return last_active_instance
def get_first_pending_instance(self) -> tp.Optional[ProgressBarT]:
"""Get the first pending instance."""
last_active_instance = self.get_last_active_instance()
if last_active_instance is None:
return None
min_open_time = None
first_pending_instance = None
for inst in self.instances.values():
if inst.pending and inst.open_time > last_active_instance.open_time:
if min_open_time is None or inst.open_time < min_open_time:
min_open_time = inst.open_time
first_pending_instance = inst
return first_pending_instance
def get_pending_instance(self, instance: ProgressBarT) -> tp.Optional[ProgressBarT]:
"""Get the pending instance.
If the bar id is not None, searches for the same id in the dictionary."""
if instance.bar_id is not None:
for inst in self.instances.values():
if inst is not instance and inst.pending:
if inst.bar_id == instance.bar_id:
return inst
return None
last_active_instance = self.get_last_active_instance()
if last_active_instance is None:
return None
min_open_time = None
first_pending_instance = None
for inst in self.instances.values():
if inst.pending and inst.open_time > last_active_instance.open_time:
if min_open_time is None or inst.open_time < min_open_time:
min_open_time = inst.open_time
first_pending_instance = inst
return first_pending_instance
def get_parent_instances(self, instance: ProgressBarT) -> tp.List[ProgressBarT]:
"""Get the (active) parent instances of an instance."""
parent_instances = []
for inst in self.instances.values():
if inst is not instance and inst.active:
if inst.open_time < instance.open_time:
parent_instances.append(inst)
return parent_instances
def get_parent_instance(self, instance: ProgressBarT) -> tp.Optional[ProgressBarT]:
"""Get the (active) parent instance of an instance."""
max_open_time = None
parent_instance = None
for inst in self.get_parent_instances(instance):
if max_open_time is None or inst.open_time > max_open_time:
max_open_time = inst.open_time
parent_instance = inst
return parent_instance
def get_child_instances(self, instance: ProgressBarT) -> tp.List[ProgressBarT]:
"""Get child (active or pending) instances of an instance."""
child_instances = []
for inst in self.instances.values():
if inst is not instance and (inst.active or inst.pending):
if inst.open_time > instance.open_time:
child_instances.append(inst)
return child_instances
def clear_instances(self) -> None:
"""Clear instances."""
self.instances.clear()
pbar_reg = PBarRegistry()
"""Default registry of type `PBarRegistry`."""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for working with returns.
Offers common financial risk and performance metrics as found in [empyrical](https://github.com/quantopian/empyrical),
an adapter for quantstats, and other features based on returns."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.returns.accessors import *
from vectorbtpro.returns.nb import *
from vectorbtpro.returns.qs_adapter import *
__exclude_from__all__ = [
"enums",
]
__import_if_installed__ = dict()
__import_if_installed__["qs_adapter"] = "quantstats"
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom Pandas accessors for returns.
Methods can be accessed as follows:
* `ReturnsSRAccessor` -> `pd.Series.vbt.returns.*`
* `ReturnsDFAccessor` -> `pd.DataFrame.vbt.returns.*`
!!! note
The underlying Series/DataFrame must already be a return series.
To convert price to returns, use `ReturnsAccessor.from_value`.
Grouping is only supported by the methods that accept the `group_by` argument.
Accessors do not utilize caching.
There are three options to compute returns and get the accessor:
```pycon
>>> from vectorbtpro import *
>>> price = pd.Series([1.1, 1.2, 1.3, 1.2, 1.1])
>>> # 1. pd.Series.pct_change
>>> rets = price.pct_change()
>>> ret_acc = rets.vbt.returns(freq='d')
>>> # 2. vectorbtpro.generic.accessors.GenericAccessor.to_returns
>>> rets = price.vbt.to_returns()
>>> ret_acc = rets.vbt.returns(freq='d')
>>> # 3. vectorbtpro.returns.accessors.ReturnsAccessor.from_value
>>> ret_acc = pd.Series.vbt.returns.from_value(price, freq='d')
>>> # vectorbtpro.returns.accessors.ReturnsAccessor.total
>>> ret_acc.total()
0.0
```
The accessors extend `vectorbtpro.generic.accessors`.
```pycon
>>> # inherited from GenericAccessor
>>> ret_acc.max()
0.09090909090909083
```
## Defaults
`vectorbtpro.returns.accessors.ReturnsAccessor` accepts `defaults` dictionary where you can pass
defaults for arguments used throughout the accessor, such as
* `start_value`: The starting value.
* `window`: Window length.
* `minp`: Minimum number of observations in a window required to have a value.
* `ddof`: Delta Degrees of Freedom.
* `risk_free`: Constant risk-free return throughout the period.
* `levy_alpha`: Scaling relation (Levy stability exponent).
* `required_return`: Minimum acceptance return of the investor.
* `cutoff`: Decimal representing the percentage cutoff for the bottom percentile of returns.
* `periods`: Number of observations for annualization. Can be an integer or "dt_periods".
Defaults as well as `bm_returns` and `year_freq` can be set globally using settings:
```pycon
>>> benchmark = pd.Series([1.05, 1.1, 1.15, 1.1, 1.05])
>>> bm_returns = benchmark.vbt.to_returns()
>>> vbt.settings.returns['bm_returns'] = bm_returns
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `ReturnsAccessor.metrics`.
```pycon
>>> ret_acc.stats()
Start 0
End 4
Duration 5 days 00:00:00
Total Return [%] 0
Benchmark Return [%] 0
Annualized Return [%] 0
Annualized Volatility [%] 184.643
Sharpe Ratio 0.691185
Calmar Ratio 0
Max Drawdown [%] 15.3846
Omega Ratio 1.08727
Sortino Ratio 1.17805
Skew 0.00151002
Kurtosis -5.94737
Tail Ratio 1.08985
Common Sense Ratio 1.08985
Value at Risk -0.0823718
Alpha 0.78789
Beta 1.83864
dtype: object
```
!!! note
`ReturnsAccessor.stats` does not support grouping.
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `ReturnsAccessor.subplots`.
`ReturnsAccessor` class has a single subplot based on `ReturnsAccessor.plot_cumulative`:
```pycon
>>> ret_acc.plots().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
import numpy as np
import pandas as pd
from pandas.tseries.offsets import BaseOffset
from vectorbtpro import _typing as tp
from vectorbtpro.accessors import register_vbt_accessor, register_df_vbt_accessor, register_sr_vbt_accessor
from vectorbtpro.base.reshaping import to_1d_array, to_2d_array, broadcast_array_to, broadcast_to
from vectorbtpro.base.wrapping import ArrayWrapper, Wrapping
from vectorbtpro.generic.accessors import GenericAccessor, GenericSRAccessor, GenericDFAccessor
from vectorbtpro.generic.drawdowns import Drawdowns
from vectorbtpro.generic.sim_range import SimRangeMixin
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.returns import nb
from vectorbtpro.utils import checks, chunking as ch, datetime_ as dt
from vectorbtpro.utils.config import resolve_dict, merge_dicts, HybridConfig, Config
from vectorbtpro.utils.decorators import hybrid_property, hybrid_method
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"ReturnsAccessor",
"ReturnsSRAccessor",
"ReturnsDFAccessor",
]
__pdoc__ = {}
ReturnsAccessorT = tp.TypeVar("ReturnsAccessorT", bound="ReturnsAccessor")
@register_vbt_accessor("returns")
class ReturnsAccessor(GenericAccessor, SimRangeMixin):
"""Accessor on top of return series. For both, Series and DataFrames.
Accessible via `pd.Series.vbt.returns` and `pd.DataFrame.vbt.returns`.
Args:
obj (pd.Series or pd.DataFrame): Pandas object representing returns.
bm_returns (array_like): Pandas object representing benchmark returns.
log_returns (bool): Whether returns and benchmark returns are provided as log returns.
year_freq (any): Year frequency for annualization purposes.
defaults (dict): Defaults that override `defaults` in `vectorbtpro._settings.returns`.
sim_start (int, datetime_like, or array_like): Simulation start per column.
sim_end (int, datetime_like, or array_like): Simulation end per column.
**kwargs: Keyword arguments that are passed down to `vectorbtpro.generic.accessors.GenericAccessor`."""
@classmethod
def from_value(
cls: tp.Type[ReturnsAccessorT],
value: tp.ArrayLike,
init_value: tp.ArrayLike = np.nan,
log_returns: bool = False,
sim_start: tp.Optional[tp.Array1d] = None,
sim_end: tp.Optional[tp.Array1d] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrapper_kwargs: tp.KwargsLike = None,
return_values: bool = False,
**kwargs,
) -> tp.Union[ReturnsAccessorT, tp.SeriesFrame]:
"""Returns a new `ReturnsAccessor` instance with returns calculated from `value`."""
if wrapper_kwargs is None:
wrapper_kwargs = {}
if not checks.is_any_array(value):
value = np.asarray(value)
if wrapper is None:
wrapper = ArrayWrapper.from_obj(value, **wrapper_kwargs)
elif len(wrapper_kwargs) > 0:
wrapper = wrapper.replace(**wrapper_kwargs)
value = to_2d_array(value)
init_value = broadcast_array_to(init_value, value.shape[1])
sim_start = cls.resolve_sim_start(sim_start=sim_start, wrapper=wrapper, group_by=False)
sim_end = cls.resolve_sim_end(sim_end=sim_end, wrapper=wrapper, group_by=False)
func = jit_reg.resolve_option(nb.returns_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
returns = func(
value,
init_value=init_value,
log_returns=log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
if return_values:
return wrapper.wrap(returns, group_by=False)
return cls(wrapper, returns, sim_start=sim_start, sim_end=sim_end, **kwargs)
@classmethod
def resolve_row_stack_kwargs(
cls: tp.Type[ReturnsAccessorT],
*objs: tp.MaybeTuple[ReturnsAccessorT],
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `ReturnsAccessor` after stacking along rows."""
kwargs = GenericAccessor.resolve_row_stack_kwargs(*objs, **kwargs)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, ReturnsAccessor):
raise TypeError("Each object to be merged must be an instance of ReturnsAccessor")
if "bm_returns" not in kwargs:
bm_returns = []
stack_bm_returns = True
for obj in objs:
if obj.config["bm_returns"] is not None:
bm_returns.append(obj.config["bm_returns"])
else:
stack_bm_returns = False
break
if stack_bm_returns:
kwargs["bm_returns"] = kwargs["wrapper"].row_stack_arrs(
*bm_returns,
group_by=False,
wrap=False,
)
if "sim_start" not in kwargs:
kwargs["sim_start"] = cls.row_stack_sim_start(kwargs["wrapper"], *objs)
if "sim_end" not in kwargs:
kwargs["sim_end"] = cls.row_stack_sim_end(kwargs["wrapper"], *objs)
return kwargs
@classmethod
def resolve_column_stack_kwargs(
cls: tp.Type[ReturnsAccessorT],
*objs: tp.MaybeTuple[ReturnsAccessorT],
reindex_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments for initializing `ReturnsAccessor` after stacking along columns."""
kwargs = GenericAccessor.resolve_column_stack_kwargs(*objs, reindex_kwargs=reindex_kwargs, **kwargs)
kwargs.pop("reindex_kwargs", None)
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, ReturnsAccessor):
raise TypeError("Each object to be merged must be an instance of ReturnsAccessor")
if "bm_returns" not in kwargs:
bm_returns = []
stack_bm_returns = True
for obj in objs:
if obj.bm_returns is not None:
bm_returns.append(obj.bm_returns)
else:
stack_bm_returns = False
break
if stack_bm_returns:
kwargs["bm_returns"] = kwargs["wrapper"].column_stack_arrs(
*bm_returns,
reindex_kwargs=reindex_kwargs,
group_by=False,
wrap=False,
)
if "sim_start" not in kwargs:
kwargs["sim_start"] = cls.column_stack_sim_start(kwargs["wrapper"], *objs)
if "sim_end" not in kwargs:
kwargs["sim_end"] = cls.column_stack_sim_end(kwargs["wrapper"], *objs)
return kwargs
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
log_returns: bool = False,
year_freq: tp.Optional[tp.FrequencyLike] = None,
defaults: tp.KwargsLike = None,
sim_start: tp.Optional[tp.Array1d] = None,
sim_end: tp.Optional[tp.Array1d] = None,
**kwargs,
) -> None:
GenericAccessor.__init__(
self,
wrapper,
obj=obj,
bm_returns=bm_returns,
log_returns=log_returns,
year_freq=year_freq,
defaults=defaults,
sim_start=sim_start,
sim_end=sim_end,
**kwargs,
)
SimRangeMixin.__init__(self, sim_start=sim_start, sim_end=sim_end)
self._bm_returns = bm_returns
self._log_returns = log_returns
self._year_freq = year_freq
self._defaults = defaults
@hybrid_property
def sr_accessor_cls(cls_or_self) -> tp.Type["ReturnsSRAccessor"]:
"""Accessor class for `pd.Series`."""
return ReturnsSRAccessor
@hybrid_property
def df_accessor_cls(cls_or_self) -> tp.Type["ReturnsDFAccessor"]:
"""Accessor class for `pd.DataFrame`."""
return ReturnsDFAccessor
def indexing_func(
self: ReturnsAccessorT,
*args,
wrapper_meta: tp.DictLike = None,
**kwargs,
) -> ReturnsAccessorT:
"""Perform indexing on `ReturnsAccessor`."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.indexing_func_meta(*args, **kwargs)
new_obj = wrapper_meta["new_wrapper"].wrap(
self.to_2d_array()[wrapper_meta["row_idxs"], :][:, wrapper_meta["col_idxs"]],
group_by=False,
)
if self._bm_returns is not None:
new_bm_returns = ArrayWrapper.select_from_flex_array(
self._bm_returns,
row_idxs=wrapper_meta["row_idxs"],
col_idxs=wrapper_meta["col_idxs"],
rows_changed=wrapper_meta["rows_changed"],
columns_changed=wrapper_meta["columns_changed"],
)
else:
new_bm_returns = None
new_sim_start = self.sim_start_indexing_func(wrapper_meta)
new_sim_end = self.sim_end_indexing_func(wrapper_meta)
if checks.is_series(new_obj):
return self.replace(
cls_=self.sr_accessor_cls,
wrapper=wrapper_meta["new_wrapper"],
obj=new_obj,
bm_returns=new_bm_returns,
sim_start=new_sim_start,
sim_end=new_sim_end,
)
return self.replace(
cls_=self.df_accessor_cls,
wrapper=wrapper_meta["new_wrapper"],
obj=new_obj,
bm_returns=new_bm_returns,
sim_start=new_sim_start,
sim_end=new_sim_end,
)
# ############# Properties ############# #
@property
def bm_returns(self) -> tp.Optional[tp.SeriesFrame]:
"""Benchmark returns."""
from vectorbtpro._settings import settings
returns_cfg = settings["returns"]
bm_returns = self._bm_returns
if bm_returns is None:
bm_returns = returns_cfg["bm_returns"]
if bm_returns is not None:
bm_returns = self.wrapper.wrap(bm_returns, group_by=False)
return bm_returns
def get_bm_returns_acc(
self,
bm_returns: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
) -> tp.Optional[ReturnsAccessorT]:
"""Get accessor for benchmark returns."""
if bm_returns is None:
bm_returns = self.bm_returns
if bm_returns is None:
return None
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
return self.replace(
obj=bm_returns,
bm_returns=None,
sim_start=sim_start,
sim_end=sim_end,
)
@property
def bm_returns_acc(self) -> tp.Optional[ReturnsAccessorT]:
"""`ReturnsAccessor.get_bm_returns_acc` with default arguments."""
return self.get_bm_returns_acc()
@property
def log_returns(self) -> bool:
"""Whether returns and benchmark returns are provided as log returns."""
return self._log_returns
@classmethod
def auto_detect_ann_factor(cls, index: pd.DatetimeIndex) -> tp.Optional[float]:
"""Auto-detect annualization factor from a datetime index."""
checks.assert_instance_of(index, pd.DatetimeIndex, arg_name="index")
if len(index) == 1:
return None
offset = index[0] + pd.offsets.YearBegin() - index[0]
first_date = index[0] + offset
last_date = index[-1] + offset
next_year_date = last_date + pd.offsets.YearBegin()
ratio = (last_date.value - first_date.value) / (next_year_date.value - first_date.value)
ann_factor = len(index) / ratio
ann_factor /= next_year_date.year - first_date.year
return ann_factor
@classmethod
def parse_ann_factor(cls, index: pd.DatetimeIndex, method_name: str = "max") -> tp.Optional[float]:
"""Parse annualization factor from a datetime index."""
checks.assert_instance_of(index, pd.DatetimeIndex, arg_name="index")
if len(index) == 1:
return None
offset = index[0] + pd.offsets.YearBegin() - index[0]
shifted_index = index + offset
years = shifted_index.year
full_years = years[years < years.max()]
if len(full_years) == 0:
return None
return getattr(full_years.value_counts(), method_name.lower())()
@classmethod
def ann_factor_to_year_freq(
cls,
ann_factor: float,
freq: tp.PandasFrequency,
method_name: tp.Optional[str] = None,
) -> tp.PandasFrequency:
"""Convert annualization factor into year frequency."""
if method_name not in (None, False):
if method_name is True:
ann_factor = round(ann_factor)
else:
ann_factor = getattr(np, method_name.lower())(ann_factor)
if checks.is_float(ann_factor) and float.is_integer(ann_factor):
ann_factor = int(ann_factor)
if checks.is_float(ann_factor) and isinstance(freq, BaseOffset):
freq = dt.offset_to_timedelta(freq)
return ann_factor * freq
@classmethod
def year_freq_depends_on_index(cls, year_freq: tp.FrequencyLike) -> bool:
"""Return whether frequency depends on index."""
if isinstance(year_freq, str):
year_freq = " ".join(year_freq.strip().split())
if year_freq == "auto" or year_freq.startswith("auto_"):
return True
if year_freq.startswith("index_"):
return True
return False
@hybrid_method
def get_year_freq(
cls_or_self,
year_freq: tp.Optional[tp.FrequencyLike] = None,
index: tp.Optional[tp.Index] = None,
freq: tp.Optional[tp.PandasFrequency] = None,
) -> tp.Optional[tp.PandasFrequency]:
"""Resolve year frequency.
If `year_freq` is "auto", uses `ReturnsAccessor.auto_detect_ann_factor`. If `year_freq`
is "auto_[method_name]`, also applies the method `np.[method_name]` to the annualization factor,
mostly to round it. If `year_freq` is "index_[method_name]", uses `ReturnsAccessor.parse_ann_factor`
to determine the annualization factor by applying the method to `pd.DatetimeIndex.year`."""
if not isinstance(cls_or_self, type):
if year_freq is None:
year_freq = cls_or_self._year_freq
if year_freq is None:
from vectorbtpro._settings import settings
returns_cfg = settings["returns"]
year_freq = returns_cfg["year_freq"]
if year_freq is None:
return None
if isinstance(year_freq, str):
year_freq = " ".join(year_freq.strip().split())
if cls_or_self.year_freq_depends_on_index(year_freq):
if not isinstance(cls_or_self, type):
if index is None:
index = cls_or_self.wrapper.index
if freq is None:
freq = cls_or_self.wrapper.freq
if index is None or not isinstance(index, pd.DatetimeIndex) or freq is None:
return None
if year_freq == "auto" or year_freq.startswith("auto_"):
ann_factor = cls_or_self.auto_detect_ann_factor(index)
if year_freq == "auto":
method_name = None
else:
method_name = year_freq.replace("auto_", "")
year_freq = cls_or_self.ann_factor_to_year_freq(
ann_factor,
dt.to_freq(freq),
method_name=method_name,
)
else:
method_name = year_freq.replace("index_", "")
ann_factor = cls_or_self.parse_ann_factor(index, method_name=method_name)
year_freq = cls_or_self.ann_factor_to_year_freq(
ann_factor,
dt.to_freq(freq),
method_name=None,
)
return dt.to_freq(year_freq)
@property
def year_freq(self) -> tp.Optional[tp.PandasFrequency]:
"""Year frequency."""
return self.get_year_freq()
@hybrid_method
def get_ann_factor(
cls_or_self,
year_freq: tp.Optional[tp.FrequencyLike] = None,
freq: tp.Optional[tp.FrequencyLike] = None,
raise_error: bool = False,
) -> tp.Optional[float]:
"""Get the annualization factor from the year and data frequency."""
if isinstance(cls_or_self, type):
from vectorbtpro._settings import settings
returns_cfg = settings["returns"]
wrapping_cfg = settings["wrapping"]
if year_freq is None:
year_freq = returns_cfg["year_freq"]
if freq is None:
freq = wrapping_cfg["freq"]
if freq is not None and dt.freq_depends_on_index(freq):
freq = None
else:
if year_freq is None:
year_freq = cls_or_self.year_freq
if freq is None:
freq = cls_or_self.wrapper.freq
if year_freq is None:
if not raise_error:
return None
raise ValueError(
"Year frequency is None. "
"Pass it as `year_freq` or define it globally under `settings.returns`. "
"To determine year frequency automatically, use 'auto'."
)
if freq is None:
if not raise_error:
return None
raise ValueError(
"Index frequency is None. "
"Pass it as `freq` or define it globally under `settings.wrapping`. "
"To determine frequency automatically, use 'auto'."
)
return dt.to_timedelta(year_freq, approximate=True) / dt.to_timedelta(freq, approximate=True)
@property
def ann_factor(self) -> float:
"""Annualization factor."""
return self.get_ann_factor(raise_error=True)
@hybrid_method
def get_periods(
cls_or_self,
periods: tp.Union[None, str, tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
wrapper: tp.Optional[ArrayWrapper] = None,
group_by: tp.GroupByLike = None,
) -> tp.Optional[tp.ArrayLike]:
"""Prepare periods."""
if not isinstance(cls_or_self, type) and periods is None:
periods = cls_or_self.defaults["periods"]
if isinstance(periods, str) and periods.lower() == "dt_periods":
if not isinstance(cls_or_self, type):
if wrapper is None:
wrapper = cls_or_self.wrapper
else:
checks.assert_not_none(wrapper, arg_name="wrapper")
sim_start = cls_or_self.resolve_sim_start(
sim_start=sim_start,
allow_none=True,
wrapper=wrapper,
group_by=group_by,
)
sim_end = cls_or_self.resolve_sim_end(
sim_end=sim_end,
allow_none=True,
wrapper=wrapper,
group_by=group_by,
)
if sim_start is not None or sim_end is not None:
if sim_start is None:
sim_start = cls_or_self.resolve_sim_start(
sim_start=sim_start,
allow_none=False,
wrapper=wrapper,
group_by=group_by,
)
if sim_end is None:
sim_end = cls_or_self.resolve_sim_end(
sim_end=sim_end,
allow_none=False,
wrapper=wrapper,
group_by=group_by,
)
periods = []
for i in range(len(sim_start)):
sim_index = wrapper.index[sim_start[i] : sim_end[i]]
if len(sim_index) == 0:
periods.append(0)
else:
periods.append(wrapper.index_acc.get_dt_periods(index=sim_index))
periods = np.asarray(periods)
else:
periods = wrapper.dt_periods
return periods
@property
def periods(self) -> tp.Optional[tp.ArrayLike]:
"""Periods."""
return self.get_periods()
def deannualize(self, value: float) -> float:
"""Deannualize a value."""
return np.power(1 + value, 1.0 / self.ann_factor) - 1.0
@property
def defaults(self) -> tp.Kwargs:
"""Defaults for `ReturnsAccessor`.
Merges `defaults` from `vectorbtpro._settings.returns` with `defaults`
from `ReturnsAccessor.__init__`."""
from vectorbtpro._settings import settings
returns_defaults_cfg = settings["returns"]["defaults"]
return merge_dicts(returns_defaults_cfg, self._defaults)
# ############# Transforming ############# #
def mirror(
self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Mirror returns.
See `vectorbtpro.returns.nb.mirror_returns_nb`."""
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.mirror_returns_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
mirrored_returns = func(
self.to_2d_array(),
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(mirrored_returns, group_by=False, **resolve_dict(wrap_kwargs))
def cumulative(
self,
start_value: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Cumulative returns.
See `vectorbtpro.returns.nb.cumulative_returns_nb`."""
if start_value is None:
start_value = self.defaults["start_value"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.cumulative_returns_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
cumulative = func(
self.to_2d_array(),
start_value=start_value,
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(cumulative, group_by=False, **resolve_dict(wrap_kwargs))
def resample(
self: ReturnsAccessorT,
*args,
fill_with_zero: bool = True,
wrapper_meta: tp.DictLike = None,
**kwargs,
) -> ReturnsAccessorT:
"""Perform resampling on `ReturnsAccessor`."""
if wrapper_meta is None:
wrapper_meta = self.wrapper.resample_meta(*args, **kwargs)
new_wrapper = wrapper_meta["new_wrapper"]
new_obj = self.resample_apply(
wrapper_meta["resampler"],
nb.total_return_1d_nb,
self.log_returns,
)
if fill_with_zero:
new_obj = new_obj.vbt.fillna(0.0)
if self._bm_returns is not None:
new_bm_returns = self.bm_returns.vbt.resample_apply(
wrapper_meta["resampler"],
nb.total_return_1d_nb,
self.log_returns,
)
if fill_with_zero:
new_bm_returns = new_bm_returns.vbt.fillna(0.0)
else:
new_bm_returns = None
new_sim_start = self.resample_sim_start(new_wrapper)
new_sim_end = self.resample_sim_end(new_wrapper)
return self.replace(
wrapper=wrapper_meta["new_wrapper"],
obj=new_obj,
bm_returns=new_bm_returns,
sim_start=new_sim_start,
sim_end=new_sim_end,
)
def resample_returns(
self,
rule: tp.AnyRuleLike,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> tp.SeriesFrame:
"""Resample returns to a custom frequency, date offset, or index."""
checks.assert_instance_of(self.obj.index, dt.PandasDatetimeIndex)
func = jit_reg.resolve_option(nb.total_return_1d_nb, jitted)
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
args=ch.ArgsTaker(
None,
)
),
)
return self.resample_apply(
rule,
func,
self.log_returns,
jitted=jitted,
chunked=chunked,
**kwargs,
)
def daily(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> tp.SeriesFrame:
"""Daily returns."""
return self.resample_returns("1D", jitted=jitted, chunked=chunked, **kwargs)
def annual(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> tp.SeriesFrame:
"""Annual returns."""
return self.resample_returns(self.year_freq, jitted=jitted, chunked=chunked, **kwargs)
# ############# Metrics ############# #
def final_value(
self,
start_value: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Final value.
See `vectorbtpro.returns.nb.final_value_nb`."""
if start_value is None:
start_value = self.defaults["start_value"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.final_value_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
start_value=start_value,
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="final_value"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_final_value(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
start_value: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Rolling final value.
See `vectorbtpro.returns.nb.rolling_final_value_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if start_value is None:
start_value = self.defaults["start_value"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_final_value_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
start_value=start_value,
log_returns=self.log_returns,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def total(
self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Total return.
See `vectorbtpro.returns.nb.total_return_nb`."""
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.total_return_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="total_return"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_total(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Rolling total return.
See `vectorbtpro.returns.nb.rolling_total_return_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_total_return_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
log_returns=self.log_returns,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def annualized(
self,
periods: tp.Union[None, str, tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Annualized return.
See `vectorbtpro.returns.nb.annualized_return_nb`."""
periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.annualized_return_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
self.ann_factor,
periods=periods,
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="annualized_return"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_annualized(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling annualized return.
See `vectorbtpro.returns.nb.rolling_annualized_return_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_annualized_return_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
self.ann_factor,
log_returns=self.log_returns,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def annualized_volatility(
self,
levy_alpha: tp.Optional[float] = None,
ddof: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Annualized volatility.
See `vectorbtpro.returns.nb.annualized_volatility_nb`."""
if levy_alpha is None:
levy_alpha = self.defaults["levy_alpha"]
if ddof is None:
ddof = self.defaults["ddof"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.annualized_volatility_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
self.ann_factor,
levy_alpha=levy_alpha,
ddof=ddof,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="annualized_volatility"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_annualized_volatility(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
levy_alpha: tp.Optional[float] = None,
ddof: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling annualized volatility.
See `vectorbtpro.returns.nb.rolling_annualized_volatility_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if levy_alpha is None:
levy_alpha = self.defaults["levy_alpha"]
if ddof is None:
ddof = self.defaults["ddof"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_annualized_volatility_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
self.ann_factor,
levy_alpha=levy_alpha,
ddof=ddof,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def calmar_ratio(
self,
periods: tp.Union[None, str, tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Calmar ratio.
See `vectorbtpro.returns.nb.calmar_ratio_nb`."""
periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.calmar_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
self.ann_factor,
periods=periods,
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="calmar_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_calmar_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling Calmar ratio.
See `vectorbtpro.returns.nb.rolling_calmar_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_calmar_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
self.ann_factor,
log_returns=self.log_returns,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def omega_ratio(
self,
risk_free: tp.Optional[float] = None,
required_return: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Omega ratio.
See `vectorbtpro.returns.nb.omega_ratio_nb`."""
if risk_free is None:
risk_free = self.defaults["risk_free"]
if required_return is None:
required_return = self.defaults["required_return"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.deannualized_return_nb, jitted)
required_return = func(required_return, self.ann_factor)
func = jit_reg.resolve_option(nb.omega_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - risk_free - required_return,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="omega_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_omega_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
risk_free: tp.Optional[float] = None,
required_return: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling Omega ratio.
See `vectorbtpro.returns.nb.rolling_omega_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if risk_free is None:
risk_free = self.defaults["risk_free"]
if required_return is None:
required_return = self.defaults["required_return"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.deannualized_return_nb, jitted)
required_return = func(required_return, self.ann_factor)
func = jit_reg.resolve_option(nb.rolling_omega_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - risk_free - required_return,
window,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def sharpe_ratio(
self,
annualized: bool = True,
risk_free: tp.Optional[float] = None,
ddof: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Sharpe ratio.
See `vectorbtpro.returns.nb.sharpe_ratio_nb`."""
if risk_free is None:
risk_free = self.defaults["risk_free"]
if ddof is None:
ddof = self.defaults["ddof"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.sharpe_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
if annualized:
ann_factor = self.ann_factor
else:
ann_factor = 1
out = func(
self.to_2d_array() - risk_free,
ann_factor,
ddof=ddof,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="sharpe_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_sharpe_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
annualized: bool = True,
risk_free: tp.Optional[float] = None,
ddof: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
stream_mode: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling Sharpe ratio.
See `vectorbtpro.returns.nb.rolling_sharpe_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if risk_free is None:
risk_free = self.defaults["risk_free"]
if ddof is None:
ddof = self.defaults["ddof"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_sharpe_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
if annualized:
ann_factor = self.ann_factor
else:
ann_factor = 1
out = func(
self.to_2d_array() - risk_free,
window,
ann_factor,
ddof=ddof,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
stream_mode=stream_mode,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def sharpe_ratio_std(
self,
risk_free: tp.Optional[float] = None,
ddof: tp.Optional[int] = None,
bias: bool = True,
wrap_kwargs: tp.KwargsLike = None,
):
"""Standard deviation of the sharpe ratio estimation."""
from scipy import stats as scipy_stats
returns = to_2d_array(self.obj)
nanmask = np.isnan(returns)
if nanmask.any():
returns = returns.copy()
returns[nanmask] = 0.0
n = len(returns)
skew = scipy_stats.skew(returns, axis=0, bias=bias)
kurtosis = scipy_stats.kurtosis(returns, axis=0, bias=bias)
sr = to_1d_array(self.sharpe_ratio(annualized=False, risk_free=risk_free, ddof=ddof))
out = np.sqrt((1 + (0.5 * sr**2) - (skew * sr) + (((kurtosis - 3) / 4) * sr**2)) / (n - 1))
wrap_kwargs = merge_dicts(dict(name_or_index="sharpe_ratio_std"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def prob_sharpe_ratio(
self,
bm_returns: tp.Optional[tp.ArrayLike] = None,
risk_free: tp.Optional[float] = None,
ddof: tp.Optional[int] = None,
bias: bool = True,
wrap_kwargs: tp.KwargsLike = None,
):
"""Probabilistic Sharpe Ratio (PSR)."""
from scipy import stats as scipy_stats
if bm_returns is None:
bm_returns = self.bm_returns
if bm_returns is not None:
bm_sr = to_1d_array(
self.replace(obj=bm_returns, bm_returns=None).sharpe_ratio(
annualized=False,
risk_free=risk_free,
ddof=ddof,
)
)
else:
bm_sr = 0
sr = to_1d_array(self.sharpe_ratio(annualized=False, risk_free=risk_free, ddof=ddof))
sr_std = to_1d_array(self.sharpe_ratio_std(risk_free=risk_free, ddof=ddof, bias=bias))
out = scipy_stats.norm.cdf((sr - bm_sr) / sr_std)
wrap_kwargs = merge_dicts(dict(name_or_index="prob_sharpe_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def deflated_sharpe_ratio(
self,
risk_free: tp.Optional[float] = None,
ddof: tp.Optional[int] = None,
bias: bool = True,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Deflated Sharpe Ratio (DSR).
Expresses the chance that the advertised strategy has a positive Sharpe ratio."""
from scipy import stats as scipy_stats
if risk_free is None:
risk_free = self.defaults["risk_free"]
if ddof is None:
ddof = self.defaults["ddof"]
sharpe_ratio = to_1d_array(self.sharpe_ratio(annualized=False, risk_free=risk_free, ddof=ddof))
var_sharpe = np.nanvar(sharpe_ratio, ddof=ddof)
returns = to_2d_array(self.obj)
nanmask = np.isnan(returns)
if nanmask.any():
returns = returns.copy()
returns[nanmask] = 0.0
skew = scipy_stats.skew(returns, axis=0, bias=bias)
kurtosis = scipy_stats.kurtosis(returns, axis=0, bias=bias)
SR0 = sharpe_ratio + np.sqrt(var_sharpe) * (
(1 - np.euler_gamma) * scipy_stats.norm.ppf(1 - 1 / self.wrapper.shape_2d[1])
+ np.euler_gamma * scipy_stats.norm.ppf(1 - 1 / (self.wrapper.shape_2d[1] * np.e))
)
out = scipy_stats.norm.cdf(
((sharpe_ratio - SR0) * np.sqrt(self.wrapper.shape_2d[0] - 1))
/ np.sqrt(1 - skew * sharpe_ratio + ((kurtosis - 1) / 4) * sharpe_ratio**2)
)
wrap_kwargs = merge_dicts(dict(name_or_index="deflated_sharpe_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def downside_risk(
self,
required_return: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Downside risk.
See `vectorbtpro.returns.nb.downside_risk_nb`."""
if required_return is None:
required_return = self.defaults["required_return"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.downside_risk_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - required_return,
self.ann_factor,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="downside_risk"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_downside_risk(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
required_return: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling downside risk.
See `vectorbtpro.returns.nb.rolling_downside_risk_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if required_return is None:
required_return = self.defaults["required_return"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_downside_risk_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - required_return,
window,
self.ann_factor,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def sortino_ratio(
self,
required_return: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Sortino ratio.
See `vectorbtpro.returns.nb.sortino_ratio_nb`."""
if required_return is None:
required_return = self.defaults["required_return"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.sortino_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - required_return,
self.ann_factor,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="sortino_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_sortino_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
required_return: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling Sortino ratio.
See `vectorbtpro.returns.nb.rolling_sortino_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if required_return is None:
required_return = self.defaults["required_return"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_sortino_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - required_return,
window,
self.ann_factor,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def information_ratio(
self,
bm_returns: tp.Optional[tp.ArrayLike] = None,
ddof: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Information ratio.
See `vectorbtpro.returns.nb.information_ratio_nb`."""
if ddof is None:
ddof = self.defaults["ddof"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.information_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - to_2d_array(bm_returns),
ddof=ddof,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="information_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_information_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
ddof: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling information ratio.
See `vectorbtpro.returns.nb.rolling_information_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if ddof is None:
ddof = self.defaults["ddof"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_information_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - to_2d_array(bm_returns),
window,
ddof=ddof,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def beta(
self,
bm_returns: tp.Optional[tp.ArrayLike] = None,
ddof: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Beta.
See `vectorbtpro.returns.nb.beta_nb`."""
if ddof is None:
ddof = self.defaults["ddof"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.beta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
to_2d_array(bm_returns),
ddof=ddof,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="beta"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_beta(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
ddof: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling beta.
See `vectorbtpro.returns.nb.rolling_beta_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if ddof is None:
ddof = self.defaults["ddof"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_beta_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
to_2d_array(bm_returns),
window,
ddof=ddof,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def alpha(
self,
bm_returns: tp.Optional[tp.ArrayLike] = None,
risk_free: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Alpha.
See `vectorbtpro.returns.nb.alpha_nb`."""
if risk_free is None:
risk_free = self.defaults["risk_free"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.alpha_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - risk_free,
to_2d_array(bm_returns) - risk_free,
self.ann_factor,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="alpha"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_alpha(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
risk_free: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling alpha.
See `vectorbtpro.returns.nb.rolling_alpha_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if risk_free is None:
risk_free = self.defaults["risk_free"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_alpha_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array() - risk_free,
to_2d_array(bm_returns) - risk_free,
window,
self.ann_factor,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def tail_ratio(
self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
noarr_mode: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Tail ratio.
See `vectorbtpro.returns.nb.tail_ratio_nb`."""
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.tail_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
sim_start=sim_start,
sim_end=sim_end,
noarr_mode=noarr_mode,
)
wrap_kwargs = merge_dicts(dict(name_or_index="tail_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_tail_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
noarr_mode: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling tail ratio.
See `vectorbtpro.returns.nb.rolling_tail_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_tail_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
noarr_mode=noarr_mode,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def profit_factor(
self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Profit factor.
See `vectorbtpro.returns.nb.profit_factor_nb`."""
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.profit_factor_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="profit_factor"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_profit_factor(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling profit factor.
See `vectorbtpro.returns.nb.rolling_profit_factor_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_profit_factor_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def common_sense_ratio(
self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Common Sense Ratio (CSR).
See `vectorbtpro.returns.nb.common_sense_ratio_nb`."""
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.common_sense_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="common_sense_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_common_sense_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling Common Sense Ratio (CSR).
See `vectorbtpro.returns.nb.rolling_common_sense_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_common_sense_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def value_at_risk(
self,
cutoff: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
noarr_mode: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Value at Risk (VaR).
See `vectorbtpro.returns.nb.value_at_risk_nb`."""
if cutoff is None:
cutoff = self.defaults["cutoff"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.value_at_risk_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
cutoff=cutoff,
sim_start=sim_start,
sim_end=sim_end,
noarr_mode=noarr_mode,
)
wrap_kwargs = merge_dicts(dict(name_or_index="value_at_risk"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_value_at_risk(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
cutoff: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
noarr_mode: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling Value at Risk (VaR).
See `vectorbtpro.returns.nb.rolling_value_at_risk_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if cutoff is None:
cutoff = self.defaults["cutoff"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_value_at_risk_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
cutoff=cutoff,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
noarr_mode=noarr_mode,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def cond_value_at_risk(
self,
cutoff: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
noarr_mode: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Conditional Value at Risk (CVaR).
See `vectorbtpro.returns.nb.cond_value_at_risk_nb`."""
if cutoff is None:
cutoff = self.defaults["cutoff"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.cond_value_at_risk_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
cutoff=cutoff,
sim_start=sim_start,
sim_end=sim_end,
noarr_mode=noarr_mode,
)
wrap_kwargs = merge_dicts(dict(name_or_index="cond_value_at_risk"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_cond_value_at_risk(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
cutoff: tp.Optional[float] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
noarr_mode: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling Conditional Value at Risk (CVaR).
See `vectorbtpro.returns.nb.rolling_cond_value_at_risk_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if cutoff is None:
cutoff = self.defaults["cutoff"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_cond_value_at_risk_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
cutoff=cutoff,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
noarr_mode=noarr_mode,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def capture_ratio(
self,
bm_returns: tp.Optional[tp.ArrayLike] = None,
periods: tp.Union[None, str, tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Capture ratio.
See `vectorbtpro.returns.nb.capture_ratio_nb`."""
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.capture_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
to_2d_array(bm_returns),
self.ann_factor,
periods=periods,
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="capture_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_capture_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling capture ratio.
See `vectorbtpro.returns.nb.rolling_capture_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_capture_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
to_2d_array(bm_returns),
window,
self.ann_factor,
log_returns=self.log_returns,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def up_capture_ratio(
self,
bm_returns: tp.Optional[tp.ArrayLike] = None,
periods: tp.Union[None, str, tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Up-market capture ratio.
See `vectorbtpro.returns.nb.up_capture_ratio_nb`."""
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.up_capture_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
to_2d_array(bm_returns),
self.ann_factor,
periods=periods,
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="up_capture_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_up_capture_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling up-market capture ratio.
See `vectorbtpro.returns.nb.rolling_up_capture_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_up_capture_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
to_2d_array(bm_returns),
window,
self.ann_factor,
log_returns=self.log_returns,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def down_capture_ratio(
self,
bm_returns: tp.Optional[tp.ArrayLike] = None,
periods: tp.Union[None, str, tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Up-market capture ratio.
See `vectorbtpro.returns.nb.down_capture_ratio_nb`."""
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
periods = self.get_periods(periods=periods, sim_start=sim_start, sim_end=sim_end)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.down_capture_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
to_2d_array(bm_returns),
self.ann_factor,
periods=periods,
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="down_capture_ratio"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_down_capture_ratio(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling down-market capture ratio.
See `vectorbtpro.returns.nb.rolling_down_capture_ratio_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
if bm_returns is None:
bm_returns = self.bm_returns
checks.assert_not_none(bm_returns, arg_name="bm_returns")
bm_returns = broadcast_to(bm_returns, self.obj)
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_down_capture_ratio_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
to_2d_array(bm_returns),
window,
self.ann_factor,
log_returns=self.log_returns,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
def drawdown(
self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Relative decline from a peak."""
return self.cumulative(
start_value=1,
sim_start=sim_start,
sim_end=sim_end,
jitted=jitted,
chunked=chunked,
).vbt.drawdown(
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
)
def max_drawdown(
self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Maximum Drawdown (MDD).
See `vectorbtpro.returns.nb.max_drawdown_nb`.
Yields the same out as `max_drawdown` of `ReturnsAccessor.drawdowns`."""
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.max_drawdown_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
log_returns=self.log_returns,
sim_start=sim_start,
sim_end=sim_end,
)
wrap_kwargs = merge_dicts(dict(name_or_index="max_drawdown"), wrap_kwargs)
return self.wrapper.wrap_reduced(out, group_by=False, **wrap_kwargs)
def rolling_max_drawdown(
self,
window: tp.Optional[int] = None,
*,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""Rolling Maximum Drawdown (MDD).
See `vectorbtpro.returns.nb.rolling_max_drawdown_nb`."""
if window is None:
window = self.defaults["window"]
if minp is None:
minp = self.defaults["minp"]
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
func = jit_reg.resolve_option(nb.rolling_max_drawdown_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
out = func(
self.to_2d_array(),
window,
log_returns=self.log_returns,
minp=minp,
sim_start=sim_start,
sim_end=sim_end,
)
return self.wrapper.wrap(out, group_by=False, **resolve_dict(wrap_kwargs))
@property
def drawdowns(self) -> Drawdowns:
"""`ReturnsAccessor.get_drawdowns` with default arguments."""
return self.get_drawdowns()
def get_drawdowns(
self,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> Drawdowns:
"""Generate drawdown records of cumulative returns.
See `vectorbtpro.generic.drawdowns.Drawdowns`."""
sim_start = self.resolve_sim_start(sim_start=sim_start, group_by=False)
sim_end = self.resolve_sim_end(sim_end=sim_end, group_by=False)
return Drawdowns.from_price(
self.cumulative(
start_value=1.0,
sim_start=sim_start,
sim_end=sim_end,
jitted=jitted,
),
sim_start=sim_start,
sim_end=sim_end,
wrapper=self.wrapper,
**kwargs,
)
@property
def qs(self) -> "QSAdapter":
"""Quantstats adapter."""
from vectorbtpro.returns.qs_adapter import QSAdapter
return QSAdapter(self)
# ############# Resolution ############# #
def resolve_self(
self: ReturnsAccessorT,
cond_kwargs: tp.KwargsLike = None,
custom_arg_names: tp.Optional[tp.Set[str]] = None,
impacts_caching: bool = True,
silence_warnings: bool = False,
) -> ReturnsAccessorT:
"""Resolve self.
See `vectorbtpro.base.wrapping.Wrapping.resolve_self`.
Creates a copy of this instance `year_freq` is different in `cond_kwargs`."""
if cond_kwargs is None:
cond_kwargs = {}
if custom_arg_names is None:
custom_arg_names = set()
reself = Wrapping.resolve_self(
self,
cond_kwargs=cond_kwargs,
custom_arg_names=custom_arg_names,
impacts_caching=impacts_caching,
silence_warnings=silence_warnings,
)
if "year_freq" in cond_kwargs:
self_copy = reself.replace(year_freq=cond_kwargs["year_freq"])
if self_copy.year_freq != reself.year_freq:
if not silence_warnings:
warn(
f"Changing the year frequency will create a copy of this object. "
f"Consider setting it upon object creation to re-use existing cache."
)
for alias in reself.self_aliases:
if alias not in custom_arg_names:
cond_kwargs[alias] = self_copy
cond_kwargs["year_freq"] = self_copy.year_freq
if impacts_caching:
cond_kwargs["use_caching"] = False
return self_copy
return reself
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `ReturnsAccessor.stats`.
Merges `vectorbtpro.generic.accessors.GenericAccessor.stats_defaults`,
defaults from `ReturnsAccessor.defaults` (acting as `settings`), and
`stats` from `vectorbtpro._settings.returns`"""
from vectorbtpro._settings import settings
returns_stats_cfg = settings["returns"]["stats"]
return merge_dicts(
GenericAccessor.stats_defaults.__get__(self),
dict(settings=self.defaults),
dict(settings=dict(year_freq=self.year_freq)),
returns_stats_cfg,
)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func="sim_start_index",
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func="sim_end_index",
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func="sim_duration",
apply_to_timedelta=True,
tags="wrapper",
),
total_return=dict(
title="Total Return [%]",
calc_func="total",
post_calc_func=lambda self, out, settings: out * 100,
tags="returns",
),
bm_return=dict(
title="Benchmark Return [%]",
calc_func="bm_returns_acc.total",
post_calc_func=lambda self, out, settings: out * 100,
check_has_bm_returns=True,
tags="returns",
),
ann_return=dict(
title="Annualized Return [%]",
calc_func="annualized",
post_calc_func=lambda self, out, settings: out * 100,
check_has_freq=True,
check_has_year_freq=True,
tags="returns",
),
ann_volatility=dict(
title="Annualized Volatility [%]",
calc_func="annualized_volatility",
post_calc_func=lambda self, out, settings: out * 100,
check_has_freq=True,
check_has_year_freq=True,
tags="returns",
),
max_dd=dict(
title="Max Drawdown [%]",
calc_func="drawdowns.get_max_drawdown",
post_calc_func=lambda self, out, settings: -out * 100,
tags=["returns", "drawdowns"],
),
max_dd_duration=dict(
title="Max Drawdown Duration",
calc_func="drawdowns.get_max_duration",
fill_wrap_kwargs=True,
tags=["returns", "drawdowns", "duration"],
),
sharpe_ratio=dict(
title="Sharpe Ratio",
calc_func="sharpe_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags="returns",
),
calmar_ratio=dict(
title="Calmar Ratio",
calc_func="calmar_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags="returns",
),
omega_ratio=dict(
title="Omega Ratio",
calc_func="omega_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags="returns",
),
sortino_ratio=dict(
title="Sortino Ratio",
calc_func="sortino_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags="returns",
),
skew=dict(
title="Skew",
calc_func="obj.skew",
tags="returns",
),
kurtosis=dict(
title="Kurtosis",
calc_func="obj.kurtosis",
tags="returns",
),
tail_ratio=dict(
title="Tail Ratio",
calc_func="tail_ratio",
tags="returns",
),
common_sense_ratio=dict(
title="Common Sense Ratio",
calc_func="common_sense_ratio",
check_has_freq=True,
check_has_year_freq=True,
tags="returns",
),
value_at_risk=dict(
title="Value at Risk",
calc_func="value_at_risk",
tags="returns",
),
alpha=dict(
title="Alpha",
calc_func="alpha",
check_has_freq=True,
check_has_year_freq=True,
check_has_bm_returns=True,
tags="returns",
),
beta=dict(
title="Beta",
calc_func="beta",
check_has_bm_returns=True,
tags="returns",
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot_cumulative(
self,
column: tp.Optional[tp.Label] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
start_value: float = 1,
sim_start: tp.Optional[tp.ArrayLike] = None,
sim_end: tp.Optional[tp.ArrayLike] = None,
fit_sim_range: bool = True,
fill_to_benchmark: bool = False,
main_kwargs: tp.KwargsLike = None,
bm_kwargs: tp.KwargsLike = None,
pct_scale: bool = False,
hline_shape_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
xref: str = "x",
yref: str = "y",
fig: tp.Optional[tp.BaseFigure] = None,
**layout_kwargs,
) -> tp.BaseFigure:
"""Plot cumulative returns.
Args:
column (str): Name of the column to plot.
bm_returns (array_like): Benchmark return to compare returns against.
Will broadcast per element.
start_value (float): The starting value.
sim_start (int, datetime_like, or array_like): Simulation start row or index (inclusive).
sim_end (int, datetime_like, or array_like): Simulation end row or index (exclusive).
fit_sim_range (bool): Whether to fit figure to simulation range.
fill_to_benchmark (bool): Whether to fill between main and benchmark, or between main and `start_value`.
main_kwargs (dict): Keyword arguments passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot` for main.
bm_kwargs (dict): Keyword arguments passed to `vectorbtpro.generic.accessors.GenericSRAccessor.plot` for benchmark.
pct_scale (bool): Whether to use the percentage scale for the y-axis.
hline_shape_kwargs (dict): Keyword arguments passed to `plotly.graph_objects.Figure.add_shape` for `start_value` line.
add_trace_kwargs (dict): Keyword arguments passed to `add_trace`.
xref (str): X coordinate axis.
yref (str): Y coordinate axis.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.
Usage:
```pycon
>>> np.random.seed(0)
>>> rets = pd.Series(np.random.uniform(-0.05, 0.05, size=100))
>>> bm_returns = pd.Series(np.random.uniform(-0.05, 0.05, size=100))
>>> rets.vbt.returns.plot_cumulative(bm_returns=bm_returns).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro.utils.figure import make_figure, get_domain
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
xaxis = "xaxis" + xref[1:]
yaxis = "yaxis" + yref[1:]
def_layout_kwargs = {xaxis: {}, yaxis: {}}
if pct_scale:
start_value = 0
def_layout_kwargs[yaxis]["tickformat"] = ".2%"
if fig is None:
fig = make_figure()
fig.update_layout(**def_layout_kwargs)
fig.update_layout(**layout_kwargs)
x_domain = get_domain(xref, fig)
y_domain = get_domain(yref, fig)
if bm_returns is None:
bm_returns = self.bm_returns
fill_to_benchmark = fill_to_benchmark and bm_returns is not None
if bm_returns is not None:
# Plot benchmark
bm_returns = broadcast_to(bm_returns, self.obj)
bm_returns = self.select_col_from_obj(bm_returns, column=column, group_by=False)
if bm_kwargs is None:
bm_kwargs = {}
bm_kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["gray"],
),
name="Benchmark",
)
),
bm_kwargs,
)
bm_cumulative_returns = bm_returns.vbt.returns.cumulative(
start_value=start_value,
sim_start=sim_start,
sim_end=sim_end,
)
bm_cumulative_returns.vbt.lineplot(**bm_kwargs, add_trace_kwargs=add_trace_kwargs, fig=fig)
else:
bm_cumulative_returns = None
if main_kwargs is None:
main_kwargs = {}
cumulative_returns = self.cumulative(
start_value=start_value,
sim_start=sim_start,
sim_end=sim_end,
)
cumulative_returns = self.select_col_from_obj(cumulative_returns, column=column, group_by=False)
main_kwargs = merge_dicts(
dict(
trace_kwargs=dict(
line=dict(
color=plotting_cfg["color_schema"]["purple"],
),
),
other_trace_kwargs="hidden",
),
main_kwargs,
)
if fill_to_benchmark:
cumulative_returns.vbt.plot_against(
bm_cumulative_returns, add_trace_kwargs=add_trace_kwargs, fig=fig, **main_kwargs
)
else:
cumulative_returns.vbt.plot_against(start_value, add_trace_kwargs=add_trace_kwargs, fig=fig, **main_kwargs)
if hline_shape_kwargs is None:
hline_shape_kwargs = {}
fig.add_shape(
**merge_dicts(
dict(
type="line",
xref="paper",
yref=yref,
x0=x_domain[0],
y0=start_value,
x1=x_domain[1],
y1=start_value,
line=dict(
color="gray",
dash="dash",
),
),
hline_shape_kwargs,
)
)
if fit_sim_range:
fig = self.fit_fig_to_sim_range(
fig,
column=column,
sim_start=sim_start,
sim_end=sim_end,
group_by=False,
xref=xref,
)
return fig
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `ReturnsAccessor.plots`.
Merges `vectorbtpro.generic.accessors.GenericAccessor.plots_defaults`,
defaults from `ReturnsAccessor.defaults` (acting as `settings`), and
`plots` from `vectorbtpro._settings.returns`"""
from vectorbtpro._settings import settings
returns_plots_cfg = settings["returns"]["plots"]
return merge_dicts(
GenericAccessor.plots_defaults.__get__(self),
dict(settings=self.defaults),
dict(settings=dict(year_freq=self.year_freq)),
returns_plots_cfg,
)
_subplots: tp.ClassVar[Config] = HybridConfig(
dict(
plot_cumulative=dict(
title="Cumulative Returns",
yaxis_kwargs=dict(title="Cumulative returns"),
plot_func="plot_cumulative",
pass_hline_shape_kwargs=True,
pass_add_trace_kwargs=True,
pass_xref=True,
pass_yref=True,
tags="returns",
)
)
)
@property
def subplots(self) -> Config:
return self._subplots
ReturnsAccessor.override_metrics_doc(__pdoc__)
ReturnsAccessor.override_subplots_doc(__pdoc__)
@register_sr_vbt_accessor("returns")
class ReturnsSRAccessor(ReturnsAccessor, GenericSRAccessor):
"""Accessor on top of return series. For Series only.
Accessible via `pd.Series.vbt.returns`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
year_freq: tp.Optional[tp.FrequencyLike] = None,
defaults: tp.KwargsLike = None,
sim_start: tp.Optional[tp.Array1d] = None,
sim_end: tp.Optional[tp.Array1d] = None,
_full_init: bool = True,
**kwargs,
) -> None:
GenericSRAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs)
if _full_init:
ReturnsAccessor.__init__(
self,
wrapper,
obj=obj,
bm_returns=bm_returns,
year_freq=year_freq,
defaults=defaults,
sim_start=sim_start,
sim_end=sim_end,
**kwargs,
)
@register_df_vbt_accessor("returns")
class ReturnsDFAccessor(ReturnsAccessor, GenericDFAccessor):
"""Accessor on top of return series. For DataFrames only.
Accessible via `pd.DataFrame.vbt.returns`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
bm_returns: tp.Optional[tp.ArrayLike] = None,
year_freq: tp.Optional[tp.FrequencyLike] = None,
defaults: tp.KwargsLike = None,
sim_start: tp.Optional[tp.Array1d] = None,
sim_end: tp.Optional[tp.Array1d] = None,
_full_init: bool = True,
**kwargs,
) -> None:
GenericDFAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs)
if _full_init:
ReturnsAccessor.__init__(
self,
wrapper,
obj=obj,
bm_returns=bm_returns,
year_freq=year_freq,
defaults=defaults,
sim_start=sim_start,
sim_end=sim_end,
**kwargs,
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Named tuples and enumerated types for returns."""
from vectorbtpro import _typing as tp
__pdoc__all__ = __all__ = [
"RollSharpeAIS",
"RollSharpeAOS",
]
__pdoc__ = {}
# ############# States ############# #
class RollSharpeAIS(tp.NamedTuple):
i: int
ret: float
pre_window_ret: float
cumsum: float
cumsum_sq: float
nancnt: int
window: int
minp: tp.Optional[int]
ddof: int
ann_factor: float
__pdoc__[
"RollSharpeAIS"
] = """A named tuple representing the input state of
`vectorbtpro.returns.nb.rolling_sharpe_ratio_acc_nb`."""
class RollSharpeAOS(tp.NamedTuple):
cumsum: float
cumsum_sq: float
nancnt: int
value: float
__pdoc__[
"RollSharpeAOS"
] = """A named tuple representing the output state of
`vectorbtpro.returns.nb.rolling_sharpe_ratio_acc_nb`."""
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for returns.
Provides an arsenal of Numba-compiled functions that are used by accessors and for measuring
portfolio performance. These only accept NumPy arrays and other Numba-compatible types.
!!! note
vectorbt treats matrices as first-class citizens and expects input arrays to be
2-dim, unless function has suffix `_1d` or is meant to be input to another function.
Data is processed along index (axis 0).
All functions passed as argument must be Numba-compiled."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro._settings import settings
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb
from vectorbtpro.base.reshaping import to_1d_array_nb
from vectorbtpro.generic import nb as generic_nb, enums as generic_enums
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.returns.enums import RollSharpeAIS, RollSharpeAOS
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.math_ import add_nb
__all__ = []
_inf_to_nan = settings["returns"]["inf_to_nan"]
_nan_to_zero = settings["returns"]["nan_to_zero"]
# ############# Metrics ############# #
@register_jitted(cache=True)
def get_return_nb(
input_value: float,
output_value: float,
log_returns: bool = False,
inf_to_nan: bool = _inf_to_nan,
nan_to_zero: bool = _nan_to_zero,
) -> float:
"""Calculate return from input and output value."""
if input_value == 0:
if output_value == 0:
r = 0.0
else:
r = np.inf * np.sign(output_value)
else:
return_value = add_nb(output_value, -input_value) / input_value
if log_returns:
r = np.log1p(return_value)
else:
r = return_value
if inf_to_nan and np.isinf(r):
r = np.nan
if nan_to_zero and np.isnan(r):
r = 0.0
return r
@register_jitted(cache=True)
def returns_1d_nb(
arr: tp.Array1d,
init_value: float = np.nan,
log_returns: bool = False,
) -> tp.Array1d:
"""Calculate returns."""
out = np.empty(arr.shape, dtype=float_)
if np.isnan(init_value) and arr.shape[0] > 0:
input_value = arr[0]
else:
input_value = init_value
for i in range(arr.shape[0]):
output_value = arr[i]
out[i] = get_return_nb(input_value, output_value, log_returns=log_returns)
input_value = output_value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="arr", axis=1),
arg_take_spec=dict(
arr=ch.ArraySlicer(axis=1),
init_value=base_ch.FlexArraySlicer(),
log_returns=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def returns_nb(
arr: tp.Array2d,
init_value: tp.FlexArray1dLike = np.nan,
log_returns: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""2-dim version of `returns_1d_nb`."""
init_value_ = to_1d_array_nb(np.asarray(init_value))
out = np.full(arr.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=arr.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(arr.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_init_value = flex_select_1d_pc_nb(init_value_, col)
out[_sim_start:_sim_end, col] = returns_1d_nb(
arr[_sim_start:_sim_end, col],
init_value=_init_value,
log_returns=log_returns,
)
return out
@register_jitted(cache=True)
def mirror_returns_1d_nb(returns: tp.Array1d, log_returns: bool = False) -> tp.Array1d:
"""Calculate mirrored returns.
A mirrored return is an inverse, or negative return. For log returns, it negates each return.
For simple returns, it uses the formula $\frac{1}{1 + R_t} - 1$."""
out = np.empty(returns.shape, dtype=float_)
for i in range(returns.shape[0]):
if log_returns:
out[i] = -returns[i]
else:
if returns[i] <= -1:
out[i] = np.inf
else:
out[i] = (1 / (1 + returns[i])) - 1
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
log_returns=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def mirror_returns_nb(
returns: tp.Array2d,
log_returns: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""2-dim version of `mirror_returns_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = mirror_returns_1d_nb(
returns[_sim_start:_sim_end, col],
log_returns=log_returns,
)
return out
@register_jitted(cache=True)
def cumulative_returns_1d_nb(
returns: tp.Array1d,
start_value: float = 1.0,
log_returns: bool = False,
) -> tp.Array1d:
"""Cumulative returns."""
out = np.empty_like(returns, dtype=float_)
if log_returns:
cumsum = 0
for i in range(returns.shape[0]):
if not np.isnan(returns[i]):
cumsum += returns[i]
if start_value == 0:
out[i] = cumsum
else:
out[i] = np.exp(cumsum) * start_value
else:
cumprod = 1
for i in range(returns.shape[0]):
if not np.isnan(returns[i]):
cumprod *= 1 + returns[i]
if start_value == 0:
out[i] = cumprod - 1
else:
out[i] = cumprod * start_value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
start_value=None,
log_returns=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cumulative_returns_nb(
returns: tp.Array2d,
start_value: float = 1.0,
log_returns: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""2-dim version of `cumulative_returns_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = cumulative_returns_1d_nb(
returns[_sim_start:_sim_end, col],
start_value=start_value,
log_returns=log_returns,
)
return out
@register_jitted(cache=True)
def final_value_1d_nb(
returns: tp.Array1d,
start_value: float = 1.0,
log_returns: bool = False,
) -> float:
"""Final value."""
if log_returns:
cumsum = 0
for i in range(returns.shape[0]):
if not np.isnan(returns[i]):
cumsum += returns[i]
if start_value == 0:
return cumsum
return np.exp(cumsum) * start_value
else:
cumprod = 1
for i in range(returns.shape[0]):
if not np.isnan(returns[i]):
cumprod *= 1 + returns[i]
if start_value == 0:
return cumprod - 1
return cumprod * start_value
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
start_value=None,
log_returns=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def final_value_nb(
returns: tp.Array2d,
start_value: float = 1.0,
log_returns: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `final_value_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = final_value_1d_nb(
returns[_sim_start:_sim_end, col],
start_value=start_value,
log_returns=log_returns,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
start_value=None,
log_returns=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_final_value_nb(
returns: tp.Array2d,
window: int,
start_value: float = 1.0,
log_returns: bool = False,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `final_value_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
final_value_1d_nb,
start_value,
log_returns,
)
return out
@register_jitted(cache=True)
def total_return_1d_nb(returns: tp.Array1d, log_returns: bool = False) -> float:
"""Total return."""
return final_value_1d_nb(returns, start_value=0.0, log_returns=log_returns)
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
log_returns=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def total_return_nb(
returns: tp.Array2d,
log_returns: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `total_return_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = total_return_1d_nb(returns[_sim_start:_sim_end, col], log_returns=log_returns)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
log_returns=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_total_return_nb(
returns: tp.Array2d,
window: int,
log_returns: bool = False,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `total_return_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
total_return_1d_nb,
log_returns,
)
return out
@register_jitted(cache=True)
def annualized_return_1d_nb(
returns: tp.Array1d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[float] = None,
) -> float:
"""Annualized total return.
This is equivalent to the compound annual growth rate (CAGR)."""
if periods is None:
periods = returns.shape[0]
final_value = final_value_1d_nb(returns, log_returns=log_returns)
if periods == 0:
return np.nan
return final_value ** (ann_factor / periods) - 1
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
ann_factor=None,
log_returns=None,
periods=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def annualized_return_nb(
returns: tp.Array2d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `annualized_return_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
if periods is None:
period_ = sim_end_ - sim_start_
else:
period_ = to_1d_array_nb(np.asarray(periods).astype(int_))
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_period = flex_select_1d_pc_nb(period_, col)
out[col] = annualized_return_1d_nb(
returns[_sim_start:_sim_end, col],
ann_factor,
log_returns=log_returns,
periods=_period,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
log_returns=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_annualized_return_nb(
returns: tp.Array2d,
window: int,
ann_factor: float,
log_returns: bool = False,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `annualized_return_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
annualized_return_1d_nb,
ann_factor,
log_returns,
)
return out
@register_jitted(cache=True)
def annualized_volatility_1d_nb(
returns: tp.Array1d,
ann_factor: float,
levy_alpha: float = 2.0,
ddof: int = 0,
) -> float:
"""Annualized volatility of a strategy."""
return generic_nb.nanstd_1d_nb(returns, ddof=ddof) * ann_factor ** (1.0 / levy_alpha)
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
ann_factor=None,
levy_alpha=None,
ddof=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def annualized_volatility_nb(
returns: tp.Array2d,
ann_factor: float,
levy_alpha: float = 2.0,
ddof: int = 0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `annualized_volatility_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = annualized_volatility_1d_nb(
returns[_sim_start:_sim_end, col],
ann_factor,
levy_alpha=levy_alpha,
ddof=ddof,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
levy_alpha=None,
ddof=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_annualized_volatility_nb(
returns: tp.Array2d,
window: int,
ann_factor: float,
levy_alpha: float = 2.0,
ddof: int = 0,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `annualized_volatility_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
annualized_volatility_1d_nb,
ann_factor,
levy_alpha,
ddof,
)
return out
@register_jitted(cache=True)
def max_drawdown_1d_nb(returns: tp.Array1d, log_returns: bool = False) -> float:
"""Total maximum drawdown (MDD)."""
cum_ret = np.nan
value_max = 1.0
out = 0.0
for i in range(returns.shape[0]):
if not np.isnan(returns[i]):
if np.isnan(cum_ret):
cum_ret = 1.0
if log_returns:
ret = np.exp(returns[i]) - 1
else:
ret = returns[i]
cum_ret *= ret + 1.0
if cum_ret > value_max:
value_max = cum_ret
elif cum_ret < value_max:
dd = cum_ret / value_max - 1
if dd < out:
out = dd
if np.isnan(cum_ret):
return np.nan
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
log_returns=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def max_drawdown_nb(
returns: tp.Array2d,
log_returns: bool = False,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `max_drawdown_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = max_drawdown_1d_nb(returns[_sim_start:_sim_end, col], log_returns=log_returns)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
log_returns=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_max_drawdown_nb(
returns: tp.Array2d,
window: int,
log_returns: bool = False,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `max_drawdown_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
max_drawdown_1d_nb,
log_returns,
)
return out
@register_jitted(cache=True)
def calmar_ratio_1d_nb(
returns: tp.Array1d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[float] = None,
) -> float:
"""Calmar ratio, or drawdown ratio, of a strategy."""
max_drawdown = max_drawdown_1d_nb(returns, log_returns=log_returns)
if max_drawdown == 0:
return np.nan
annualized_return = annualized_return_1d_nb(
returns,
ann_factor,
log_returns=log_returns,
periods=periods,
)
if max_drawdown == 0:
if annualized_return == 0:
return np.nan
return np.inf
return annualized_return / np.abs(max_drawdown)
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
ann_factor=None,
log_returns=None,
periods=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def calmar_ratio_nb(
returns: tp.Array2d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `calmar_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
if periods is None:
period_ = sim_end_ - sim_start_
else:
period_ = to_1d_array_nb(np.asarray(periods).astype(int_))
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_period = flex_select_1d_pc_nb(period_, col)
out[col] = calmar_ratio_1d_nb(
returns[_sim_start:_sim_end, col],
ann_factor,
log_returns=log_returns,
periods=_period,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
log_returns=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_calmar_ratio_nb(
returns: tp.Array2d,
window: int,
ann_factor: float,
log_returns: bool = False,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `calmar_ratio_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
calmar_ratio_1d_nb,
ann_factor,
log_returns,
)
return out
@register_jitted(cache=True)
def deannualized_return_nb(ret: float, ann_factor: float) -> float:
"""Deannualized return."""
if ann_factor == 1:
return ret
if ann_factor <= -1:
return np.nan
return (1 + ret) ** (1.0 / ann_factor) - 1
@register_jitted(cache=True)
def omega_ratio_1d_nb(returns: tp.Array1d) -> float:
"""Omega ratio of a strategy."""
numer = 0.0
denom = 0.0
for i in range(returns.shape[0]):
ret = returns[i]
if ret > 0:
numer += ret
elif ret < 0:
denom -= ret
if denom == 0:
if numer == 0:
return np.nan
return np.inf
return numer / denom
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def omega_ratio_nb(
returns: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `omega_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = omega_ratio_1d_nb(returns[_sim_start:_sim_end, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_omega_ratio_nb(
returns: tp.Array2d,
window: int,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `omega_ratio_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
omega_ratio_1d_nb,
)
return out
@register_jitted(cache=True)
def sharpe_ratio_1d_nb(
returns: tp.Array1d,
ann_factor: float,
ddof: int = 0,
) -> float:
"""Sharpe ratio of a strategy."""
mean = np.nanmean(returns)
std = generic_nb.nanstd_1d_nb(returns, ddof=ddof)
if std == 0:
if mean == 0:
return np.nan
return np.inf
return mean / std * np.sqrt(ann_factor)
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
ann_factor=None,
ddof=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def sharpe_ratio_nb(
returns: tp.Array2d,
ann_factor: float,
ddof: int = 0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `sharpe_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = sharpe_ratio_1d_nb(
returns[_sim_start:_sim_end, col],
ann_factor,
ddof=ddof,
)
return out
@register_jitted(cache=True)
def rolling_sharpe_ratio_acc_nb(in_state: RollSharpeAIS) -> RollSharpeAOS:
"""Accumulator of `rolling_sharpe_ratio_stream_nb`.
Takes a state of type `vectorbtpro.returns.enums.RollSharpeAIS` and returns
a state of type `vectorbtpro.returns.enums.RollSharpeAOS`."""
mean_in_state = generic_enums.RollMeanAIS(
i=in_state.i,
value=in_state.ret,
pre_window_value=in_state.pre_window_ret,
cumsum=in_state.cumsum,
nancnt=in_state.nancnt,
window=in_state.window,
minp=in_state.minp,
)
mean_out_state = generic_nb.rolling_mean_acc_nb(mean_in_state)
std_in_state = generic_enums.RollStdAIS(
i=in_state.i,
value=in_state.ret,
pre_window_value=in_state.pre_window_ret,
cumsum=in_state.cumsum,
cumsum_sq=in_state.cumsum_sq,
nancnt=in_state.nancnt,
window=in_state.window,
minp=in_state.minp,
ddof=in_state.ddof,
)
std_out_state = generic_nb.rolling_std_acc_nb(std_in_state)
mean = mean_out_state.value
std = std_out_state.value
if std == 0:
sharpe = np.nan
else:
sharpe = mean / std * np.sqrt(in_state.ann_factor)
return RollSharpeAOS(
cumsum=std_out_state.cumsum,
cumsum_sq=std_out_state.cumsum_sq,
nancnt=std_out_state.nancnt,
value=sharpe,
)
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
ddof=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def rolling_sharpe_ratio_stream_nb(
returns: tp.Array2d,
window: int,
ann_factor: float,
ddof: int = 0,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling Sharpe ratio in a streaming fashion.
Uses `rolling_sharpe_ratio_acc_nb` at each iteration."""
if window is None:
window = returns.shape[0]
if minp is None:
minp = window
out = np.full(returns.shape, np.nan, dtype=float_)
if returns.shape[0] == 0:
return out
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
cumsum = 0.0
cumsum_sq = 0.0
nancnt = 0
for i in range(_sim_start, _sim_end):
in_state = RollSharpeAIS(
i=i - _sim_start,
ret=returns[i, col],
pre_window_ret=returns[i - window, col] if i - window >= 0 else np.nan,
cumsum=cumsum,
cumsum_sq=cumsum_sq,
nancnt=nancnt,
window=window,
minp=minp,
ddof=ddof,
ann_factor=ann_factor,
)
out_state = rolling_sharpe_ratio_acc_nb(in_state)
cumsum = out_state.cumsum
cumsum_sq = out_state.cumsum_sq
nancnt = out_state.nancnt
out[i, col] = out_state.value
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
ddof=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
stream_mode=None,
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_sharpe_ratio_nb(
returns: tp.Array2d,
window: int,
ann_factor: float,
ddof: int = 0,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
stream_mode: bool = True,
) -> tp.Array2d:
"""Rolling version of `sharpe_ratio_1d_nb`."""
if stream_mode:
return rolling_sharpe_ratio_stream_nb(
returns,
window,
ann_factor,
minp=minp,
ddof=ddof,
sim_start=sim_start,
sim_end=sim_end,
)
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
sharpe_ratio_1d_nb,
ann_factor,
ddof,
)
return out
@register_jitted(cache=True)
def downside_risk_1d_nb(returns: tp.Array1d, ann_factor: float) -> float:
"""Downside deviation below a threshold."""
cnt = 0
adj_ret_sqrd_sum = np.nan
for i in range(returns.shape[0]):
if not np.isnan(returns[i]):
cnt += 1
if np.isnan(adj_ret_sqrd_sum):
adj_ret_sqrd_sum = 0.0
if returns[i] <= 0:
adj_ret_sqrd_sum += returns[i] ** 2
if cnt == 0:
return np.nan
adj_ret_sqrd_mean = adj_ret_sqrd_sum / cnt
return np.sqrt(adj_ret_sqrd_mean) * np.sqrt(ann_factor)
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
ann_factor=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def downside_risk_nb(
returns: tp.Array2d,
ann_factor: float,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `downside_risk_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = downside_risk_1d_nb(returns[_sim_start:_sim_end, col], ann_factor)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_downside_risk_nb(
returns: tp.Array2d,
window: int,
ann_factor: float,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `downside_risk_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
downside_risk_1d_nb,
ann_factor,
)
return out
@register_jitted(cache=True)
def sortino_ratio_1d_nb(returns: tp.Array1d, ann_factor: float) -> float:
"""Sortino ratio of a strategy."""
avg_annualized_return = np.nanmean(returns) * ann_factor
downside_risk = downside_risk_1d_nb(returns, ann_factor)
if downside_risk == 0:
if avg_annualized_return == 0:
return np.nan
return np.inf
return avg_annualized_return / downside_risk
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
ann_factor=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def sortino_ratio_nb(
returns: tp.Array2d,
ann_factor: float,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `sortino_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = sortino_ratio_1d_nb(returns[_sim_start:_sim_end, col], ann_factor)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_sortino_ratio_nb(
returns: tp.Array2d,
window: int,
ann_factor: float,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `sortino_ratio_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
sortino_ratio_1d_nb,
ann_factor,
)
return out
@register_jitted(cache=True)
def information_ratio_1d_nb(returns: tp.Array1d, ddof: int = 0) -> float:
"""Information ratio of a strategy."""
mean = np.nanmean(returns)
std = generic_nb.nanstd_1d_nb(returns, ddof=ddof)
if std == 0:
if mean == 0:
return np.nan
return np.inf
return mean / std
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
ddof=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def information_ratio_nb(
returns: tp.Array2d,
ddof: int = 0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `information_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = information_ratio_1d_nb(returns[_sim_start:_sim_end, col], ddof=ddof)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
ddof=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_information_ratio_nb(
returns: tp.Array2d,
window: int,
ddof: int = 0,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `information_ratio_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
information_ratio_1d_nb,
ddof,
)
return out
@register_jitted(cache=True)
def beta_1d_nb(
returns: tp.Array1d,
bm_returns: tp.Array1d,
ddof: int = 0,
) -> float:
"""Beta."""
cov = generic_nb.nancov_1d_nb(returns, bm_returns, ddof=ddof)
var = generic_nb.nanvar_1d_nb(bm_returns, ddof=ddof)
if var == 0:
if cov == 0:
return np.nan
return np.inf
return cov / var
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
ddof=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def beta_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
ddof: int = 0,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `beta_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = beta_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
ddof=ddof,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
window=None,
ddof=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_beta_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
window: int,
ddof: int = 0,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `beta_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
window,
minp,
beta_1d_nb,
ddof,
)
return out
@register_jitted(cache=True)
def alpha_1d_nb(
returns: tp.Array1d,
bm_returns: tp.Array1d,
ann_factor: float,
) -> float:
"""Annualized alpha."""
beta = beta_1d_nb(returns, bm_returns)
return (np.nanmean(returns) - beta * np.nanmean(bm_returns) + 1) ** ann_factor - 1
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
ann_factor=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def alpha_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
ann_factor: float,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `alpha_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = alpha_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
ann_factor,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_alpha_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
window: int,
ann_factor: float,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `alpha_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
window,
minp,
alpha_1d_nb,
ann_factor,
)
return out
@register_jitted(cache=True)
def tail_ratio_1d_nb(returns: tp.Array1d) -> float:
"""Ratio between the right (95%) and left tail (5%)."""
perc_95 = np.abs(np.nanpercentile(returns, 95))
perc_5 = np.abs(np.nanpercentile(returns, 5))
if perc_5 == 0:
if perc_95 == 0:
return np.nan
return np.inf
return perc_95 / perc_5
@register_jitted(cache=True)
def tail_ratio_noarr_1d_nb(returns: tp.Array1d) -> float:
"""`tail_ratio_1d_nb` that does not allocate any arrays."""
perc_95 = np.abs(generic_nb.nanpercentile_noarr_1d_nb(returns, 95))
perc_5 = np.abs(generic_nb.nanpercentile_noarr_1d_nb(returns, 5))
if perc_5 == 0:
if perc_95 == 0:
return np.nan
return np.inf
return perc_95 / perc_5
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
noarr_mode=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def tail_ratio_nb(
returns: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
noarr_mode: bool = True,
) -> tp.Array1d:
"""2-dim version of `tail_ratio_1d_nb` and `tail_ratio_noarr_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
if noarr_mode:
out[col] = tail_ratio_noarr_1d_nb(returns[_sim_start:_sim_end, col])
else:
out[col] = tail_ratio_1d_nb(returns[_sim_start:_sim_end, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
noarr_mode=None,
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_tail_ratio_nb(
returns: tp.Array2d,
window: int,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
noarr_mode: bool = True,
) -> tp.Array2d:
"""Rolling version of `tail_ratio_1d_nb` and `tail_ratio_noarr_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
if noarr_mode:
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
tail_ratio_noarr_1d_nb,
)
else:
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
tail_ratio_1d_nb,
)
return out
@register_jitted(cache=True)
def profit_factor_1d_nb(returns: tp.Array1d) -> float:
"""Profit factor."""
numer = 0
denom = 0
for i in range(returns.shape[0]):
if not np.isnan(returns[i]):
if returns[i] > 0:
numer += returns[i]
elif returns[i] < 0:
denom += abs(returns[i])
if denom == 0:
if numer == 0:
return np.nan
return np.inf
return numer / denom
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def profit_factor_nb(
returns: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `profit_factor_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = profit_factor_1d_nb(returns[_sim_start:_sim_end, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_profit_factor_nb(
returns: tp.Array2d,
window: int,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `profit_factor_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
profit_factor_1d_nb,
)
return out
@register_jitted(cache=True)
def common_sense_ratio_1d_nb(returns: tp.Array1d) -> float:
"""Common Sense Ratio."""
tail_ratio = tail_ratio_1d_nb(returns)
profit_factor = profit_factor_1d_nb(returns)
return tail_ratio * profit_factor
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def common_sense_ratio_nb(
returns: tp.Array2d,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `common_sense_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[col] = common_sense_ratio_1d_nb(returns[_sim_start:_sim_end, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_common_sense_ratio_nb(
returns: tp.Array2d,
window: int,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `common_sense_ratio_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
common_sense_ratio_1d_nb,
)
return out
@register_jitted(cache=True)
def value_at_risk_1d_nb(returns: tp.Array1d, cutoff: float = 0.05) -> float:
"""Value at risk (VaR) of a returns stream."""
return np.nanpercentile(returns, 100 * cutoff)
@register_jitted(cache=True)
def value_at_risk_noarr_1d_nb(returns: tp.Array1d, cutoff: float = 0.05) -> float:
"""`value_at_risk_1d_nb` that does not allocate any arrays."""
return generic_nb.nanpercentile_noarr_1d_nb(returns, 100 * cutoff)
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
cutoff=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
noarr_mode=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def value_at_risk_nb(
returns: tp.Array2d,
cutoff: float = 0.05,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
noarr_mode: bool = True,
) -> tp.Array1d:
"""2-dim version of `value_at_risk_1d_nb` and `value_at_risk_noarr_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
if noarr_mode:
out[col] = value_at_risk_noarr_1d_nb(returns[_sim_start:_sim_end, col], cutoff=cutoff)
else:
out[col] = value_at_risk_1d_nb(returns[_sim_start:_sim_end, col], cutoff=cutoff)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
cutoff=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
noarr_mode=None,
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_value_at_risk_nb(
returns: tp.Array2d,
window: int,
cutoff: float = 0.05,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
noarr_mode: bool = True,
) -> tp.Array2d:
"""Rolling version of `value_at_risk_1d_nb` and `value_at_risk_noarr_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
if noarr_mode:
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
value_at_risk_noarr_1d_nb,
cutoff,
)
else:
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
value_at_risk_1d_nb,
cutoff,
)
return out
@register_jitted(cache=True)
def cond_value_at_risk_1d_nb(returns: tp.Array1d, cutoff: float = 0.05) -> float:
"""Conditional value at risk (CVaR) of a returns stream."""
cutoff_index = int((len(returns) - 1) * cutoff)
return np.mean(np.partition(returns, cutoff_index)[: cutoff_index + 1])
@register_jitted(cache=True)
def cond_value_at_risk_noarr_1d_nb(returns: tp.Array1d, cutoff: float = 0.05) -> float:
"""`cond_value_at_risk_1d_nb` that does not allocate any arrays."""
return generic_nb.nanpartition_mean_noarr_1d_nb(returns, cutoff * 100)
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
cutoff=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
noarr_mode=None,
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def cond_value_at_risk_nb(
returns: tp.Array2d,
cutoff: float = 0.05,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
noarr_mode: bool = True,
) -> tp.Array1d:
"""2-dim version of `cond_value_at_risk_1d_nb` and `cond_value_at_risk_noarr_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
if noarr_mode:
out[col] = cond_value_at_risk_noarr_1d_nb(returns[_sim_start:_sim_end, col], cutoff=cutoff)
else:
out[col] = cond_value_at_risk_1d_nb(returns[_sim_start:_sim_end, col], cutoff=cutoff)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
window=None,
cutoff=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
noarr_mode=None,
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_cond_value_at_risk_nb(
returns: tp.Array2d,
window: int,
cutoff: float = 0.05,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
noarr_mode: bool = True,
) -> tp.Array2d:
"""Rolling version of `cond_value_at_risk_1d_nb` and `cond_value_at_risk_noarr_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
if noarr_mode:
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
cond_value_at_risk_noarr_1d_nb,
cutoff,
)
else:
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_1d_nb(
returns[_sim_start:_sim_end, col],
window,
minp,
cond_value_at_risk_1d_nb,
cutoff,
)
return out
@register_jitted(cache=True)
def capture_ratio_1d_nb(
returns: tp.Array1d,
bm_returns: tp.Array1d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[float] = None,
) -> float:
"""Capture ratio."""
annualized_return1 = annualized_return_1d_nb(
returns,
ann_factor,
log_returns=log_returns,
periods=periods,
)
annualized_return2 = annualized_return_1d_nb(
bm_returns,
ann_factor,
log_returns=log_returns,
periods=periods,
)
if annualized_return2 == 0:
if annualized_return1 == 0:
return np.nan
return np.inf
return annualized_return1 / annualized_return2
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
ann_factor=None,
log_returns=None,
periods=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def capture_ratio_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `capture_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
if periods is None:
period_ = sim_end_ - sim_start_
else:
period_ = to_1d_array_nb(np.asarray(periods).astype(int_))
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_period = flex_select_1d_pc_nb(period_, col)
out[col] = capture_ratio_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
ann_factor,
log_returns=log_returns,
periods=_period,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
log_returns=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_capture_ratio_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
window: int,
ann_factor: float,
log_returns: bool = False,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `capture_ratio_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
window,
minp,
capture_ratio_1d_nb,
ann_factor,
log_returns,
)
return out
@register_jitted(cache=True)
def up_capture_ratio_1d_nb(
returns: tp.Array1d,
bm_returns: tp.Array1d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[float] = None,
) -> float:
"""Capture ratio for periods when the benchmark return is positive."""
if periods is None:
periods = returns.shape[0]
def _annualized_pos_return(a):
ann_ret = np.nan
ret_cnt = 0
for i in range(a.shape[0]):
if not np.isnan(a[i]):
if log_returns:
_a = np.exp(a[i]) - 1
else:
_a = a[i]
if np.isnan(ann_ret):
ann_ret = 1.0
if _a > 0:
ann_ret *= _a + 1.0
ret_cnt += 1
if ret_cnt == 0:
return np.nan
if periods == 0:
return np.nan
return ann_ret ** (ann_factor / periods) - 1
annualized_return = _annualized_pos_return(returns)
annualized_bm_return = _annualized_pos_return(bm_returns)
if annualized_bm_return == 0:
if annualized_return == 0:
return np.nan
return np.inf
return annualized_return / annualized_bm_return
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
ann_factor=None,
log_returns=None,
periods=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def up_capture_ratio_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `up_capture_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
if periods is None:
period_ = sim_end_ - sim_start_
else:
period_ = to_1d_array_nb(np.asarray(periods).astype(int_))
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_period = flex_select_1d_pc_nb(period_, col)
out[col] = up_capture_ratio_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
ann_factor,
log_returns=log_returns,
periods=_period,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
log_returns=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_up_capture_ratio_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
window: int,
ann_factor: float,
log_returns: bool = False,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `up_capture_ratio_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
window,
minp,
up_capture_ratio_1d_nb,
ann_factor,
log_returns,
)
return out
@register_jitted(cache=True)
def down_capture_ratio_1d_nb(
returns: tp.Array1d,
bm_returns: tp.Array1d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[float] = None,
) -> float:
"""Capture ratio for periods when the benchmark return is negative."""
if periods is None:
periods = returns.shape[0]
def _annualized_neg_return(a):
ann_ret = np.nan
ret_cnt = 0
for i in range(a.shape[0]):
if not np.isnan(a[i]):
if log_returns:
_a = np.exp(a[i]) - 1
else:
_a = a[i]
if np.isnan(ann_ret):
ann_ret = 1.0
if _a < 0:
ann_ret *= _a + 1.0
ret_cnt += 1
if ret_cnt == 0:
return np.nan
if periods == 0:
return np.nan
return ann_ret ** (ann_factor / periods) - 1
annualized_return = _annualized_neg_return(returns)
annualized_bm_return = _annualized_neg_return(bm_returns)
if annualized_bm_return == 0:
if annualized_return == 0:
return np.nan
return np.inf
return annualized_return / annualized_bm_return
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
ann_factor=None,
log_returns=None,
periods=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def down_capture_ratio_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
ann_factor: float,
log_returns: bool = False,
periods: tp.Optional[tp.FlexArray1dLike] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array1d:
"""2-dim version of `down_capture_ratio_1d_nb`."""
out = np.full(returns.shape[1], np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
if periods is None:
period_ = sim_end_ - sim_start_
else:
period_ = to_1d_array_nb(np.asarray(periods).astype(int_))
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
_period = flex_select_1d_pc_nb(period_, col)
out[col] = down_capture_ratio_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
ann_factor,
log_returns=log_returns,
periods=_period,
)
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="returns", axis=1),
arg_take_spec=dict(
returns=ch.ArraySlicer(axis=1),
bm_returns=ch.ArraySlicer(axis=1),
window=None,
ann_factor=None,
log_returns=None,
minp=None,
sim_start=base_ch.FlexArraySlicer(),
sim_end=base_ch.FlexArraySlicer(),
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rolling_down_capture_ratio_nb(
returns: tp.Array2d,
bm_returns: tp.Array2d,
window: int,
ann_factor: float,
log_returns: bool = False,
minp: tp.Optional[int] = None,
sim_start: tp.Optional[tp.FlexArray1dLike] = None,
sim_end: tp.Optional[tp.FlexArray1dLike] = None,
) -> tp.Array2d:
"""Rolling version of `down_capture_ratio_1d_nb`."""
out = np.full(returns.shape, np.nan, dtype=float_)
sim_start_, sim_end_ = generic_nb.prepare_sim_range_nb(
sim_shape=returns.shape,
sim_start=sim_start,
sim_end=sim_end,
)
for col in prange(returns.shape[1]):
_sim_start = sim_start_[col]
_sim_end = sim_end_[col]
if _sim_start >= _sim_end:
continue
out[_sim_start:_sim_end, col] = generic_nb.rolling_reduce_two_1d_nb(
returns[_sim_start:_sim_end, col],
bm_returns[_sim_start:_sim_end, col],
window,
minp,
down_capture_ratio_1d_nb,
ann_factor,
log_returns,
)
return out
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Adapter class for QuantStats.
!!! note
Accessors do not utilize caching.
We can access the adapter from `ReturnsAccessor`:
```pycon
>>> from vectorbtpro import *
>>> import quantstats as qs
>>> np.random.seed(42)
>>> rets = pd.Series(np.random.uniform(-0.1, 0.1, size=(100,)))
>>> bm_returns = pd.Series(np.random.uniform(-0.1, 0.1, size=(100,)))
>>> rets.vbt.returns.qs.r_squared(benchmark=bm_returns)
0.0011582111228735541
```
Which is the same as:
```pycon
>>> qs.stats.r_squared(rets, bm_returns)
```
So why not just using `qs.stats`?
First, we can define all parameters such as benchmark returns once and avoid passing them repeatedly
to every function. Second, vectorbt automatically translates parameters passed to `ReturnsAccessor`
for the use in quantstats.
```pycon
>>> # Defaults that vectorbt understands
>>> ret_acc = rets.vbt.returns(
... bm_returns=bm_returns,
... freq='d',
... year_freq='365d',
... defaults=dict(risk_free=0.001)
... )
>>> ret_acc.qs.r_squared()
0.0011582111228735541
>>> ret_acc.qs.sharpe()
-1.9158923252075455
>>> # Defaults that only quantstats understands
>>> qs_defaults = dict(
... benchmark=bm_returns,
... periods=365,
... rf=0.001
... )
>>> ret_acc_qs = rets.vbt.returns.qs(defaults=qs_defaults)
>>> ret_acc_qs.r_squared()
0.0011582111228735541
>>> ret_acc_qs.sharpe()
-1.9158923252075455
```
The adapter automatically passes the returns to the particular function.
It also merges the defaults defined in the settings, the defaults passed to `ReturnsAccessor`,
and the defaults passed to `QSAdapter` itself, and matches them with the argument names listed
in the function's signature.
For example, the `periods` argument defaults to the annualization factor
`ReturnsAccessor.ann_factor`, which itself is based on the `freq` argument. This makes the results
produced by quantstats and vectorbt at least somewhat similar.
```pycon
>>> vbt.settings.wrapping['freq'] = 'h'
>>> vbt.settings.returns['year_freq'] = '365d'
>>> rets.vbt.returns.sharpe_ratio() # ReturnsAccessor
-9.38160953971508
>>> rets.vbt.returns.qs.sharpe() # quantstats via QSAdapter
-9.38160953971508
```
We can still override any argument by overriding its default or by passing it directly to the function:
```pycon
>>> rets.vbt.returns.qs(defaults=dict(periods=252)).sharpe()
-1.5912029345745982
>>> rets.vbt.returns.qs.sharpe(periods=252)
-1.5912029345745982
>>> qs.stats.sharpe(rets)
-1.5912029345745982
```
"""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("quantstats")
from inspect import getmembers, isfunction, signature, Parameter
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.returns.accessors import ReturnsAccessor
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts, Configured
from vectorbtpro.utils.parsing import get_func_arg_names, has_variable_kwargs
__all__ = [
"QSAdapter",
]
def attach_qs_methods(cls: tp.Type[tp.T], replace_signature: bool = True) -> tp.Type[tp.T]:
"""Class decorator to attach quantstats methods."""
import quantstats as qs
checks.assert_subclass_of(cls, "QSAdapter")
for module_name in ["utils", "stats", "plots", "reports"]:
for qs_func_name, qs_func in getmembers(getattr(qs, module_name), isfunction):
if not qs_func_name.startswith("_") and checks.func_accepts_arg(qs_func, "returns"):
if module_name == "plots":
new_method_name = "plot_" + qs_func_name
elif module_name == "reports":
new_method_name = qs_func_name + "_report"
else:
new_method_name = qs_func_name
def new_method(
self,
*,
_func: tp.Callable = qs_func,
column: tp.Optional[tp.Label] = None,
**kwargs,
) -> tp.Any:
func_arg_names = get_func_arg_names(_func)
has_var_kwargs = has_variable_kwargs(_func)
defaults = self.defaults
if has_var_kwargs:
pass_kwargs = dict(kwargs)
else:
pass_kwargs = {}
for arg_name in func_arg_names:
if arg_name not in kwargs:
if arg_name in defaults:
pass_kwargs[arg_name] = defaults[arg_name]
elif arg_name == "benchmark":
if self.returns_acc.bm_returns is not None:
pass_kwargs["benchmark"] = self.returns_acc.bm_returns
elif arg_name == "periods":
pass_kwargs["periods"] = int(self.returns_acc.ann_factor)
elif arg_name == "periods_per_year":
pass_kwargs["periods_per_year"] = int(self.returns_acc.ann_factor)
elif not has_var_kwargs:
pass_kwargs[arg_name] = kwargs[arg_name]
returns = self.returns_acc.select_col_from_obj(
self.returns_acc.obj,
column=column,
wrapper=self.returns_acc.wrapper.regroup(False),
)
if returns.name is None:
returns = returns.rename("Strategy")
else:
returns = returns.rename(str(returns.name))
null_mask = returns.isnull()
if "benchmark" in pass_kwargs and pass_kwargs["benchmark"] is not None:
benchmark = pass_kwargs["benchmark"]
benchmark = self.returns_acc.select_col_from_obj(
benchmark,
column=column,
wrapper=self.returns_acc.wrapper.regroup(False),
)
if benchmark.name is None:
benchmark = benchmark.rename("Benchmark")
else:
benchmark = benchmark.rename(str(benchmark.name))
bm_null_mask = benchmark.isnull()
null_mask = null_mask | bm_null_mask
benchmark = benchmark.loc[~null_mask]
if isinstance(benchmark.index, pd.DatetimeIndex):
if benchmark.index.tz is not None:
benchmark = benchmark.tz_convert("utc")
if benchmark.index.tz is not None:
benchmark = benchmark.tz_localize(None)
pass_kwargs["benchmark"] = benchmark
returns = returns.loc[~null_mask]
if isinstance(returns.index, pd.DatetimeIndex):
if returns.index.tz is not None:
returns = returns.tz_convert("utc")
if returns.index.tz is not None:
returns = returns.tz_localize(None)
signature(_func).bind(returns=returns, **pass_kwargs)
return _func(returns=returns, **pass_kwargs)
if replace_signature:
# Replace the function's signature with the original one
source_sig = signature(qs_func)
new_method_params = tuple(signature(new_method).parameters.values())
self_arg = new_method_params[0]
column_arg = new_method_params[2]
other_args = [
(
p.replace(kind=Parameter.KEYWORD_ONLY)
if p.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
else p
)
for p in list(source_sig.parameters.values())[1:]
]
source_sig = source_sig.replace(parameters=(self_arg, column_arg) + tuple(other_args))
new_method.__signature__ = source_sig
new_method.__name__ = new_method_name
new_method.__module__ = cls.__module__
new_method.__qualname__ = f"{cls.__name__}.{new_method.__name__}"
new_method.__doc__ = f"See `quantstats.{module_name}.{qs_func_name}`."
setattr(cls, new_method_name, new_method)
return cls
QSAdapterT = tp.TypeVar("QSAdapterT", bound="QSAdapter")
@attach_qs_methods
class QSAdapter(Configured):
"""Adapter class for quantstats."""
def __init__(self, returns_acc: ReturnsAccessor, defaults: tp.KwargsLike = None, **kwargs) -> None:
checks.assert_instance_of(returns_acc, ReturnsAccessor)
self._returns_acc = returns_acc
self._defaults = defaults
Configured.__init__(self, returns_acc=returns_acc, defaults=defaults, **kwargs)
def __call__(self: QSAdapterT, **kwargs) -> QSAdapterT:
"""Allows passing arguments to the initializer."""
return self.replace(**kwargs)
@property
def returns_acc(self) -> ReturnsAccessor:
"""Returns accessor."""
return self._returns_acc
@property
def defaults_mapping(self) -> tp.Dict:
"""Common argument names in quantstats mapped to `ReturnsAccessor.defaults`."""
return dict(rf="risk_free", rolling_period="window")
@property
def defaults(self) -> tp.Kwargs:
"""Defaults for `QSAdapter`.
Merges `defaults` from `vectorbtpro._settings.qs_adapter`, `returns_acc.defaults`
(with adapted naming), and `defaults` from `QSAdapter.__init__`."""
from vectorbtpro._settings import settings
qs_adapter_defaults_cfg = settings["qs_adapter"]["defaults"]
mapped_defaults = dict()
for k, v in self.defaults_mapping.items():
if v in self.returns_acc.defaults:
mapped_defaults[k] = self.returns_acc.defaults[v]
return merge_dicts(qs_adapter_defaults_cfg, mapped_defaults, self._defaults)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom signal generators built with the signal factory.
You can access all the indicators by `vbt.*`."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.signals.generators.ohlcstcx import *
from vectorbtpro.signals.generators.ohlcstx import *
from vectorbtpro.signals.generators.rand import *
from vectorbtpro.signals.generators.randnx import *
from vectorbtpro.signals.generators.randx import *
from vectorbtpro.signals.generators.rprob import *
from vectorbtpro.signals.generators.rprobcx import *
from vectorbtpro.signals.generators.rprobnx import *
from vectorbtpro.signals.generators.rprobx import *
from vectorbtpro.signals.generators.stcx import *
from vectorbtpro.signals.generators.stx import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `OHLCSTCX`."""
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.generators.ohlcstx import ohlcstx_config, ohlcstx_func_config, _bind_ohlcstx_plot
__all__ = [
"OHLCSTCX",
]
__pdoc__ = {}
OHLCSTCX = SignalFactory(
**ohlcstx_config.merge_with(
dict(
class_name="OHLCSTCX",
module_name=__name__,
short_name="ohlcstcx",
mode="chain",
)
),
).with_place_func(
**ohlcstx_func_config,
)
class _OHLCSTCX(OHLCSTCX):
"""Exit signal generator based on OHLC and stop values.
Generates chain of `new_entries` and `exits` based on `entries` and
`vectorbtpro.signals.nb.ohlc_stop_place_nb`.
See `OHLCSTX` for notes on parameters."""
plot = _bind_ohlcstx_plot(OHLCSTCX, "new_entries")
setattr(OHLCSTCX, "__doc__", _OHLCSTCX.__doc__)
setattr(OHLCSTCX, "plot", _OHLCSTCX.plot)
OHLCSTCX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `OHLCSTX`."""
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.signals.enums import StopType
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.nb import ohlc_stop_place_nb
from vectorbtpro.utils.config import ReadonlyConfig, merge_dicts
__all__ = [
"OHLCSTX",
]
__pdoc__ = {}
ohlcstx_config = ReadonlyConfig(
dict(
class_name="OHLCSTX",
module_name=__name__,
short_name="ohlcstx",
mode="exits",
input_names=["entry_price", "open", "high", "low", "close"],
in_output_names=["stop_price", "stop_type"],
param_names=["sl_stop", "tsl_th", "tsl_stop", "tp_stop", "reverse"],
attr_settings=dict(
stop_type=dict(dtype=StopType),
),
)
)
"""Factory config for `OHLCSTX`."""
ohlcstx_func_config = ReadonlyConfig(
dict(
exit_place_func_nb=ohlc_stop_place_nb,
exit_settings=dict(
pass_inputs=["entry_price", "open", "high", "low", "close"],
pass_in_outputs=["stop_price", "stop_type"],
pass_params=["sl_stop", "tsl_th", "tsl_stop", "tp_stop", "reverse"],
pass_kwargs=["is_entry_open"],
),
in_output_settings=dict(
stop_price=dict(dtype=float_),
stop_type=dict(dtype=int_),
),
param_settings=dict(
sl_stop=flex_elem_param_config,
tsl_th=flex_elem_param_config,
tsl_stop=flex_elem_param_config,
tp_stop=flex_elem_param_config,
reverse=flex_elem_param_config,
),
open=np.nan,
high=np.nan,
low=np.nan,
close=np.nan,
stop_price=np.nan,
stop_type=-1,
sl_stop=np.nan,
tsl_th=np.nan,
tsl_stop=np.nan,
tp_stop=np.nan,
reverse=False,
is_entry_open=False,
)
)
"""Exit function config for `OHLCSTX`."""
OHLCSTX = SignalFactory(**ohlcstx_config).with_place_func(**ohlcstx_func_config)
def _bind_ohlcstx_plot(base_cls: type, entries_attr: str) -> tp.Callable:
base_cls_plot = base_cls.plot
def plot(
self,
column: tp.Optional[tp.Label] = None,
ohlc_kwargs: tp.KwargsLike = None,
entry_price_kwargs: tp.KwargsLike = None,
entry_trace_kwargs: tp.KwargsLike = None,
exit_trace_kwargs: tp.KwargsLike = None,
add_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
_base_cls_plot: tp.Callable = base_cls_plot,
**layout_kwargs,
) -> tp.BaseFigure:
self_col = self.select_col(column=column, group_by=False)
if ohlc_kwargs is None:
ohlc_kwargs = {}
if entry_price_kwargs is None:
entry_price_kwargs = {}
if add_trace_kwargs is None:
add_trace_kwargs = {}
open_any = not self_col.open.isnull().all()
high_any = not self_col.high.isnull().all()
low_any = not self_col.low.isnull().all()
close_any = not self_col.close.isnull().all()
if open_any and high_any and low_any and close_any:
ohlc_df = pd.concat((self_col.open, self_col.high, self_col.low, self_col.close), axis=1)
ohlc_df.columns = ["Open", "High", "Low", "Close"]
ohlc_kwargs = merge_dicts(layout_kwargs, dict(ohlc_trace_kwargs=dict(opacity=0.5)), ohlc_kwargs)
fig = ohlc_df.vbt.ohlcv.plot(fig=fig, **ohlc_kwargs)
else:
entry_price_kwargs = merge_dicts(layout_kwargs, entry_price_kwargs)
fig = self_col.entry_price.rename("Entry price").vbt.lineplot(fig=fig, **entry_price_kwargs)
_base_cls_plot(
self_col,
entry_y=self_col.entry_price,
exit_y=self_col.stop_price,
exit_types=self_col.stop_type_readable,
entry_trace_kwargs=entry_trace_kwargs,
exit_trace_kwargs=exit_trace_kwargs,
add_trace_kwargs=add_trace_kwargs,
fig=fig,
**layout_kwargs,
)
return fig
plot.__doc__ = """Plot OHLC, `{0}.{1}` and `{0}.exits`.
Args:
ohlc_kwargs (dict): Keyword arguments passed to
`vectorbtpro.ohlcv.accessors.OHLCVDFAccessor.plot`.
entry_trace_kwargs (dict): Keyword arguments passed to
`vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_entries` for `{0}.{1}`.
exit_trace_kwargs (dict): Keyword arguments passed to
`vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_exits` for `{0}.exits`.
fig (Figure or FigureWidget): Figure to add traces to.
**layout_kwargs: Keyword arguments for layout.""".format(
base_cls.__name__,
entries_attr,
)
if entries_attr == "entries":
plot.__doc__ += """
Usage:
```pycon
>>> ohlcstx.iloc[:, 0].plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
return plot
class _OHLCSTX(OHLCSTX):
"""Exit signal generator based on OHLC and stop values.
Generates `exits` based on `entries` and `vectorbtpro.signals.nb.ohlc_stop_place_nb`.
!!! hint
All parameters can be either a single value (per frame) or a NumPy array (per row, column,
or element). To generate multiple combinations, pass them as lists.
!!! warning
Searches for an exit after each entry. If two entries come one after another, no exit can be placed.
Consider either cleaning up entry signals prior to passing, or using `OHLCSTCX`.
Usage:
Test each stop type:
```pycon
>>> from vectorbtpro import *
>>> entries = pd.Series([True, False, False, False, False, False])
>>> price = pd.DataFrame({
... 'open': [10, 11, 12, 11, 10, 9],
... 'high': [11, 12, 13, 12, 11, 10],
... 'low': [9, 10, 11, 10, 9, 8],
... 'close': [10, 11, 12, 11, 10, 9]
... })
>>> ohlcstx = vbt.OHLCSTX.run(
... entries,
... price['open'],
... price['open'],
... price['high'],
... price['low'],
... price['close'],
... sl_stop=[0.1, np.nan, np.nan, np.nan],
... tsl_th=[np.nan, np.nan, 0.2, np.nan],
... tsl_stop=[np.nan, 0.1, 0.3, np.nan],
... tp_stop=[np.nan, np.nan, np.nan, 0.1],
... is_entry_open=True
... )
>>> ohlcstx.entries
ohlcstx_sl_stop 0.1 NaN NaN NaN
ohlcstx_tsl_th NaN NaN 0.2 NaN
ohlcstx_tsl_stop NaN 0.1 0.3 NaN
ohlcstx_tp_stop NaN NaN NaN 0.1
0 True True True True
1 False False False False
2 False False False False
3 False False False False
4 False False False False
5 False False False False
>>> ohlcstx.exits
ohlcstx_sl_stop 0.1 NaN NaN NaN
ohlcstx_tsl_th NaN NaN 0.2 NaN
ohlcstx_tsl_stop NaN 0.1 0.3 NaN
ohlcstx_tp_stop NaN NaN NaN 0.1
0 False False False False
1 False False False True
2 False False False False
3 False True False False
4 True False True False
5 False False False False
>>> ohlcstx.stop_price
ohlcstx_sl_stop 0.1 NaN NaN NaN
ohlcstx_tsl_th NaN NaN 0.2 NaN
ohlcstx_tsl_stop NaN 0.1 0.3 NaN
ohlcstx_tp_stop NaN NaN NaN 0.1
0 NaN NaN NaN NaN
1 NaN NaN NaN 11.0
2 NaN NaN NaN NaN
3 NaN 11.7 NaN NaN
4 9.0 NaN 9.1 NaN
5 NaN NaN NaN NaN
>>> ohlcstx.stop_type_readable
ohlcstx_sl_stop 0.1 NaN NaN NaN
ohlcstx_tsl_th NaN NaN 0.2 NaN
ohlcstx_tsl_stop NaN 0.1 0.3 NaN
ohlcstx_tp_stop NaN NaN NaN 0.1
0 None None None None
1 None None None TP
2 None None None None
3 None TSL None None
4 SL None TTP None
5 None None None None
```
"""
plot = _bind_ohlcstx_plot(OHLCSTX, "entries")
setattr(OHLCSTX, "__doc__", _OHLCSTX.__doc__)
setattr(OHLCSTX, "plot", _OHLCSTX.plot)
OHLCSTX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RAND`."""
import numpy as np
from vectorbtpro.indicators.configs import flex_col_param_config
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.nb import rand_place_nb
__all__ = [
"RAND",
]
__pdoc__ = {}
RAND = SignalFactory(
class_name="RAND",
module_name=__name__,
short_name="rand",
mode="entries",
param_names=["n"],
).with_place_func(
entry_place_func_nb=rand_place_nb,
entry_settings=dict(
pass_params=["n"],
),
param_settings=dict(
n=flex_col_param_config,
),
seed=None,
)
class _RAND(RAND):
"""Random entry signal generator based on the number of signals.
Generates `entries` based on `vectorbtpro.signals.nb.rand_place_nb`.
!!! hint
Parameter `n` can be either a single value (per frame) or a NumPy array (per column).
To generate multiple combinations, pass it as a list.
Usage:
Test three different entry counts values:
```pycon
>>> from vectorbtpro import *
>>> rand = vbt.RAND.run(input_shape=(6,), n=[1, 2, 3], seed=42)
>>> rand.entries
rand_n 1 2 3
0 True True True
1 False False True
2 False False False
3 False True False
4 False False True
5 False False False
```
Entry count can also be set per column:
```pycon
>>> rand = vbt.RAND.run(input_shape=(8, 2), n=[np.array([1, 2]), 3], seed=42)
>>> rand.entries
rand_n 1 2 3 3
0 1 0 1
0 False False True False
1 True False False False
2 False False False True
3 False True True False
4 False False False False
5 False False False True
6 False False True False
7 False True False True
```
"""
pass
setattr(RAND, "__doc__", _RAND.__doc__)
RAND.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RANDNX`."""
from vectorbtpro.indicators.configs import flex_col_param_config
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.nb import rand_enex_apply_nb
__all__ = [
"RANDNX",
]
__pdoc__ = {}
RANDNX = SignalFactory(
class_name="RANDNX",
module_name=__name__,
short_name="randnx",
mode="both",
param_names=["n"],
).with_apply_func(
rand_enex_apply_nb,
require_input_shape=True,
param_settings=dict(
n=flex_col_param_config,
),
kwargs_as_args=["entry_wait", "exit_wait"],
entry_wait=1,
exit_wait=1,
seed=None,
)
class _RANDNX(RANDNX):
"""Random entry and exit signal generator based on the number of signals.
Generates `entries` and `exits` based on `vectorbtpro.signals.nb.rand_enex_apply_nb`.
See `RAND` for notes on parameters.
Usage:
Test three different entry and exit counts:
```pycon
>>> from vectorbtpro import *
>>> randnx = vbt.RANDNX.run(
... input_shape=(6,),
... n=[1, 2, 3],
... seed=42)
>>> randnx.entries
randnx_n 1 2 3
0 True True True
1 False False False
2 False True True
3 False False False
4 False False True
5 False False False
>>> randnx.exits
randnx_n 1 2 3
0 False False False
1 True True True
2 False False False
3 False True True
4 False False False
5 False False True
```
"""
pass
setattr(RANDNX, "__doc__", _RANDNX.__doc__)
RANDNX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RANDX`."""
import numpy as np
import pandas as pd
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.nb import rand_place_nb
__all__ = [
"RANDX",
]
__pdoc__ = {}
RANDX = SignalFactory(
class_name="RANDX",
module_name=__name__,
short_name="randx",
mode="exits",
).with_place_func(
exit_place_func_nb=rand_place_nb,
exit_settings=dict(
pass_kwargs=dict(n=np.array([1])),
),
seed=None,
)
class _RANDX(RANDX):
"""Random exit signal generator based on the number of signals.
Generates `exits` based on `entries` and `vectorbtpro.signals.nb.rand_place_nb`.
See `RAND` for notes on parameters.
Usage:
Generate an exit for each entry:
```pycon
>>> from vectorbtpro import *
>>> entries = pd.Series([True, False, False, True, False, False])
>>> randx = vbt.RANDX.run(entries, seed=42)
>>> randx.exits
0 False
1 False
2 True
3 False
4 True
5 False
dtype: bool
```
"""
pass
setattr(RANDX, "__doc__", _RANDX.__doc__)
RANDX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RPROB`."""
import numpy as np
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.nb import rand_by_prob_place_nb
__all__ = [
"RPROB",
]
__pdoc__ = {}
RPROB = SignalFactory(
class_name="RPROB",
module_name=__name__,
short_name="rprob",
mode="entries",
param_names=["prob"],
).with_place_func(
entry_place_func_nb=rand_by_prob_place_nb,
entry_settings=dict(
pass_params=["prob"],
pass_kwargs=["pick_first"],
),
param_settings=dict(
prob=flex_elem_param_config,
),
seed=None,
)
class _RPROB(RPROB):
"""Random entry signal generator based on probabilities.
Generates `entries` based on `vectorbtpro.signals.nb.rand_by_prob_place_nb`.
!!! hint
All parameters can be either a single value (per frame) or a NumPy array (per row, column,
or element). To generate multiple combinations, pass them as lists.
Usage:
Generate three columns with different entry probabilities:
```pycon
>>> from vectorbtpro import *
>>> rprob = vbt.RPROB.run(input_shape=(5,), prob=[0., 0.5, 1.], seed=42)
>>> rprob.entries
rprob_prob 0.0 0.5 1.0
0 False True True
1 False True True
2 False False True
3 False False True
4 False False True
```
Probability can also be set per row, column, or element:
```pycon
>>> rprob = vbt.RPROB.run(input_shape=(5,), prob=np.array([0., 0., 1., 1., 1.]), seed=42)
>>> rprob.entries
0 False
1 False
2 True
3 True
4 True
Name: array_0, dtype: bool
```
"""
pass
setattr(RPROB, "__doc__", _RPROB.__doc__)
RPROB.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RPROBCX`."""
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.generators.rprobx import rprobx_config, rprobx_func_config
__all__ = [
"RPROBCX",
]
__pdoc__ = {}
RPROBCX = SignalFactory(
**rprobx_config.merge_with(
dict(
class_name="RPROBCX",
module_name=__name__,
short_name="rprobcx",
mode="chain",
)
),
).with_place_func(**rprobx_func_config)
class _RPROBCX(RPROBCX):
"""Random exit signal generator based on probabilities.
Generates chain of `new_entries` and `exits` based on `entries` and
`vectorbtpro.signals.nb.rand_by_prob_place_nb`.
See `RPROB` for notes on parameters."""
pass
setattr(RPROBCX, "__doc__", _RPROBCX.__doc__)
RPROBCX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RPROBNX`."""
import numpy as np
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.nb import rand_by_prob_place_nb
__all__ = [
"RPROBNX",
]
__pdoc__ = {}
RPROBNX = SignalFactory(
class_name="RPROBNX",
module_name=__name__,
short_name="rprobnx",
mode="both",
param_names=["entry_prob", "exit_prob"],
).with_place_func(
entry_place_func_nb=rand_by_prob_place_nb,
entry_settings=dict(
pass_params=["entry_prob"],
pass_kwargs=["pick_first"],
),
exit_place_func_nb=rand_by_prob_place_nb,
exit_settings=dict(
pass_params=["exit_prob"],
pass_kwargs=["pick_first"],
),
param_settings=dict(
entry_prob=flex_elem_param_config,
exit_prob=flex_elem_param_config,
),
seed=None,
)
class _RPROBNX(RPROBNX):
"""Random entry and exit signal generator based on probabilities.
Generates `entries` and `exits` based on `vectorbtpro.signals.nb.rand_by_prob_place_nb`.
See `RPROB` for notes on parameters.
Usage:
Test all probability combinations:
```pycon
>>> from vectorbtpro import *
>>> rprobnx = vbt.RPROBNX.run(
... input_shape=(5,),
... entry_prob=[0.5, 1.],
... exit_prob=[0.5, 1.],
... param_product=True,
... seed=42)
>>> rprobnx.entries
rprobnx_entry_prob 0.5 0.5 1.0 0.5
rprobnx_exit_prob 0.5 1.0 0.5 1.0
0 True True True True
1 False False False False
2 False False False True
3 False False False False
4 False False True True
>>> rprobnx.exits
rprobnx_entry_prob 0.5 0.5 1.0 1.0
rprobnx_exit_prob 0.5 1.0 0.5 1.0
0 False False False False
1 False True False True
2 False False False False
3 False False True True
4 True False False False
```
Probabilities can also be set per row, column, or element:
```pycon
>>> entry_prob1 = np.array([1., 0., 1., 0., 1.])
>>> entry_prob2 = np.array([0., 1., 0., 1., 0.])
>>> rprobnx = vbt.RPROBNX.run(
... input_shape=(5,),
... entry_prob=[entry_prob1, entry_prob2],
... exit_prob=1.,
... seed=42)
>>> rprobnx.entries
rprobnx_entry_prob array_0 array_1
rprobnx_exit_prob 1.0 1.0
0 True False
1 False True
2 True False
3 False True
4 True False
>>> rprobnx.exits
rprobnx_entry_prob array_0 array_1
rprobnx_exit_prob 1.0 1.0
0 False False
1 True False
2 False True
3 True False
4 False True
```
"""
pass
setattr(RPROBNX, "__doc__", _RPROBNX.__doc__)
RPROBNX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `RPROBX`."""
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.nb import rand_by_prob_place_nb
from vectorbtpro.utils.config import ReadonlyConfig
__all__ = [
"RPROBX",
]
__pdoc__ = {}
rprobx_config = ReadonlyConfig(
dict(
class_name="RPROBX",
module_name=__name__,
short_name="rprobx",
mode="exits",
param_names=["prob"],
),
)
"""Factory config for `RPROBX`."""
rprobx_func_config = ReadonlyConfig(
dict(
exit_place_func_nb=rand_by_prob_place_nb,
exit_settings=dict(
pass_params=["prob"],
pass_kwargs=["pick_first"],
),
param_settings=dict(
prob=flex_elem_param_config,
),
seed=None,
)
)
"""Exit function config for `RPROBX`."""
RPROBX = SignalFactory(**rprobx_config).with_place_func(**rprobx_func_config)
class _RPROBX(RPROBX):
"""Random exit signal generator based on probabilities.
Generates `exits` based on `entries` and `vectorbtpro.signals.nb.rand_by_prob_place_nb`.
See `RPROB` for notes on parameters."""
pass
setattr(RPROBX, "__doc__", _RPROBX.__doc__)
RPROBX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `STCX`."""
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.generators.stx import stx_config, stx_func_config
__all__ = [
"STCX",
]
__pdoc__ = {}
STCX = SignalFactory(
**stx_config.merge_with(
dict(
class_name="STCX",
module_name=__name__,
short_name="stcx",
mode="chain",
)
)
).with_place_func(**stx_func_config)
class _STCX(STCX):
"""Exit signal generator based on stop values.
Generates chain of `new_entries` and `exits` based on `entries` and
`vectorbtpro.signals.nb.stop_place_nb`.
See `STX` for notes on parameters."""
pass
setattr(STCX, "__doc__", _STCX.__doc__)
STCX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Module with `STX`."""
import numpy as np
from vectorbtpro.indicators.configs import flex_elem_param_config
from vectorbtpro.signals.factory import SignalFactory
from vectorbtpro.signals.nb import stop_place_nb
from vectorbtpro.utils.config import ReadonlyConfig
__all__ = [
"STX",
]
__pdoc__ = {}
stx_config = ReadonlyConfig(
dict(
class_name="STX",
module_name=__name__,
short_name="stx",
mode="exits",
input_names=["entry_ts", "ts", "follow_ts"],
in_output_names=["stop_ts"],
param_names=["stop", "trailing"],
)
)
"""Factory config for `STX`."""
stx_func_config = ReadonlyConfig(
dict(
exit_place_func_nb=stop_place_nb,
exit_settings=dict(
pass_inputs=["entry_ts", "ts", "follow_ts"],
pass_in_outputs=["stop_ts"],
pass_params=["stop", "trailing"],
),
param_settings=dict(
stop=flex_elem_param_config,
trailing=flex_elem_param_config,
),
trailing=False,
ts=np.nan,
follow_ts=np.nan,
stop_ts=np.nan,
)
)
"""Exit function config for `STX`."""
STX = SignalFactory(**stx_config).with_place_func(**stx_func_config)
class _STX(STX):
"""Exit signal generator based on stop values.
Generates `exits` based on `entries` and `vectorbtpro.signals.nb.stop_place_nb`.
!!! hint
All parameters can be either a single value (per frame) or a NumPy array (per row, column,
or element). To generate multiple combinations, pass them as lists."""
pass
setattr(STX, "__doc__", _STX.__doc__)
STX.fix_docstrings(__pdoc__)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Modules for working with signals."""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.signals.accessors import *
from vectorbtpro.signals.factory import *
from vectorbtpro.signals.generators import *
from vectorbtpro.signals.nb import *
__exclude_from__all__ = [
"enums",
]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom Pandas accessors for signals.
Methods can be accessed as follows:
* `SignalsSRAccessor` -> `pd.Series.vbt.signals.*`
* `SignalsDFAccessor` -> `pd.DataFrame.vbt.signals.*`
```pycon
>>> from vectorbtpro import *
>>> # vectorbtpro.signals.accessors.SignalsAccessor.pos_rank
>>> pd.Series([False, True, True, True, False]).vbt.signals.pos_rank()
0 -1
1 0
2 1
3 2
4 -1
dtype: int64
```
The accessors extend `vectorbtpro.generic.accessors`.
!!! note
The underlying Series/DataFrame must already be a signal series and have boolean data type.
Grouping is only supported by the methods that accept the `group_by` argument.
Accessors do not utilize caching.
Run for the examples below:
```pycon
>>> mask = pd.DataFrame({
... 'a': [True, False, False, False, False],
... 'b': [True, False, True, False, True],
... 'c': [True, True, True, False, False]
... }, index=pd.date_range("2020", periods=5))
>>> mask
a b c
2020-01-01 True True True
2020-01-02 False False True
2020-01-03 False True True
2020-01-04 False False False
2020-01-05 False True False
```
## Stats
!!! hint
See `vectorbtpro.generic.stats_builder.StatsBuilderMixin.stats` and `SignalsAccessor.metrics`.
```pycon
>>> mask.vbt.signals.stats(column='a')
Start 2020-01-01 00:00:00
End 2020-01-05 00:00:00
Period 5 days 00:00:00
Total 1
Rate [%] 20.0
First Index 2020-01-01 00:00:00
Last Index 2020-01-01 00:00:00
Norm Avg Index [-1, 1] -1.0
Distance: Min NaT
Distance: Median NaT
Distance: Max NaT
Total Partitions 1
Partition Rate [%] 100.0
Partition Length: Min 1 days 00:00:00
Partition Length: Median 1 days 00:00:00
Partition Length: Max 1 days 00:00:00
Partition Distance: Min NaT
Partition Distance: Median NaT
Partition Distance: Max NaT
Name: a, dtype: object
```
We can pass another signal array to compare this array with:
```pycon
>>> mask.vbt.signals.stats(column='a', settings=dict(target=mask['b']))
Start 2020-01-01 00:00:00
End 2020-01-05 00:00:00
Period 5 days 00:00:00
Total 1
Rate [%] 20.0
Total Overlapping 1
Overlapping Rate [%] 33.333333
First Index 2020-01-01 00:00:00
Last Index 2020-01-01 00:00:00
Norm Avg Index [-1, 1] -1.0
Distance -> Target: Min 0 days 00:00:00
Distance -> Target: Median 2 days 00:00:00
Distance -> Target: Max 4 days 00:00:00
Total Partitions 1
Partition Rate [%] 100.0
Partition Length: Min 1 days 00:00:00
Partition Length: Median 1 days 00:00:00
Partition Length: Max 1 days 00:00:00
Partition Distance: Min NaT
Partition Distance: Median NaT
Partition Distance: Max NaT
Name: a, dtype: object
```
We can also return duration as a floating number rather than a timedelta:
```pycon
>>> mask.vbt.signals.stats(column='a', settings=dict(to_timedelta=False))
Start 2020-01-01 00:00:00
End 2020-01-05 00:00:00
Period 5
Total 1
Rate [%] 20.0
First Index 2020-01-01 00:00:00
Last Index 2020-01-01 00:00:00
Norm Avg Index [-1, 1] -1.0
Distance: Min NaN
Distance: Median NaN
Distance: Max NaN
Total Partitions 1
Partition Rate [%] 100.0
Partition Length: Min 1.0
Partition Length: Median 1.0
Partition Length: Max 1.0
Partition Distance: Min NaN
Partition Distance: Median NaN
Partition Distance: Max NaN
Name: a, dtype: object
```
`SignalsAccessor.stats` also supports (re-)grouping:
```pycon
>>> mask.vbt.signals.stats(column=0, group_by=[0, 0, 1])
Start 2020-01-01 00:00:00
End 2020-01-05 00:00:00
Period 5 days 00:00:00
Total 4
Rate [%] 40.0
First Index 2020-01-01 00:00:00
Last Index 2020-01-05 00:00:00
Norm Avg Index [-1, 1] -0.25
Distance: Min 2 days 00:00:00
Distance: Median 2 days 00:00:00
Distance: Max 2 days 00:00:00
Total Partitions 4
Partition Rate [%] 100.0
Partition Length: Min 1 days 00:00:00
Partition Length: Median 1 days 00:00:00
Partition Length: Max 1 days 00:00:00
Partition Distance: Min 2 days 00:00:00
Partition Distance: Median 2 days 00:00:00
Partition Distance: Max 2 days 00:00:00
Name: 0, dtype: object
```
## Plots
!!! hint
See `vectorbtpro.generic.plots_builder.PlotsBuilderMixin.plots` and `SignalsAccessor.subplots`.
This class inherits subplots from `vectorbtpro.generic.accessors.GenericAccessor`.
"""
from functools import partialmethod
import numpy as np
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.accessors import register_vbt_accessor, register_df_vbt_accessor, register_sr_vbt_accessor
from vectorbtpro.base import chunking as base_ch, reshaping, indexes
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.accessors import GenericAccessor, GenericSRAccessor, GenericDFAccessor
from vectorbtpro.generic.ranges import Ranges
from vectorbtpro.records.mapped_array import MappedArray
from vectorbtpro.registries.ch_registry import ch_reg
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.signals import nb, enums
from vectorbtpro.utils import checks
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.colors import adjust_lightness
from vectorbtpro.utils.config import resolve_dict, merge_dicts, Config, HybridConfig
from vectorbtpro.utils.decorators import hybrid_method, hybrid_property
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.random_ import set_seed_nb
from vectorbtpro.utils.template import RepEval, substitute_templates
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"SignalsAccessor",
"SignalsSRAccessor",
"SignalsDFAccessor",
]
__pdoc__ = {}
@register_vbt_accessor("signals")
class SignalsAccessor(GenericAccessor):
"""Accessor on top of signal series. For both, Series and DataFrames.
Accessible via `pd.Series.vbt.signals` and `pd.DataFrame.vbt.signals`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
**kwargs,
) -> None:
GenericAccessor.__init__(self, wrapper, obj=obj, **kwargs)
checks.assert_dtype(self._obj, np.bool_)
@hybrid_property
def sr_accessor_cls(cls_or_self) -> tp.Type["SignalsSRAccessor"]:
"""Accessor class for `pd.Series`."""
return SignalsSRAccessor
@hybrid_property
def df_accessor_cls(cls_or_self) -> tp.Type["SignalsDFAccessor"]:
"""Accessor class for `pd.DataFrame`."""
return SignalsDFAccessor
# ############# Overriding ############# #
@classmethod
def empty(cls, *args, fill_value: bool = False, **kwargs) -> tp.SeriesFrame:
"""`vectorbtpro.base.accessors.BaseAccessor.empty` with `fill_value=False`."""
return GenericAccessor.empty(*args, fill_value=fill_value, dtype=np.bool_, **kwargs)
@classmethod
def empty_like(cls, *args, fill_value: bool = False, **kwargs) -> tp.SeriesFrame:
"""`vectorbtpro.base.accessors.BaseAccessor.empty_like` with `fill_value=False`."""
return GenericAccessor.empty_like(*args, fill_value=fill_value, dtype=np.bool_, **kwargs)
bshift = partialmethod(GenericAccessor.bshift, fill_value=False)
fshift = partialmethod(GenericAccessor.fshift, fill_value=False)
ago = partialmethod(GenericAccessor.ago, fill_value=False)
realign = partialmethod(GenericAccessor.realign, nan_value=False)
# ############# Generation ############# #
@classmethod
def generate(
cls,
shape: tp.Union[tp.ShapeLike, ArrayWrapper],
place_func_nb: tp.PlaceFunc,
*args,
place_args: tp.ArgsLike = None,
only_once: bool = True,
wait: int = 1,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.signals.nb.generate_nb`.
`shape` can be a shape-like tuple or an instance of `vectorbtpro.base.wrapping.ArrayWrapper`
(will be used as `wrapper`).
Arguments to `place_func_nb` can be passed either as `*args` or `place_args` (but not both!).
Usage:
* Generate random signals manually:
```pycon
>>> @njit
... def place_func_nb(c):
... i = np.random.choice(len(c.out))
... c.out[i] = True
... return i
>>> vbt.pd_acc.signals.generate(
... (5, 3),
... place_func_nb,
... wrap_kwargs=dict(
... index=mask.index,
... columns=mask.columns
... )
... )
a b c
2020-01-01 True False False
2020-01-02 False True False
2020-01-03 False False True
2020-01-04 False False False
2020-01-05 False False False
```
"""
if isinstance(shape, ArrayWrapper):
wrapper = shape
shape = wrapper.shape
if len(args) > 0 and place_args is not None:
raise ValueError("Must provide either *args or place_args, not both")
if place_args is None:
place_args = args
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
shape_2d = cls.resolve_shape(shape)
if len(broadcast_named_args) > 0:
broadcast_named_args = reshaping.broadcast(broadcast_named_args, to_shape=shape_2d, **broadcast_kwargs)
template_context = merge_dicts(
broadcast_named_args,
dict(shape=shape, shape_2d=shape_2d, wait=wait),
template_context,
)
place_args = substitute_templates(place_args, template_context, eval_id="place_args")
func = jit_reg.resolve_option(nb.generate_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
result = func(
target_shape=shape_2d,
place_func_nb=place_func_nb,
place_args=place_args,
only_once=only_once,
wait=wait,
)
if wrapper is None:
wrapper = ArrayWrapper.from_shape(shape, ndim=cls.ndim)
if wrap_kwargs is None:
wrap_kwargs = resolve_dict(wrap_kwargs)
return wrapper.wrap(result, **wrap_kwargs)
@classmethod
def generate_both(
cls,
shape: tp.Union[tp.ShapeLike, ArrayWrapper],
entry_place_func_nb: tp.PlaceFunc,
exit_place_func_nb: tp.PlaceFunc,
*args,
entry_place_args: tp.ArgsLike = None,
exit_place_args: tp.ArgsLike = None,
entry_wait: int = 1,
exit_wait: int = 1,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame]:
"""See `vectorbtpro.signals.nb.generate_enex_nb`.
`shape` can be a shape-like tuple or an instance of `vectorbtpro.base.wrapping.ArrayWrapper`
(will be used as `wrapper`).
Arguments to `entry_place_func_nb` can be passed either as `*args` or `entry_place_args` while
arguments to `exit_place_func_nb` can be passed either as `*args` or `exit_place_args` (but not both!).
Usage:
* Generate entry and exit signals one after another:
```pycon
>>> @njit
... def place_func_nb(c):
... c.out[0] = True
... return 0
>>> en, ex = vbt.pd_acc.signals.generate_both(
... (5, 3),
... entry_place_func_nb=place_func_nb,
... exit_place_func_nb=place_func_nb,
... wrap_kwargs=dict(
... index=mask.index,
... columns=mask.columns
... )
... )
>>> en
a b c
2020-01-01 True True True
2020-01-02 False False False
2020-01-03 True True True
2020-01-04 False False False
2020-01-05 True True True
>>> ex
a b c
2020-01-01 False False False
2020-01-02 True True True
2020-01-03 False False False
2020-01-04 True True True
2020-01-05 False False False
```
* Generate three entries and one exit one after another:
```pycon
>>> @njit
... def entry_place_func_nb(c, n):
... c.out[:n] = True
... return n - 1
>>> @njit
... def exit_place_func_nb(c, n):
... c.out[:n] = True
... return n - 1
>>> en, ex = vbt.pd_acc.signals.generate_both(
... (5, 3),
... entry_place_func_nb=entry_place_func_nb,
... entry_place_args=(3,),
... exit_place_func_nb=exit_place_func_nb,
... exit_place_args=(1,),
... wrap_kwargs=dict(
... index=mask.index,
... columns=mask.columns
... )
... )
>>> en
a b c
2020-01-01 True True True
2020-01-02 True True True
2020-01-03 True True True
2020-01-04 False False False
2020-01-05 True True True
>>> ex
a b c
2020-01-01 False False False
2020-01-02 False False False
2020-01-03 False False False
2020-01-04 True True True
2020-01-05 False False False
```
"""
if isinstance(shape, ArrayWrapper):
wrapper = shape
shape = wrapper.shape
if len(args) > 0 and entry_place_args is not None:
raise ValueError("Must provide either *args or entry_place_args, not both")
if len(args) > 0 and exit_place_args is not None:
raise ValueError("Must provide either *args or exit_place_args, not both")
if entry_place_args is None:
entry_place_args = args
if exit_place_args is None:
exit_place_args = args
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
shape_2d = cls.resolve_shape(shape)
if len(broadcast_named_args) > 0:
broadcast_named_args = reshaping.broadcast(
broadcast_named_args,
to_shape=shape_2d,
**broadcast_kwargs,
)
template_context = merge_dicts(
broadcast_named_args,
dict(
shape=shape,
shape_2d=shape_2d,
entry_wait=entry_wait,
exit_wait=exit_wait,
),
template_context,
)
entry_place_args = substitute_templates(entry_place_args, template_context, eval_id="entry_place_args")
exit_place_args = substitute_templates(exit_place_args, template_context, eval_id="exit_place_args")
func = jit_reg.resolve_option(nb.generate_enex_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
result1, result2 = func(
target_shape=shape_2d,
entry_place_func_nb=entry_place_func_nb,
entry_place_args=entry_place_args,
exit_place_func_nb=exit_place_func_nb,
exit_place_args=exit_place_args,
entry_wait=entry_wait,
exit_wait=exit_wait,
)
if wrapper is None:
wrapper = ArrayWrapper.from_shape(shape, ndim=cls.ndim)
if wrap_kwargs is None:
wrap_kwargs = resolve_dict(wrap_kwargs)
return wrapper.wrap(result1, **wrap_kwargs), wrapper.wrap(result2, **wrap_kwargs)
def generate_exits(
self,
exit_place_func_nb: tp.PlaceFunc,
*args,
exit_place_args: tp.ArgsLike = None,
wait: int = 1,
until_next: bool = True,
skip_until_exit: bool = False,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.signals.nb.generate_ex_nb`.
Usage:
* Generate an exit just before the next entry:
```pycon
>>> @njit
... def exit_place_func_nb(c):
... c.out[-1] = True
... return len(c.out) - 1
>>> mask.vbt.signals.generate_exits(exit_place_func_nb)
a b c
2020-01-01 False False False
2020-01-02 False True False
2020-01-03 False False False
2020-01-04 False True False
2020-01-05 True False True
```
"""
if len(args) > 0 and exit_place_args is not None:
raise ValueError("Must provide either *args or exit_place_args, not both")
if exit_place_args is None:
exit_place_args = args
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
obj = self.obj
if len(broadcast_named_args) > 0:
broadcast_named_args = {"obj": obj, **broadcast_named_args}
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
obj = broadcast_named_args["obj"]
else:
wrapper = self.wrapper
obj = reshaping.to_2d_array(obj)
template_context = merge_dicts(
broadcast_named_args,
dict(wait=wait, until_next=until_next, skip_until_exit=skip_until_exit),
template_context,
)
exit_place_args = substitute_templates(exit_place_args, template_context, eval_id="exit_place_args")
func = jit_reg.resolve_option(nb.generate_ex_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
exits = func(
entries=obj,
exit_place_func_nb=exit_place_func_nb,
exit_place_args=exit_place_args,
wait=wait,
until_next=until_next,
skip_until_exit=skip_until_exit,
)
return wrapper.wrap(exits, group_by=False, **resolve_dict(wrap_kwargs))
# ############# Cleaning ############# #
@hybrid_method
def clean(
cls_or_self,
*objs,
force_first: bool = True,
keep_conflicts: bool = False,
reverse_order: bool = False,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeTuple[tp.SeriesFrame]:
"""Clean signals.
If one array is passed, see `SignalsAccessor.first`. If two arrays passed,
entries and exits, see `vectorbtpro.signals.nb.clean_enex_nb`."""
if broadcast_kwargs is None:
broadcast_kwargs = {}
if wrap_kwargs is None:
wrap_kwargs = {}
if not isinstance(cls_or_self, type):
objs = (cls_or_self.obj, *objs)
if len(objs) == 1:
obj = objs[0]
if not isinstance(obj, (pd.Series, pd.DataFrame)):
obj = ArrayWrapper.from_obj(obj).wrap(obj)
return obj.vbt.signals.first(wrap_kwargs=wrap_kwargs, jitted=jitted, chunked=chunked)
if len(objs) == 2:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
broadcasted_args, wrapper = reshaping.broadcast(
dict(entries=objs[0], exits=objs[1]),
return_wrapper=True,
**broadcast_kwargs,
)
func = jit_reg.resolve_option(nb.clean_enex_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
entries_out, exits_out = func(
entries=broadcasted_args["entries"],
exits=broadcasted_args["exits"],
force_first=force_first,
keep_conflicts=keep_conflicts,
reverse_order=reverse_order,
)
return (
wrapper.wrap(entries_out, group_by=False, **wrap_kwargs),
wrapper.wrap(exits_out, group_by=False, **wrap_kwargs),
)
raise ValueError("This method accepts either one or two arrays")
# ############# Random signals ############# #
@classmethod
def generate_random(
cls,
shape: tp.Union[tp.ShapeLike, ArrayWrapper],
n: tp.Optional[tp.ArrayLike] = None,
prob: tp.Optional[tp.ArrayLike] = None,
pick_first: bool = False,
seed: tp.Optional[int] = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> tp.SeriesFrame:
"""Generate signals randomly.
`shape` can be a shape-like tuple or an instance of `vectorbtpro.base.wrapping.ArrayWrapper`
(will be used as `wrapper`).
If `n` is set, uses `vectorbtpro.signals.nb.rand_place_nb`.
If `prob` is set, uses `vectorbtpro.signals.nb.rand_by_prob_place_nb`.
For arguments, see `SignalsAccessor.generate`.
`n` must be either a scalar or an array that will broadcast to the number of columns.
`prob` must be either a single number or an array that will broadcast to match `shape`.
Specify `seed` to make output deterministic.
Usage:
* For each column, generate a variable number of signals:
```pycon
>>> vbt.pd_acc.signals.generate_random(
... (5, 3),
... n=[0, 1, 2],
... seed=42,
... wrap_kwargs=dict(
... index=mask.index,
... columns=mask.columns
... )
... )
a b c
2020-01-01 False False False
2020-01-02 False False False
2020-01-03 False False True
2020-01-04 False True False
2020-01-05 False False True
```
* For each column and time step, pick a signal with 50% probability:
```pycon
>>> vbt.pd_acc.signals.generate_random(
... (5, 3),
... prob=0.5,
... seed=42,
... wrap_kwargs=dict(
... index=mask.index,
... columns=mask.columns
... )
... )
a b c
2020-01-01 True True True
2020-01-02 False True False
2020-01-03 False False False
2020-01-04 False False True
2020-01-05 True False True
```
"""
if isinstance(shape, ArrayWrapper):
if "wrapper" in kwargs:
raise ValueError("Must provide wrapper either via shape or wrapper, not both")
kwargs["wrapper"] = shape
shape = kwargs["wrapper"].shape
shape_2d = cls.resolve_shape(shape)
if n is not None and prob is not None:
raise ValueError("Must provide either n or prob, not both")
if seed is not None:
set_seed_nb(seed)
if n is not None:
n = reshaping.broadcast_array_to(n, shape_2d[1])
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(),
),
),
)
return cls.generate(
shape,
jit_reg.resolve_option(nb.rand_place_nb, jitted),
n,
jitted=jitted,
chunked=chunked,
**kwargs,
)
if prob is not None:
prob = reshaping.to_2d_array(reshaping.broadcast_array_to(prob, shape))
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(axis=1),
None,
None,
),
),
)
return cls.generate(
shape,
jit_reg.resolve_option(nb.rand_by_prob_place_nb, jitted),
prob,
pick_first,
jitted=jitted,
chunked=chunked,
**kwargs,
)
raise ValueError("Must provide at least n or prob")
@classmethod
def generate_random_both(
cls,
shape: tp.Union[tp.ShapeLike, ArrayWrapper],
n: tp.Optional[tp.ArrayLike] = None,
entry_prob: tp.Optional[tp.ArrayLike] = None,
exit_prob: tp.Optional[tp.ArrayLike] = None,
seed: tp.Optional[int] = None,
entry_wait: int = 1,
exit_wait: int = 1,
entry_pick_first: bool = True,
exit_pick_first: bool = True,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrapper: tp.Optional[ArrayWrapper] = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.Tuple[tp.SeriesFrame, tp.SeriesFrame]:
"""Generate chain of entry and exit signals randomly.
`shape` can be a shape-like tuple or an instance of `vectorbtpro.base.wrapping.ArrayWrapper`
(will be used as `wrapper`).
If `n` is set, uses `vectorbtpro.signals.nb.generate_rand_enex_nb`.
If `entry_prob` and `exit_prob` are set, uses `SignalsAccessor.generate_both` with
`vectorbtpro.signals.nb.rand_by_prob_place_nb`.
Usage:
* For each column, generate two entries and exits randomly:
```pycon
>>> en, ex = vbt.pd_acc.signals.generate_random_both(
... (5, 3),
... n=2,
... seed=42,
... wrap_kwargs=dict(
... index=mask.index,
... columns=mask.columns
... )
... )
>>> en
a b c
2020-01-01 False False True
2020-01-02 True True False
2020-01-03 False False False
2020-01-04 True True True
2020-01-05 False False False
>>> ex
a b c
2020-01-01 False False False
2020-01-02 False False True
2020-01-03 True True False
2020-01-04 False False False
2020-01-05 True True True
```
* For each column and time step, pick entry with 50% probability and exit right after:
```pycon
>>> en, ex = vbt.pd_acc.signals.generate_random_both(
... (5, 3),
... entry_prob=0.5,
... exit_prob=1.,
... seed=42,
... wrap_kwargs=dict(
... index=mask.index,
... columns=mask.columns
... )
... )
>>> en
a b c
2020-01-01 True True True
2020-01-02 False False False
2020-01-03 False False False
2020-01-04 False False True
2020-01-05 True False False
>>> ex
a b c
2020-01-01 False False False
2020-01-02 True True True
2020-01-03 False False False
2020-01-04 False False False
2020-01-05 False False True
```
"""
if isinstance(shape, ArrayWrapper):
wrapper = shape
shape = wrapper.shape
shape_2d = cls.resolve_shape(shape)
if n is not None and (entry_prob is not None or exit_prob is not None):
raise ValueError("Must provide either n or any of the entry_prob and exit_prob, not both")
if seed is not None:
set_seed_nb(seed)
if n is not None:
n = reshaping.broadcast_array_to(n, shape_2d[1])
func = jit_reg.resolve_option(nb.generate_rand_enex_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
entries, exits = func(shape_2d, n, entry_wait, exit_wait)
if wrapper is None:
wrapper = ArrayWrapper.from_shape(shape, ndim=cls.ndim)
if wrap_kwargs is None:
wrap_kwargs = resolve_dict(wrap_kwargs)
return wrapper.wrap(entries, **wrap_kwargs), wrapper.wrap(exits, **wrap_kwargs)
elif entry_prob is not None and exit_prob is not None:
entry_prob = reshaping.to_2d_array(reshaping.broadcast_array_to(entry_prob, shape))
exit_prob = reshaping.to_2d_array(reshaping.broadcast_array_to(exit_prob, shape))
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
entry_place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(axis=1),
None,
),
exit_place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(axis=1),
None,
),
),
)
return cls.generate_both(
shape,
entry_place_func_nb=jit_reg.resolve_option(nb.rand_by_prob_place_nb, jitted),
entry_place_args=(entry_prob, entry_pick_first),
exit_place_func_nb=jit_reg.resolve_option(nb.rand_by_prob_place_nb, jitted),
exit_place_args=(exit_prob, exit_pick_first),
entry_wait=entry_wait,
exit_wait=exit_wait,
jitted=jitted,
chunked=chunked,
wrapper=wrapper,
wrap_kwargs=wrap_kwargs,
)
raise ValueError("Must provide at least n, or entry_prob and exit_prob")
def generate_random_exits(
self,
prob: tp.Optional[tp.ArrayLike] = None,
seed: tp.Optional[int] = None,
wait: int = 1,
until_next: bool = True,
skip_until_exit: bool = False,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Generate exit signals randomly.
If `prob` is None, uses `vectorbtpro.signals.nb.rand_place_nb`.
Otherwise, uses `vectorbtpro.signals.nb.rand_by_prob_place_nb`.
Uses `SignalsAccessor.generate_exits`.
Specify `seed` to make output deterministic.
Usage:
* After each entry in `mask`, generate exactly one exit:
```pycon
>>> mask.vbt.signals.generate_random_exits(seed=42)
a b c
2020-01-01 False False False
2020-01-02 False True False
2020-01-03 False False False
2020-01-04 True True False
2020-01-05 False False True
```
* After each entry in `mask` and at each time step, generate exit with 50% probability:
```pycon
>>> mask.vbt.signals.generate_random_exits(prob=0.5, seed=42)
a b c
2020-01-01 False False False
2020-01-02 True False False
2020-01-03 False False False
2020-01-04 False False False
2020-01-05 False False True
```
"""
if seed is not None:
set_seed_nb(seed)
if prob is not None:
broadcast_kwargs = merge_dicts(
dict(keep_flex=dict(obj=False, prob=True)),
broadcast_kwargs,
)
broadcasted_args = reshaping.broadcast(
dict(obj=self.obj, prob=prob),
**broadcast_kwargs,
)
obj = broadcasted_args["obj"]
prob = broadcasted_args["prob"]
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
exit_place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(axis=1),
None,
)
),
)
return obj.vbt.signals.generate_exits(
jit_reg.resolve_option(nb.rand_by_prob_place_nb, jitted),
prob,
True,
wait=wait,
until_next=until_next,
skip_until_exit=skip_until_exit,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
n = reshaping.broadcast_array_to(1, self.wrapper.shape_2d[1])
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
exit_place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(),
)
),
)
return self.generate_exits(
jit_reg.resolve_option(nb.rand_place_nb, jitted),
n,
wait=wait,
until_next=until_next,
skip_until_exit=skip_until_exit,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
# ############# Stop signals ############# #
def generate_stop_exits(
self,
entry_ts: tp.ArrayLike,
ts: tp.ArrayLike = np.nan,
follow_ts: tp.ArrayLike = np.nan,
stop: tp.ArrayLike = np.nan,
trailing: tp.ArrayLike = False,
out_dict: tp.Optional[tp.Dict[str, tp.ArrayLike]] = None,
entry_wait: int = 1,
exit_wait: int = 1,
until_next: bool = True,
skip_until_exit: bool = False,
chain: bool = False,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeTuple[tp.SeriesFrame]:
"""Generate exits based on when `ts` hits the stop.
For arguments, see `vectorbtpro.signals.nb.stop_place_nb`.
If `chain` is True, uses `SignalsAccessor.generate_both`.
Otherwise, uses `SignalsAccessor.generate_exits`.
Use `out_dict` as a dict to pass `stop_ts` array. You can also set `out_dict` to {}
to produce this array automatically and still have access to it.
All array-like arguments including stops and `out_dict` will broadcast using
`vectorbtpro.base.reshaping.broadcast` and `broadcast_kwargs`.
!!! hint
Default arguments will generate an exit signal strictly between two entry signals.
If both entry signals are too close to each other, no exit will be generated.
To ignore all entries that come between an entry and its exit,
set `until_next` to False and `skip_until_exit` to True.
To remove all entries that come between an entry and its exit,
set `chain` to True. This will return two arrays: new entries and exits.
Usage:
* Regular stop loss:
```pycon
>>> ts = pd.Series([1, 2, 3, 2, 1])
>>> mask.vbt.signals.generate_stop_exits(ts, stop=-0.1)
a b c
2020-01-01 False False False
2020-01-02 False False False
2020-01-03 False False False
2020-01-04 False True True
2020-01-05 False False False
```
* Trailing stop loss:
```pycon
>>> mask.vbt.signals.generate_stop_exits(ts, stop=-0.1, trailing=True)
a b c
2020-01-01 False False False
2020-01-02 False False False
2020-01-03 False False False
2020-01-04 True True True
2020-01-05 False False False
```
* Testing multiple take profit stops:
```pycon
>>> mask.vbt.signals.generate_stop_exits(ts, stop=vbt.Param([1.0, 1.5]))
stop 1.0 1.5
a b c a b c
2020-01-01 False False False False False False
2020-01-02 True True False False False False
2020-01-03 False False False True False False
2020-01-04 False False False False False False
2020-01-05 False False False False False False
```
"""
if wrap_kwargs is None:
wrap_kwargs = {}
entries = self.obj
if out_dict is None:
out_dict_passed = False
out_dict = {}
else:
out_dict_passed = True
stop_ts = out_dict.get("stop_ts", np.nan if out_dict_passed else None)
broadcastable_args = dict(
entries=entries,
entry_ts=entry_ts,
ts=ts,
follow_ts=follow_ts,
stop=stop,
trailing=trailing,
stop_ts=stop_ts,
)
broadcast_kwargs = merge_dicts(
dict(
keep_flex=dict(entries=False, stop_ts=False, _def=True),
require_kwargs=dict(requirements="W"),
),
broadcast_kwargs,
)
broadcasted_args = reshaping.broadcast(broadcastable_args, **broadcast_kwargs)
entries = broadcasted_args["entries"]
stop_ts = broadcasted_args["stop_ts"]
if stop_ts is None:
stop_ts = np.empty_like(entries, dtype=float_)
stop_ts = reshaping.to_2d_array(stop_ts)
entries_arr = reshaping.to_2d_array(entries)
wrapper = ArrayWrapper.from_obj(entries)
if chain:
if checks.is_series(entries):
cls = self.sr_accessor_cls
else:
cls = self.df_accessor_cls
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
entry_place_args=ch.ArgsTaker(
ch.ArraySlicer(axis=1),
),
exit_place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
),
),
)
out_dict["stop_ts"] = wrapper.wrap(stop_ts, group_by=False, **wrap_kwargs)
return cls.generate_both(
entries.shape,
entry_place_func_nb=jit_reg.resolve_option(nb.first_place_nb, jitted),
entry_place_args=(entries_arr,),
exit_place_func_nb=jit_reg.resolve_option(nb.stop_place_nb, jitted),
exit_place_args=(
broadcasted_args["entry_ts"],
broadcasted_args["ts"],
broadcasted_args["follow_ts"],
stop_ts,
broadcasted_args["stop"],
broadcasted_args["trailing"],
),
entry_wait=entry_wait,
exit_wait=exit_wait,
wrapper=wrapper,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
else:
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
exit_place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
)
),
)
if skip_until_exit and until_next:
warn("skip_until_exit=True has only effect when until_next=False")
out_dict["stop_ts"] = wrapper.wrap(stop_ts, group_by=False, **wrap_kwargs)
return entries.vbt.signals.generate_exits(
jit_reg.resolve_option(nb.stop_place_nb, jitted),
broadcasted_args["entry_ts"],
broadcasted_args["ts"],
broadcasted_args["follow_ts"],
stop_ts,
broadcasted_args["stop"],
broadcasted_args["trailing"],
wait=exit_wait,
until_next=until_next,
skip_until_exit=skip_until_exit,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
def generate_ohlc_stop_exits(
self,
entry_price: tp.ArrayLike,
open: tp.ArrayLike = np.nan,
high: tp.ArrayLike = np.nan,
low: tp.ArrayLike = np.nan,
close: tp.ArrayLike = np.nan,
sl_stop: tp.ArrayLike = np.nan,
tsl_th: tp.ArrayLike = np.nan,
tsl_stop: tp.ArrayLike = np.nan,
tp_stop: tp.ArrayLike = np.nan,
reverse: tp.ArrayLike = False,
is_entry_open: bool = False,
out_dict: tp.Optional[tp.Dict[str, tp.ArrayLike]] = None,
entry_wait: int = 1,
exit_wait: int = 1,
until_next: bool = True,
skip_until_exit: bool = False,
chain: bool = False,
broadcast_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeTuple[tp.SeriesFrame]:
"""Generate exits based on when the price hits (trailing) stop loss or take profit.
Use `out_dict` as a dict to pass `stop_price` and `stop_type` arrays. You can also
set `out_dict` to {} to produce these arrays automatically and still have access to them.
For arguments, see `vectorbtpro.signals.nb.ohlc_stop_place_nb`.
If `chain` is True, uses `SignalsAccessor.generate_both`.
Otherwise, uses `SignalsAccessor.generate_exits`.
All array-like arguments including stops and `out_dict` will broadcast using
`vectorbtpro.base.reshaping.broadcast` and `broadcast_kwargs`.
For arguments, see `vectorbtpro.signals.nb.ohlc_stop_place_nb`.
!!! hint
Default arguments will generate an exit signal strictly between two entry signals.
If both entry signals are too close to each other, no exit will be generated.
To ignore all entries that come between an entry and its exit,
set `until_next` to False and `skip_until_exit` to True.
To remove all entries that come between an entry and its exit,
set `chain` to True. This will return two arrays: new entries and exits.
Usage:
* Generate exits for TSL and TP of 10%:
```pycon
>>> price = pd.DataFrame({
... 'open': [10, 11, 12, 11, 10],
... 'high': [11, 12, 13, 12, 11],
... 'low': [9, 10, 11, 10, 9],
... 'close': [10, 11, 12, 11, 10]
... })
>>> out_dict = {}
>>> exits = mask.vbt.signals.generate_ohlc_stop_exits(
... price["open"],
... price['open'],
... price['high'],
... price['low'],
... price['close'],
... tsl_stop=0.1,
... tp_stop=0.1,
... is_entry_open=True,
... out_dict=out_dict,
... )
>>> exits
a b c
2020-01-01 False False False
2020-01-02 True True False
2020-01-03 False False False
2020-01-04 False True True
2020-01-05 False False False
>>> out_dict['stop_price']
a b c
2020-01-01 NaN NaN NaN
2020-01-02 11.0 11.0 NaN
2020-01-03 NaN NaN NaN
2020-01-04 NaN 10.8 10.8
2020-01-05 NaN NaN NaN
>>> out_dict['stop_type'].vbt(mapping=vbt.sig_enums.StopType).apply_mapping()
a b c
2020-01-01 None None None
2020-01-02 TP TP None
2020-01-03 None None None
2020-01-04 None TSL TSL
2020-01-05 None None None
```
Notice how the first two entry signals in the third column have no exit signal -
there is no room between them for an exit signal.
* To find an exit for the first entry and ignore all entries that are in-between them,
we can pass `until_next=False` and `skip_until_exit=True`:
```pycon
>>> out_dict = {}
>>> exits = mask.vbt.signals.generate_ohlc_stop_exits(
... price['open'],
... price['open'],
... price['high'],
... price['low'],
... price['close'],
... tsl_stop=0.1,
... tp_stop=0.1,
... is_entry_open=True,
... out_dict=out_dict,
... until_next=False,
... skip_until_exit=True
... )
>>> exits
a b c
2020-01-01 False False False
2020-01-02 True True True
2020-01-03 False False False
2020-01-04 False True True
2020-01-05 False False False
>>> out_dict['stop_price']
a b c
2020-01-01 NaN NaN NaN
2020-01-02 11.0 11.0 11.0
2020-01-03 NaN NaN NaN
2020-01-04 NaN 10.8 10.8
2020-01-05 NaN NaN NaN
>>> out_dict['stop_type'].vbt(mapping=vbt.sig_enums.StopType).apply_mapping()
a b c
2020-01-01 None None None
2020-01-02 TP TP TP
2020-01-03 None None None
2020-01-04 None TSL TSL
2020-01-05 None None None
```
Now, the first signal in the third column gets executed regardless of the entries that come next,
which is very similar to the logic that is implemented in `vectorbtpro.portfolio.base.Portfolio.from_signals`.
* To automatically remove all ignored entry signals, pass `chain=True`.
This will return a new entries array:
```pycon
>>> out_dict = {}
>>> new_entries, exits = mask.vbt.signals.generate_ohlc_stop_exits(
... price['open'],
... price['open'],
... price['high'],
... price['low'],
... price['close'],
... tsl_stop=0.1,
... tp_stop=0.1,
... is_entry_open=True,
... out_dict=out_dict,
... chain=True
... )
>>> new_entries
a b c
2020-01-01 True True True
2020-01-02 False False False << removed entry in the third column
2020-01-03 False True True
2020-01-04 False False False
2020-01-05 False True False
>>> exits
a b c
2020-01-01 False False False
2020-01-02 True True True
2020-01-03 False False False
2020-01-04 False True True
2020-01-05 False False False
```
!!! warning
The last two examples above make entries dependent upon exits - this makes only sense
if you have no other exit arrays to combine this stop exit array with.
* Test multiple parameter combinations:
```pycon
>>> exits = mask.vbt.signals.generate_ohlc_stop_exits(
... price['open'],
... price['open'],
... price['high'],
... price['low'],
... price['close'],
... sl_stop=vbt.Param([False, 0.1]),
... tsl_stop=vbt.Param([False, 0.1]),
... is_entry_open=True
... )
>>> exits
sl_stop False 0.1 \\
tsl_stop False 0.1 False
a b c a b c a b c
2020-01-01 False False False False False False False False False
2020-01-02 False False False False False False False False False
2020-01-03 False False False False False False False False False
2020-01-04 False False False True True True False True True
2020-01-05 False False False False False False True False False
sl_stop
tsl_stop 0.1
a b c
2020-01-01 False False False
2020-01-02 False False False
2020-01-03 False False False
2020-01-04 True True True
2020-01-05 False False False
```
"""
if wrap_kwargs is None:
wrap_kwargs = {}
entries = self.obj
if out_dict is None:
out_dict_passed = False
out_dict = {}
else:
out_dict_passed = True
stop_price = out_dict.get("stop_price", np.nan if out_dict_passed else None)
stop_type = out_dict.get("stop_type", -1 if out_dict_passed else None)
broadcastable_args = dict(
entries=entries,
entry_price=entry_price,
open=open,
high=high,
low=low,
close=close,
sl_stop=sl_stop,
tsl_th=tsl_th,
tsl_stop=tsl_stop,
tp_stop=tp_stop,
reverse=reverse,
stop_price=stop_price,
stop_type=stop_type,
)
broadcast_kwargs = merge_dicts(
dict(
keep_flex=dict(entries=False, stop_price=False, stop_type=False, _def=True),
require_kwargs=dict(requirements="W"),
),
broadcast_kwargs,
)
broadcasted_args = reshaping.broadcast(broadcastable_args, **broadcast_kwargs)
entries = broadcasted_args["entries"]
stop_price = broadcasted_args["stop_price"]
if stop_price is None:
stop_price = np.empty_like(entries, dtype=float_)
stop_price = reshaping.to_2d_array(stop_price)
stop_type = broadcasted_args["stop_type"]
if stop_type is None:
stop_type = np.empty_like(entries, dtype=int_)
stop_type = reshaping.to_2d_array(stop_type)
entries_arr = reshaping.to_2d_array(entries)
wrapper = ArrayWrapper.from_obj(entries)
if chain:
if checks.is_series(entries):
cls = self.sr_accessor_cls
else:
cls = self.df_accessor_cls
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
entry_place_args=ch.ArgsTaker(
ch.ArraySlicer(axis=1),
),
exit_place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
ch.ArraySlicer(axis=1),
ch.ArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
None,
),
),
)
new_entries, exits = cls.generate_both(
entries.shape,
entry_place_func_nb=jit_reg.resolve_option(nb.first_place_nb, jitted),
entry_place_args=(entries_arr,),
exit_place_func_nb=jit_reg.resolve_option(nb.ohlc_stop_place_nb, jitted),
exit_place_args=(
broadcasted_args["entry_price"],
broadcasted_args["open"],
broadcasted_args["high"],
broadcasted_args["low"],
broadcasted_args["close"],
stop_price,
stop_type,
broadcasted_args["sl_stop"],
broadcasted_args["tsl_th"],
broadcasted_args["tsl_stop"],
broadcasted_args["tp_stop"],
broadcasted_args["reverse"],
is_entry_open,
),
entry_wait=entry_wait,
exit_wait=exit_wait,
wrapper=wrapper,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
out_dict["stop_price"] = wrapper.wrap(stop_price, group_by=False, **wrap_kwargs)
out_dict["stop_type"] = wrapper.wrap(stop_type, group_by=False, **wrap_kwargs)
return new_entries, exits
else:
if skip_until_exit and until_next:
warn("skip_until_exit=True has only effect when until_next=False")
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
exit_place_args=ch.ArgsTaker(
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
ch.ArraySlicer(axis=1),
ch.ArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
base_ch.FlexArraySlicer(axis=1),
None,
)
),
)
exits = entries.vbt.signals.generate_exits(
jit_reg.resolve_option(nb.ohlc_stop_place_nb, jitted),
broadcasted_args["entry_price"],
broadcasted_args["open"],
broadcasted_args["high"],
broadcasted_args["low"],
broadcasted_args["close"],
stop_price,
stop_type,
broadcasted_args["sl_stop"],
broadcasted_args["tsl_th"],
broadcasted_args["tsl_stop"],
broadcasted_args["tp_stop"],
broadcasted_args["reverse"],
is_entry_open,
wait=exit_wait,
until_next=until_next,
skip_until_exit=skip_until_exit,
jitted=jitted,
chunked=chunked,
wrap_kwargs=wrap_kwargs,
**kwargs,
)
out_dict["stop_price"] = wrapper.wrap(stop_price, group_by=False, **wrap_kwargs)
out_dict["stop_type"] = wrapper.wrap(stop_type, group_by=False, **wrap_kwargs)
return exits
# ############# Ranking ############# #
def rank(
self,
rank_func_nb: tp.RankFunc,
*args,
rank_args: tp.ArgsLike = None,
reset_by: tp.Optional[tp.ArrayLike] = None,
after_false: bool = False,
after_reset: bool = False,
reset_wait: int = 1,
as_mapped: bool = False,
broadcast_named_args: tp.KwargsLike = None,
broadcast_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Union[tp.SeriesFrame, MappedArray]:
"""See `vectorbtpro.signals.nb.rank_nb`.
Arguments to `rank_func_nb` can be passed either as `*args` or `rank_args` (but not both!).
Will broadcast with `reset_by` using `vectorbtpro.base.reshaping.broadcast` and `broadcast_kwargs`.
Set `as_mapped` to True to return an instance of `vectorbtpro.records.mapped_array.MappedArray`."""
if len(args) > 0 and rank_args is not None:
raise ValueError("Must provide either *args or rank_args, not both")
if rank_args is None:
rank_args = args
if broadcast_named_args is None:
broadcast_named_args = {}
if broadcast_kwargs is None:
broadcast_kwargs = {}
if template_context is None:
template_context = {}
if wrap_kwargs is None:
wrap_kwargs = {}
if reset_by is not None:
broadcast_named_args = {"obj": self.obj, "reset_by": reset_by, **broadcast_named_args}
else:
broadcast_named_args = {"obj": self.obj, **broadcast_named_args}
if len(broadcast_named_args) > 1:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
broadcast_named_args, wrapper = reshaping.broadcast(
broadcast_named_args,
return_wrapper=True,
**broadcast_kwargs,
)
else:
wrapper = self.wrapper
obj = reshaping.to_2d_array(broadcast_named_args["obj"])
if reset_by is not None:
reset_by = reshaping.to_2d_array(broadcast_named_args["reset_by"])
template_context = merge_dicts(
dict(
obj=obj,
reset_by=reset_by,
after_false=after_false,
after_reset=after_reset,
reset_wait=reset_wait,
),
template_context,
)
rank_args = substitute_templates(rank_args, template_context, eval_id="rank_args")
func = jit_reg.resolve_option(nb.rank_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
rank = func(
mask=obj,
rank_func_nb=rank_func_nb,
rank_args=rank_args,
reset_by=reset_by,
after_false=after_false,
after_reset=after_reset,
reset_wait=reset_wait,
)
rank_wrapped = wrapper.wrap(rank, group_by=False, **wrap_kwargs)
if as_mapped:
rank_wrapped = rank_wrapped.replace(-1, np.nan)
return rank_wrapped.vbt.to_mapped(dropna=True, dtype=int_, **kwargs)
return rank_wrapped
def pos_rank(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
allow_gaps: bool = False,
**kwargs,
) -> tp.Union[tp.SeriesFrame, MappedArray]:
"""Get signal position ranks.
Uses `SignalsAccessor.rank` with `vectorbtpro.signals.nb.sig_pos_rank_nb`.
Usage:
* Rank each True value in each partition in `mask`:
```pycon
>>> mask.vbt.signals.pos_rank()
a b c
2020-01-01 0 0 0
2020-01-02 -1 -1 1
2020-01-03 -1 0 2
2020-01-04 -1 -1 -1
2020-01-05 -1 0 -1
>>> mask.vbt.signals.pos_rank(after_false=True)
a b c
2020-01-01 -1 -1 -1
2020-01-02 -1 -1 -1
2020-01-03 -1 0 -1
2020-01-04 -1 -1 -1
2020-01-05 -1 0 -1
>>> mask.vbt.signals.pos_rank(allow_gaps=True)
a b c
2020-01-01 0 0 0
2020-01-02 -1 -1 1
2020-01-03 -1 1 2
2020-01-04 -1 -1 -1
2020-01-05 -1 2 -1
>>> mask.vbt.signals.pos_rank(reset_by=~mask, allow_gaps=True)
a b c
2020-01-01 0 0 0
2020-01-02 -1 -1 1
2020-01-03 -1 0 2
2020-01-04 -1 -1 -1
2020-01-05 -1 0 -1
```
"""
chunked = ch.specialize_chunked_option(
chunked,
arg_take_spec=dict(
rank_args=ch.ArgsTaker(
None,
)
),
)
return self.rank(
rank_func_nb=jit_reg.resolve_option(nb.sig_pos_rank_nb, jitted),
rank_args=(allow_gaps,),
jitted=jitted,
chunked=chunked,
**kwargs,
)
def pos_rank_after(
self,
reset_by: tp.ArrayLike,
after_reset: bool = True,
allow_gaps: bool = True,
**kwargs,
) -> tp.Union[tp.SeriesFrame, MappedArray]:
"""Get signal position ranks after each signal in `reset_by`.
!!! note
`allow_gaps` is enabled by default."""
return self.pos_rank(reset_by=reset_by, after_reset=after_reset, allow_gaps=allow_gaps, **kwargs)
def partition_pos_rank(
self,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> tp.Union[tp.SeriesFrame, MappedArray]:
"""Get partition position ranks.
Uses `SignalsAccessor.rank` with `vectorbtpro.signals.nb.part_pos_rank_nb`.
Usage:
* Rank each partition of True values in `mask`:
```pycon
>>> mask.vbt.signals.partition_pos_rank()
a b c
2020-01-01 0 0 0
2020-01-02 -1 -1 0
2020-01-03 -1 1 0
2020-01-04 -1 -1 -1
2020-01-05 -1 2 -1
>>> mask.vbt.signals.partition_pos_rank(after_false=True)
a b c
2020-01-01 -1 -1 -1
2020-01-02 -1 -1 -1
2020-01-03 -1 0 -1
2020-01-04 -1 -1 -1
2020-01-05 -1 1 -1
>>> mask.vbt.signals.partition_pos_rank(reset_by=mask)
a b c
2020-01-01 0 0 0
2020-01-02 -1 -1 0
2020-01-03 -1 0 0
2020-01-04 -1 -1 -1
2020-01-05 -1 0 -1
```
"""
return self.rank(
jit_reg.resolve_option(nb.part_pos_rank_nb, jitted),
jitted=jitted,
chunked=chunked,
**kwargs,
)
def partition_pos_rank_after(self, reset_by: tp.ArrayLike, **kwargs) -> tp.Union[tp.SeriesFrame, MappedArray]:
"""Get partition position ranks after each signal in `reset_by`."""
return self.partition_pos_rank(reset_by=reset_by, after_reset=True, **kwargs)
def first(
self,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Select signals that satisfy the condition `pos_rank == 0`.
Uses `SignalsAccessor.pos_rank`."""
pos_rank = self.pos_rank(**kwargs).values
return self.wrapper.wrap(pos_rank == 0, group_by=False, **resolve_dict(wrap_kwargs))
def first_after(
self,
reset_by: tp.ArrayLike,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Select signals that satisfy the condition `pos_rank == 0`.
Uses `SignalsAccessor.pos_rank_after`."""
pos_rank = self.pos_rank_after(reset_by, **kwargs).values
return self.wrapper.wrap(pos_rank == 0, group_by=False, **resolve_dict(wrap_kwargs))
def nth(
self,
n: int,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Select signals that satisfy the condition `pos_rank == n`.
Uses `SignalsAccessor.pos_rank`."""
pos_rank = self.pos_rank(**kwargs).values
return self.wrapper.wrap(pos_rank == n, group_by=False, **resolve_dict(wrap_kwargs))
def nth_after(
self,
n: int,
reset_by: tp.ArrayLike,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Select signals that satisfy the condition `pos_rank == n`.
Uses `SignalsAccessor.pos_rank_after`."""
pos_rank = self.pos_rank_after(reset_by, **kwargs).values
return self.wrapper.wrap(pos_rank == n, group_by=False, **resolve_dict(wrap_kwargs))
def from_nth(
self,
n: int,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Select signals that satisfy the condition `pos_rank >= n`.
Uses `SignalsAccessor.pos_rank`."""
pos_rank = self.pos_rank(**kwargs).values
return self.wrapper.wrap(pos_rank >= n, group_by=False, **resolve_dict(wrap_kwargs))
def from_nth_after(
self,
n: int,
reset_by: tp.ArrayLike,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Select signals that satisfy the condition `pos_rank >= n`.
Uses `SignalsAccessor.pos_rank_after`."""
pos_rank = self.pos_rank_after(reset_by, **kwargs).values
return self.wrapper.wrap(pos_rank >= n, group_by=False, **resolve_dict(wrap_kwargs))
def to_nth(
self,
n: int,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Select signals that satisfy the condition `pos_rank < n`.
Uses `SignalsAccessor.pos_rank`."""
pos_rank = self.pos_rank(**kwargs).values
return self.wrapper.wrap(pos_rank < n, group_by=False, **resolve_dict(wrap_kwargs))
def to_nth_after(
self,
n: int,
reset_by: tp.ArrayLike,
wrap_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.SeriesFrame:
"""Select signals that satisfy the condition `pos_rank < n`.
Uses `SignalsAccessor.pos_rank_after`."""
pos_rank = self.pos_rank_after(reset_by, **kwargs).values
return self.wrapper.wrap(pos_rank < n, group_by=False, **resolve_dict(wrap_kwargs))
def pos_rank_mapped(self, group_by: tp.GroupByLike = None, **kwargs) -> MappedArray:
"""Get a mapped array of signal position ranks.
Uses `SignalsAccessor.pos_rank`."""
return self.pos_rank(as_mapped=True, group_by=group_by, **kwargs)
def partition_pos_rank_mapped(self, group_by: tp.GroupByLike = None, **kwargs) -> MappedArray:
"""Get a mapped array of partition position ranks.
Uses `SignalsAccessor.partition_pos_rank`."""
return self.partition_pos_rank(as_mapped=True, group_by=group_by, **kwargs)
# ############# Distance ############# #
def distance_from_last(
self,
nth: int = 1,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""See `vectorbtpro.signals.nb.distance_from_last_nb`.
Usage:
* Get the distance to the last signal:
```pycon
>>> mask.vbt.signals.distance_from_last()
a b c
2020-01-01 -1 -1 -1
2020-01-02 1 1 1
2020-01-03 2 2 1
2020-01-04 3 1 1
2020-01-05 4 2 2
```
* Get the distance to the second last signal:
```pycon
>>> mask.vbt.signals.distance_from_last(nth=2)
a b c
2020-01-01 -1 -1 -1
2020-01-02 -1 -1 1
2020-01-03 -1 2 1
2020-01-04 -1 3 2
2020-01-05 -1 2 3
```
"""
func = jit_reg.resolve_option(nb.distance_from_last_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
distance_from_last = func(self.to_2d_array(), nth=nth)
return self.wrapper.wrap(distance_from_last, group_by=False, **resolve_dict(wrap_kwargs))
# ############# Conversion ############# #
def to_mapped(
self,
group_by: tp.GroupByLike = None,
**kwargs,
) -> MappedArray:
"""Convert this object into an instance of `vectorbtpro.records.mapped_array.MappedArray`."""
mapped_arr = self.to_2d_array().flatten(order="F")
col_arr = np.repeat(np.arange(self.wrapper.shape_2d[1]), self.wrapper.shape_2d[0])
idx_arr = np.tile(np.arange(self.wrapper.shape_2d[0]), self.wrapper.shape_2d[1])
new_mapped_arr = mapped_arr[mapped_arr]
new_col_arr = col_arr[mapped_arr]
new_idx_arr = idx_arr[mapped_arr]
return MappedArray(
wrapper=self.wrapper,
mapped_arr=new_mapped_arr,
col_arr=new_col_arr,
idx_arr=new_idx_arr,
**kwargs,
).regroup(group_by)
# ############# Relation ############# #
def get_relation_str(self, relation: tp.Union[int, str]) -> str:
"""Get direction string for `relation`."""
if isinstance(relation, str):
relation = map_enum_fields(relation, enums.SignalRelation)
if relation == enums.SignalRelation.OneOne:
return ">-<"
if relation == enums.SignalRelation.OneMany:
return "->"
if relation == enums.SignalRelation.ManyOne:
return "<-"
if relation == enums.SignalRelation.ManyMany:
return "<->"
raise ValueError(f"Invalid relation: {relation}")
# ############# Ranges ############# #
def delta_ranges(
self,
delta: tp.Union[str, int, tp.FrequencyLike],
group_by: tp.GroupByLike = None,
**kwargs,
) -> Ranges:
"""Build a record array of the type `vectorbtpro.generic.ranges.Ranges`
from a delta applied after each signal (or before if delta is negative)."""
return Ranges.from_delta(self.to_mapped(), delta=delta, **kwargs).regroup(group_by)
def between_ranges(
self,
target: tp.Optional[tp.ArrayLike] = None,
relation: tp.Union[int, str] = "onemany",
incl_open: bool = False,
broadcast_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
attach_target: bool = False,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> Ranges:
"""Wrap the result of `vectorbtpro.signals.nb.between_ranges_nb`
with `vectorbtpro.generic.ranges.Ranges`.
If `target` specified, see `vectorbtpro.signals.nb.between_two_ranges_nb`.
Both will broadcast using `vectorbtpro.base.reshaping.broadcast` and `broadcast_kwargs`.
Usage:
* One array:
```pycon
>>> mask_sr = pd.Series([True, False, False, True, False, True, True])
>>> ranges = mask_sr.vbt.signals.between_ranges()
>>> ranges
>>> ranges.readable
Range Id Column Start Index End Index Status
0 0 0 0 3 Closed
1 1 0 3 5 Closed
2 2 0 5 6 Closed
>>> ranges.duration.values
array([3, 2, 1])
```
* Two arrays, traversing the signals of the first array:
```pycon
>>> mask_sr1 = pd.Series([True, True, True, False, False])
>>> mask_sr2 = pd.Series([False, False, True, False, True])
>>> ranges = mask_sr1.vbt.signals.between_ranges(target=mask_sr2)
>>> ranges
>>> ranges.readable
Range Id Column Start Index End Index Status
0 0 0 2 2 Closed
1 1 0 2 4 Closed
>>> ranges.duration.values
array([0, 2])
```
* Two arrays, traversing the signals of the second array:
```pycon
>>> ranges = mask_sr1.vbt.signals.between_ranges(target=mask_sr2, relation="manyone")
>>> ranges
>>> ranges.readable
Range Id Column Start Index End Index Status
0 0 0 0 2 Closed
1 1 0 1 2 Closed
2 2 0 2 2 Closed
>>> ranges.duration.values
array([0, 2])
```
"""
if broadcast_kwargs is None:
broadcast_kwargs = {}
if isinstance(relation, str):
relation = map_enum_fields(relation, enums.SignalRelation)
if target is None:
func = jit_reg.resolve_option(nb.between_ranges_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
range_records = func(self.to_2d_array(), incl_open=incl_open)
wrapper = self.wrapper
to_attach = self.obj
else:
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
broadcasted_args, wrapper = reshaping.broadcast(
dict(obj=self.obj, target=target),
return_wrapper=True,
**broadcast_kwargs,
)
func = jit_reg.resolve_option(nb.between_two_ranges_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
range_records = func(
broadcasted_args["obj"],
broadcasted_args["target"],
relation=relation,
incl_open=incl_open,
)
to_attach = broadcasted_args["target"] if attach_target else broadcasted_args["obj"]
kwargs = merge_dicts(dict(close=to_attach), kwargs)
return Ranges.from_records(wrapper, range_records, **kwargs).regroup(group_by)
def partition_ranges(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> Ranges:
"""Wrap the result of `vectorbtpro.signals.nb.partition_ranges_nb`
with `vectorbtpro.generic.ranges.Ranges`.
If `use_end_idxs` is True, uses the index of the last signal in each partition as `idx_arr`.
Otherwise, uses the index of the first signal.
Usage:
```pycon
>>> mask_sr = pd.Series([True, True, True, False, True, True])
>>> mask_sr.vbt.signals.partition_ranges().readable
Range Id Column Start Timestamp End Timestamp Status
0 0 0 0 3 Closed
1 1 0 4 5 Open
```
"""
func = jit_reg.resolve_option(nb.partition_ranges_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
range_records = func(self.to_2d_array())
kwargs = merge_dicts(dict(close=self.obj), kwargs)
return Ranges.from_records(self.wrapper, range_records, **kwargs).regroup(group_by)
def between_partition_ranges(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
**kwargs,
) -> Ranges:
"""Wrap the result of `vectorbtpro.signals.nb.between_partition_ranges_nb`
with `vectorbtpro.generic.ranges.Ranges`.
Usage:
```pycon
>>> mask_sr = pd.Series([True, False, False, True, False, True, True])
>>> mask_sr.vbt.signals.between_partition_ranges().readable
Range Id Column Start Timestamp End Timestamp Status
0 0 0 0 3 Closed
1 1 0 3 5 Closed
```
"""
func = jit_reg.resolve_option(nb.between_partition_ranges_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
range_records = func(self.to_2d_array())
kwargs = merge_dicts(dict(close=self.obj), kwargs)
return Ranges.from_records(self.wrapper, range_records, **kwargs).regroup(group_by)
# ############# Raveling ############# #
@classmethod
def index_from_unravel(
cls,
range_: tp.Array1d,
row_idxs: tp.Array1d,
index: tp.Index,
signal_index_type: str = "range",
signal_index_name: str = "signal",
):
"""Get index from an unraveling operation."""
if signal_index_type.lower() == "range":
return pd.Index(range_, name=signal_index_name)
if signal_index_type.lower() in ("position", "positions"):
return pd.Index(row_idxs, name=signal_index_name)
if signal_index_type.lower() in ("label", "labels"):
if -1 in row_idxs:
raise ValueError("Some columns have no signals. Use other signal index types.")
return pd.Index(index[row_idxs], name=signal_index_name)
raise ValueError(f"Invalid signal_index_type: '{signal_index_type}'")
def unravel(
self,
incl_empty_cols: bool = True,
force_signal_index: bool = False,
signal_index_type: str = "range",
signal_index_name: str = "signal",
jitted: tp.JittedOption = None,
clean_index_kwargs: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeTuple[tp.SeriesFrame]:
"""Unravel signals.
See `vectorbtpro.signals.nb.unravel_nb`.
Argument `signal_index_type` takes the following values:
* "range": Basic signal counter in a column
* "position(s)": Integer position (row) of signal in a column
* "label(s)": Label of signal in a column
Usage:
```pycon
>>> mask.vbt.signals.unravel()
signal 0 0 1 2 0 1 2
a b b b c c c
2020-01-01 True True False False True False False
2020-01-02 False False False False False True False
2020-01-03 False False True False False False True
2020-01-04 False False False False False False False
2020-01-05 False False False True False False False
```
"""
if clean_index_kwargs is None:
clean_index_kwargs = {}
if wrap_kwargs is None:
wrap_kwargs = {}
func = jit_reg.resolve_option(nb.unravel_nb, jitted)
new_mask, range_, row_idxs, col_idxs = func(self.to_2d_array(), incl_empty_cols=incl_empty_cols)
if new_mask.shape == self.wrapper.shape_2d and incl_empty_cols and not force_signal_index:
return self.wrapper.wrap(new_mask)
if not incl_empty_cols and (row_idxs == -1).all():
raise ValueError("No columns left")
signal_index = self.index_from_unravel(
range_,
row_idxs,
self.wrapper.index,
signal_index_type=signal_index_type,
signal_index_name=signal_index_name,
)
new_columns = indexes.stack_indexes((signal_index, self.wrapper.columns[col_idxs]), **clean_index_kwargs)
return self.wrapper.wrap(new_mask, columns=new_columns, group_by=False, **wrap_kwargs)
@hybrid_method
def unravel_between(
cls_or_self,
*objs,
relation: tp.Union[int, str] = "onemany",
incl_open_source: bool = False,
incl_open_target: bool = False,
incl_empty_cols: bool = True,
broadcast_kwargs: tp.KwargsLike = None,
force_signal_index: bool = False,
signal_index_type: str = "pair_range",
signal_index_name: str = "signal",
jitted: tp.JittedOption = None,
clean_index_kwargs: tp.KwargsLike = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeTuple[tp.SeriesFrame]:
"""Unravel signal pairs.
If one array is passed, see `vectorbtpro.signals.nb.unravel_between_nb`.
If two arrays are passed, see `vectorbtpro.signals.nb.unravel_between_two_nb`.
Argument `signal_index_type` takes the following values:
* "pair_range": Basic pair counter in a column
* "range": Basic signal counter in a column
* "source_range": Basic signal counter in a source column
* "target_range": Basic signal counter in a target column
* "position(s)": Integer position (row) of signal in a column
* "source_position(s)": Integer position (row) of signal in a source column
* "target_position(s)": Integer position (row) of signal in a target column
* "label(s)": Label of signal in a column
* "source_label(s)": Label of signal in a source column
* "target_label(s)": Label of signal in a target column
Usage:
* One mask:
```pycon
>>> mask.vbt.signals.unravel_between()
signal -1 0 1 0 1
a b b c c
2020-01-01 False True False True False
2020-01-02 False False False True True
2020-01-03 False True True False True
2020-01-04 False False False False False
2020-01-05 False False True False False
>>> mask.vbt.signals.unravel_between(signal_index_type="position")
source_signal -1 0 2 0 1
target_signal -1 2 4 1 2
a b b c c
2020-01-01 False True False True False
2020-01-02 False False False True True
2020-01-03 False True True False True
2020-01-04 False False False False False
2020-01-05 False False True False False
```
* Two masks:
```pycon
>>> source_mask = pd.Series([True, True, False, False, True, True])
>>> target_mask = pd.Series([False, False, True, True, False, False])
>>> new_source_mask, new_target_mask = vbt.pd_acc.signals.unravel_between(
... source_mask,
... target_mask
... )
>>> new_source_mask
signal 0 1
0 False False
1 True True
2 False False
3 False False
4 False False
5 False False
>>> new_target_mask
signal 0 1
0 False False
1 False False
2 True False
3 False True
4 False False
5 False False
>>> new_source_mask, new_target_mask = vbt.pd_acc.signals.unravel_between(
... source_mask,
... target_mask,
... relation="chain"
... )
>>> new_source_mask
signal 0 1
0 True False
1 False False
2 False False
3 False False
4 False True
5 False False
>>> new_target_mask
signal 0 1
0 False False
1 False False
2 True True
3 False False
4 False False
5 False False
```
"""
if broadcast_kwargs is None:
broadcast_kwargs = {}
if clean_index_kwargs is None:
clean_index_kwargs = {}
if wrap_kwargs is None:
wrap_kwargs = {}
if isinstance(relation, str):
relation = map_enum_fields(relation, enums.SignalRelation)
signal_index_type = signal_index_type.lower()
if not isinstance(cls_or_self, type):
objs = (cls_or_self.obj, *objs)
def _build_new_columns(
source_range,
target_range,
source_idxs,
target_idxs,
col_idxs,
):
indexes_to_stack = []
if signal_index_type == "pair_range":
one_points = np.concatenate((np.array([0]), col_idxs[1:] - col_idxs[:-1]))
basic_range = np.arange(len(col_idxs))
range_points = np.where(one_points == 1, basic_range, one_points)
signal_range = basic_range - np.maximum.accumulate(range_points)
signal_range[(source_range == -1) & (target_range == -1)] = -1
indexes_to_stack.append(pd.Index(signal_range, name=signal_index_name))
else:
if not signal_index_type.startswith("target_"):
indexes_to_stack.append(
cls_or_self.index_from_unravel(
source_range,
source_idxs,
wrapper.index,
signal_index_type=signal_index_type.replace("source_", ""),
signal_index_name="source_" + signal_index_name,
)
)
if not signal_index_type.startswith("source_"):
indexes_to_stack.append(
cls_or_self.index_from_unravel(
target_range,
target_idxs,
wrapper.index,
signal_index_type=signal_index_type.replace("target_", ""),
signal_index_name="target_" + signal_index_name,
)
)
if len(indexes_to_stack) == 1:
indexes_to_stack[0] = indexes_to_stack[0].rename(signal_index_name)
return indexes.stack_indexes((*indexes_to_stack, wrapper.columns[col_idxs]), **clean_index_kwargs)
if len(objs) == 1:
obj = objs[0]
wrapper = ArrayWrapper.from_obj(obj)
if not isinstance(obj, (pd.Series, pd.DataFrame)):
obj = wrapper.wrap(obj)
func = jit_reg.resolve_option(nb.unravel_between_nb, jitted)
new_mask, source_range, target_range, source_idxs, target_idxs, col_idxs = func(
reshaping.to_2d_array(obj),
incl_open_source=incl_open_source,
incl_empty_cols=incl_empty_cols,
)
if new_mask.shape == wrapper.shape_2d and incl_empty_cols and not force_signal_index:
return wrapper.wrap(new_mask)
if not incl_empty_cols and (source_idxs == -1).all():
raise ValueError("No columns left")
new_columns = _build_new_columns(
source_range,
target_range,
source_idxs,
target_idxs,
col_idxs,
)
return wrapper.wrap(new_mask, columns=new_columns, group_by=False, **wrap_kwargs)
if len(objs) == 2:
source = objs[0]
target = objs[1]
broadcast_kwargs = merge_dicts(dict(to_pd=False, min_ndim=2), broadcast_kwargs)
broadcasted_args, wrapper = reshaping.broadcast(
dict(source=source, target=target),
return_wrapper=True,
**broadcast_kwargs,
)
func = jit_reg.resolve_option(nb.unravel_between_two_nb, jitted)
new_source_mask, new_target_mask, source_range, target_range, source_idxs, target_idxs, col_idxs = func(
broadcasted_args["source"],
broadcasted_args["target"],
relation=relation,
incl_open_source=incl_open_source,
incl_open_target=incl_open_target,
incl_empty_cols=incl_empty_cols,
)
if new_source_mask.shape == wrapper.shape_2d and incl_empty_cols and not force_signal_index:
return wrapper.wrap(new_source_mask), wrapper.wrap(new_target_mask)
if not incl_empty_cols and (source_idxs == -1).all() and (target_idxs == -1).all():
raise ValueError("No columns left")
new_columns = _build_new_columns(
source_range,
target_range,
source_idxs,
target_idxs,
col_idxs,
)
new_source_mask = wrapper.wrap(new_source_mask, columns=new_columns, group_by=False, **wrap_kwargs)
new_target_mask = wrapper.wrap(new_target_mask, columns=new_columns, group_by=False, **wrap_kwargs)
return new_source_mask, new_target_mask
raise ValueError("This method accepts either one or two arrays")
def ravel(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.SeriesFrame:
"""Ravel signals.
See `vectorbtpro.signals.nb.ravel_nb`.
Usage:
```pycon
>>> unravel_mask = mask.vbt.signals.unravel()
>>> original_mask = unravel_mask.vbt.signals.ravel(group_by=vbt.ExceptLevel("signal"))
>>> original_mask
a b c
2020-01-01 True True True
2020-01-02 False False True
2020-01-03 False True True
2020-01-04 False False False
2020-01-05 False True False
```
"""
if wrap_kwargs is None:
wrap_kwargs = {}
group_map = self.wrapper.grouper.get_group_map(group_by=group_by)
func = jit_reg.resolve_option(nb.ravel_nb, jitted)
new_mask = func(self.to_2d_array(), group_map)
return self.wrapper.wrap(new_mask, group_by=group_by, **wrap_kwargs)
# ############# Index ############# #
def nth_index(
self,
n: int,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""See `vectorbtpro.signals.nb.nth_index_nb`.
Usage:
```pycon
>>> mask.vbt.signals.nth_index(0)
a 2020-01-01
b 2020-01-01
c 2020-01-01
Name: nth_index, dtype: datetime64[ns]
>>> mask.vbt.signals.nth_index(2)
a NaT
b 2020-01-05
c 2020-01-03
Name: nth_index, dtype: datetime64[ns]
>>> mask.vbt.signals.nth_index(-1)
a 2020-01-01
b 2020-01-05
c 2020-01-03
Name: nth_index, dtype: datetime64[ns]
>>> mask.vbt.signals.nth_index(-1, group_by=True)
Timestamp('2020-01-05 00:00:00')
```
"""
if self.is_frame() and self.wrapper.grouper.is_grouped(group_by=group_by):
squeezed = self.squeeze_grouped(
jit_reg.resolve_option(generic_nb.any_reduce_nb, jitted),
group_by=group_by,
jitted=jitted,
chunked=chunked,
)
arr = reshaping.to_2d_array(squeezed)
else:
arr = self.to_2d_array()
func = jit_reg.resolve_option(nb.nth_index_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
nth_index = func(arr, n)
wrap_kwargs = merge_dicts(dict(name_or_index="nth_index", to_index=True), wrap_kwargs)
return self.wrapper.wrap_reduced(nth_index, group_by=group_by, **wrap_kwargs)
def norm_avg_index(
self,
group_by: tp.GroupByLike = None,
jitted: tp.JittedOption = None,
chunked: tp.ChunkedOption = None,
wrap_kwargs: tp.KwargsLike = None,
) -> tp.MaybeSeries:
"""See `vectorbtpro.signals.nb.norm_avg_index_nb`.
Normalized average index measures the average signal location relative to the middle of the column.
This way, we can quickly see where the majority of signals are located.
Common values are:
* -1.0: only the first signal is set
* 1.0: only the last signal is set
* 0.0: symmetric distribution around the middle
* [-1.0, 0.0): average signal is on the left
* (0.0, 1.0]: average signal is on the right
Usage:
```pycon
>>> pd.Series([True, False, False, False]).vbt.signals.norm_avg_index()
-1.0
>>> pd.Series([False, False, False, True]).vbt.signals.norm_avg_index()
1.0
>>> pd.Series([True, False, False, True]).vbt.signals.norm_avg_index()
0.0
```
"""
if self.is_frame() and self.wrapper.grouper.is_grouped(group_by=group_by):
group_lens = self.wrapper.grouper.get_group_lens(group_by=group_by)
func = jit_reg.resolve_option(nb.norm_avg_index_grouped_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
norm_avg_index = func(self.to_2d_array(), group_lens)
else:
func = jit_reg.resolve_option(nb.norm_avg_index_nb, jitted)
func = ch_reg.resolve_option(func, chunked)
norm_avg_index = func(self.to_2d_array())
wrap_kwargs = merge_dicts(dict(name_or_index="norm_avg_index"), wrap_kwargs)
return self.wrapper.wrap_reduced(norm_avg_index, group_by=group_by, **wrap_kwargs)
def index_mapped(self, group_by: tp.GroupByLike = None, **kwargs) -> MappedArray:
"""Get a mapped array of indices.
See `vectorbtpro.generic.accessors.GenericAccessor.to_mapped`.
Only True values will be considered."""
indices = np.arange(len(self.wrapper.index), dtype=float_)[:, None]
indices = np.tile(indices, (1, len(self.wrapper.columns)))
indices = reshaping.soft_to_ndim(indices, self.wrapper.ndim)
indices[~self.obj.values] = np.nan
return self.wrapper.wrap(indices).vbt.to_mapped(dropna=True, dtype=int_, group_by=group_by, **kwargs)
def total(self, wrap_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None) -> tp.MaybeSeries:
"""Total number of True values in each column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="total"), wrap_kwargs)
return self.sum(group_by=group_by, wrap_kwargs=wrap_kwargs)
def rate(self, wrap_kwargs: tp.KwargsLike = None, group_by: tp.GroupByLike = None, **kwargs) -> tp.MaybeSeries:
"""`SignalsAccessor.total` divided by the total index length in each column/group."""
total = reshaping.to_1d_array(self.total(group_by=group_by, **kwargs))
wrap_kwargs = merge_dicts(dict(name_or_index="rate"), wrap_kwargs)
total_steps = self.wrapper.grouper.get_group_lens(group_by=group_by) * self.wrapper.shape[0]
return self.wrapper.wrap_reduced(total / total_steps, group_by=group_by, **wrap_kwargs)
def total_partitions(
self,
wrap_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""Total number of partitions of True values in each column/group."""
wrap_kwargs = merge_dicts(dict(name_or_index="total_partitions"), wrap_kwargs)
return self.partition_ranges(**kwargs).count(group_by=group_by, wrap_kwargs=wrap_kwargs)
def partition_rate(
self,
wrap_kwargs: tp.KwargsLike = None,
group_by: tp.GroupByLike = None,
**kwargs,
) -> tp.MaybeSeries:
"""`SignalsAccessor.total_partitions` divided by `SignalsAccessor.total` in each column/group."""
total_partitions = reshaping.to_1d_array(self.total_partitions(group_by=group_by, *kwargs))
total = reshaping.to_1d_array(self.total(group_by=group_by, *kwargs))
wrap_kwargs = merge_dicts(dict(name_or_index="partition_rate"), wrap_kwargs)
return self.wrapper.wrap_reduced(total_partitions / total, group_by=group_by, **wrap_kwargs)
# ############# Stats ############# #
@property
def stats_defaults(self) -> tp.Kwargs:
"""Defaults for `SignalsAccessor.stats`.
Merges `vectorbtpro.generic.accessors.GenericAccessor.stats_defaults` and
`stats` from `vectorbtpro._settings.signals`."""
from vectorbtpro._settings import settings
signals_stats_cfg = settings["signals"]["stats"]
return merge_dicts(GenericAccessor.stats_defaults.__get__(self), signals_stats_cfg)
_metrics: tp.ClassVar[Config] = HybridConfig(
dict(
start_index=dict(
title="Start Index",
calc_func=lambda self: self.wrapper.index[0],
agg_func=None,
tags="wrapper",
),
end_index=dict(
title="End Index",
calc_func=lambda self: self.wrapper.index[-1],
agg_func=None,
tags="wrapper",
),
total_duration=dict(
title="Total Duration",
calc_func=lambda self: len(self.wrapper.index),
apply_to_timedelta=True,
agg_func=None,
tags="wrapper",
),
total=dict(title="Total", calc_func="total", tags="signals"),
rate=dict(
title="Rate [%]",
calc_func="rate",
post_calc_func=lambda self, out, settings: out * 100,
tags="signals",
),
total_overlapping=dict(
title="Total Overlapping",
calc_func=lambda self, target, group_by: (self & target).vbt.signals.total(group_by=group_by),
check_silent_has_target=True,
tags=["signals", "target"],
),
overlapping_rate=dict(
title="Overlapping Rate [%]",
calc_func=lambda self, target, group_by: (self & target).vbt.signals.total(group_by=group_by)
/ (self | target).vbt.signals.total(group_by=group_by),
post_calc_func=lambda self, out, settings: out * 100,
check_silent_has_target=True,
tags=["signals", "target"],
),
first_index=dict(
title="First Index",
calc_func="nth_index",
n=0,
wrap_kwargs=dict(to_index=True),
tags=["signals", "index"],
),
last_index=dict(
title="Last Index",
calc_func="nth_index",
n=-1,
wrap_kwargs=dict(to_index=True),
tags=["signals", "index"],
),
norm_avg_index=dict(title="Norm Avg Index [-1, 1]", calc_func="norm_avg_index", tags=["signals", "index"]),
distance=dict(
title=RepEval(
"f'Distance {self.get_relation_str(relation)} {target_name}' if target is not None else 'Distance'"
),
calc_func="between_ranges.duration",
post_calc_func=lambda self, out, settings: {
"Min": out.min(),
"Median": out.median(),
"Max": out.max(),
},
apply_to_timedelta=True,
tags=RepEval("['signals', 'distance', 'target'] if target is not None else ['signals', 'distance']"),
),
total_partitions=dict(
title="Total Partitions",
calc_func="total_partitions",
tags=["signals", "partitions"],
),
partition_rate=dict(
title="Partition Rate [%]",
calc_func="partition_rate",
post_calc_func=lambda self, out, settings: out * 100,
tags=["signals", "partitions"],
),
partition_len=dict(
title="Partition Length",
calc_func="partition_ranges.duration",
post_calc_func=lambda self, out, settings: {
"Min": out.min(),
"Median": out.median(),
"Max": out.max(),
},
apply_to_timedelta=True,
tags=["signals", "partitions", "distance"],
),
partition_distance=dict(
title="Partition Distance",
calc_func="between_partition_ranges.duration",
post_calc_func=lambda self, out, settings: {
"Min": out.min(),
"Median": out.median(),
"Max": out.max(),
},
apply_to_timedelta=True,
tags=["signals", "partitions", "distance"],
),
)
)
@property
def metrics(self) -> Config:
return self._metrics
# ############# Plotting ############# #
def plot(
self,
yref: str = "y",
column: tp.Optional[tp.Label] = None,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Plot signals.
Args:
yref (str): Y coordinate axis.
column (hashable): Column to plot.
**kwargs: Keyword arguments passed to `vectorbtpro.generic.accessors.GenericAccessor.lineplot`.
Usage:
```pycon
>>> mask[['a', 'c']].vbt.signals.plot().show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
if column is not None:
_self = self.select_col(column=column)
else:
_self = self
default_kwargs = dict(trace_kwargs=dict(line=dict(shape="hv")))
default_kwargs["yaxis" + yref[1:]] = dict(tickmode="array", tickvals=[0, 1], ticktext=["false", "true"])
return _self.obj.vbt.lineplot(**merge_dicts(default_kwargs, kwargs))
def plot_as_markers(
self,
y: tp.Optional[tp.ArrayLike] = None,
column: tp.Optional[tp.Label] = None,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Plot Series as markers.
Args:
y (array_like): Y-axis values to plot markers on.
column (hashable): Column to plot.
**kwargs: Keyword arguments passed to `vectorbtpro.generic.accessors.GenericAccessor.scatterplot`.
Usage:
```pycon
>>> ts = pd.Series([1, 2, 3, 2, 1], index=mask.index)
>>> fig = ts.vbt.lineplot()
>>> mask['b'].vbt.signals.plot_as_entries(y=ts, fig=fig)
>>> (~mask['b']).vbt.signals.plot_as_exits(y=ts, fig=fig).show()
```
{: .iimg loading=lazy }
{: .iimg loading=lazy }
"""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
obj = self.obj
if isinstance(obj, pd.DataFrame):
obj = self.select_col_from_obj(obj, column=column)
if y is None:
y = pd.Series.vbt.empty_like(obj, 1)
else:
y = reshaping.to_pd_array(y)
if isinstance(y, pd.DataFrame):
y = self.select_col_from_obj(y, column=column)
obj, y = reshaping.broadcast(obj, y, columns_from="keep")
obj = obj.fillna(False).astype(np.bool_)
if y.name is None:
y = y.rename("Y")
def_kwargs = dict(
trace_kwargs=dict(
marker=dict(
symbol="circle",
color=plotting_cfg["contrast_color_schema"]["blue"],
size=7,
),
name=obj.name,
)
)
kwargs = merge_dicts(def_kwargs, kwargs)
if "marker_color" in kwargs["trace_kwargs"]:
marker_color = kwargs["trace_kwargs"]["marker_color"]
else:
marker_color = kwargs["trace_kwargs"]["marker"]["color"]
if isinstance(marker_color, str) and "rgba" not in marker_color:
line_color = adjust_lightness(marker_color)
else:
line_color = marker_color
kwargs = merge_dicts(
dict(
trace_kwargs=dict(
marker=dict(
line=dict(width=1, color=line_color),
),
),
),
kwargs,
)
return y[obj].vbt.scatterplot(**kwargs)
def plot_as_entries(
self,
y: tp.Optional[tp.ArrayLike] = None,
column: tp.Optional[tp.Label] = None,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Plot signals as entry markers.
See `SignalsSRAccessor.plot_as_markers`."""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
return self.plot_as_markers(
y=y,
column=column,
**merge_dicts(
dict(
trace_kwargs=dict(
marker=dict(
symbol="triangle-up",
color=plotting_cfg["contrast_color_schema"]["green"],
size=8,
),
name="Entries",
)
),
kwargs,
),
)
def plot_as_exits(
self,
y: tp.Optional[tp.ArrayLike] = None,
column: tp.Optional[tp.Label] = None,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Plot signals as exit markers.
See `SignalsSRAccessor.plot_as_markers`."""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
return self.plot_as_markers(
y=y,
column=column,
**merge_dicts(
dict(
trace_kwargs=dict(
marker=dict(
symbol="triangle-down",
color=plotting_cfg["contrast_color_schema"]["red"],
size=8,
),
name="Exits",
)
),
kwargs,
),
)
def plot_as_entry_marks(
self,
y: tp.Optional[tp.ArrayLike] = None,
column: tp.Optional[tp.Label] = None,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Plot signals as marked entry markers.
See `SignalsSRAccessor.plot_as_markers`."""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
return self.plot_as_markers(
y=y,
column=column,
**merge_dicts(
dict(
trace_kwargs=dict(
marker=dict(
symbol="circle",
color="rgba(0, 0, 0, 0)",
size=20,
line=dict(
color=plotting_cfg["contrast_color_schema"]["green"],
width=2,
),
),
name="Entry marks",
)
),
kwargs,
),
)
def plot_as_exit_marks(
self,
y: tp.Optional[tp.ArrayLike] = None,
column: tp.Optional[tp.Label] = None,
**kwargs,
) -> tp.Union[tp.BaseFigure, tp.TraceUpdater]:
"""Plot signals as marked exit markers.
See `SignalsSRAccessor.plot_as_markers`."""
from vectorbtpro._settings import settings
plotting_cfg = settings["plotting"]
return self.plot_as_markers(
y=y,
column=column,
**merge_dicts(
dict(
trace_kwargs=dict(
marker=dict(
symbol="circle",
color="rgba(0, 0, 0, 0)",
size=20,
line=dict(
color=plotting_cfg["contrast_color_schema"]["red"],
width=2,
),
),
name="Exit marks",
)
),
kwargs,
),
)
@property
def plots_defaults(self) -> tp.Kwargs:
"""Defaults for `SignalsAccessor.plots`.
Merges `vectorbtpro.generic.accessors.GenericAccessor.plots_defaults` and
`plots` from `vectorbtpro._settings.signals`."""
from vectorbtpro._settings import settings
signals_plots_cfg = settings["signals"]["plots"]
return merge_dicts(GenericAccessor.plots_defaults.__get__(self), signals_plots_cfg)
@property
def subplots(self) -> Config:
return self._subplots
SignalsAccessor.override_metrics_doc(__pdoc__)
SignalsAccessor.override_subplots_doc(__pdoc__)
@register_sr_vbt_accessor("signals")
class SignalsSRAccessor(SignalsAccessor, GenericSRAccessor):
"""Accessor on top of signal series. For Series only.
Accessible via `pd.Series.vbt.signals`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
GenericSRAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs)
if _full_init:
SignalsAccessor.__init__(self, wrapper, obj=obj, **kwargs)
@register_df_vbt_accessor("signals")
class SignalsDFAccessor(SignalsAccessor, GenericDFAccessor):
"""Accessor on top of signal series. For DataFrames only.
Accessible via `pd.DataFrame.vbt.signals`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
_full_init: bool = True,
**kwargs,
) -> None:
GenericDFAccessor.__init__(self, wrapper, obj=obj, _full_init=False, **kwargs)
if _full_init:
SignalsAccessor.__init__(self, wrapper, obj=obj, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Named tuples and enumerated types for signals.
Defines enums and other schemas for `vectorbtpro.signals`."""
from vectorbtpro import _typing as tp
from vectorbtpro.utils.formatting import prettify
__pdoc__all__ = __all__ = [
"StopType",
"SignalRelation",
"FactoryMode",
"GenEnContext",
"GenExContext",
"GenEnExContext",
"RankContext",
]
__pdoc__ = {}
# ############# Enums ############# #
class StopTypeT(tp.NamedTuple):
SL: int = 0
TSL: int = 1
TTP: int = 2
TP: int = 3
TD: int = 4
DT: int = 5
StopType = StopTypeT()
"""_"""
__pdoc__[
"StopType"
] = f"""Stop type.
```python
{prettify(StopType)}
```
"""
class SignalRelationT(tp.NamedTuple):
OneOne: int = 0
OneMany: int = 1
ManyOne: int = 2
ManyMany: int = 3
Chain: int = 4
AnyChain: int = 5
SignalRelation = SignalRelationT()
"""_"""
__pdoc__[
"SignalRelation"
] = f"""SignalRelation between two masks.
```python
{prettify(SignalRelation)}
```
Attributes:
OneOne: One source signal maps to exactly one succeeding target signal.
OneMany: One source signal can map to one or more succeeding target signals.
ManyOne: One or more source signals can map to exactly one succeeding target signal.
ManyMany: One or more source signals can map to one or more succeeding target signals.
Chain: First source signal maps to the first target signal after it and vice versa.
AnyChain: First signal maps to the first opposite signal after it and vice versa.
"""
class FactoryModeT(tp.NamedTuple):
Entries: int = 0
Exits: int = 1
Both: int = 2
Chain: int = 3
FactoryMode = FactoryModeT()
"""_"""
__pdoc__[
"FactoryMode"
] = f"""Factory mode.
```python
{prettify(FactoryMode)}
```
Attributes:
Entries: Generate entries only using `generate_func_nb`.
Takes no input signal arrays.
Produces one output signal array - `entries`.
Such generators often have no suffix.
Exits: Generate exits only using `generate_ex_func_nb`.
Takes one input signal array - `entries`.
Produces one output signal array - `exits`.
Such generators often have suffix 'X'.
Both: Generate both entries and exits using `generate_enex_func_nb`.
Takes no input signal arrays.
Produces two output signal arrays - `entries` and `exits`.
Such generators often have suffix 'NX'.
Chain: Generate chain of entries and exits using `generate_enex_func_nb`.
Takes one input signal array - `entries`.
Produces two output signal arrays - `new_entries` and `exits`.
Such generators often have suffix 'CX'.
"""
# ############# Named tuples ############# #
class GenEnContext(tp.NamedTuple):
target_shape: tp.Shape
only_once: bool
wait: int
entries_out: tp.Array2d
out: tp.Array1d
from_i: int
to_i: int
col: int
__pdoc__["GenEnContext"] = "Context of an entry signal generator."
__pdoc__["GenEnContext.target_shape"] = "Target shape."
__pdoc__["GenEnContext.only_once"] = "Whether to run the placement function only once."
__pdoc__["GenEnContext.wait"] = "Number of ticks to wait before placing the next entry."
__pdoc__["GenEnContext.entries_out"] = "Output array with entries."
__pdoc__["GenEnContext.out"] = "Current segment of the output array with entries."
__pdoc__["GenEnContext.from_i"] = "Start index of the segment (inclusive)."
__pdoc__["GenEnContext.to_i"] = "End index of the segment (exclusive)."
__pdoc__["GenEnContext.col"] = "Column of the segment."
class GenExContext(tp.NamedTuple):
entries: tp.Array2d
until_next: bool
skip_until_exit: bool
exits_out: tp.Array2d
out: tp.Array1d
wait: int
from_i: int
to_i: int
col: int
__pdoc__["GenExContext"] = "Context of an exit signal generator."
__pdoc__["GenExContext.entries"] = "Input array with entries."
__pdoc__["GenExContext.until_next"] = "Whether to place signals up to the next entry signal."
__pdoc__["GenExContext.skip_until_exit"] = "Whether to skip processing entry signals until the next exit."
__pdoc__["GenExContext.exits_out"] = "Output array with exits."
__pdoc__["GenExContext.out"] = "Current segment of the output array with exits."
__pdoc__["GenExContext.wait"] = "Number of ticks to wait before placing exits."
__pdoc__["GenExContext.from_i"] = "Start index of the segment (inclusive)."
__pdoc__["GenExContext.to_i"] = "End index of the segment (exclusive)."
__pdoc__["GenExContext.col"] = "Column of the segment."
class GenEnExContext(tp.NamedTuple):
target_shape: tp.Shape
entry_wait: int
exit_wait: int
entries_out: tp.Array2d
exits_out: tp.Array2d
entries_turn: bool
wait: int
out: tp.Array1d
from_i: int
to_i: int
col: int
__pdoc__["GenExContext"] = "Context of an entry/exit signal generator."
__pdoc__["GenExContext.target_shape"] = "Target shape."
__pdoc__["GenExContext.entry_wait"] = "Number of ticks to wait before placing entries."
__pdoc__["GenExContext.exit_wait"] = "Number of ticks to wait before placing exits."
__pdoc__["GenExContext.entries_out"] = "Output array with entries."
__pdoc__["GenExContext.exits_out"] = "Output array with exits."
__pdoc__["GenExContext.entries_turn"] = "Whether the current turn is generating an entry."
__pdoc__["GenExContext.out"] = "Current segment of the output array with entries/exits."
__pdoc__["GenExContext.wait"] = "Number of ticks to wait before placing entries/exits."
__pdoc__["GenExContext.from_i"] = "Start index of the segment (inclusive)."
__pdoc__["GenExContext.to_i"] = "End index of the segment (exclusive)."
__pdoc__["GenExContext.col"] = "Column of the segment."
class RankContext(tp.NamedTuple):
mask: tp.Array2d
reset_by: tp.Optional[tp.Array2d]
after_false: bool
after_reset: bool
reset_wait: int
col: int
i: int
last_false_i: int
last_reset_i: int
all_sig_cnt: int
all_part_cnt: int
all_sig_in_part_cnt: int
nonres_sig_cnt: int
nonres_part_cnt: int
nonres_sig_in_part_cnt: int
sig_cnt: int
part_cnt: int
sig_in_part_cnt: int
__pdoc__["RankContext"] = "Context of a ranker."
__pdoc__["RankContext.mask"] = "Source mask."
__pdoc__["RankContext.reset_by"] = "Resetting mask."
__pdoc__["RankContext.after_false"] = (
"""Whether to disregard the first partition of True values if there is no False value before them."""
)
__pdoc__["RankContext.after_reset"] = (
"""Whether to disregard the first partition of True values coming before the first reset signal."""
)
__pdoc__["RankContext.reset_wait"] = """Number of ticks to wait before resetting the current partition."""
__pdoc__["RankContext.col"] = "Current column."
__pdoc__["RankContext.i"] = "Current row."
__pdoc__["RankContext.last_false_i"] = "Row of the last False value in the main mask."
__pdoc__[
"RankContext.last_reset_i"
] = """Row of the last True value in the resetting mask.
Doesn't take into account `reset_wait`."""
__pdoc__["RankContext.all_sig_cnt"] = """Number of all signals encountered including this."""
__pdoc__["RankContext.all_part_cnt"] = """Number of all partitions encountered including this."""
__pdoc__["RankContext.all_sig_in_part_cnt"] = (
"""Number of signals encountered in the current partition including this."""
)
__pdoc__["RankContext.nonres_sig_cnt"] = """Number of non-resetting signals encountered including this."""
__pdoc__["RankContext.nonres_part_cnt"] = """Number of non-resetting partitions encountered including this."""
__pdoc__["RankContext.nonres_sig_in_part_cnt"] = (
"""Number of signals encountered in the current non-resetting partition including this."""
)
__pdoc__["RankContext.sig_cnt"] = """Number of valid and resetting signals encountered including this."""
__pdoc__["RankContext.part_cnt"] = """Number of valid and resetting partitions encountered including this."""
__pdoc__["RankContext.sig_in_part_cnt"] = (
"""Number of signals encountered in the current valid and resetting partition including this."""
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Factory for building signal generators.
The signal factory class `SignalFactory` extends `vectorbtpro.indicators.factory.IndicatorFactory`
to offer a convenient way to create signal generators of any complexity. By providing it with information
such as entry and exit functions and the names of inputs, parameters, and outputs, it will create a
stand-alone class capable of generating signals for an arbitrary combination of inputs and parameters.
"""
import inspect
import numpy as np
from numba import njit
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import combining
from vectorbtpro.indicators.factory import IndicatorFactory, IndicatorBase, CacheOutputT
from vectorbtpro.registries.jit_registry import jit_reg
from vectorbtpro.signals.enums import FactoryMode
from vectorbtpro.signals.nb import generate_nb, generate_ex_nb, generate_enex_nb, first_place_nb
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.enum_ import map_enum_fields
from vectorbtpro.utils.params import to_typed_list
__all__ = [
"SignalFactory",
]
class SignalFactory(IndicatorFactory):
"""A factory for building signal generators.
Extends `vectorbtpro.indicators.factory.IndicatorFactory` with place functions.
Generates a fixed number of outputs (depending upon `mode`).
If you need to generate other outputs, use in-place outputs (via `in_output_names`).
See `vectorbtpro.signals.enums.FactoryMode` for supported generation modes.
Other arguments are passed to `vectorbtpro.indicators.factory.IndicatorFactory`.
"""
def __init__(
self,
*args,
mode: tp.Union[str, int] = FactoryMode.Both,
input_names: tp.Optional[tp.Sequence[str]] = None,
attr_settings: tp.KwargsLike = None,
**kwargs,
) -> None:
mode = map_enum_fields(mode, FactoryMode)
if input_names is None:
input_names = []
else:
input_names = list(input_names)
if attr_settings is None:
attr_settings = {}
if "entries" in input_names:
raise ValueError("entries cannot be used in input_names")
if "exits" in input_names:
raise ValueError("exits cannot be used in input_names")
if mode == FactoryMode.Entries:
output_names = ["entries"]
elif mode == FactoryMode.Exits:
input_names = ["entries"] + input_names
output_names = ["exits"]
elif mode == FactoryMode.Both:
output_names = ["entries", "exits"]
else:
input_names = ["entries"] + input_names
output_names = ["new_entries", "exits"]
if "entries" in input_names:
attr_settings["entries"] = dict(dtype=np.bool_)
for output_name in output_names:
attr_settings[output_name] = dict(dtype=np.bool_)
IndicatorFactory.__init__(
self,
*args,
mode=mode,
input_names=input_names,
output_names=output_names,
attr_settings=attr_settings,
**kwargs,
)
self._mode = mode
def plot(
_self,
column: tp.Optional[tp.Label] = None,
entry_y: tp.Union[None, str, tp.ArrayLike] = None,
exit_y: tp.Union[None, str, tp.ArrayLike] = None,
entry_types: tp.Optional[tp.ArrayLike] = None,
exit_types: tp.Optional[tp.ArrayLike] = None,
entry_trace_kwargs: tp.KwargsLike = None,
exit_trace_kwargs: tp.KwargsLike = None,
fig: tp.Optional[tp.BaseFigure] = None,
**kwargs,
) -> tp.BaseFigure:
self_col = _self.select_col(column=column, group_by=False)
if entry_y is not None and isinstance(entry_y, str):
entry_y = getattr(self_col, entry_y)
if exit_y is not None and isinstance(exit_y, str):
exit_y = getattr(self_col, exit_y)
if entry_trace_kwargs is None:
entry_trace_kwargs = {}
if exit_trace_kwargs is None:
exit_trace_kwargs = {}
entry_trace_kwargs = merge_dicts(
dict(name="New Entries" if mode == FactoryMode.Chain else "Entries"),
entry_trace_kwargs,
)
exit_trace_kwargs = merge_dicts(dict(name="Exits"), exit_trace_kwargs)
if entry_types is not None:
entry_types = np.asarray(entry_types)
entry_trace_kwargs = merge_dicts(
dict(customdata=entry_types, hovertemplate="(%{x}, %{y}) Type: %{customdata}"),
entry_trace_kwargs,
)
if exit_types is not None:
exit_types = np.asarray(exit_types)
exit_trace_kwargs = merge_dicts(
dict(customdata=exit_types, hovertemplate="(%{x}, %{y}) Type: %{customdata}"),
exit_trace_kwargs,
)
if mode == FactoryMode.Entries:
fig = self_col.entries.vbt.signals.plot_as_entries(
y=entry_y,
trace_kwargs=entry_trace_kwargs,
fig=fig,
**kwargs,
)
elif mode == FactoryMode.Exits:
fig = self_col.entries.vbt.signals.plot_as_entries(
y=entry_y,
trace_kwargs=entry_trace_kwargs,
fig=fig,
**kwargs,
)
fig = self_col.exits.vbt.signals.plot_as_exits(
y=exit_y,
trace_kwargs=exit_trace_kwargs,
fig=fig,
**kwargs,
)
elif mode == FactoryMode.Both:
fig = self_col.entries.vbt.signals.plot_as_entries(
y=entry_y,
trace_kwargs=entry_trace_kwargs,
fig=fig,
**kwargs,
)
fig = self_col.exits.vbt.signals.plot_as_exits(
y=exit_y,
trace_kwargs=exit_trace_kwargs,
fig=fig,
**kwargs,
)
else:
fig = self_col.new_entries.vbt.signals.plot_as_entries(
y=entry_y,
trace_kwargs=entry_trace_kwargs,
fig=fig,
**kwargs,
)
fig = self_col.exits.vbt.signals.plot_as_exits(
y=exit_y,
trace_kwargs=exit_trace_kwargs,
fig=fig,
**kwargs,
)
return fig
plot.__doc__ = """Plot `{0}.{1}` and `{0}.exits`.
Args:
entry_y (array_like): Y-axis values to plot entry markers on.
exit_y (array_like): Y-axis values to plot exit markers on.
entry_types (array_like): Entry types in string format.
exit_types (array_like): Exit types in string format.
entry_trace_kwargs (dict): Keyword arguments passed to
`vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_entries` for `{0}.{1}`.
exit_trace_kwargs (dict): Keyword arguments passed to
`vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_exits` for `{0}.exits`.
fig (Figure or FigureWidget): Figure to add traces to.
**kwargs: Keyword arguments passed to `vectorbtpro.signals.accessors.SignalsSRAccessor.plot_as_markers`.
""".format(
self.class_name,
"new_entries" if mode == FactoryMode.Chain else "entries",
)
setattr(self.Indicator, "plot", plot)
@property
def mode(self):
"""Factory mode."""
return self._mode
def with_place_func(
self,
entry_place_func_nb: tp.Optional[tp.PlaceFunc] = None,
exit_place_func_nb: tp.Optional[tp.PlaceFunc] = None,
generate_func_nb: tp.Optional[tp.Callable] = None,
generate_ex_func_nb: tp.Optional[tp.Callable] = None,
generate_enex_func_nb: tp.Optional[tp.Callable] = None,
cache_func: tp.Callable = None,
entry_settings: tp.KwargsLike = None,
exit_settings: tp.KwargsLike = None,
cache_settings: tp.KwargsLike = None,
jit_kwargs: tp.KwargsLike = None,
jitted: tp.JittedOption = None,
**kwargs,
) -> tp.Type[IndicatorBase]:
"""Build signal generator class around entry and exit placement functions.
A placement function is simply a function that places signals.
There are two types of it: entry placement function and exit placement function.
Each placement function takes broadcast time series, broadcast in-place output time series,
broadcast parameter arrays, and other arguments, and returns an array of indices
corresponding to chosen signals. See `vectorbtpro.signals.nb.generate_nb`.
Args:
entry_place_func_nb (callable): `place_func_nb` that returns indices of entries.
Defaults to `vectorbtpro.signals.nb.first_place_nb` for `FactoryMode.Chain`.
exit_place_func_nb (callable): `place_func_nb` that returns indices of exits.
generate_func_nb (callable): Entry generation function.
Defaults to `vectorbtpro.signals.nb.generate_nb`.
generate_ex_func_nb (callable): Exit generation function.
Defaults to `vectorbtpro.signals.nb.generate_ex_nb`.
generate_enex_func_nb (callable): Entry and exit generation function.
Defaults to `vectorbtpro.signals.nb.generate_enex_nb`.
cache_func (callable): A caching function to preprocess data beforehand.
All returned objects will be passed as last arguments to placement functions.
entry_settings (dict): Settings dict for `entry_place_func_nb`.
exit_settings (dict): Settings dict for `exit_place_func_nb`.
cache_settings (dict): Settings dict for `cache_func`.
jit_kwargs (dict): Keyword arguments passed to `@njit` decorator of the parameter selection function.
By default, has `nogil` set to True.
jitted (any): See `vectorbtpro.utils.jitting.resolve_jitted_option`.
Gets applied to generation functions only. If the respective generation
function is not jitted, then the apply function won't be jitted as well.
**kwargs: Keyword arguments passed to `IndicatorFactory.with_custom_func`.
!!! note
Choice functions must be Numba-compiled.
Which inputs, parameters and arguments to pass to each function must be
explicitly indicated in the function's settings dict. By default, nothing is passed.
Passing keyword arguments directly to the placement functions is not supported.
Use `pass_kwargs` in a settings dict to pass keyword arguments as positional.
Settings dict of each function can have the following keys:
Attributes:
pass_inputs (list of str): Input names to pass to the placement function.
Defaults to []. Order matters. Each name must be in `input_names`.
pass_in_outputs (list of str): In-place output names to pass to the placement function.
Defaults to []. Order matters. Each name must be in `in_output_names`.
pass_params (list of str): Parameter names to pass to the placement function.
Defaults to []. Order matters. Each name must be in `param_names`.
pass_kwargs (dict, list of str or list of tuple): Keyword arguments from `kwargs` dict to
pass as positional arguments to the placement function.
Defaults to []. Order matters.
If any element is a tuple, must contain the name and the default value.
If any element is a string, the default value is None.
Built-in keys include:
* `input_shape`: Input shape if no input time series passed.
Default is provided by the pipeline if `pass_input_shape` is True.
* `wait`: Number of ticks to wait before placing signals.
Default is 1.
* `until_next`: Whether to place signals up to the next entry signal.
Default is True. Applied in `generate_ex_func_nb` only.
* `skip_until_exit`: Whether to skip processing entry signals until the next exit.
Default is False. Applied in `generate_ex_func_nb` only.
* `pick_first`: Whether to stop as soon as the first exit signal is found.
Default is False with `FactoryMode.Entries`, otherwise is True.
* `temp_idx_arr`: Empty integer array used to temporarily store indices.
Default is an automatically generated array of shape `input_shape[0]`.
You can also pass `temp_idx_arr1`, `temp_idx_arr2`, etc. to generate multiple.
pass_cache (bool): Whether to pass cache from `cache_func` to the placement function.
Defaults to False. Cache is passed unpacked.
The following arguments can be passed to `run` and `run_combs` methods:
Args:
*args: Can be used instead of `place_args`.
place_args (tuple): Arguments passed to any placement function (depending on the mode).
entry_place_args (tuple): Arguments passed to the entry placement function.
exit_place_args (tuple): Arguments passed to the exit placement function.
entry_args (tuple): Alias for `entry_place_args`.
exit_args (tuple): Alias for `exit_place_args`.
cache_args (tuple): Arguments passed to the cache function.
entry_kwargs (tuple): Settings for the entry placement function. Also contains arguments
passed as positional if in `pass_kwargs`.
exit_kwargs (tuple): Settings for the exit placement function. Also contains arguments
passed as positional if in `pass_kwargs`.
cache_kwargs (tuple): Settings for the cache function. Also contains arguments
passed as positional if in `pass_kwargs`.
return_cache (bool): Whether to return only cache.
use_cache (any): Cache to use.
**kwargs: Default keyword arguments (depending on the mode).
For more arguments, see `vectorbtpro.indicators.factory.IndicatorBase.run_pipeline`.
Usage:
* The simplest signal indicator that places True at the very first index:
```pycon
>>> from vectorbtpro import *
>>> @njit
... def entry_place_func_nb(c):
... c.out[0] = True
... return 0
>>> @njit
... def exit_place_func_nb(c):
... c.out[0] = True
... return 0
>>> MySignals = vbt.SignalFactory().with_place_func(
... entry_place_func_nb=entry_place_func_nb,
... exit_place_func_nb=exit_place_func_nb,
... entry_kwargs=dict(wait=1),
... exit_kwargs=dict(wait=1)
... )
>>> my_sig = MySignals.run(input_shape=(3, 3))
>>> my_sig.entries
0 1 2
0 True True True
1 False False False
2 True True True
>>> my_sig.exits
0 1 2
0 False False False
1 True True True
2 False False False
```
* Take the first entry and place an exit after waiting `n` ticks. Find the next entry and repeat.
Test three different `n` values.
```pycon
>>> from vectorbtpro.signals.factory import SignalFactory
>>> @njit
... def wait_place_nb(c, n):
... if n < len(c.out):
... c.out[n] = True
... return n
... return -1
>>> # Build signal generator
>>> MySignals = SignalFactory(
... mode='chain',
... param_names=['n']
... ).with_place_func(
... exit_place_func_nb=wait_place_nb,
... exit_settings=dict(
... pass_params=['n']
... )
... )
>>> # Run signal generator
>>> entries = [True, True, True, True, True]
>>> my_sig = MySignals.run(entries, [0, 1, 2])
>>> my_sig.entries # input entries
custom_n 0 1 2
0 True True True
1 True True True
2 True True True
3 True True True
4 True True True
>>> my_sig.new_entries # output entries
custom_n 0 1 2
0 True True True
1 False False False
2 True False False
3 False True False
4 True False True
>>> my_sig.exits # output exits
custom_n 0 1 2
0 False False False
1 True False False
2 False True False
3 True False True
4 False False False
```
* To combine multiple iterative signals, you would need to create a custom placement function.
Here is an example of combining two random generators using "OR" rule (the first signal wins):
```pycon
>>> from vectorbtpro.indicators.configs import flex_elem_param_config
>>> from vectorbtpro.signals.factory import SignalFactory
>>> from vectorbtpro.signals.nb import rand_by_prob_place_nb
>>> # Enum to distinguish random generators
>>> RandType = namedtuple('RandType', ['R1', 'R2'])(0, 1)
>>> # Define exit placement function
>>> @njit
... def rand_exit_place_nb(c, rand_type, prob1, prob2):
... for out_i in range(len(c.out)):
... if np.random.uniform(0, 1) < prob1:
... c.out[out_i] = True
... rand_type[c.from_i + out_i] = RandType.R1
... return out_i
... if np.random.uniform(0, 1) < prob2:
... c.out[out_i] = True
... rand_type[c.from_i + out_i] = RandType.R2
... return out_i
... return -1
>>> # Build signal generator
>>> MySignals = SignalFactory(
... mode='chain',
... in_output_names=['rand_type'],
... param_names=['prob1', 'prob2'],
... attr_settings=dict(
... rand_type=dict(dtype=RandType) # creates rand_type_readable
... )
... ).with_place_func(
... exit_place_func_nb=rand_exit_place_nb,
... exit_settings=dict(
... pass_in_outputs=['rand_type'],
... pass_params=['prob1', 'prob2']
... ),
... param_settings=dict(
... prob1=flex_elem_param_config, # param per frame/row/col/element
... prob2=flex_elem_param_config
... ),
... rand_type=-1 # fill with this value
... )
>>> # Run signal generator
>>> entries = [True, True, True, True, True]
>>> my_sig = MySignals.run(entries, [0., 1.], [0., 1.], param_product=True)
>>> my_sig.new_entries
custom_prob1 0.0 1.0
custom_prob2 0.0 1.0 0.0 1.0
0 True True True True
1 False False False False
2 False True True True
3 False False False False
4 False True True True
>>> my_sig.exits
custom_prob1 0.0 1.0
custom_prob2 0.0 1.0 0.0 1.0
0 False False False False
1 False True True True
2 False False False False
3 False True True True
4 False False False False
>>> my_sig.rand_type_readable
custom_prob1 0.0 1.0
custom_prob2 0.0 1.0 0.0 1.0
0
1 R2 R1 R1
2
3 R2 R1 R1
4
```
"""
Indicator = self.Indicator
setattr(Indicator, "entry_place_func_nb", entry_place_func_nb)
setattr(Indicator, "exit_place_func_nb", exit_place_func_nb)
module_name = self.module_name
mode = self.mode
input_names = self.input_names
param_names = self.param_names
in_output_names = self.in_output_names
if generate_func_nb is None:
generate_func_nb = generate_nb
if generate_ex_func_nb is None:
generate_ex_func_nb = generate_ex_nb
if generate_enex_func_nb is None:
generate_enex_func_nb = generate_enex_nb
if jitted is not None:
generate_func_nb = jit_reg.resolve_option(generate_func_nb, jitted)
generate_ex_func_nb = jit_reg.resolve_option(generate_ex_func_nb, jitted)
generate_enex_func_nb = jit_reg.resolve_option(generate_enex_func_nb, jitted)
default_chain_entry_func = True
if mode == FactoryMode.Entries:
jit_apply_func = checks.is_numba_func(generate_func_nb)
require_input_shape = len(input_names) == 0
checks.assert_not_none(entry_place_func_nb, arg_name="entry_place_func_nb")
if exit_place_func_nb is not None:
raise ValueError("exit_place_func_nb cannot be used with FactoryMode.Entries")
elif mode == FactoryMode.Exits:
jit_apply_func = checks.is_numba_func(generate_ex_func_nb)
require_input_shape = False
if entry_place_func_nb is not None:
raise ValueError("entry_place_func_nb cannot be used with FactoryMode.Exits")
checks.assert_not_none(exit_place_func_nb, arg_name="exit_place_func_nb")
elif mode == FactoryMode.Both:
jit_apply_func = checks.is_numba_func(generate_enex_func_nb)
require_input_shape = len(input_names) == 0
checks.assert_not_none(entry_place_func_nb, arg_name="entry_place_func_nb")
checks.assert_not_none(exit_place_func_nb, arg_name="exit_place_func_nb")
else:
jit_apply_func = checks.is_numba_func(generate_enex_func_nb)
require_input_shape = False
if entry_place_func_nb is None:
entry_place_func_nb = first_place_nb
else:
default_chain_entry_func = False
if entry_settings is None:
entry_settings = {}
entry_settings = merge_dicts(dict(pass_inputs=["entries"]), entry_settings)
checks.assert_not_none(entry_place_func_nb, arg_name="entry_place_func_nb")
checks.assert_not_none(exit_place_func_nb, arg_name="exit_place_func_nb")
require_input_shape = kwargs.pop("require_input_shape", require_input_shape)
if entry_settings is None:
entry_settings = {}
if exit_settings is None:
exit_settings = {}
if cache_settings is None:
cache_settings = {}
valid_keys = ["pass_inputs", "pass_in_outputs", "pass_params", "pass_kwargs", "pass_cache"]
checks.assert_dict_valid(entry_settings, valid_keys)
checks.assert_dict_valid(exit_settings, valid_keys)
checks.assert_dict_valid(cache_settings, valid_keys)
# Get input names for each function
def _get_func_names(func_settings: tp.Kwargs, setting: str, all_names: tp.Sequence[str]) -> tp.List[str]:
func_input_names = func_settings.get(setting, None)
if func_input_names is None:
return []
else:
for name in func_input_names:
checks.assert_in(name, all_names)
return func_input_names
entry_input_names = _get_func_names(entry_settings, "pass_inputs", input_names)
exit_input_names = _get_func_names(exit_settings, "pass_inputs", input_names)
cache_input_names = _get_func_names(cache_settings, "pass_inputs", input_names)
entry_in_output_names = _get_func_names(entry_settings, "pass_in_outputs", in_output_names)
exit_in_output_names = _get_func_names(exit_settings, "pass_in_outputs", in_output_names)
cache_in_output_names = _get_func_names(cache_settings, "pass_in_outputs", in_output_names)
entry_param_names = _get_func_names(entry_settings, "pass_params", param_names)
exit_param_names = _get_func_names(exit_settings, "pass_params", param_names)
cache_param_names = _get_func_names(cache_settings, "pass_params", param_names)
# Build a function that selects a parameter tuple
if mode == FactoryMode.Entries:
_0 = "i"
_0 += ", target_shape"
if len(entry_input_names) > 0:
_0 += ", " + ", ".join(entry_input_names)
if len(entry_in_output_names) > 0:
_0 += ", " + ", ".join(entry_in_output_names)
if len(entry_param_names) > 0:
_0 += ", " + ", ".join(entry_param_names)
_0 += ", entry_args"
_0 += ", only_once"
_0 += ", wait"
_1 = "target_shape=target_shape"
_1 += ", place_func_nb=entry_place_func_nb"
_1 += ", place_args=("
if len(entry_input_names) > 0:
_1 += ", ".join(entry_input_names) + ", "
if len(entry_in_output_names) > 0:
_1 += ", ".join(map(lambda x: x + "[i]", entry_in_output_names)) + ", "
if len(entry_param_names) > 0:
_1 += ", ".join(map(lambda x: x + "[i]", entry_param_names)) + ", "
_1 += "*entry_args,)"
_1 += ", only_once=only_once"
_1 += ", wait=wait"
func_str = "def apply_func({0}):\n return generate_func_nb({1})".format(_0, _1)
scope = {"generate_func_nb": generate_func_nb, "entry_place_func_nb": entry_place_func_nb}
elif mode == FactoryMode.Exits:
_0 = "i"
_0 += ", entries"
if len(exit_input_names) > 0:
_0 += ", " + ", ".join(exit_input_names)
if len(exit_in_output_names) > 0:
_0 += ", " + ", ".join(exit_in_output_names)
if len(exit_param_names) > 0:
_0 += ", " + ", ".join(exit_param_names)
_0 += ", exit_args"
_0 += ", wait"
_0 += ", until_next"
_0 += ", skip_until_exit"
_1 = "entries=entries"
_1 += ", exit_place_func_nb=exit_place_func_nb"
_1 += ", exit_place_args=("
if len(exit_input_names) > 0:
_1 += ", ".join(exit_input_names) + ", "
if len(exit_in_output_names) > 0:
_1 += ", ".join(map(lambda x: x + "[i]", exit_in_output_names)) + ", "
if len(exit_param_names) > 0:
_1 += ", ".join(map(lambda x: x + "[i]", exit_param_names)) + ", "
_1 += "*exit_args,)"
_1 += ", wait=wait"
_1 += ", until_next=until_next"
_1 += ", skip_until_exit=skip_until_exit"
func_str = "def apply_func({0}):\n return generate_ex_func_nb({1})".format(_0, _1)
scope = {"generate_ex_func_nb": generate_ex_func_nb, "exit_place_func_nb": exit_place_func_nb}
else:
_0 = "i"
_0 += ", target_shape"
if len(entry_input_names) > 0:
_0 += ", " + ", ".join(map(lambda x: "_entry_" + x, entry_input_names))
if len(entry_in_output_names) > 0:
_0 += ", " + ", ".join(map(lambda x: "_entry_" + x, entry_in_output_names))
if len(entry_param_names) > 0:
_0 += ", " + ", ".join(map(lambda x: "_entry_" + x, entry_param_names))
_0 += ", entry_args"
if len(exit_input_names) > 0:
_0 += ", " + ", ".join(map(lambda x: "_exit_" + x, exit_input_names))
if len(exit_in_output_names) > 0:
_0 += ", " + ", ".join(map(lambda x: "_exit_" + x, exit_in_output_names))
if len(exit_param_names) > 0:
_0 += ", " + ", ".join(map(lambda x: "_exit_" + x, exit_param_names))
_0 += ", exit_args"
_0 += ", entry_wait"
_0 += ", exit_wait"
_1 = "target_shape=target_shape"
_1 += ", entry_place_func_nb=entry_place_func_nb"
_1 += ", entry_place_args=("
if len(entry_input_names) > 0:
_1 += ", ".join(map(lambda x: "_entry_" + x, entry_input_names)) + ", "
if len(entry_in_output_names) > 0:
_1 += ", ".join(map(lambda x: "_entry_" + x + "[i]", entry_in_output_names)) + ", "
if len(entry_param_names) > 0:
_1 += ", ".join(map(lambda x: "_entry_" + x + "[i]", entry_param_names)) + ", "
_1 += "*entry_args,)"
_1 += ", exit_place_func_nb=exit_place_func_nb"
_1 += ", exit_place_args=("
if len(exit_input_names) > 0:
_1 += ", ".join(map(lambda x: "_exit_" + x, exit_input_names)) + ", "
if len(exit_in_output_names) > 0:
_1 += ", ".join(map(lambda x: "_exit_" + x + "[i]", exit_in_output_names)) + ", "
if len(exit_param_names) > 0:
_1 += ", ".join(map(lambda x: "_exit_" + x + "[i]", exit_param_names)) + ", "
_1 += "*exit_args,)"
_1 += ", entry_wait=entry_wait"
_1 += ", exit_wait=exit_wait"
func_str = "def apply_func({0}):\n return generate_enex_func_nb({1})".format(_0, _1)
scope = {
"generate_enex_func_nb": generate_enex_func_nb,
"entry_place_func_nb": entry_place_func_nb,
"exit_place_func_nb": exit_place_func_nb,
}
filename = inspect.getfile(lambda: None)
code = compile(func_str, filename, "single")
exec(code, scope)
apply_func = scope["apply_func"]
if module_name is not None:
apply_func.__module__ = module_name
if jit_apply_func:
jit_kwargs = merge_dicts(dict(nogil=True), jit_kwargs)
apply_func = njit(apply_func, **jit_kwargs)
setattr(Indicator, "apply_func", apply_func)
def custom_func(
input_list: tp.List[tp.AnyArray],
in_output_list: tp.List[tp.List[tp.AnyArray]],
param_list: tp.List[tp.List[tp.ParamValue]],
*args,
input_shape: tp.Optional[tp.Shape] = None,
place_args: tp.ArgsLike = None,
entry_place_args: tp.ArgsLike = None,
exit_place_args: tp.ArgsLike = None,
entry_args: tp.ArgsLike = None,
exit_args: tp.ArgsLike = None,
cache_args: tp.ArgsLike = None,
entry_kwargs: tp.KwargsLike = None,
exit_kwargs: tp.KwargsLike = None,
cache_kwargs: tp.KwargsLike = None,
return_cache: bool = False,
use_cache: tp.Optional[CacheOutputT] = None,
execute_kwargs: tp.KwargsLike = None,
**_kwargs,
) -> tp.Union[CacheOutputT, tp.Array2d, tp.List[tp.Array2d]]:
# Get arguments
if len(input_list) == 0:
if input_shape is None:
raise ValueError("Pass input_shape if no input time series were passed")
else:
input_shape = input_list[0].shape
if len(args) > 0 and place_args is not None:
raise ValueError("Must provide either *args or place_args, not both")
if place_args is None:
place_args = args
if (
mode == FactoryMode.Entries
or mode == FactoryMode.Both
or (mode == FactoryMode.Chain and not default_chain_entry_func)
):
if len(place_args) > 0 and entry_place_args is not None:
raise ValueError("Must provide either place_args or entry_place_args, not both")
if entry_place_args is None:
entry_place_args = place_args
else:
if entry_place_args is None:
entry_place_args = ()
if mode in (FactoryMode.Exits, FactoryMode.Both, FactoryMode.Chain):
if len(place_args) > 0 and exit_place_args is not None:
raise ValueError("Must provide either place_args or exit_place_args, not both")
if exit_place_args is None:
exit_place_args = place_args
else:
if exit_place_args is None:
exit_place_args = ()
if len(entry_place_args) > 0 and entry_args is not None:
raise ValueError("Must provide either entry_place_args or entry_args, not both")
if entry_args is None:
entry_args = entry_place_args
if len(exit_place_args) > 0 and exit_args is not None:
raise ValueError("Must provide either exit_place_args or exit_args, not both")
if exit_args is None:
exit_args = exit_place_args
if cache_args is None:
cache_args = ()
if (
mode == FactoryMode.Entries
or mode == FactoryMode.Both
or (mode == FactoryMode.Chain and not default_chain_entry_func)
):
entry_kwargs = merge_dicts(_kwargs, entry_kwargs)
else:
if entry_kwargs is None:
entry_kwargs = {}
if mode in (FactoryMode.Exits, FactoryMode.Both, FactoryMode.Chain):
exit_kwargs = merge_dicts(_kwargs, exit_kwargs)
else:
if exit_kwargs is None:
exit_kwargs = {}
if cache_kwargs is None:
cache_kwargs = {}
kwargs_defaults = dict(
input_shape=input_shape,
only_once=mode == FactoryMode.Entries,
wait=1,
until_next=True,
skip_until_exit=False,
pick_first=mode != FactoryMode.Entries,
)
entry_kwargs = merge_dicts(kwargs_defaults, entry_kwargs)
exit_kwargs = merge_dicts(kwargs_defaults, exit_kwargs)
cache_kwargs = merge_dicts(kwargs_defaults, cache_kwargs)
only_once = entry_kwargs["only_once"]
entry_wait = entry_kwargs["wait"]
exit_wait = exit_kwargs["wait"]
until_next = exit_kwargs["until_next"]
skip_until_exit = exit_kwargs["skip_until_exit"]
# Distribute arguments across functions
entry_input_list = []
exit_input_list = []
cache_input_list = []
for input_name in entry_input_names:
entry_input_list.append(input_list[input_names.index(input_name)])
for input_name in exit_input_names:
exit_input_list.append(input_list[input_names.index(input_name)])
for input_name in cache_input_names:
cache_input_list.append(input_list[input_names.index(input_name)])
entry_in_output_list = []
exit_in_output_list = []
cache_in_output_list = []
for in_output_name in entry_in_output_names:
entry_in_output_list.append(in_output_list[in_output_names.index(in_output_name)])
for in_output_name in exit_in_output_names:
exit_in_output_list.append(in_output_list[in_output_names.index(in_output_name)])
for in_output_name in cache_in_output_names:
cache_in_output_list.append(in_output_list[in_output_names.index(in_output_name)])
entry_param_list = []
exit_param_list = []
cache_param_list = []
for param_name in entry_param_names:
entry_param_list.append(param_list[param_names.index(param_name)])
for param_name in exit_param_names:
exit_param_list.append(param_list[param_names.index(param_name)])
for param_name in cache_param_names:
cache_param_list.append(param_list[param_names.index(param_name)])
n_params = len(param_list[0]) if len(param_list) > 0 else 1
def _build_more_args(func_settings: tp.Kwargs, func_kwargs: tp.Kwargs) -> tp.Args:
pass_kwargs = func_settings.get("pass_kwargs", [])
if isinstance(pass_kwargs, dict):
pass_kwargs = list(pass_kwargs.items())
more_args = ()
for key in pass_kwargs:
value = None
if isinstance(key, tuple):
key, value = key
else:
if key.startswith("temp_idx_arr"):
value = np.empty((input_shape[0],), dtype=int_)
value = func_kwargs.get(key, value)
more_args += (value,)
return more_args
entry_more_args = _build_more_args(entry_settings, entry_kwargs)
exit_more_args = _build_more_args(exit_settings, exit_kwargs)
cache_more_args = _build_more_args(cache_settings, cache_kwargs)
# Caching
cache = use_cache
if cache is None and cache_func is not None:
_cache_in_output_list = cache_in_output_list
_cache_param_list = cache_param_list
if checks.is_numba_func(cache_func):
_cache_in_output_list = list(map(to_typed_list, cache_in_output_list))
_cache_param_list = list(map(to_typed_list, cache_param_list))
cache = cache_func(
*cache_input_list,
*_cache_in_output_list,
*_cache_param_list,
*cache_args,
*cache_more_args,
)
if return_cache:
return cache
if cache is None:
cache = ()
if not isinstance(cache, tuple):
cache = (cache,)
entry_cache = ()
exit_cache = ()
if entry_settings.get("pass_cache", False):
entry_cache = cache
if exit_settings.get("pass_cache", False):
exit_cache = cache
# Apply and concatenate
if mode == FactoryMode.Entries:
_entry_in_output_list = list(map(to_typed_list, entry_in_output_list))
_entry_param_list = list(map(to_typed_list, entry_param_list))
return combining.apply_and_concat(
n_params,
apply_func,
input_shape,
*entry_input_list,
*_entry_in_output_list,
*_entry_param_list,
entry_args + entry_more_args + entry_cache,
only_once,
entry_wait,
n_outputs=1,
jitted_loop=jit_apply_func,
execute_kwargs=execute_kwargs,
)
elif mode == FactoryMode.Exits:
_exit_in_output_list = list(map(to_typed_list, exit_in_output_list))
_exit_param_list = list(map(to_typed_list, exit_param_list))
return combining.apply_and_concat(
n_params,
apply_func,
input_list[0],
*exit_input_list,
*_exit_in_output_list,
*_exit_param_list,
exit_args + exit_more_args + exit_cache,
exit_wait,
until_next,
skip_until_exit,
n_outputs=1,
jitted_loop=jit_apply_func,
execute_kwargs=execute_kwargs,
)
else:
_entry_in_output_list = list(map(to_typed_list, entry_in_output_list))
_entry_param_list = list(map(to_typed_list, entry_param_list))
_exit_in_output_list = list(map(to_typed_list, exit_in_output_list))
_exit_param_list = list(map(to_typed_list, exit_param_list))
return combining.apply_and_concat(
n_params,
apply_func,
input_shape,
*entry_input_list,
*_entry_in_output_list,
*_entry_param_list,
entry_args + entry_more_args + entry_cache,
*exit_input_list,
*_exit_in_output_list,
*_exit_param_list,
exit_args + exit_more_args + exit_cache,
entry_wait,
exit_wait,
n_outputs=2,
jitted_loop=jit_apply_func,
execute_kwargs=execute_kwargs,
)
return self.with_custom_func(
custom_func,
pass_packed=True,
require_input_shape=require_input_shape,
**kwargs,
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Numba-compiled functions for signals.
Provides an arsenal of Numba-compiled functions that are used by accessors
and in many other parts of the backtesting pipeline, such as technical indicators.
These only accept NumPy arrays and other Numba-compatible types.
!!! note
vectorbt treats matrices as first-class citizens and expects input arrays to be
2-dim, unless function has suffix `_1d` or is meant to be input to another function.
Data is processed along index (axis 0).
All functions passed as argument must be Numba-compiled."""
import numpy as np
from numba import prange
from vectorbtpro import _typing as tp
from vectorbtpro._dtypes import *
from vectorbtpro.base import chunking as base_ch
from vectorbtpro.base.flex_indexing import flex_select_1d_pc_nb, flex_select_nb
from vectorbtpro.generic import nb as generic_nb
from vectorbtpro.generic.enums import range_dt, RangeStatus
from vectorbtpro.records import chunking as records_ch
from vectorbtpro.registries.ch_registry import register_chunkable
from vectorbtpro.registries.jit_registry import register_jitted
from vectorbtpro.signals.enums import *
from vectorbtpro.utils import chunking as ch
from vectorbtpro.utils.array_ import uniform_summing_to_one_nb, rescale_float_to_int_nb, rescale_nb
from vectorbtpro.utils.template import Rep
__all__ = []
# ############# Generation ############# #
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
place_func_nb=None,
place_args=ch.ArgsTaker(),
only_once=None,
wait=None,
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def generate_nb(
target_shape: tp.Shape,
place_func_nb: tp.PlaceFunc,
place_args: tp.Args = (),
only_once: bool = True,
wait: int = 1,
) -> tp.Array2d:
"""Create a boolean matrix of `target_shape` and place signals using `place_func_nb`.
Args:
target_shape (array): Target shape.
place_func_nb (callable): Signal placement function.
`place_func_nb` must accept a context of type `vectorbtpro.signals.enums.GenEnContext`,
and return the index of the last signal (-1 to break the loop).
place_args: Arguments passed to `place_func_nb`.
only_once (bool): Whether to run the placement function only once.
wait (int): Number of ticks to wait before placing the next entry.
!!! note
The first argument is always a 1-dimensional boolean array that contains only those
elements where signals can be placed. The range and column indices only describe which
range this array maps to.
"""
if wait < 0:
raise ValueError("wait must be zero or greater")
out = np.full(target_shape, False, dtype=np.bool_)
for col in prange(target_shape[1]):
from_i = 0
while from_i <= target_shape[0] - 1:
c = GenEnContext(
target_shape=target_shape,
only_once=only_once,
wait=wait,
entries_out=out,
out=out[from_i:, col],
from_i=from_i,
to_i=target_shape[0],
col=col,
)
_last_i = place_func_nb(c, *place_args)
if _last_i == -1:
break
last_i = from_i + _last_i
if last_i < from_i or last_i >= target_shape[0]:
raise ValueError("Last index is out of bounds")
if not out[last_i, col]:
out[last_i, col] = True
if only_once:
break
from_i = last_i + wait
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="entries", axis=1),
arg_take_spec=dict(
entries=ch.ArraySlicer(axis=1),
exit_place_func_nb=None,
exit_place_args=ch.ArgsTaker(),
wait=None,
until_next=None,
skip_until_exit=None,
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def generate_ex_nb(
entries: tp.Array2d,
exit_place_func_nb: tp.PlaceFunc,
exit_place_args: tp.Args = (),
wait: int = 1,
until_next: bool = True,
skip_until_exit: bool = False,
) -> tp.Array2d:
"""Place exit signals using `exit_place_func_nb` after each signal in `entries`.
Args:
entries (array): Boolean array with entry signals.
exit_place_func_nb (callable): Exit place function.
`exit_place_func_nb` must accept a context of type `vectorbtpro.signals.enums.GenExContext`,
and return the index of the last signal (-1 to break the loop).
exit_place_args (callable): Arguments passed to `exit_place_func_nb`.
wait (int): Number of ticks to wait before placing exits.
!!! note
Setting `wait` to 0 or False may result in two signals at one bar.
until_next (int): Whether to place signals up to the next entry signal.
!!! note
Setting it to False makes it difficult to tell which exit belongs to which entry.
skip_until_exit (bool): Whether to skip processing entry signals until the next exit.
Has only effect when `until_next` is disabled.
!!! note
Setting it to True makes it impossible to tell which exit belongs to which entry.
"""
if wait < 0:
raise ValueError("wait must be zero or greater")
out = np.full_like(entries, False)
def _place_exits(from_i, to_i, col, last_exit_i):
if from_i > -1:
if skip_until_exit and from_i <= last_exit_i:
return last_exit_i
from_i += wait
if not until_next:
to_i = entries.shape[0]
if to_i > from_i:
c = GenExContext(
entries=out,
until_next=until_next,
skip_until_exit=skip_until_exit,
exits_out=out,
out=out[from_i:to_i, col],
wait=wait,
from_i=from_i,
to_i=to_i,
col=col,
)
_last_exit_i = exit_place_func_nb(c, *exit_place_args)
if _last_exit_i != -1:
last_exit_i = from_i + _last_exit_i
if last_exit_i < from_i or last_exit_i >= entries.shape[0]:
raise ValueError("Last index is out of bounds")
if not out[last_exit_i, col]:
out[last_exit_i, col] = True
elif skip_until_exit:
last_exit_i = -1
return last_exit_i
for col in prange(entries.shape[1]):
from_i = -1
last_exit_i = -1
should_stop = False
for i in range(entries.shape[0]):
if entries[i, col]:
last_exit_i = _place_exits(from_i, i, col, last_exit_i)
if skip_until_exit and last_exit_i == -1 and from_i != -1:
should_stop = True
break
from_i = i
if should_stop:
continue
last_exit_i = _place_exits(from_i, entries.shape[0], col, last_exit_i)
return out
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
entry_place_func_nb=None,
entry_place_args=ch.ArgsTaker(),
exit_place_func_nb=None,
exit_place_args=ch.ArgsTaker(),
entry_wait=None,
exit_wait=None,
),
merge_func="column_stack",
)
@register_jitted
def generate_enex_nb(
target_shape: tp.Shape,
entry_place_func_nb: tp.PlaceFunc,
entry_place_args: tp.Args,
exit_place_func_nb: tp.PlaceFunc,
exit_place_args: tp.Args,
entry_wait: int = 1,
exit_wait: int = 1,
) -> tp.Tuple[tp.Array2d, tp.Array2d]:
"""Place entry signals using `entry_place_func_nb` and exit signals using
`exit_place_func_nb` one after another.
Args:
target_shape (array): Target shape.
entry_place_func_nb (callable): Entry place function.
`entry_place_func_nb` must accept a context of type `vectorbtpro.signals.enums.GenEnExContext`,
and return the index of the last signal (-1 to break the loop).
entry_place_args (tuple): Arguments unpacked and passed to `entry_place_func_nb`.
exit_place_func_nb (callable): Exit place function.
`exit_place_func_nb` must accept a context of type `vectorbtpro.signals.enums.GenEnExContext`,
and return the index of the last signal (-1 to break the loop).
exit_place_args (tuple): Arguments unpacked and passed to `exit_place_func_nb`.
entry_wait (int): Number of ticks to wait before placing entries.
!!! note
Setting `entry_wait` to 0 or False assumes that both entry and exit can be processed
within the same bar, and exit can be processed before entry.
exit_wait (int): Number of ticks to wait before placing exits.
!!! note
Setting `exit_wait` to 0 or False assumes that both entry and exit can be processed
within the same bar, and entry can be processed before exit.
"""
if entry_wait < 0:
raise ValueError("entry_wait must be zero or greater")
if exit_wait < 0:
raise ValueError("exit_wait must be zero or greater")
if entry_wait == 0 and exit_wait == 0:
raise ValueError("entry_wait and exit_wait cannot be both 0")
entries = np.full(target_shape, False)
exits = np.full(target_shape, False)
def _place_signals(entries_turn, out, from_i, col, wait, place_func_nb, args):
to_i = target_shape[0]
if to_i > from_i:
c = GenEnExContext(
target_shape=target_shape,
entry_wait=entry_wait,
exit_wait=exit_wait,
entries_out=entries,
exits_out=exits,
entries_turn=entries_turn,
out=out[from_i:to_i, col],
wait=wait if from_i > 0 else 0,
from_i=from_i,
to_i=to_i,
col=col,
)
_last_i = place_func_nb(c, *args)
if _last_i == -1:
return -1
last_i = from_i + _last_i
if last_i < from_i or last_i >= target_shape[0]:
raise ValueError("Last index is out of bounds")
if not out[last_i, col]:
out[last_i, col] = True
return last_i
return -1
for col in range(target_shape[1]):
from_i = 0
entries_turn = True
first_signal = True
while from_i != -1:
if entries_turn:
if not first_signal:
from_i += entry_wait
from_i = _place_signals(
entries_turn,
entries,
from_i,
col,
entry_wait,
entry_place_func_nb,
entry_place_args,
)
entries_turn = False
else:
from_i += exit_wait
from_i = _place_signals(
entries_turn,
exits,
from_i,
col,
exit_wait,
exit_place_func_nb,
exit_place_args,
)
entries_turn = True
first_signal = False
return entries, exits
# ############# Random signals ############# #
@register_jitted
def rand_place_nb(c: tp.Union[GenEnContext, GenExContext, GenEnExContext], n: tp.FlexArray1d) -> int:
"""`place_func_nb` to randomly pick `n` values.
`n` uses flexible indexing."""
size = min(c.to_i - c.from_i, flex_select_1d_pc_nb(n, c.col))
k = 0
last_i = -1
while k < size:
i = np.random.choice(len(c.out))
if not c.out[i]:
c.out[i] = True
k += 1
if i > last_i:
last_i = i
return last_i
@register_jitted
def rand_by_prob_place_nb(
c: tp.Union[GenEnContext, GenExContext, GenEnExContext],
prob: tp.FlexArray2d,
pick_first: bool = False,
) -> int:
"""`place_func_nb` to randomly place signals with probability `prob`.
`prob` uses flexible indexing."""
last_i = -1
for i in range(c.from_i, c.to_i):
if np.random.uniform(0, 1) < flex_select_nb(prob, i, c.col):
c.out[i - c.from_i] = True
last_i = i - c.from_i
if pick_first:
break
return last_i
@register_chunkable(
size=ch.ShapeSizer(arg_query="target_shape", axis=1),
arg_take_spec=dict(
target_shape=ch.ShapeSlicer(axis=1),
n=base_ch.FlexArraySlicer(),
entry_wait=None,
exit_wait=None,
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def generate_rand_enex_nb(
target_shape: tp.Shape,
n: tp.FlexArray1d,
entry_wait: int = 1,
exit_wait: int = 1,
) -> tp.Tuple[tp.Array2d, tp.Array2d]:
"""Pick a number of entries and the same number of exits one after another.
Respects `entry_wait` and `exit_wait` constraints through a number of tricks.
Tries to mimic a uniform distribution as much as possible.
The idea is the following: with constraints, there is some fixed amount of total
space required between first entry and last exit. Upscale this space in a way that
distribution of entries and exit is similar to a uniform distribution. This means
randomizing the position of first entry, last exit, and all signals between them.
`n` uses flexible indexing and thus must be at least a 0-dim array."""
entries = np.full(target_shape, False)
exits = np.full(target_shape, False)
if entry_wait == 0 and exit_wait == 0:
raise ValueError("entry_wait and exit_wait cannot be both 0")
if entry_wait == 1 and exit_wait == 1:
# Basic case
both = generate_nb(target_shape, rand_place_nb, (n * 2,), only_once=True, wait=1)
for col in prange(both.shape[1]):
both_idxs = np.flatnonzero(both[:, col])
entries[both_idxs[0::2], col] = True
exits[both_idxs[1::2], col] = True
else:
for col in prange(target_shape[1]):
_n = flex_select_1d_pc_nb(n, col)
if _n == 1:
entry_idx = np.random.randint(0, target_shape[0] - exit_wait)
entries[entry_idx, col] = True
else:
# Minimum range between two entries
min_range = entry_wait + exit_wait
# Minimum total range between first and last entry
min_total_range = min_range * (_n - 1)
if target_shape[0] < min_total_range + exit_wait + 1:
raise ValueError("Cannot take a larger sample than population")
# We should decide how much space should be allocate before first and after last entry
# Maximum space outside of min_total_range
max_free_space = target_shape[0] - min_total_range - 1
# If min_total_range is tiny compared to max_free_space, limit it
# otherwise we would have huge space before first and after last entry
# Limit it such as distribution of entries mimics uniform
free_space = min(max_free_space, 3 * target_shape[0] // (_n + 1))
# What about last exit? it requires exit_wait space
free_space -= exit_wait
# Now we need to distribute free space among three ranges:
# 1) before first, 2) between first and last added to min_total_range, 3) after last
# We do 2) such that min_total_range can freely expand to maximum
# We allocate twice as much for 3) as for 1) because an exit is missing
rand_floats = uniform_summing_to_one_nb(6)
chosen_spaces = rescale_float_to_int_nb(rand_floats, (0, free_space), free_space)
first_idx = chosen_spaces[0]
last_idx = target_shape[0] - np.sum(chosen_spaces[-2:]) - exit_wait - 1
# Selected range between first and last entry
total_range = last_idx - first_idx
# Maximum range between two entries within total_range
max_range = total_range - (_n - 2) * min_range
# Select random ranges within total_range
rand_floats = uniform_summing_to_one_nb(_n - 1)
chosen_ranges = rescale_float_to_int_nb(rand_floats, (min_range, max_range), total_range)
# Translate them into entries
entry_idxs = np.empty(_n, dtype=int_)
entry_idxs[0] = first_idx
entry_idxs[1:] = chosen_ranges
entry_idxs = np.cumsum(entry_idxs)
entries[entry_idxs, col] = True
# Generate exits
for col in range(target_shape[1]):
entry_idxs = np.flatnonzero(entries[:, col])
for j in range(len(entry_idxs)):
entry_i = entry_idxs[j] + exit_wait
if j < len(entry_idxs) - 1:
exit_i = entry_idxs[j + 1] - entry_wait
else:
exit_i = entries.shape[0] - 1
i = np.random.randint(exit_i - entry_i + 1)
exits[entry_i + i, col] = True
return entries, exits
def rand_enex_apply_nb(
target_shape: tp.Shape,
n: tp.FlexArray1d,
entry_wait: int = 1,
exit_wait: int = 1,
) -> tp.Tuple[tp.Array2d, tp.Array2d]:
"""`apply_func_nb` that calls `generate_rand_enex_nb`."""
return generate_rand_enex_nb(target_shape, n, entry_wait=entry_wait, exit_wait=exit_wait)
# ############# Stop signals ############# #
@register_jitted
def first_place_nb(c: tp.Union[GenEnContext, GenExContext, GenEnExContext], mask: tp.Array2d) -> int:
"""`place_func_nb` that keeps only the first signal in `mask`."""
last_i = -1
for i in range(c.from_i, c.to_i):
if mask[i, c.col]:
c.out[i - c.from_i] = True
last_i = i - c.from_i
break
return last_i
@register_jitted
def stop_place_nb(
c: tp.Union[GenExContext, GenEnExContext],
entry_ts: tp.FlexArray2d,
ts: tp.FlexArray2d,
follow_ts: tp.FlexArray2d,
stop_ts_out: tp.Array2d,
stop: tp.FlexArray2d,
trailing: tp.FlexArray2d,
) -> int:
"""`place_func_nb` that places an exit signal whenever a threshold is being hit.
!!! note
Waiting time cannot be higher than 1.
If waiting time is 0, `entry_ts` should be the first value in the bar.
If waiting time is 1, `entry_ts` should be the last value in the bar.
Args:
c (GenExContext or GenEnExContext): Signal context.
entry_ts (array of float): Entry price.
Utilizes flexible indexing.
ts (array of float): Price to compare the stop value against.
Utilizes flexible indexing. If NaN, defaults to `entry_ts`.
follow_ts (array of float): Following price.
Utilizes flexible indexing. If NaN, defaults to `ts`. Applied only if the stop is trailing.
stop_ts_out (array of float): Array where hit price of each exit will be stored.
Must be of the full shape.
stop (array of float): Stop value.
Utilizes flexible indexing. Set an element to `np.nan` to disable it.
trailing (array of bool): Whether the stop is trailing.
Utilizes flexible indexing. Set an element to False to disable it.
"""
if c.wait > 1:
raise ValueError("Wait must be either 0 or 1")
init_i = c.from_i - c.wait
init_entry_ts = flex_select_nb(entry_ts, init_i, c.col)
init_stop = flex_select_nb(stop, init_i, c.col)
if init_stop == 0:
init_stop = np.nan
init_trailing = flex_select_nb(trailing, init_i, c.col)
max_high = min_low = init_entry_ts
last_i = -1
for i in range(c.from_i, c.to_i):
curr_entry_ts = flex_select_nb(entry_ts, i, c.col)
curr_ts = flex_select_nb(ts, i, c.col)
curr_follow_ts = flex_select_nb(follow_ts, i, c.col)
if np.isnan(curr_ts):
curr_ts = curr_entry_ts
if np.isnan(curr_follow_ts):
if not np.isnan(curr_entry_ts):
if init_stop >= 0:
curr_follow_ts = min(curr_entry_ts, curr_ts)
else:
curr_follow_ts = max(curr_entry_ts, curr_ts)
else:
curr_follow_ts = curr_ts
if not np.isnan(init_stop):
if init_trailing:
if init_stop >= 0:
# Trailing stop buy
curr_stop_price = min_low * (1 + abs(init_stop))
else:
# Trailing stop sell
curr_stop_price = max_high * (1 - abs(init_stop))
else:
curr_stop_price = init_entry_ts * (1 + init_stop)
# Check if stop price is within bar
if not np.isnan(init_stop):
if init_stop >= 0:
exit_signal = curr_ts >= curr_stop_price
else:
exit_signal = curr_ts <= curr_stop_price
if exit_signal:
stop_ts_out[i, c.col] = curr_stop_price
c.out[i - c.from_i] = True
last_i = i - c.from_i
break
# Keep track of lowest low and highest high if trailing
if init_trailing:
if curr_follow_ts < min_low:
min_low = curr_follow_ts
elif curr_follow_ts > max_high:
max_high = curr_follow_ts
return last_i
@register_jitted
def ohlc_stop_place_nb(
c: tp.Union[GenExContext, GenEnExContext],
entry_price: tp.FlexArray2d,
open: tp.FlexArray2d,
high: tp.FlexArray2d,
low: tp.FlexArray2d,
close: tp.FlexArray2d,
stop_price_out: tp.Array2d,
stop_type_out: tp.Array2d,
sl_stop: tp.FlexArray2d,
tsl_th: tp.FlexArray2d,
tsl_stop: tp.FlexArray2d,
tp_stop: tp.FlexArray2d,
reverse: tp.FlexArray2d,
is_entry_open: bool = False,
) -> int:
"""`place_func_nb` that places an exit signal whenever a threshold is being hit using OHLC.
Compared to `stop_place_nb`, takes into account the whole bar, can check for both
(trailing) stop loss and take profit simultaneously, and tracks hit price and stop type.
!!! note
Waiting time cannot be higher than 1.
Args:
c (GenExContext or GenEnExContext): Signal context.
entry_price (array of float): Entry price.
Utilizes flexible indexing.
open (array of float): Open price.
Utilizes flexible indexing. If Nan and `is_entry_open` is True, defaults to entry price.
high (array of float): High price.
Utilizes flexible indexing. If NaN, gets calculated from open and close.
low (array of float): Low price.
Utilizes flexible indexing. If NaN, gets calculated from open and close.
close (array of float): Close price.
Utilizes flexible indexing. If Nan and `is_entry_open` is False, defaults to entry price.
stop_price_out (array of float): Array where hit price of each exit will be stored.
Must be of the full shape.
stop_type_out (array of int): Array where stop type of each exit will be stored.
Must be of the full shape. 0 for stop loss, 1 for take profit.
sl_stop (array of float): Stop loss as a percentage.
Utilizes flexible indexing. Set an element to `np.nan` to disable.
tsl_th (array of float): Take profit threshold as a percentage for the trailing stop loss.
Utilizes flexible indexing. Set an element to `np.nan` to disable.
tsl_stop (array of float): Trailing stop loss as a percentage for the trailing stop loss.
Utilizes flexible indexing. Set an element to `np.nan` to disable.
tp_stop (array of float): Take profit as a percentage.
Utilizes flexible indexing. Set an element to `np.nan` to disable.
reverse (array of float): Whether to do the opposite, i.e.: prices are followed downwards.
Utilizes flexible indexing.
is_entry_open (bool): Whether entry price comes right at or before open.
If True, uses high and low of the entry bar. Otherwise, uses only close.
"""
if c.wait > 1:
raise ValueError("Wait must be either 0 or 1")
init_i = c.from_i - c.wait
init_entry_price = flex_select_nb(entry_price, init_i, c.col)
init_sl_stop = abs(flex_select_nb(sl_stop, init_i, c.col))
init_tp_stop = abs(flex_select_nb(tp_stop, init_i, c.col))
init_tsl_th = abs(flex_select_nb(tsl_th, init_i, c.col))
init_tsl_stop = abs(flex_select_nb(tsl_stop, init_i, c.col))
init_reverse = flex_select_nb(reverse, init_i, c.col)
last_high = last_low = init_entry_price
last_i = -1
for i in range(c.from_i - c.wait, c.to_i):
# Resolve current bar
_entry_price = flex_select_nb(entry_price, i, c.col)
_open = flex_select_nb(open, i, c.col)
_high = flex_select_nb(high, i, c.col)
_low = flex_select_nb(low, i, c.col)
_close = flex_select_nb(close, i, c.col)
if np.isnan(_open) and not np.isnan(_entry_price) and is_entry_open:
_open = _entry_price
if np.isnan(_close) and not np.isnan(_entry_price) and not is_entry_open:
_close = _entry_price
if np.isnan(_high):
if np.isnan(_open):
_high = _close
elif np.isnan(_close):
_high = _open
else:
_high = max(_open, _close)
if np.isnan(_low):
if np.isnan(_open):
_low = _close
elif np.isnan(_close):
_low = _open
else:
_low = min(_open, _close)
if i > init_i or is_entry_open:
curr_high = _high
curr_low = _low
else:
curr_high = curr_low = _close
if i >= c.from_i:
# Calculate stop prices
if not np.isnan(init_sl_stop):
if init_reverse:
curr_sl_stop_price = init_entry_price * (1 + init_sl_stop)
else:
curr_sl_stop_price = init_entry_price * (1 - init_sl_stop)
if not np.isnan(init_tsl_stop):
if np.isnan(init_tsl_th):
if init_reverse:
curr_tsl_stop_price = last_low * (1 + init_tsl_stop)
else:
curr_tsl_stop_price = last_high * (1 - init_tsl_stop)
else:
if init_reverse:
if last_low <= init_entry_price * (1 - init_tsl_th):
curr_tsl_stop_price = last_low * (1 + init_tsl_stop)
else:
curr_tsl_stop_price = np.nan
else:
if last_high >= init_entry_price * (1 + init_tsl_th):
curr_tsl_stop_price = last_high * (1 - init_tsl_stop)
else:
curr_tsl_stop_price = np.nan
if not np.isnan(init_tp_stop):
if init_reverse:
curr_tp_stop_price = init_entry_price * (1 - init_tp_stop)
else:
curr_tp_stop_price = init_entry_price * (1 + init_tp_stop)
# Check if stop price is within bar
exit_signal = False
if not np.isnan(init_sl_stop):
# SL hit?
stop_price = np.nan
if not init_reverse:
if _open <= curr_sl_stop_price:
stop_price = _open
if curr_low <= curr_sl_stop_price:
stop_price = curr_sl_stop_price
else:
if _open >= curr_sl_stop_price:
stop_price = _open
if curr_high >= curr_sl_stop_price:
stop_price = curr_sl_stop_price
if not np.isnan(stop_price):
stop_price_out[i, c.col] = stop_price
stop_type_out[i, c.col] = StopType.SL
exit_signal = True
if not exit_signal and not np.isnan(init_tsl_stop):
# TSL/TTP hit?
stop_price = np.nan
if not init_reverse:
if _open <= curr_tsl_stop_price:
stop_price = _open
if curr_low <= curr_tsl_stop_price:
stop_price = curr_tsl_stop_price
else:
if _open >= curr_tsl_stop_price:
stop_price = _open
if curr_high >= curr_tsl_stop_price:
stop_price = curr_tsl_stop_price
if not np.isnan(stop_price):
stop_price_out[i, c.col] = stop_price
if np.isnan(init_tsl_th):
stop_type_out[i, c.col] = StopType.TSL
else:
stop_type_out[i, c.col] = StopType.TTP
exit_signal = True
if not exit_signal and not np.isnan(init_tp_stop):
# TP hit?
stop_price = np.nan
if not init_reverse:
if _open >= curr_tp_stop_price:
stop_price = _open
if curr_high >= curr_tp_stop_price:
stop_price = curr_tp_stop_price
else:
if _open <= curr_tp_stop_price:
stop_price = _open
if curr_low <= curr_tp_stop_price:
stop_price = curr_tp_stop_price
if not np.isnan(stop_price):
stop_price_out[i, c.col] = stop_price
stop_type_out[i, c.col] = StopType.TP
exit_signal = True
if exit_signal:
c.out[i - c.from_i] = True
last_i = i - c.from_i
break
if i > init_i or is_entry_open:
# Keep track of the lowest low and the highest high
if curr_low < last_low:
last_low = curr_low
if curr_high > last_high:
last_high = curr_high
return last_i
# ############# Ranking ############# #
@register_chunkable(
size=ch.ArraySizer(arg_query="mask", axis=1),
arg_take_spec=dict(
mask=ch.ArraySlicer(axis=1),
rank_func_nb=None,
rank_args=ch.ArgsTaker(),
reset_by=None,
after_false=None,
after_reset=None,
reset_wait=None,
),
merge_func="column_stack",
)
@register_jitted(tags={"can_parallel"})
def rank_nb(
mask: tp.Array2d,
rank_func_nb: tp.RankFunc,
rank_args: tp.Args = (),
reset_by: tp.Optional[tp.Array2d] = None,
after_false: bool = False,
after_reset: bool = False,
reset_wait: int = 1,
) -> tp.Array2d:
"""Rank each signal using `rank_func_nb`.
Applies `rank_func_nb` on each True value. Must accept a context of type
`vectorbtpro.signals.enums.RankContext`. Must return -1 for no rank, otherwise 0 or greater.
Setting `after_false` to True will disregard the first partition of True values
if there is no False value before them. Setting `after_reset` to True will disregard
the first partition of True values coming before the first reset signal. Setting `reset_wait`
to 0 will treat the signal at the same position as the reset signal as the first signal in
the next partition. Setting it to 1 will treat it as the last signal in the previous partition."""
out = np.full(mask.shape, -1, dtype=int_)
for col in prange(mask.shape[1]):
in_partition = False
false_seen = not after_false
reset_seen = reset_by is None
last_false_i = -1
last_reset_i = -1
all_sig_cnt = 0
all_part_cnt = 0
all_sig_in_part_cnt = 0
nonres_sig_cnt = 0
nonres_part_cnt = 0
nonres_sig_in_part_cnt = 0
sig_cnt = 0
part_cnt = 0
sig_in_part_cnt = 0
for i in range(mask.shape[0]):
if reset_by is not None and reset_by[i, col]:
last_reset_i = i
if last_reset_i > -1 and i - last_reset_i == reset_wait:
reset_seen = True
sig_cnt = 0
part_cnt = 0
sig_in_part_cnt = 0
if mask[i, col]:
all_sig_cnt += 1
if i == 0 or not mask[i - 1, col]:
all_part_cnt += 1
all_sig_in_part_cnt += 1
if not (after_false and not false_seen) and not (after_reset and not reset_seen):
nonres_sig_cnt += 1
sig_cnt += 1
if not in_partition:
nonres_part_cnt += 1
part_cnt += 1
elif last_reset_i > -1 and i - last_reset_i == reset_wait:
part_cnt += 1
nonres_sig_in_part_cnt += 1
sig_in_part_cnt += 1
in_partition = True
c = RankContext(
mask=mask,
reset_by=reset_by,
after_false=after_false,
after_reset=after_reset,
reset_wait=reset_wait,
col=col,
i=i,
last_false_i=last_false_i,
last_reset_i=last_reset_i,
all_sig_cnt=all_sig_cnt,
all_part_cnt=all_part_cnt,
all_sig_in_part_cnt=all_sig_in_part_cnt,
nonres_sig_cnt=nonres_sig_cnt,
nonres_part_cnt=nonres_part_cnt,
nonres_sig_in_part_cnt=nonres_sig_in_part_cnt,
sig_cnt=sig_cnt,
part_cnt=part_cnt,
sig_in_part_cnt=sig_in_part_cnt,
)
out[i, col] = rank_func_nb(c, *rank_args)
else:
all_sig_in_part_cnt = 0
nonres_sig_in_part_cnt = 0
sig_in_part_cnt = 0
last_false_i = i
in_partition = False
false_seen = True
return out
@register_jitted
def sig_pos_rank_nb(c: RankContext, allow_gaps: bool) -> int:
"""`rank_func_nb` that returns the rank of each signal by its position in the partition
if `allow_gaps` is False, otherwise globally.
Resets at each reset signal."""
if allow_gaps:
return c.sig_cnt - 1
return c.sig_in_part_cnt - 1
@register_jitted
def part_pos_rank_nb(c: RankContext) -> int:
"""`rank_func_nb` that returns the rank of each partition by its position in the series.
Resets at each reset signal."""
return c.part_cnt - 1
# ############# Distance ############# #
@register_jitted(cache=True)
def distance_from_last_1d_nb(mask: tp.Array1d, nth: int = 1) -> tp.Array1d:
"""Distance from the last n-th True value to the current value.
Unless `nth` is zero, the current True value isn't counted as one of the last True values."""
if nth < 0:
raise ValueError("nth must be at least 0")
out = np.empty(mask.shape, dtype=int_)
last_indices = np.empty(mask.shape, dtype=int_)
k = 0
for i in range(mask.shape[0]):
if nth == 0:
if mask[i]:
last_indices[k] = i
k += 1
if k - 1 < 0:
out[i] = -1
else:
out[i] = i - last_indices[k - 1]
elif nth == 1:
if k - nth < 0:
out[i] = -1
else:
out[i] = i - last_indices[k - nth]
if mask[i]:
last_indices[k] = i
k += 1
else:
if mask[i]:
last_indices[k] = i
k += 1
if k - nth < 0:
out[i] = -1
else:
out[i] = i - last_indices[k - nth]
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="mask", axis=1),
arg_take_spec=dict(mask=ch.ArraySlicer(axis=1), nth=None),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def distance_from_last_nb(mask: tp.Array2d, nth: int = 1) -> tp.Array2d:
"""2-dim version of `distance_from_last_1d_nb`."""
out = np.empty(mask.shape, dtype=int_)
for col in prange(mask.shape[1]):
out[:, col] = distance_from_last_1d_nb(mask[:, col], nth=nth)
return out
# ############# Cleaning ############# #
@register_jitted(cache=True)
def clean_enex_1d_nb(
entries: tp.Array1d,
exits: tp.Array1d,
force_first: bool = True,
keep_conflicts: bool = False,
reverse_order: bool = False,
) -> tp.Tuple[tp.Array1d, tp.Array1d]:
"""Clean entry and exit arrays by picking the first signal out of each.
Set `force_first` to True to force placing the first entry/exit before the first exit/entry.
Set `keep_conflicts` to True to process signals at the same timestamp sequentially instead of removing them.
Set `reverse_order` to True to reverse the order of signals."""
entries_out = np.full(entries.shape, False, dtype=np.bool_)
exits_out = np.full(exits.shape, False, dtype=np.bool_)
def _process_entry(i, phase):
if ((not force_first or not reverse_order) and phase == -1) or phase == 1:
phase = 0
entries_out[i] = True
return phase
def _process_exit(i, phase):
if ((not force_first or reverse_order) and phase == -1) or phase == 0:
phase = 1
exits_out[i] = True
return phase
phase = -1
for i in range(entries.shape[0]):
if entries[i] and exits[i]:
if keep_conflicts:
if not reverse_order:
phase = _process_entry(i, phase)
phase = _process_exit(i, phase)
else:
phase = _process_exit(i, phase)
phase = _process_entry(i, phase)
elif entries[i]:
phase = _process_entry(i, phase)
elif exits[i]:
phase = _process_exit(i, phase)
return entries_out, exits_out
@register_chunkable(
size=ch.ArraySizer(arg_query="entries", axis=1),
arg_take_spec=dict(
entries=ch.ArraySlicer(axis=1),
exits=ch.ArraySlicer(axis=1),
force_first=None,
keep_conflicts=None,
reverse_order=None,
),
merge_func="column_stack",
)
@register_jitted(cache=True, tags={"can_parallel"})
def clean_enex_nb(
entries: tp.Array2d,
exits: tp.Array2d,
force_first: bool = True,
keep_conflicts: bool = False,
reverse_order: bool = False,
) -> tp.Tuple[tp.Array2d, tp.Array2d]:
"""2-dim version of `clean_enex_1d_nb`."""
entries_out = np.empty(entries.shape, dtype=np.bool_)
exits_out = np.empty(exits.shape, dtype=np.bool_)
for col in prange(entries.shape[1]):
entries_out[:, col], exits_out[:, col] = clean_enex_1d_nb(
entries[:, col],
exits[:, col],
force_first=force_first,
keep_conflicts=keep_conflicts,
reverse_order=reverse_order,
)
return entries_out, exits_out
# ############# Relation ############# #
@register_jitted(cache=True)
def relation_idxs_1d_nb(
source_mask: tp.Array1d,
target_mask: tp.Array1d,
relation: int = SignalRelation.OneMany,
) -> tp.Tuple[tp.Array1d, tp.Array1d, tp.Array1d, tp.Array1d]:
"""Get index pairs of True values between a source and target mask.
For `relation`, see `vectorbtpro.signals.enums.SignalRelation`.
!!! note
If both True values happen at the same time, source signal is assumed to come first."""
if relation == SignalRelation.Chain or relation == SignalRelation.AnyChain:
max_signals = source_mask.shape[0] * 2
else:
max_signals = source_mask.shape[0]
source_range_out = np.full(max_signals, -1, dtype=int_)
target_range_out = np.full(max_signals, -1, dtype=int_)
source_idxs_out = np.full(max_signals, -1, dtype=int_)
target_idxs_out = np.full(max_signals, -1, dtype=int_)
source_j = -1
target_j = -1
k = 0
if relation == SignalRelation.OneOne:
fill_k = -1
for i in range(source_mask.shape[0]):
if source_mask[i]:
source_j += 1
if target_mask[i]:
target_j += 1
if source_mask[i]:
source_range_out[k] = source_j
source_idxs_out[k] = i
if fill_k == -1:
fill_k = k
k += 1
if target_mask[i]:
if fill_k == -1:
target_range_out[k] = target_j
target_idxs_out[k] = i
k += 1
else:
target_range_out[fill_k] = target_j
target_idxs_out[fill_k] = i
fill_k += 1
if fill_k == k:
fill_k = -1
elif relation == SignalRelation.OneMany:
source_idx = -1
source_placed = True
for i in range(source_mask.shape[0]):
if source_mask[i]:
source_j += 1
if target_mask[i]:
target_j += 1
if source_mask[i]:
if not source_placed:
source_range_out[k] = source_j
source_idxs_out[k] = source_idx
k += 1
source_idx = i
source_placed = False
if target_mask[i]:
if source_idx == -1:
target_range_out[k] = target_j
target_idxs_out[k] = i
k += 1
else:
source_range_out[k] = source_j
target_range_out[k] = target_j
source_idxs_out[k] = source_idx
target_idxs_out[k] = i
k += 1
source_placed = True
if not source_placed:
source_range_out[k] = source_j
source_idxs_out[k] = source_idx
k += 1
elif relation == SignalRelation.ManyOne:
target_idx = -1
target_placed = True
for i in range(source_mask.shape[0] - 1, -1, -1):
if source_mask[i]:
source_j += 1
if target_mask[i]:
target_j += 1
if target_mask[i]:
if not target_placed:
target_range_out[k] = target_j
target_idxs_out[k] = target_idx
k += 1
target_idx = i
target_placed = False
if source_mask[i]:
if target_idx == -1:
source_range_out[k] = source_j
source_idxs_out[k] = i
k += 1
else:
source_range_out[k] = source_j
target_range_out[k] = target_j
source_idxs_out[k] = i
target_idxs_out[k] = target_idx
k += 1
target_placed = True
if not target_placed:
target_range_out[k] = target_j
target_idxs_out[k] = target_idx
k += 1
source_range_out[:k] = source_range_out[:k][::-1]
target_range_out[:k] = target_range_out[:k][::-1]
for _k in range(k):
source_range_out[_k] = source_j - source_range_out[_k]
target_range_out[_k] = target_j - target_range_out[_k]
source_idxs_out[:k] = source_idxs_out[:k][::-1]
target_idxs_out[:k] = target_idxs_out[:k][::-1]
elif relation == SignalRelation.ManyMany:
source_idx = -1
from_k = -1
for i in range(source_mask.shape[0]):
if source_mask[i]:
source_j += 1
if target_mask[i]:
target_j += 1
if source_mask[i]:
source_idx = i
source_range_out[k] = source_j
source_idxs_out[k] = source_idx
if from_k == -1:
from_k = k
k += 1
if target_mask[i]:
if from_k == -1:
if source_idx != -1:
source_range_out[k] = source_j
source_idxs_out[k] = source_idx
target_range_out[k] = target_j
target_idxs_out[k] = i
k += 1
else:
for _k in range(from_k, k):
target_range_out[_k] = target_j
target_idxs_out[_k] = i
from_k = -1
elif relation == SignalRelation.Chain:
source_idx = -1
target_idx = -1
any_placed = False
for i in range(source_mask.shape[0]):
if source_mask[i]:
source_j += 1
if target_mask[i]:
target_j += 1
if source_mask[i]:
if source_idx == -1:
source_idx = i
source_range_out[k] = source_j
source_idxs_out[k] = source_idx
any_placed = True
if target_idx != -1:
target_idx = -1
k += 1
source_range_out[k] = source_j
source_idxs_out[k] = source_idx
if target_mask[i]:
if target_idx == -1 and source_idx != -1:
target_idx = i
target_range_out[k] = target_j
target_idxs_out[k] = target_idx
source_idx = -1
k += 1
target_range_out[k] = target_j
target_idxs_out[k] = target_idx
any_placed = True
if any_placed:
k += 1
elif relation == SignalRelation.AnyChain:
source_idx = -1
target_idx = -1
any_placed = False
for i in range(source_mask.shape[0]):
if source_mask[i]:
source_j += 1
if target_mask[i]:
target_j += 1
if source_mask[i]:
if source_idx == -1:
source_idx = i
source_range_out[k] = source_j
source_idxs_out[k] = source_idx
any_placed = True
if target_idx != -1:
target_idx = -1
k += 1
source_range_out[k] = source_j
source_idxs_out[k] = source_idx
if target_mask[i]:
if target_idx == -1:
target_idx = i
target_range_out[k] = target_j
target_idxs_out[k] = target_idx
any_placed = True
if source_idx != -1:
source_idx = -1
k += 1
target_range_out[k] = target_j
target_idxs_out[k] = target_idx
if any_placed:
k += 1
else:
raise ValueError("Invalid SignalRelation option")
return source_range_out[:k], target_range_out[:k], source_idxs_out[:k], target_idxs_out[:k]
# ############# Ranges ############# #
@register_chunkable(
size=ch.ArraySizer(arg_query="mask", axis=1),
arg_take_spec=dict(mask=ch.ArraySlicer(axis=1), incl_open=None),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def between_ranges_nb(mask: tp.Array2d, incl_open: bool = False) -> tp.RecordArray:
"""Create a record of type `vectorbtpro.generic.enums.range_dt` for each range between two signals in `mask`."""
new_records = np.empty(mask.shape, dtype=range_dt)
counts = np.full(mask.shape[1], 0, dtype=int_)
for col in prange(mask.shape[1]):
from_i = -1
for i in range(mask.shape[0]):
if mask[i, col]:
if from_i > -1:
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = from_i
new_records["end_idx"][r, col] = i
new_records["status"][r, col] = RangeStatus.Closed
counts[col] += 1
from_i = i
if incl_open and from_i < mask.shape[0] - 1:
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = from_i
new_records["end_idx"][r, col] = mask.shape[0] - 1
new_records["status"][r, col] = RangeStatus.Open
counts[col] += 1
return generic_nb.repartition_nb(new_records, counts)
@register_chunkable(
size=ch.ArraySizer(arg_query="source_mask", axis=1),
arg_take_spec=dict(
source_mask=ch.ArraySlicer(axis=1),
target_mask=ch.ArraySlicer(axis=1),
relation=None,
incl_open=None,
),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def between_two_ranges_nb(
source_mask: tp.Array2d,
target_mask: tp.Array2d,
relation: int = SignalRelation.OneMany,
incl_open: bool = False,
) -> tp.RecordArray:
"""Create a record of type `vectorbtpro.generic.enums.range_dt` for each range
between a source and target mask.
Index pairs are resolved with `relation_idxs_1d_nb`."""
new_records = np.empty(source_mask.shape, dtype=range_dt)
counts = np.full(source_mask.shape[1], 0, dtype=int_)
for col in prange(source_mask.shape[1]):
_, _, source_idxs, target_idsx = relation_idxs_1d_nb(
source_mask[:, col],
target_mask[:, col],
relation=relation,
)
for i in range(len(source_idxs)):
if source_idxs[i] != -1 and target_idsx[i] != -1:
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = source_idxs[i]
new_records["end_idx"][r, col] = target_idsx[i]
new_records["status"][r, col] = RangeStatus.Closed
counts[col] += 1
elif source_idxs[i] != -1 and target_idsx[i] == -1 and incl_open:
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = source_idxs[i]
new_records["end_idx"][r, col] = source_mask.shape[0] - 1
new_records["status"][r, col] = RangeStatus.Open
counts[col] += 1
return generic_nb.repartition_nb(new_records, counts)
@register_chunkable(
size=ch.ArraySizer(arg_query="mask", axis=1),
arg_take_spec=dict(mask=ch.ArraySlicer(axis=1)),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def partition_ranges_nb(mask: tp.Array2d) -> tp.RecordArray:
"""Create a record of type `vectorbtpro.generic.enums.range_dt` for each partition of signals in `mask`."""
new_records = np.empty(mask.shape, dtype=range_dt)
counts = np.full(mask.shape[1], 0, dtype=int_)
for col in prange(mask.shape[1]):
is_partition = False
from_i = -1
for i in range(mask.shape[0]):
if mask[i, col]:
if not is_partition:
from_i = i
is_partition = True
elif is_partition:
to_i = i
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = from_i
new_records["end_idx"][r, col] = to_i
new_records["status"][r, col] = RangeStatus.Closed
counts[col] += 1
is_partition = False
if i == mask.shape[0] - 1:
if is_partition:
to_i = mask.shape[0] - 1
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = from_i
new_records["end_idx"][r, col] = to_i
new_records["status"][r, col] = RangeStatus.Open
counts[col] += 1
return generic_nb.repartition_nb(new_records, counts)
@register_chunkable(
size=ch.ArraySizer(arg_query="mask", axis=1),
arg_take_spec=dict(mask=ch.ArraySlicer(axis=1)),
merge_func=records_ch.merge_records,
merge_kwargs=dict(chunk_meta=Rep("chunk_meta")),
)
@register_jitted(cache=True, tags={"can_parallel"})
def between_partition_ranges_nb(mask: tp.Array2d) -> tp.RecordArray:
"""Create a record of type `vectorbtpro.generic.enums.range_dt` for each range between two partitions in `mask`."""
new_records = np.empty(mask.shape, dtype=range_dt)
counts = np.full(mask.shape[1], 0, dtype=int_)
for col in prange(mask.shape[1]):
is_partition = False
from_i = -1
for i in range(mask.shape[0]):
if mask[i, col]:
if not is_partition and from_i != -1:
to_i = i
r = counts[col]
new_records["id"][r, col] = r
new_records["col"][r, col] = col
new_records["start_idx"][r, col] = from_i
new_records["end_idx"][r, col] = to_i
new_records["status"][r, col] = RangeStatus.Closed
counts[col] += 1
is_partition = True
from_i = i
else:
is_partition = False
return generic_nb.repartition_nb(new_records, counts)
# ############# Raveling ############# #
@register_jitted(cache=True)
def unravel_nb(
mask: tp.Array2d,
incl_empty_cols: bool = True,
) -> tp.Tuple[
tp.Array2d,
tp.Array1d,
tp.Array1d,
tp.Array1d,
]:
"""Unravel each True value in a mask to a separate column.
Returns the new mask, the index of each True value in its column, the row index of each
True value in its column, and the column index of each True value in the original mask."""
true_idxs = np.flatnonzero(mask.transpose())
start_idxs = np.full(mask.shape[1], -1, dtype=int_)
end_idxs = np.full(mask.shape[1], 0, dtype=int_)
for i in range(len(true_idxs)):
col = true_idxs[i] // mask.shape[0]
if i == 0:
prev_col = -1
else:
prev_col = true_idxs[i - 1] // mask.shape[0]
if col != prev_col:
start_idxs[col] = i
end_idxs[col] = i + 1
n_cols = (end_idxs - start_idxs).sum()
new_mask = np.full((mask.shape[0], n_cols), False, dtype=np.bool_)
range_ = np.full(n_cols, -1, dtype=int_)
row_idxs = np.full(n_cols, -1, dtype=int_)
col_idxs = np.empty(n_cols, dtype=int_)
k = 0
for i in range(len(start_idxs)):
start_idx = start_idxs[i]
end_idx = end_idxs[i]
col_filled = False
if start_idx != -1:
for j in range(start_idx, end_idx):
new_mask[true_idxs[j] % mask.shape[0], k] = True
range_[k] = j - start_idx
row_idxs[k] = true_idxs[j] % mask.shape[0]
col_idxs[k] = true_idxs[j] // mask.shape[0]
k += 1
col_filled = True
if not col_filled and incl_empty_cols:
if k == 0:
col_idxs[k] = 0
else:
col_idxs[k] = col_idxs[k - 1] + 1
k += 1
return new_mask[:, :k], range_[:k], row_idxs[:k], col_idxs[:k]
@register_jitted(cache=True)
def unravel_between_nb(
mask: tp.Array2d,
incl_open_source: bool = False,
incl_empty_cols: bool = True,
) -> tp.Tuple[
tp.Array2d,
tp.Array1d,
tp.Array1d,
tp.Array1d,
tp.Array1d,
tp.Array1d,
]:
"""Unravel each pair of successive True values in a mask to a separate column.
Returns the new mask, the index of each source True value in its column, the index of
each target True value in its column, the row index of each source True value in the original
mask, the row index of each target True value in the original mask, and the column index of
each True value in the original mask."""
true_idxs = np.flatnonzero(mask.transpose())
start_idxs = np.full(mask.shape[1], -1, dtype=int_)
end_idxs = np.full(mask.shape[1], 0, dtype=int_)
for i in range(len(true_idxs)):
col = true_idxs[i] // mask.shape[0]
if i == 0:
prev_col = -1
else:
prev_col = true_idxs[i - 1] // mask.shape[0]
if col != prev_col:
start_idxs[col] = i
end_idxs[col] = i + 1
n_cols = (end_idxs - start_idxs).sum()
new_mask = np.full((mask.shape[0], n_cols), False, dtype=np.bool_)
source_range = np.full(n_cols, -1, dtype=int_)
target_range = np.full(n_cols, -1, dtype=int_)
source_idxs = np.full(n_cols, -1, dtype=int_)
target_idxs = np.full(n_cols, -1, dtype=int_)
col_idxs = np.empty(n_cols, dtype=int_)
k = 0
for i in range(len(start_idxs)):
start_idx = start_idxs[i]
end_idx = end_idxs[i]
col_filled = False
if start_idx != -1:
for j in range(start_idx, end_idx):
if j == end_idx - 1 and not incl_open_source:
continue
new_mask[true_idxs[j] % mask.shape[0], k] = True
source_range[k] = j - start_idx
source_idxs[k] = true_idxs[j] % mask.shape[0]
if j < end_idx - 1:
new_mask[true_idxs[j + 1] % mask.shape[0], k] = True
target_range[k] = j + 1 - start_idx
target_idxs[k] = true_idxs[j + 1] % mask.shape[0]
col_idxs[k] = true_idxs[j] // mask.shape[0]
k += 1
col_filled = True
if not col_filled and incl_empty_cols:
if k == 0:
col_idxs[k] = 0
else:
col_idxs[k] = col_idxs[k - 1] + 1
k += 1
return (
new_mask[:, :k],
source_range[:k],
target_range[:k],
source_idxs[:k],
target_idxs[:k],
col_idxs[:k],
)
@register_jitted(cache=True)
def unravel_between_two_nb(
source_mask: tp.Array2d,
target_mask: tp.Array2d,
relation: int = SignalRelation.OneMany,
incl_open_source: bool = False,
incl_open_target: bool = False,
incl_empty_cols: bool = True,
) -> tp.Tuple[
tp.Array2d,
tp.Array2d,
tp.Array1d,
tp.Array1d,
tp.Array1d,
tp.Array1d,
tp.Array1d,
]:
"""Unravel each pair of successive True values between a source and target mask to a separate column.
Index pairs are resolved with `relation_idxs_1d_nb`.
Returns the new source mask, the new target mask, the index of each source True value in its column,
the index of each target True value in its column, the row index of each True value in each
original mask, and the column index of each True value in both original masks."""
if relation == SignalRelation.Chain or relation == SignalRelation.AnyChain:
max_signals = source_mask.shape[0] * 2
else:
max_signals = source_mask.shape[0]
source_range_2d = np.empty((max_signals, source_mask.shape[1]), dtype=int_)
target_range_2d = np.empty((max_signals, source_mask.shape[1]), dtype=int_)
source_idxs_2d = np.empty((max_signals, source_mask.shape[1]), dtype=int_)
target_idxs_2d = np.empty((max_signals, source_mask.shape[1]), dtype=int_)
counts = np.empty(source_mask.shape[1], dtype=int_)
n_cols = 0
for col in range(source_mask.shape[1]):
source_range_col, target_range_col, source_idxs_col, target_idxs_col = relation_idxs_1d_nb(
source_mask[:, col],
target_mask[:, col],
relation=relation,
)
n_idxs = len(source_idxs_col)
source_range_2d[:n_idxs, col] = source_range_col
target_range_2d[:n_idxs, col] = target_range_col
source_idxs_2d[:n_idxs, col] = source_idxs_col
target_idxs_2d[:n_idxs, col] = target_idxs_col
counts[col] = n_idxs
if n_idxs == 0:
n_cols += 1
else:
n_cols += n_idxs
new_source_mask = np.full((source_mask.shape[0], n_cols), False, dtype=np.bool_)
new_target_mask = np.full((source_mask.shape[0], n_cols), False, dtype=np.bool_)
source_range = np.full(n_cols, -1, dtype=int_)
target_range = np.full(n_cols, -1, dtype=int_)
source_idxs = np.full(n_cols, -1, dtype=int_)
target_idxs = np.full(n_cols, -1, dtype=int_)
col_idxs = np.empty(n_cols, dtype=int_)
k = 0
for c in range(len(counts)):
col_filled = False
if counts[c] > 0:
for j in range(counts[c]):
source_idx = source_idxs_2d[j, c]
target_idx = target_idxs_2d[j, c]
if source_idx != -1 and target_idx != -1:
new_source_mask[source_idx, k] = True
new_target_mask[target_idx, k] = True
source_range[k] = source_range_2d[j, c]
target_range[k] = target_range_2d[j, c]
source_idxs[k] = source_idx
target_idxs[k] = target_idx
col_idxs[k] = c
k += 1
col_filled = True
elif source_idx != -1 and incl_open_source:
new_source_mask[source_idx, k] = True
source_range[k] = source_range_2d[j, c]
source_idxs[k] = source_idx
col_idxs[k] = c
k += 1
col_filled = True
elif target_idx != -1 and incl_open_target:
new_target_mask[target_idx, k] = True
target_range[k] = target_range_2d[j, c]
target_idxs[k] = target_idx
col_idxs[k] = c
k += 1
col_filled = True
if not col_filled and incl_empty_cols:
col_idxs[k] = c
k += 1
return (
new_source_mask[:, :k],
new_target_mask[:, :k],
source_range[:k],
target_range[:k],
source_idxs[:k],
target_idxs[:k],
col_idxs[:k],
)
@register_jitted(cache=True, tags={"can_parallel"})
def ravel_nb(mask: tp.Array2d, group_map: tp.GroupMap) -> tp.Array2d:
"""Ravel True values of each group into a separate column."""
group_idxs, group_lens = group_map
group_start_idxs = np.cumsum(group_lens) - group_lens
out = np.full((mask.shape[0], len(group_lens)), False, dtype=np.bool_)
for group in prange(len(group_lens)):
group_len = group_lens[group]
start_idx = group_start_idxs[group]
col_idxs = group_idxs[start_idx : start_idx + group_len]
for col in col_idxs:
for i in range(mask.shape[0]):
if mask[i, col]:
out[i, group] = True
return out
# ############# Index ############# #
@register_jitted(cache=True)
def nth_index_1d_nb(mask: tp.Array1d, n: int) -> int:
"""Get the index of the n-th True value.
!!! note
`n` starts with 0 and can be negative."""
if n >= 0:
found = -1
for i in range(mask.shape[0]):
if mask[i]:
found += 1
if found == n:
return i
else:
found = 0
for i in range(mask.shape[0] - 1, -1, -1):
if mask[i]:
found -= 1
if found == n:
return i
return -1
@register_chunkable(
size=ch.ArraySizer(arg_query="mask", axis=1),
arg_take_spec=dict(mask=ch.ArraySlicer(axis=1), n=None),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def nth_index_nb(mask: tp.Array2d, n: int) -> tp.Array1d:
"""2-dim version of `nth_index_1d_nb`."""
out = np.empty(mask.shape[1], dtype=int_)
for col in prange(mask.shape[1]):
out[col] = nth_index_1d_nb(mask[:, col], n)
return out
@register_jitted(cache=True)
def norm_avg_index_1d_nb(mask: tp.Array1d) -> float:
"""Get mean index normalized to (-1, 1)."""
mean_index = np.mean(np.flatnonzero(mask))
return rescale_nb(mean_index, (0, len(mask) - 1), (-1, 1))
@register_chunkable(
size=ch.ArraySizer(arg_query="mask", axis=1),
arg_take_spec=dict(mask=ch.ArraySlicer(axis=1)),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def norm_avg_index_nb(mask: tp.Array2d) -> tp.Array1d:
"""2-dim version of `norm_avg_index_1d_nb`."""
out = np.empty(mask.shape[1], dtype=float_)
for col in prange(mask.shape[1]):
out[col] = norm_avg_index_1d_nb(mask[:, col])
return out
@register_chunkable(
size=ch.ArraySizer(arg_query="group_lens", axis=0),
arg_take_spec=dict(
mask=base_ch.array_gl_slicer,
group_lens=ch.ArraySlicer(axis=0),
),
merge_func="concat",
)
@register_jitted(cache=True, tags={"can_parallel"})
def norm_avg_index_grouped_nb(mask, group_lens):
"""Grouped version of `norm_avg_index_nb`."""
out = np.empty(len(group_lens), dtype=float_)
group_end_idxs = np.cumsum(group_lens)
group_start_idxs = group_end_idxs - group_lens
for group in prange(len(group_lens)):
from_col = group_start_idxs[group]
to_col = group_end_idxs[group]
temp_sum = 0
temp_cnt = 0
for col in range(from_col, to_col):
for i in range(mask.shape[0]):
if mask[i, col]:
temp_sum += i
temp_cnt += 1
out[group] = rescale_nb(temp_sum / temp_cnt, (0, mask.shape[0] - 1), (-1, 1))
return out
{
"data": {
"histogram2dcontour": [
{
"type": "histogram2dcontour",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"choropleth": [
{
"type": "choropleth",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"histogram2d": [
{
"type": "histogram2d",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"heatmap": [
{
"type": "heatmap",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"heatmapgl": [
{
"type": "heatmapgl",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"contourcarpet": [
{
"type": "contourcarpet",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"contour": [
{
"type": "contour",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"surface": [
{
"type": "surface",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"mesh3d": [
{
"type": "mesh3d",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"scatter": [
{
"marker": {
"line": {
"color": "#313439"
}
},
"type": "scatter"
}
],
"parcoords": [
{
"type": "parcoords",
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterpolargl": [
{
"type": "scatterpolargl",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"bar": [
{
"error_x": {
"color": "#d6dfef"
},
"error_y": {
"color": "#d6dfef"
},
"marker": {
"line": {
"color": "#232428",
"width": 0.5
}
},
"type": "bar"
}
],
"scattergeo": [
{
"type": "scattergeo",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterpolar": [
{
"type": "scatterpolar",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"histogram": [
{
"type": "histogram",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattergl": [
{
"marker": {
"line": {
"color": "#313439"
}
},
"type": "scattergl"
}
],
"scatter3d": [
{
"type": "scatter3d",
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattermapbox": [
{
"type": "scattermapbox",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterternary": [
{
"type": "scatterternary",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattercarpet": [
{
"type": "scattercarpet",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#A2B1C6",
"gridcolor": "#313439",
"linecolor": "#313439",
"minorgridcolor": "#313439",
"startlinecolor": "#A2B1C6"
},
"baxis": {
"endlinecolor": "#A2B1C6",
"gridcolor": "#313439",
"linecolor": "#313439",
"minorgridcolor": "#313439",
"startlinecolor": "#A2B1C6"
},
"type": "carpet"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#313439"
},
"line": {
"color": "#232428"
}
},
"header": {
"fill": {
"color": "#2a3f5f"
},
"line": {
"color": "#232428"
}
},
"type": "table"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#232428",
"width": 0.5
}
},
"type": "barpolar"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
]
},
"layout": {
"colorway": [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#dc3912",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf"
],
"font": {
"color": "#d6dfef"
},
"hovermode": "closest",
"hoverlabel": {
"align": "left"
},
"paper_bgcolor": "#232428",
"plot_bgcolor": "#232428",
"polar": {
"bgcolor": "#232428",
"angularaxis": {
"gridcolor": "#313439",
"linecolor": "#313439",
"ticks": ""
},
"radialaxis": {
"gridcolor": "#313439",
"linecolor": "#313439",
"ticks": ""
}
},
"ternary": {
"bgcolor": "#232428",
"aaxis": {
"gridcolor": "#313439",
"linecolor": "#313439",
"ticks": ""
},
"baxis": {
"gridcolor": "#313439",
"linecolor": "#313439",
"ticks": ""
},
"caxis": {
"gridcolor": "#313439",
"linecolor": "#313439",
"ticks": ""
}
},
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"sequential": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
],
"sequentialminus": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
],
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
]
},
"xaxis": {
"gridcolor": "#313439",
"linecolor": "#313439",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "#313439",
"automargin": true,
"zerolinewidth": 2
},
"yaxis": {
"gridcolor": "#313439",
"linecolor": "#313439",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "#313439",
"automargin": true,
"zerolinewidth": 2
},
"scene": {
"xaxis": {
"backgroundcolor": "#232428",
"gridcolor": "#313439",
"linecolor": "#313439",
"showbackground": true,
"ticks": "",
"zerolinecolor": "#aec0d6",
"gridwidth": 2
},
"yaxis": {
"backgroundcolor": "#232428",
"gridcolor": "#313439",
"linecolor": "#313439",
"showbackground": true,
"ticks": "",
"zerolinecolor": "#aec0d6",
"gridwidth": 2
},
"zaxis": {
"backgroundcolor": "#232428",
"gridcolor": "#313439",
"linecolor": "#313439",
"showbackground": true,
"ticks": "",
"zerolinecolor": "#aec0d6",
"gridwidth": 2
}
},
"shapedefaults": {
"line": {
"color": "#d6dfef"
}
},
"annotationdefaults": {
"arrowcolor": "#d6dfef",
"arrowhead": 0,
"arrowwidth": 1
},
"geo": {
"bgcolor": "#232428",
"landcolor": "#232428",
"subunitcolor": "#313439",
"showland": true,
"showlakes": true,
"lakecolor": "#232428"
},
"title": {
"x": 0.05
},
"updatemenudefaults": {
"bgcolor": "#313439",
"borderwidth": 0
},
"sliderdefaults": {
"bgcolor": "#aec0d6",
"borderwidth": 1,
"bordercolor": "#232428",
"tickwidth": 0
},
"mapbox": {
"style": "dark"
}
}
}
{
"data": {
"histogram2dcontour": [
{
"type": "histogram2dcontour",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"choropleth": [
{
"type": "choropleth",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"histogram2d": [
{
"type": "histogram2d",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"heatmap": [
{
"type": "heatmap",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"heatmapgl": [
{
"type": "heatmapgl",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"contourcarpet": [
{
"type": "contourcarpet",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"contour": [
{
"type": "contour",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"surface": [
{
"type": "surface",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
},
"colorscale": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
]
}
],
"mesh3d": [
{
"type": "mesh3d",
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
],
"scatter": [
{
"type": "scatter",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"parcoords": [
{
"type": "parcoords",
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterpolargl": [
{
"type": "scatterpolargl",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"bar": [
{
"error_x": {
"color": "#2a3f5f"
},
"error_y": {
"color": "#2a3f5f"
},
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
}
},
"type": "bar"
}
],
"scattergeo": [
{
"type": "scattergeo",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterpolar": [
{
"type": "scatterpolar",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"histogram": [
{
"type": "histogram",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattergl": [
{
"type": "scattergl",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatter3d": [
{
"type": "scatter3d",
"line": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattermapbox": [
{
"type": "scattermapbox",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scatterternary": [
{
"type": "scatterternary",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"scattercarpet": [
{
"type": "scattercarpet",
"marker": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
}
}
],
"carpet": [
{
"aaxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"baxis": {
"endlinecolor": "#2a3f5f",
"gridcolor": "white",
"linecolor": "white",
"minorgridcolor": "white",
"startlinecolor": "#2a3f5f"
},
"type": "carpet"
}
],
"table": [
{
"cells": {
"fill": {
"color": "#EBF0F8"
},
"line": {
"color": "white"
}
},
"header": {
"fill": {
"color": "#C8D4E3"
},
"line": {
"color": "white"
}
},
"type": "table"
}
],
"barpolar": [
{
"marker": {
"line": {
"color": "#E5ECF6",
"width": 0.5
}
},
"type": "barpolar"
}
],
"pie": [
{
"automargin": true,
"type": "pie"
}
]
},
"layout": {
"colorway": [
"#1f77b4",
"#ff7f0e",
"#2ca02c",
"#dc3912",
"#9467bd",
"#8c564b",
"#e377c2",
"#7f7f7f",
"#bcbd22",
"#17becf"
],
"font": {
"color": "#2a3f5f"
},
"hovermode": "closest",
"hoverlabel": {
"align": "left"
},
"paper_bgcolor": "white",
"plot_bgcolor": "#E5ECF6",
"polar": {
"bgcolor": "#E5ECF6",
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"ternary": {
"bgcolor": "#E5ECF6",
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"sequential": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
],
"sequentialminus": [
[
0.0,
"#0d0887"
],
[
0.1111111111111111,
"#46039f"
],
[
0.2222222222222222,
"#7201a8"
],
[
0.3333333333333333,
"#9c179e"
],
[
0.4444444444444444,
"#bd3786"
],
[
0.5555555555555556,
"#d8576b"
],
[
0.6666666666666666,
"#ed7953"
],
[
0.7777777777777778,
"#fb9f3a"
],
[
0.8888888888888888,
"#fdca26"
],
[
1.0,
"#f0f921"
]
],
"diverging": [
[
0,
"#8e0152"
],
[
0.1,
"#c51b7d"
],
[
0.2,
"#de77ae"
],
[
0.3,
"#f1b6da"
],
[
0.4,
"#fde0ef"
],
[
0.5,
"#f7f7f7"
],
[
0.6,
"#e6f5d0"
],
[
0.7,
"#b8e186"
],
[
0.8,
"#7fbc41"
],
[
0.9,
"#4d9221"
],
[
1,
"#276419"
]
]
},
"xaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"automargin": true,
"zerolinewidth": 2
},
"yaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"automargin": true,
"zerolinewidth": 2
},
"scene": {
"xaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white",
"gridwidth": 2
},
"yaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white",
"gridwidth": 2
},
"zaxis": {
"backgroundcolor": "#E5ECF6",
"gridcolor": "white",
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white",
"gridwidth": 2
}
},
"shapedefaults": {
"line": {
"color": "#2a3f5f"
}
},
"annotationdefaults": {
"arrowcolor": "#2a3f5f",
"arrowhead": 0,
"arrowwidth": 1
},
"geo": {
"bgcolor": "white",
"landcolor": "#E5ECF6",
"subunitcolor": "white",
"showland": true,
"showlakes": true,
"lakecolor": "white"
},
"title": {
"x": 0.05
},
"mapbox": {
"style": "light"
}
}
}
{
"layout": {
"colorway": [
"rgb(76,114,176)",
"rgb(221,132,82)",
"rgb(85,168,104)",
"rgb(196,78,82)",
"rgb(129,114,179)",
"rgb(147,120,96)",
"rgb(218,139,195)",
"rgb(140,140,140)",
"rgb(204,185,116)",
"rgb(100,181,205)"
],
"font": {
"color": "rgb(36,36,36)"
},
"hovermode": "closest",
"hoverlabel": {
"align": "left"
},
"paper_bgcolor": "white",
"plot_bgcolor": "rgb(234,234,242)",
"polar": {
"bgcolor": "rgb(234,234,242)",
"angularaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"radialaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"ternary": {
"bgcolor": "rgb(234,234,242)",
"aaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"baxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
},
"caxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": ""
}
},
"coloraxis": {
"colorbar": {
"outlinewidth": 0,
"ticks": ""
}
},
"colorscale": {
"sequential": [
[
0.0,
"rgb(2,4,25)"
],
[
0.06274509803921569,
"rgb(24,15,41)"
],
[
0.12549019607843137,
"rgb(47,23,57)"
],
[
0.18823529411764706,
"rgb(71,28,72)"
],
[
0.25098039215686274,
"rgb(97,30,82)"
],
[
0.3137254901960784,
"rgb(123,30,89)"
],
[
0.3764705882352941,
"rgb(150,27,91)"
],
[
0.4392156862745098,
"rgb(177,22,88)"
],
[
0.5019607843137255,
"rgb(203,26,79)"
],
[
0.5647058823529412,
"rgb(223,47,67)"
],
[
0.6274509803921569,
"rgb(236,76,61)"
],
[
0.6901960784313725,
"rgb(242,107,73)"
],
[
0.7529411764705882,
"rgb(244,135,95)"
],
[
0.8156862745098039,
"rgb(245,162,122)"
],
[
0.8784313725490196,
"rgb(246,188,153)"
],
[
0.9411764705882353,
"rgb(247,212,187)"
],
[
1.0,
"rgb(250,234,220)"
]
],
"sequentialminus": [
[
0.0,
"rgb(2,4,25)"
],
[
0.06274509803921569,
"rgb(24,15,41)"
],
[
0.12549019607843137,
"rgb(47,23,57)"
],
[
0.18823529411764706,
"rgb(71,28,72)"
],
[
0.25098039215686274,
"rgb(97,30,82)"
],
[
0.3137254901960784,
"rgb(123,30,89)"
],
[
0.3764705882352941,
"rgb(150,27,91)"
],
[
0.4392156862745098,
"rgb(177,22,88)"
],
[
0.5019607843137255,
"rgb(203,26,79)"
],
[
0.5647058823529412,
"rgb(223,47,67)"
],
[
0.6274509803921569,
"rgb(236,76,61)"
],
[
0.6901960784313725,
"rgb(242,107,73)"
],
[
0.7529411764705882,
"rgb(244,135,95)"
],
[
0.8156862745098039,
"rgb(245,162,122)"
],
[
0.8784313725490196,
"rgb(246,188,153)"
],
[
0.9411764705882353,
"rgb(247,212,187)"
],
[
1.0,
"rgb(250,234,220)"
]
]
},
"xaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"automargin": true,
"zerolinewidth": 2
},
"yaxis": {
"gridcolor": "white",
"linecolor": "white",
"ticks": "",
"title": {
"standoff": 15
},
"zerolinecolor": "white",
"automargin": true,
"zerolinewidth": 2
},
"scene": {
"xaxis": {
"backgroundcolor": "rgb(234,234,242)",
"gridcolor": "white",
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white",
"gridwidth": 2
},
"yaxis": {
"backgroundcolor": "rgb(234,234,242)",
"gridcolor": "white",
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white",
"gridwidth": 2
},
"zaxis": {
"backgroundcolor": "rgb(234,234,242)",
"gridcolor": "white",
"linecolor": "white",
"showbackground": true,
"ticks": "",
"zerolinecolor": "white",
"gridwidth": 2
}
},
"shapedefaults": {
"line": {
"color": "rgb(67,103,167)"
}
},
"annotationdefaults": {
"arrowcolor": "rgb(67,103,167)",
"arrowhead": 0,
"arrowwidth": 1
},
"geo": {
"bgcolor": "white",
"landcolor": "rgb(234,234,242)",
"subunitcolor": "white",
"showland": true,
"showlakes": true,
"lakecolor": "white"
},
"title": {
"x": 0.05
},
"mapbox": {
"style": "light"
}
}
}
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Utilities for working with knowledge.
Run for the examples:
```pycon
>>> dataset = [
... {"s": "ABC", "b": True, "d2": {"c": "red", "l": [1, 2]}},
... {"s": "BCD", "b": True, "d2": {"c": "blue", "l": [3, 4]}},
... {"s": "CDE", "b": False, "d2": {"c": "green", "l": [5, 6]}},
... {"s": "DEF", "b": False, "d2": {"c": "yellow", "l": [7, 8]}},
... {"s": "EFG", "b": False, "d2": {"c": "black", "l": [9, 10]}, "xyz": 123}
... ]
>>> asset = vbt.KnowledgeAsset(dataset)
```
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vectorbtpro.utils.knowledge.asset_pipelines import *
from vectorbtpro.utils.knowledge.base_asset_funcs import *
from vectorbtpro.utils.knowledge.base_assets import *
from vectorbtpro.utils.knowledge.chatting import *
from vectorbtpro.utils.knowledge.custom_asset_funcs import *
from vectorbtpro.utils.knowledge.custom_assets import *
from vectorbtpro.utils.knowledge.formatting import *
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Asset pipeline classes.
See `vectorbtpro.utils.knowledge` for the toy dataset."""
from vectorbtpro import _typing as tp
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import merge_dicts
from vectorbtpro.utils.eval_ import evaluate
from vectorbtpro.utils.execution import Task
from vectorbtpro.utils.knowledge.base_asset_funcs import AssetFunc
from vectorbtpro.utils.module_ import package_shortcut_config
from vectorbtpro.utils.parsing import get_func_arg_names
__all__ = [
"AssetPipeline",
"BasicAssetPipeline",
"ComplexAssetPipeline",
]
class AssetPipeline(Base):
"""Abstract class representing an asset pipeline."""
@classmethod
def resolve_task(
cls,
func: tp.AssetFuncLike,
*args,
prepare: bool = True,
prepare_once: bool = True,
cond_kwargs: tp.KwargsLike = None,
asset_func_meta: tp.Union[None, dict, list] = None,
**kwargs,
) -> tp.Task:
"""Resolve a task."""
if isinstance(func, tuple):
func = Task.from_tuple(func)
if isinstance(func, Task):
args = (*func.args, *args)
kwargs = merge_dicts(func.kwargs, kwargs)
func = func.func
if isinstance(func, str):
from vectorbtpro.utils.knowledge import base_asset_funcs, custom_asset_funcs
base_keys = dir(base_asset_funcs)
custom_keys = dir(custom_asset_funcs)
base_values = [getattr(base_asset_funcs, k) for k in base_keys]
custom_values = [getattr(custom_asset_funcs, k) for k in custom_keys]
module_items = dict(zip(base_keys + custom_keys, base_values + custom_values))
if (
func in module_items
and isinstance(module_items[func], type)
and issubclass(module_items[func], AssetFunc)
):
func = module_items[func]
elif func.title() + "AssetFunc" in module_items:
func = module_items[func.title() + "AssetFunc"]
else:
found_func = False
for k, v in module_items.items():
if isinstance(v, type) and issubclass(v, AssetFunc):
if v._short_name is not None:
if func.lower() == v._short_name.lower():
func = v
found_func = True
break
if not found_func:
raise ValueError(f"Function '{func}' not found")
if isinstance(func, AssetFunc):
raise TypeError("Function must be a subclass of AssetFunc, not an instance")
if isinstance(func, type) and issubclass(func, AssetFunc):
_asset_func_meta = {}
for var_name, var_type in func.__annotations__.items():
if var_name.startswith("_") and tp.get_origin(var_type) is tp.ClassVar:
_asset_func_meta[var_name] = getattr(func, var_name)
if asset_func_meta is not None:
if isinstance(asset_func_meta, dict):
asset_func_meta.update(_asset_func_meta)
else:
asset_func_meta.append(_asset_func_meta)
if prepare:
if prepare_once:
if cond_kwargs is None:
cond_kwargs = {}
if len(cond_kwargs) > 0:
prepare_arg_names = get_func_arg_names(func.prepare)
for k, v in cond_kwargs.items():
if k in prepare_arg_names:
kwargs[k] = v
args, kwargs = func.prepare(*args, **kwargs)
func = func.call
else:
func = func.prepare_and_call
else:
func = func.call
if not callable(func):
raise TypeError("Function must be callable")
return Task(func, *args, **kwargs)
def run(self, d: tp.Any) -> tp.Any:
"""Run the pipeline on a data item."""
raise NotImplementedError
def __call__(self, d: tp.Any) -> tp.Any:
return self.run(d)
class BasicAssetPipeline(AssetPipeline):
"""Class representing a basic asset pipeline.
Builds a composite function out of all functions.
Usage:
```pycon
>>> asset_pipeline = vbt.BasicAssetPipeline()
>>> asset_pipeline.append("flatten")
>>> asset_pipeline.append("query", len)
>>> asset_pipeline.append("get")
>>> asset_pipeline(dataset[0])
5
```
"""
def __init__(self, *args, **kwargs) -> None:
if len(args) == 0:
tasks = []
else:
tasks = args[0]
args = args[1:]
if not isinstance(tasks, list):
tasks = [tasks]
self._tasks = [self.resolve_task(task, *args, **kwargs) for task in tasks]
@property
def tasks(self) -> tp.List[tp.Task]:
"""Tasks."""
return self._tasks
def append(self, func: tp.AssetFuncLike, *args, **kwargs) -> None:
"""Append a task to the pipeline."""
self.tasks.append(self.resolve_task(func, *args, **kwargs))
@classmethod
def compose_tasks(cls, tasks: tp.List[tp.Task]) -> tp.Callable:
"""Compose multiple tasks into one."""
def composed(d):
result = d
for func, args, kwargs in tasks:
result = func(result, *args, **kwargs)
return result
return composed
def run(self, d: tp.Any) -> tp.Any:
return self.compose_tasks(list(self.tasks))(d)
class ComplexAssetPipeline(AssetPipeline):
"""Class representing a complex asset pipeline.
Takes an expression string and a context. Resolves functions inside the expression.
Expression is evaluated with `vectorbtpro.utils.eval_.evaluate`.
Usage:
```pycon
>>> asset_pipeline = vbt.ComplexAssetPipeline("query(flatten(d), len)")
>>> asset_pipeline(dataset[0])
5
```
"""
@classmethod
def resolve_expression_and_context(
cls,
expression: str,
context: tp.KwargsLike = None,
prepare: bool = True,
prepare_once: bool = True,
**resolve_task_kwargs,
) -> tp.Tuple[str, tp.Kwargs]:
"""Resolve an expression and a context.
Parses an expression string, extracts function calls with their arguments,
removing the first positional argument from each function, and creates a new context."""
import importlib
import builtins
import ast
import sys
if context is None:
context = {}
for k, v in package_shortcut_config.items():
if k not in context:
try:
context[k] = importlib.import_module(v)
except ImportError:
pass
tree = ast.parse(expression)
builtin_functions = set(dir(builtins))
imported_functions = set()
imported_modules = set()
defined_functions = set()
func_context = {}
class _FunctionAnalyzer(ast.NodeVisitor):
def visit_Import(self, node):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
imported_modules.add(name)
self.generic_visit(node)
def visit_ImportFrom(self, node):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
imported_functions.add(name)
self.generic_visit(node)
def visit_FunctionDef(self, node):
defined_functions.add(node.name)
self.generic_visit(node)
analyzer = _FunctionAnalyzer()
analyzer.visit(tree)
class _NodeMixin:
def get_func_name(self, func):
attrs = []
while isinstance(func, ast.Attribute):
attrs.append(func.attr)
func = func.value
if isinstance(func, ast.Name):
attrs.append(func.id)
return ".".join(reversed(attrs)) if attrs else ""
def is_function_assigned(self, func):
func_name = self.get_func_name(func)
if "." in func_name:
func_name = func_name.split(".")[0]
return (
func_name in context
or func_name in builtin_functions
or func_name in imported_functions
or func_name in imported_modules
or func_name in defined_functions
)
class _FunctionCallVisitor(ast.NodeVisitor, _NodeMixin):
def process_argument(self, arg):
if isinstance(arg, ast.Constant):
return arg.value
elif isinstance(arg, ast.Name):
var_name = arg.id
if var_name in context:
return context[var_name]
elif var_name in builtin_functions:
return getattr(builtins, var_name)
else:
raise ValueError(f"Variable '{var_name}' is not defined in the context")
elif isinstance(arg, ast.List):
return [self.process_argument(elem) for elem in arg.elts]
elif isinstance(arg, ast.Tuple):
return tuple(self.process_argument(elem) for elem in arg.elts)
elif isinstance(arg, ast.Dict):
return {self.process_argument(k): self.process_argument(v) for k, v in zip(arg.keys, arg.values)}
elif isinstance(arg, ast.Set):
return {self.process_argument(elem) for elem in arg.elts}
elif isinstance(arg, ast.Call):
if self.is_function_assigned(arg.func):
return self.get_func_name(arg.func)
raise ValueError(f"Unsupported or dynamic argument: {ast.dump(arg)}")
def visit_Call(self, node):
self.generic_visit(node)
func_name = self.get_func_name(node.func)
pos_args = []
for arg in node.args[1:]:
arg_value = self.process_argument(arg)
pos_args.append(arg_value)
kw_args = {}
for kw in node.keywords:
if kw.arg is None:
raise ValueError(f"Dynamic keyword argument names are not allowed in '{func_name}'")
kw_name = kw.arg
kw_value = self.process_argument(kw.value)
kw_args[kw_name] = kw_value
if not self.is_function_assigned(node.func):
task = cls.resolve_task(
func_name,
*pos_args,
**kw_args,
prepare=prepare,
prepare_once=prepare_once,
**resolve_task_kwargs,
)
if prepare and prepare_once:
def func(d, _task=task):
return _task.func(d, *_task.args, **_task.kwargs)
else:
func = task.func
func_context[func_name] = func
visitor = _FunctionCallVisitor()
visitor.visit(tree)
if prepare and prepare_once:
class _ArgumentPruner(ast.NodeTransformer, _NodeMixin):
def visit_Call(self, node: ast.Call):
if not self.is_function_assigned(node.func):
if node.args:
node.args = [node.args[0]]
else:
node.args = []
node.keywords = []
self.generic_visit(node)
return node
pruner = _ArgumentPruner()
modified_tree = pruner.visit(tree)
ast.fix_missing_locations(modified_tree)
if sys.version_info >= (3, 9):
new_expression = ast.unparse(modified_tree)
else:
import astor
new_expression = astor.to_source(modified_tree).strip()
else:
new_expression = expression
new_context = merge_dicts(func_context, context)
return new_expression, new_context
def __init__(
self,
expression: str,
context: tp.KwargsLike = None,
prepare_once: bool = True,
**resolve_task_kwargs,
) -> None:
self._expression, self._context = self.resolve_expression_and_context(
expression,
context=context,
prepare_once=prepare_once,
**resolve_task_kwargs,
)
@property
def expression(self) -> str:
"""Expression."""
return self._expression
@property
def context(self) -> tp.Kwargs:
"""Context."""
return self._context
def run(self, d: tp.Any) -> tp.Any:
"""Run the pipeline on a data item."""
context = merge_dicts({"d": d, "x": d}, self.context)
return evaluate(self.expression, context=context)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base asset function classes."""
import attr
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks, search_
from vectorbtpro.utils.attr_ import MISSING
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.config import merge_dicts, flat_merge_dicts
from vectorbtpro.utils.config import reorder_dict, reorder_list
from vectorbtpro.utils.execution import NoResult
from vectorbtpro.utils.formatting import dump
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.template import CustomTemplate, RepEval, RepFunc, substitute_templates
__all__ = [
"AssetFunc",
]
class AssetFunc(Base):
"""Abstract class representing an asset function."""
_short_name: tp.ClassVar[tp.Optional[str]] = None
"""Short name of the function to be used in expressions."""
_wrap: tp.ClassVar[tp.Optional[str]] = None
"""Whether the results are meant to be wrapped with `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset`."""
@classmethod
def prepare(cls, *args, **kwargs) -> tp.ArgsKwargs:
"""Prepare positional and keyword arguments."""
return args, kwargs
@classmethod
def call(cls, d: tp.Any, *args, **kwargs) -> tp.Any:
"""Call the function."""
raise NotImplementedError
@classmethod
def prepare_and_call(cls, d: tp.Any, *args, **kwargs):
"""Prepare arguments and call the function."""
args, kwargs = cls.prepare(*args, **kwargs)
return cls.call(d, *args, **kwargs)
class GetAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.get`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "get"
_wrap: tp.ClassVar[tp.Optional[str]] = False
@classmethod
def prepare(
cls,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
keep_path: tp.Optional[bool] = None,
skip_missing: tp.Optional[bool] = None,
source: tp.Optional[tp.CustomTemplateLike] = None,
template_context: tp.KwargsLike = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
keep_path = asset_cls.resolve_setting(keep_path, "keep_path")
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True)
template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context)
if path is not None:
if isinstance(path, list):
path = [search_.resolve_pathlike_key(p) for p in path]
else:
path = search_.resolve_pathlike_key(path)
if source is not None:
if isinstance(source, str):
source = RepEval(source)
elif checks.is_function(source):
if checks.is_builtin_func(source):
source = RepFunc(lambda _source=source: _source)
else:
source = RepFunc(source)
elif not isinstance(source, CustomTemplate):
raise TypeError(f"Source must be a string, function, or template, not {type(source)}")
return (), {
**dict(
path=path,
keep_path=keep_path,
skip_missing=skip_missing,
source=source,
template_context=template_context,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
keep_path: bool = False,
skip_missing: bool = False,
source: tp.Optional[tp.CustomTemplate] = None,
template_context: tp.KwargsLike = None,
) -> tp.Any:
x = d
if path is not None:
if isinstance(path, list):
xs = []
for p in path:
try:
xs.append(search_.get_pathlike_key(x, p, keep_path=True))
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
if len(xs) == 0:
return NoResult
x = merge_dicts(*xs)
else:
try:
x = search_.get_pathlike_key(x, path, keep_path=keep_path)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
return NoResult
if source is not None:
_template_context = flat_merge_dicts(
{
"d": d,
"x": x,
**(x if isinstance(x, dict) else {}),
},
template_context,
)
new_d = source.substitute(_template_context, eval_id="source")
if checks.is_function(new_d):
new_d = new_d(x)
else:
new_d = x
return new_d
class SetAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.set`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "set"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
value: tp.Any,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True)
template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context)
if checks.is_function(value):
if checks.is_builtin_func(value):
value = RepFunc(lambda _value=value: _value)
else:
value = RepFunc(value)
if path is not None:
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
else:
paths = [None]
return (), {
**dict(
value=value,
paths=paths,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
template_context=template_context,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
value: tp.Any,
paths: tp.List[tp.PathLikeKey],
skip_missing: bool = False,
make_copy: bool = True,
changed_only: bool = False,
template_context: tp.KwargsLike = None,
) -> tp.Any:
prev_keys = []
for p in paths:
x = d
if p is not None:
try:
x = search_.get_pathlike_key(x, p[:-1])
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
_template_context = flat_merge_dicts(
{
"d": d,
"x": x,
**(x if isinstance(x, dict) else {}),
},
template_context,
)
v = value.substitute(_template_context, eval_id="value")
if checks.is_function(v):
v = v(x)
d = search_.set_pathlike_key(d, p, v, make_copy=make_copy, prev_keys=prev_keys)
if not changed_only or len(prev_keys) > 0:
return d
return NoResult
class RemoveAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.remove`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "remove"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
return (), {
**dict(
paths=paths,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
paths: tp.List[tp.PathLikeKey],
skip_missing: bool = False,
make_copy: bool = True,
changed_only: bool = False,
) -> tp.Any:
prev_keys = []
for p in paths:
try:
d = search_.remove_pathlike_key(d, p, make_copy=make_copy, prev_keys=prev_keys)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
if not changed_only or len(prev_keys) > 0:
return d
return NoResult
class MoveAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.move`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "move"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
path: tp.Union[tp.PathMoveDict, tp.MaybeList[tp.PathLikeKey]],
new_path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
if new_path is None:
checks.assert_instance_of(path, dict, arg_name="path")
new_path = list(path.values())
path = list(path.keys())
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
if isinstance(new_path, list):
new_paths = [search_.resolve_pathlike_key(p) for p in new_path]
else:
new_paths = [search_.resolve_pathlike_key(new_path)]
if len(paths) != len(new_paths):
raise ValueError("Number of new paths must match number of paths")
return (), {
**dict(
paths=paths,
new_paths=new_paths,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
paths: tp.List[tp.PathLikeKey],
new_paths: tp.List[tp.PathLikeKey],
skip_missing: bool = False,
make_copy: bool = True,
changed_only: bool = False,
) -> tp.Any:
prev_keys = []
for i, p in enumerate(paths):
try:
x = search_.get_pathlike_key(d, p)
d = search_.remove_pathlike_key(d, p, make_copy=make_copy, prev_keys=prev_keys)
d = search_.set_pathlike_key(d, new_paths[i], x, make_copy=make_copy, prev_keys=prev_keys)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
if not changed_only or len(prev_keys) > 0:
return d
return NoResult
class RenameAssetFunc(MoveAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.rename`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "rename"
@classmethod
def prepare(
cls,
path: tp.Union[tp.PathRenameDict, tp.MaybeList[tp.PathLikeKey]],
new_token: tp.Optional[tp.MaybeList[tp.PathKeyToken]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
if new_token is None:
checks.assert_instance_of(path, dict, arg_name="path")
new_token = list(path.values())
path = list(path.keys())
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
if isinstance(new_token, list):
new_tokens = [search_.resolve_pathlike_key(t) for t in new_token]
else:
new_tokens = [search_.resolve_pathlike_key(new_token)]
if len(paths) != len(new_tokens):
raise ValueError("Number of new tokens must match number of paths")
new_paths = []
for i in range(len(paths)):
if len(new_tokens[i]) != 1:
raise ValueError("Exactly one token must be provided for each path")
new_paths.append(paths[i][:-1] + new_tokens[i])
return (), {
**dict(
paths=paths,
new_paths=new_paths,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
),
**kwargs,
}
class ReorderAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.reorder`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "reorder"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
new_order: tp.Union[str, tp.PathKeyTokens],
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True)
template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context)
if isinstance(new_order, str):
if new_order.lower() in ("asc", "ascending"):
new_order = lambda x: (
sorted(x)
if isinstance(x, dict)
else sorted(
range(len(x)),
key=x.__getitem__,
)
)
elif new_order.lower() in ("desc", "descending"):
new_order = lambda x: (
sorted(x)
if isinstance(x, dict)
else sorted(
range(len(x)),
key=x.__getitem__,
reverse=True,
)
)
if isinstance(new_order, str):
new_order = RepEval(new_order)
elif checks.is_function(new_order):
if checks.is_builtin_func(new_order):
new_order = RepFunc(lambda _new_order=new_order: _new_order)
else:
new_order = RepFunc(new_order)
if path is not None:
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
else:
paths = [None]
return (), {
**dict(
new_order=new_order,
paths=paths,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
template_context=template_context,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
new_order: tp.Union[tp.PathKeyTokens, tp.CustomTemplate],
paths: tp.List[tp.PathLikeKey],
skip_missing: bool = False,
make_copy: bool = True,
changed_only: bool = False,
template_context: tp.KwargsLike = None,
) -> tp.Any:
prev_keys = []
for p in paths:
x = d
if p is not None:
try:
x = search_.get_pathlike_key(x, p)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
if isinstance(new_order, CustomTemplate):
_template_context = flat_merge_dicts(
{
"d": d,
"x": x,
**(x if isinstance(x, dict) else {}),
},
template_context,
)
_new_order = new_order.substitute(_template_context, eval_id="new_order")
if checks.is_function(_new_order):
_new_order = _new_order(x)
else:
_new_order = new_order
if isinstance(x, dict):
x = reorder_dict(x, _new_order, skip_missing=skip_missing)
else:
if checks.is_namedtuple(x):
x = type(x)(*reorder_list(x, _new_order, skip_missing=skip_missing))
else:
x = type(x)(reorder_list(x, _new_order, skip_missing=skip_missing))
d = search_.set_pathlike_key(d, p, x, make_copy=make_copy, prev_keys=prev_keys)
if not changed_only or len(prev_keys) > 0:
return d
return NoResult
class QueryAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.query`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "query"
_wrap: tp.ClassVar[tp.Optional[str]] = False
@classmethod
def prepare(
cls,
expression: tp.CustomTemplateLike,
template_context: tp.KwargsLike = None,
return_type: tp.Optional[str] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True)
template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context)
return_type = asset_cls.resolve_setting(return_type, "return_type")
if isinstance(expression, str):
expression = RepEval(expression)
elif checks.is_function(expression):
if checks.is_builtin_func(expression):
expression = RepFunc(lambda _expression=expression: _expression)
else:
expression = RepFunc(expression)
elif not isinstance(expression, CustomTemplate):
raise TypeError(f"Expression must be a string, function, or template, not {type(expression)}")
return (), {
**dict(
expression=expression,
template_context=template_context,
return_type=return_type,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
expression: tp.CustomTemplate,
template_context: tp.KwargsLike = None,
return_type: str = "item",
) -> tp.Any:
_template_context = flat_merge_dicts(
{
"d": d,
"x": d,
**search_.search_config,
**(d if isinstance(d, dict) else {}),
},
template_context,
)
new_d = expression.substitute(_template_context, eval_id="expression")
if checks.is_function(new_d):
new_d = new_d(d)
if return_type.lower() == "item":
as_filter = True
elif return_type.lower() == "bool":
as_filter = False
else:
raise ValueError(f"Invalid return type: '{return_type}'")
if as_filter and isinstance(new_d, bool):
if new_d:
return d
return NoResult
return new_d
class FindAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.find`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "find"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
target: tp.MaybeList[tp.Any],
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
per_path: tp.Optional[bool] = None,
find_all: tp.Optional[bool] = None,
keep_path: tp.Optional[bool] = None,
skip_missing: tp.Optional[bool] = None,
source: tp.Optional[tp.CustomTemplateLike] = None,
in_dumps: tp.Optional[bool] = None,
dump_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
return_type: tp.Optional[str] = None,
return_path: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
per_path = asset_cls.resolve_setting(per_path, "per_path")
find_all = asset_cls.resolve_setting(find_all, "find_all")
keep_path = asset_cls.resolve_setting(keep_path, "keep_path")
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
in_dumps = asset_cls.resolve_setting(in_dumps, "in_dumps")
dump_kwargs = asset_cls.resolve_setting(dump_kwargs, "dump_kwargs", merge=True)
template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True)
template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context)
return_type = asset_cls.resolve_setting(return_type, "return_type")
return_path = asset_cls.resolve_setting(return_path, "return_path")
if path is not None:
if isinstance(path, list):
path = [search_.resolve_pathlike_key(p) for p in path]
else:
path = search_.resolve_pathlike_key(path)
if per_path:
if not isinstance(target, list):
target = [target]
if isinstance(path, list):
target *= len(path)
if not isinstance(path, list):
path = [path]
if isinstance(target, list):
path *= len(target)
if len(target) != len(path):
raise ValueError("Number of targets must match number of paths")
if source is not None:
if isinstance(source, str):
source = RepEval(source)
elif checks.is_function(source):
if checks.is_builtin_func(source):
source = RepFunc(lambda _source=source: _source)
else:
source = RepFunc(source)
elif not isinstance(source, CustomTemplate):
raise TypeError(f"Source must be a string, function, or template, not {type(source)}")
dump_kwargs = DumpAssetFunc.resolve_dump_kwargs(**dump_kwargs)
contains_arg_names = set(get_func_arg_names(search_.contains_in_obj))
search_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in contains_arg_names}
if "excl_types" not in search_kwargs:
search_kwargs["excl_types"] = (tuple, set, frozenset)
return (), {
**dict(
target=target,
path=path,
per_path=per_path,
find_all=find_all,
keep_path=keep_path,
skip_missing=skip_missing,
source=source,
in_dumps=in_dumps,
dump_kwargs=dump_kwargs,
search_kwargs=search_kwargs,
template_context=template_context,
return_type=return_type,
return_path=return_path,
),
**kwargs,
}
@classmethod
def match_func(
cls,
k: tp.Optional[tp.Hashable],
d: tp.Any,
target: tp.MaybeList[tp.Any],
find_all: bool = False,
**kwargs,
) -> bool:
"""Match function for `FindAssetFunc.call`.
Uses `vectorbtpro.utils.search_.find` with `return_type="bool"` for text,
and equality checks for other types.
Target can be a function taking the value and returning a boolean. Target can also be an
instance of `vectorbtpro.utils.search_.Not` for negation."""
if not isinstance(target, list):
targets = [target]
else:
targets = target
for target in targets:
if isinstance(target, search_.Not):
target = target.value
negation = True
else:
negation = False
if checks.is_function(target):
if target(d):
if (negation and find_all) or (not negation and not find_all):
return not negation
continue
elif d is target:
if (negation and find_all) or (not negation and not find_all):
return not negation
continue
elif d is None and target is None:
if (negation and find_all) or (not negation and not find_all):
return not negation
continue
elif checks.is_bool(d) and checks.is_bool(target):
if d == target:
if (negation and find_all) or (not negation and not find_all):
return not negation
continue
elif checks.is_number(d) and checks.is_number(target):
if d == target:
if (negation and find_all) or (not negation and not find_all):
return not negation
continue
elif isinstance(d, str) and isinstance(target, str):
if search_.find(target, d, return_type="bool", **kwargs):
if (negation and find_all) or (not negation and not find_all):
return not negation
continue
elif type(d) is type(target):
try:
if d == target:
if (negation and find_all) or (not negation and not find_all):
return not negation
continue
except Exception:
pass
if (negation and not find_all) or (not negation and find_all):
return negation
if find_all:
return True
return False
@classmethod
def call(
cls,
d: tp.Any,
target: tp.MaybeList[tp.Any],
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
per_path: bool = True,
find_all: bool = False,
keep_path: bool = False,
skip_missing: bool = False,
source: tp.Optional[tp.CustomTemplate] = None,
in_dumps: bool = False,
dump_kwargs: tp.KwargsLike = None,
search_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
return_type: str = "item",
return_path: bool = False,
**kwargs,
) -> tp.Any:
if dump_kwargs is None:
dump_kwargs = {}
if search_kwargs is None:
search_kwargs = {}
if per_path:
new_path_dct = {}
new_list = []
for i, p in enumerate(path):
x = d
try:
x = search_.get_pathlike_key(x, p, keep_path=keep_path)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
if source is not None:
_template_context = flat_merge_dicts(
{
"d": d,
"x": x,
**(x if isinstance(x, dict) else {}),
},
template_context,
)
_x = source.substitute(_template_context, eval_id="source")
if checks.is_function(_x):
x = _x(x)
else:
x = _x
if not isinstance(x, str) and in_dumps:
x = dump(x, **dump_kwargs)
t = target[i]
if return_type.lower() in ("item", "bool"):
if isinstance(t, search_.Not):
t = t.value
negation = True
else:
negation = False
if search_.contains_in_obj(
x,
cls.match_func,
target=t,
find_all=find_all,
**search_kwargs,
**kwargs,
):
if negation:
if find_all:
return NoResult if return_type.lower() == "item" else False
continue
else:
if not find_all:
return d if return_type.lower() == "item" else True
continue
else:
if negation:
if not find_all:
return d if return_type.lower() == "item" else True
continue
else:
if find_all:
return NoResult if return_type.lower() == "item" else False
continue
else:
path_dct = search_.find_in_obj(
x,
cls.match_func,
target=t,
find_all=find_all,
**search_kwargs,
**kwargs,
)
if len(path_dct) == 0:
if find_all:
return {} if return_path else []
continue
if isinstance(t, search_.Not):
raise TypeError("Target cannot be negated here")
if not isinstance(t, str):
raise ValueError("Target must be string")
for k, v in path_dct.items():
if not isinstance(v, str):
raise ValueError("Matched value must be string")
_return_type = "bool" if return_type.lower() == "field" else return_type
matches = search_.find(t, v, return_type=_return_type, **kwargs)
if return_path:
if k not in new_path_dct:
new_path_dct[k] = []
if return_type.lower() == "field":
if matches:
new_path_dct[k].append(v)
else:
new_path_dct[k].extend(matches)
else:
if return_type.lower() == "field":
if matches:
new_list.append(v)
else:
new_list.extend(matches)
if return_type.lower() in ("item", "bool"):
if find_all:
return d if return_type.lower() == "item" else True
return NoResult if return_type.lower() == "item" else False
else:
if return_path:
return new_path_dct
return new_list
else:
x = d
if path is not None:
if isinstance(path, list):
xs = []
for p in path:
try:
xs.append(search_.get_pathlike_key(x, p, keep_path=True))
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
if len(xs) == 0:
if return_type.lower() == "item":
return NoResult
if return_type.lower() == "bool":
return False
return {} if return_path else []
x = merge_dicts(*xs)
else:
try:
x = search_.get_pathlike_key(x, path, keep_path=keep_path)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
if return_type.lower() == "item":
return NoResult
if return_type.lower() == "bool":
return False
return {} if return_path else []
if source is not None:
_template_context = flat_merge_dicts(
{
"d": d,
"x": x,
**(x if isinstance(x, dict) else {}),
},
template_context,
)
_x = source.substitute(_template_context, eval_id="source")
if checks.is_function(_x):
x = _x(x)
else:
x = _x
if not isinstance(x, str) and in_dumps:
x = dump(x, **dump_kwargs)
if return_type.lower() == "item":
if search_.contains_in_obj(
x,
cls.match_func,
target=target,
find_all=find_all,
**search_kwargs,
**kwargs,
):
return d
return NoResult
elif return_type.lower() == "bool":
return search_.contains_in_obj(
x,
cls.match_func,
target=target,
find_all=find_all,
**search_kwargs,
**kwargs,
)
else:
path_dct = search_.find_in_obj(
x,
cls.match_func,
target=target,
find_all=find_all,
**search_kwargs,
**kwargs,
)
if len(path_dct) == 0:
return {} if return_path else []
if not isinstance(target, list):
targets = [target]
else:
targets = target
new_path_dct = {}
new_list = []
for target in targets:
if isinstance(target, search_.Not):
raise TypeError("Target cannot be negated here")
if not isinstance(target, str):
raise ValueError("Target must be string")
for k, v in path_dct.items():
if not isinstance(v, str):
raise ValueError("Matched value must be string")
_return_type = "bool" if return_type.lower() == "field" else return_type
matches = search_.find(target, v, return_type=_return_type, **kwargs)
if return_path:
if k not in new_path_dct:
new_path_dct[k] = []
if return_type.lower() == "field":
if matches:
new_path_dct[k].append(v)
else:
new_path_dct[k].extend(matches)
else:
if return_type.lower() == "field":
if matches:
new_list.append(v)
else:
new_list.extend(matches)
if return_path:
return new_path_dct
return new_list
class FindReplaceAssetFunc(FindAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.find_replace`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "find_replace"
@classmethod
def prepare(
cls,
target: tp.Union[dict, tp.MaybeList[tp.Any]],
replacement: tp.Optional[tp.MaybeList[tp.Any]] = None,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
per_path: tp.Optional[bool] = None,
find_all: tp.Optional[bool] = None,
keep_path: tp.Optional[bool] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
per_path = asset_cls.resolve_setting(per_path, "per_path")
find_all = asset_cls.resolve_setting(find_all, "find_all")
keep_path = asset_cls.resolve_setting(keep_path, "keep_path")
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
if replacement is None:
checks.assert_instance_of(target, dict, arg_name="path")
replacement = list(target.values())
target = list(target.keys())
if path is not None:
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
if isinstance(target, list):
paths *= len(target)
elif isinstance(replacement, list):
paths *= len(replacement)
else:
paths = [None]
if isinstance(target, list):
paths *= len(target)
elif isinstance(replacement, list):
paths *= len(replacement)
if per_path:
if not isinstance(target, list):
target = [target] * len(paths)
if not isinstance(replacement, list):
replacement = [replacement] * len(paths)
if len(target) != len(replacement) != len(paths):
raise ValueError("Number of targets and replacements must match number of paths")
find_arg_names = set(get_func_arg_names(search_.find_in_obj))
find_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in find_arg_names}
if "excl_types" not in find_kwargs:
find_kwargs["excl_types"] = (tuple, set, frozenset)
return (), {
**dict(
target=target,
replacement=replacement,
paths=paths,
per_path=per_path,
find_all=find_all,
keep_path=keep_path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
find_kwargs=find_kwargs,
),
**kwargs,
}
@classmethod
def replace_func(
cls,
k: tp.Optional[tp.Hashable],
d: tp.Any,
target: tp.MaybeList[tp.Any],
replacement: tp.MaybeList[tp.Any],
**kwargs,
) -> tp.Any:
"""Replace function for `FindReplaceAssetFunc.call`.
Uses `vectorbtpro.utils.search_.replace` for text and returns replacement for other types.
Target can be a function taking the value and returning a boolean.
Replacement can also be a function taking the value and returning a new value."""
if not isinstance(target, list):
targets = [target]
else:
targets = target
if not isinstance(replacement, list):
replacements = [replacement]
if len(targets) > 1 and len(replacements) == 1:
replacements *= len(targets)
else:
replacements = replacement
if len(targets) != len(replacements):
raise ValueError("Number of targets must match number of replacements")
for i, target in enumerate(targets):
if isinstance(target, search_.Not):
raise TypeError("Target cannot be negated here")
replacement = replacements[i]
if checks.is_function(replacement):
replacement = replacement(d)
if checks.is_function(target):
if target(d):
return replacement
elif d is target:
return replacement
elif d is None and target is None:
return replacement
elif checks.is_bool(d) and checks.is_bool(target):
if d == target:
return replacement
elif checks.is_number(d) and checks.is_number(target):
if d == target:
return replacement
elif isinstance(d, str) and isinstance(target, str):
d = search_.replace(target, replacement, d, **kwargs)
elif type(d) is type(target):
try:
if d == target:
return replacement
except Exception:
pass
return d
@classmethod
def call(
cls,
d: tp.Any,
target: tp.MaybeList[tp.Any],
replacement: tp.MaybeList[tp.Any],
paths: tp.List[tp.PathLikeKey],
per_path: bool = True,
find_all: bool = False,
keep_path: bool = False,
skip_missing: bool = False,
make_copy: bool = True,
changed_only: bool = False,
find_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Any:
if find_kwargs is None:
find_kwargs = {}
prev_keys = []
found_all = True
if find_all:
for i, p in enumerate(paths):
x = d
if p is not None:
try:
x = search_.get_pathlike_key(x, p, keep_path=keep_path)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
path_dct = search_.find_in_obj(
x,
cls.match_func,
target=target[i] if per_path else target,
find_all=find_all,
**find_kwargs,
**kwargs,
)
if len(path_dct) == 0:
found_all = False
break
if found_all:
for i, p in enumerate(paths):
x = d
if p is not None:
try:
x = search_.get_pathlike_key(x, p, keep_path=keep_path)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
path_dct = search_.find_in_obj(
x,
cls.match_func,
target=target[i] if per_path else target,
find_all=find_all,
**find_kwargs,
**kwargs,
)
for k, v in path_dct.items():
if p is not None and not keep_path:
new_p = search_.combine_pathlike_keys(p, k, minimize=True)
else:
new_p = k
v = cls.replace_func(
k,
v,
target[i] if per_path else target,
replacement[i] if per_path else replacement,
**kwargs,
)
d = search_.set_pathlike_key(d, new_p, v, make_copy=make_copy, prev_keys=prev_keys)
if not changed_only or len(prev_keys) > 0:
return d
return NoResult
class FindRemoveAssetFunc(FindAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.find_remove`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "find_remove"
@classmethod
def prepare(
cls,
target: tp.Union[dict, tp.MaybeList[tp.Any]],
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
per_path: tp.Optional[bool] = None,
find_all: tp.Optional[bool] = None,
keep_path: tp.Optional[bool] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
per_path = asset_cls.resolve_setting(per_path, "per_path")
find_all = asset_cls.resolve_setting(find_all, "find_all")
keep_path = asset_cls.resolve_setting(keep_path, "keep_path")
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
if path is not None:
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
if isinstance(target, list):
paths *= len(target)
else:
paths = [None]
if isinstance(target, list):
paths *= len(target)
if per_path:
if not isinstance(target, list):
target = [target] * len(paths)
if len(target) != len(paths):
raise ValueError("Number of targets must match number of paths")
find_arg_names = set(get_func_arg_names(search_.find_in_obj))
find_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in find_arg_names}
if "excl_types" not in find_kwargs:
find_kwargs["excl_types"] = (tuple, set, frozenset)
return (), {
**dict(
target=target,
paths=paths,
per_path=per_path,
find_all=find_all,
keep_path=keep_path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
find_kwargs=find_kwargs,
),
**kwargs,
}
@classmethod
def is_empty_func(cls, d: tp.Any) -> bool:
"""Return whether object is empty."""
if d is None:
return True
if checks.is_collection(d) and len(d) == 0:
return True
return False
@classmethod
def call(
cls,
d: tp.Any,
target: tp.MaybeList[tp.Any],
paths: tp.List[tp.PathLikeKey],
per_path: bool = True,
find_all: bool = False,
keep_path: bool = False,
skip_missing: bool = False,
make_copy: bool = True,
changed_only: bool = False,
find_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.Any:
if find_kwargs is None:
find_kwargs = {}
prev_keys = []
new_p_v_map = {}
for i, p in enumerate(paths):
x = d
if p is not None:
try:
x = search_.get_pathlike_key(x, p, keep_path=keep_path)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
path_dct = search_.find_in_obj(
x,
cls.match_func,
target=target[i] if per_path else target,
find_all=find_all,
**find_kwargs,
**kwargs,
)
if len(path_dct) == 0 and find_all:
new_p_v_map = {}
break
for k, v in path_dct.items():
if p is not None and not keep_path:
new_p = search_.combine_pathlike_keys(p, k, minimize=True)
else:
new_p = k
new_p_v_map[new_p] = v
for new_p, v in new_p_v_map.items():
d = search_.remove_pathlike_key(d, new_p, make_copy=make_copy, prev_keys=prev_keys)
if not changed_only or len(prev_keys) > 0:
return d
return NoResult
class FlattenAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.flatten`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "flatten"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
if path is not None:
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
else:
paths = [None]
if "excl_types" not in kwargs:
kwargs["excl_types"] = (tuple, set, frozenset)
return (), {
**dict(
paths=paths,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
paths: tp.List[tp.PathLikeKey],
skip_missing: bool = False,
make_copy: bool = True,
changed_only: bool = False,
**kwargs,
) -> tp.Any:
prev_keys = []
for p in paths:
x = d
if p is not None:
try:
x = search_.get_pathlike_key(x, p)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
x = search_.flatten_obj(x, **kwargs)
d = search_.set_pathlike_key(d, p, x, make_copy=make_copy, prev_keys=prev_keys)
if not changed_only or len(prev_keys) > 0:
return d
return NoResult
class UnflattenAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.unflatten`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "unflatten"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
skip_missing = asset_cls.resolve_setting(skip_missing, "skip_missing")
make_copy = asset_cls.resolve_setting(make_copy, "make_copy")
changed_only = asset_cls.resolve_setting(changed_only, "changed_only")
if path is not None:
if isinstance(path, list):
paths = [search_.resolve_pathlike_key(p) for p in path]
else:
paths = [search_.resolve_pathlike_key(path)]
else:
paths = [None]
return (), {
**dict(
paths=paths,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
paths: tp.List[tp.PathLikeKey],
skip_missing: bool = False,
make_copy: bool = True,
changed_only: bool = False,
**kwargs,
) -> tp.Any:
prev_keys = []
for p in paths:
x = d
if p is not None:
try:
x = search_.get_pathlike_key(x, p)
except (KeyError, IndexError, AttributeError) as e:
if not skip_missing:
raise e
continue
x = search_.unflatten_obj(x, **kwargs)
d = search_.set_pathlike_key(d, p, x, make_copy=make_copy, prev_keys=prev_keys)
if not changed_only or len(prev_keys) > 0:
return d
return NoResult
class DumpAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.dump`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "dump"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def resolve_dump_kwargs(
cls,
dump_engine: tp.Optional[str] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.Kwargs:
"""Resolve keyword arguments related to dumping."""
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
dump_engine = asset_cls.resolve_setting(dump_engine, "dump_engine")
kwargs = asset_cls.resolve_setting(kwargs, f"dump_engine_kwargs.{dump_engine}", default={}, merge=True)
return {"dump_engine": dump_engine, **kwargs}
@classmethod
def prepare(
cls,
source: tp.Optional[tp.CustomTemplateLike] = None,
dump_engine: tp.Optional[str] = None,
template_context: tp.KwargsLike = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True)
template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context)
dump_kwargs = cls.resolve_dump_kwargs(dump_engine=dump_engine, **kwargs)
if source is not None:
if isinstance(source, str):
source = RepEval(source)
elif checks.is_function(source):
if checks.is_builtin_func(source):
source = RepFunc(lambda _source=source: _source)
else:
source = RepFunc(source)
elif not isinstance(source, CustomTemplate):
raise TypeError(f"Source must be a string, function, or template, not {type(source)}")
return (), {
**dict(
source=source,
template_context=template_context,
),
**dump_kwargs,
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
source: tp.Optional[CustomTemplate] = None,
dump_engine: str = "nestedtext",
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.Any:
from vectorbtpro.utils.knowledge.chatting import StoreDocument, EmbeddedDocument, ScoredDocument
if source is not None:
_template_context = flat_merge_dicts(
{
"d": d,
"x": d,
**(d if isinstance(d, dict) else {}),
},
template_context,
)
new_d = source.substitute(_template_context, eval_id="source")
if checks.is_function(new_d):
new_d = new_d(d)
else:
new_d = d
if isinstance(new_d, StoreDocument):
return new_d.get_content()
if isinstance(new_d, (EmbeddedDocument, ScoredDocument)):
return new_d.document.get_content()
return dump(new_d, dump_engine=dump_engine, **kwargs)
class ToDocsAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.to_documents`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "to_docs"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
document_cls: tp.Optional[tp.Type[tp.StoreDocument]] = None,
template_context: tp.Union[tp.KwargsLike, tp.CustomTemplate] = None,
**document_kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
document_cls = asset_cls.resolve_setting(document_cls, "document_cls")
if document_cls is None:
from vectorbtpro.utils.knowledge.chatting import TextDocument
document_cls = TextDocument
template_context = asset_cls.resolve_setting(template_context, "template_context", merge=True)
template_context = flat_merge_dicts({"asset_cls": asset_cls}, template_context)
document_kwargs = {}
for k, v in document_cls.fields_dict.items():
if v.default is not MISSING:
if k in document_kwargs or asset_cls.has_setting(k, sub_path="document_kwargs"):
document_kwargs[k] = asset_cls.resolve_setting(
document_kwargs.get(k, None),
k,
sub_path="document_kwargs",
merge=isinstance(v.default, attr.Factory) and v.default.factory is dict,
)
if k == "template_context":
document_kwargs[k] = merge_dicts(template_context, document_kwargs[k])
if k == "dump_kwargs":
document_kwargs[k] = DumpAssetFunc.resolve_dump_kwargs(**document_kwargs[k])
document_kwargs = substitute_templates(
document_kwargs, template_context, eval_id="document_kwargs", strict=False
)
return (), {
**dict(document_cls=document_cls),
**document_kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
document_cls: tp.Optional[tp.Type[tp.StoreDocument]] = None,
template_context: tp.KwargsLike = None,
**document_kwargs,
) -> tp.Any:
if document_cls is None:
from vectorbtpro.utils.knowledge.chatting import TextDocument
document_cls = TextDocument
_template_context = flat_merge_dicts(
{
"d": d,
"x": d,
**(d if isinstance(d, dict) else {}),
},
template_context,
)
document_kwargs = substitute_templates(document_kwargs, _template_context, eval_id="document_kwargs")
return document_cls.from_data(d, template_context=_template_context, **document_kwargs)
class SplitTextAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.split_text`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "split_text"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
text_path: tp.Optional[tp.PathLikeKey] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**split_text_kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
from vectorbtpro.utils.knowledge.chatting import resolve_text_splitter
text_path = asset_cls.resolve_setting(text_path, "text_path", sub_path="document_kwargs")
split_text_kwargs = asset_cls.resolve_setting(
split_text_kwargs, "split_text_kwargs", sub_path="document_kwargs", merge=True
)
text_splitter = split_text_kwargs.pop("text_splitter", None)
text_splitter = resolve_text_splitter(text_splitter=text_splitter)
if isinstance(text_splitter, type):
text_splitter = text_splitter(**split_text_kwargs)
elif split_text_kwargs:
text_splitter = text_splitter.replace(**split_text_kwargs)
return (), {
**dict(
text_path=text_path,
text_splitter=text_splitter,
),
}
@classmethod
def call(
cls,
d: tp.Any,
text_path: tp.Optional[tp.PathLikeKey] = None,
**split_text_kwargs,
) -> tp.Any:
from vectorbtpro.utils.knowledge.chatting import TextDocument
document = TextDocument("", d, text_path=text_path, split_text_kwargs=split_text_kwargs)
return [document_chunk.data for document_chunk in document.split()]
# ############# Reduce classes ############# #
class ReduceAssetFunc(AssetFunc):
"""Abstract asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.reduce`."""
_wrap: tp.ClassVar[tp.Optional[str]] = False
_initializer: tp.ClassVar[tp.Optional[tp.Any]] = None
@classmethod
def call(cls, d1: tp.Any, d2: tp.Any, *args, **kwargs) -> tp.Any:
raise NotImplementedError
@classmethod
def prepare_and_call(cls, d1: tp.Any, d2: tp.Any, *args, **kwargs):
args, kwargs = cls.prepare(*args, **kwargs)
return cls.call(d1, d2, *args, **kwargs)
class CollectAssetFunc(ReduceAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.collect`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "collect"
_initializer: tp.ClassVar[tp.Optional[tp.Any]] = {}
@classmethod
def prepare(
cls,
sort_keys: tp.Optional[bool] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
if asset_cls is None:
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
asset_cls = KnowledgeAsset
sort_keys = asset_cls.resolve_setting(sort_keys, "sort_keys")
return (), {**dict(sort_keys=sort_keys), **kwargs}
@classmethod
def sort_key(cls, k: tp.Any) -> tuple:
"""Function for sorting keys."""
return (0, k) if isinstance(k, str) else (1, k)
@classmethod
def call(cls, d1: tp.Any, d2: tp.Any, sort_keys: bool = False) -> tp.Any:
if isinstance(d1, list):
d1 = {i: v for i, v in enumerate(d1)}
if isinstance(d2, list):
d2 = {i: v for i, v in enumerate(d2)}
if not isinstance(d1, dict) or not isinstance(d2, dict):
raise TypeError(f"Data items must be either dicts or lists, not {type(d1)} and {type(d2)}")
new_d1 = dict(d1)
for k1 in d1:
if k1 not in new_d1:
new_d1[k1] = [d1[k1]]
if k1 in d2:
new_d1[k1].append(d2[k1])
for k2 in d2:
if k2 not in new_d1:
new_d1[k2] = [d2[k2]]
if sort_keys:
return dict(sorted(new_d1.items(), key=lambda x: cls.sort_key(x[0])))
return new_d1
class MergeDictsAssetFunc(ReduceAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.merge_dicts`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "merge_dicts"
_wrap: tp.ClassVar[tp.Optional[str]] = True
_initializer: tp.ClassVar[tp.Optional[tp.Any]] = {}
@classmethod
def call(cls, d1: tp.Any, d2: tp.Any, **kwargs) -> tp.Any:
if not isinstance(d1, dict) or not isinstance(d2, dict):
raise TypeError(f"Data items must be dicts, not {type(d1)} and {type(d2)}")
return merge_dicts(d1, d2, **kwargs)
class MergeListsAssetFunc(ReduceAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.merge_lists`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "merge_lists"
_wrap: tp.ClassVar[tp.Optional[str]] = True
_initializer: tp.ClassVar[tp.Optional[tp.Any]] = []
@classmethod
def call(cls, d1: tp.Any, d2: tp.Any) -> tp.Any:
if not isinstance(d1, list) or not isinstance(d2, list):
raise TypeError(f"Data items must be lists, not {type(d1)} and {type(d2)}")
return d1 + d2
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Base asset classes.
See `vectorbtpro.utils.knowledge` for the toy dataset."""
import hashlib
import json
import re
from collections.abc import MutableSequence
from pathlib import Path
import pandas as pd
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import Configured
from vectorbtpro.utils.config import merge_dicts, flat_merge_dicts
from vectorbtpro.utils.decorators import hybrid_method
from vectorbtpro.utils.execution import Task, execute, NoResult
from vectorbtpro.utils.knowledge.chatting import RankContextable
from vectorbtpro.utils.module_ import get_caller_qualname
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.path_ import dir_tree_from_paths, remove_dir, check_mkdir
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.pickling import decompress, dumps, load_bytes, save, load
from vectorbtpro.utils.search_ import flatten_obj, unflatten_obj
from vectorbtpro.utils.template import CustomTemplate, RepEval, RepFunc
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"AssetCacheManager",
"KnowledgeAsset",
]
asset_cache: tp.Dict[tp.Hashable, "KnowledgeAsset"] = {}
"""Asset cache."""
class AssetCacheManager(Configured):
"""Class for managing knowledge asset cache.
For defaults, see `vectorbtpro._settings.knowledge`."""
_settings_path: tp.SettingsPath = "knowledge"
_specializable: tp.ClassVar[bool] = False
_extendable: tp.ClassVar[bool] = False
def __init__(
self,
persist_cache: tp.Optional[bool] = None,
cache_dir: tp.Optional[tp.PathLike] = None,
cache_mkdir_kwargs: tp.KwargsLike = None,
clear_cache: tp.Optional[bool] = None,
max_cache_count: tp.Optional[int] = None,
save_cache_kwargs: tp.KwargsLike = None,
load_cache_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Configured.__init__(
self,
persist_cache=persist_cache,
cache_dir=cache_dir,
cache_mkdir_kwargs=cache_mkdir_kwargs,
clear_cache=clear_cache,
max_cache_count=max_cache_count,
save_cache_kwargs=save_cache_kwargs,
load_cache_kwargs=load_cache_kwargs,
template_context=template_context,
**kwargs,
)
persist_cache = self.resolve_setting(persist_cache, "cache")
cache_dir = self.resolve_setting(cache_dir, "asset_cache_dir")
cache_mkdir_kwargs = self.resolve_setting(cache_mkdir_kwargs, "cache_mkdir_kwargs", merge=True)
clear_cache = self.resolve_setting(clear_cache, "clear_cache")
max_cache_count = self.resolve_setting(max_cache_count, "max_cache_count")
save_cache_kwargs = self.resolve_setting(save_cache_kwargs, "save_cache_kwargs", merge=True)
load_cache_kwargs = self.resolve_setting(load_cache_kwargs, "load_cache_kwargs", merge=True)
template_context = self.resolve_setting(template_context, "template_context", merge=True)
if isinstance(cache_dir, CustomTemplate):
asset_cache_dir = cache_dir
cache_dir = self.get_setting("cache_dir")
if isinstance(cache_dir, CustomTemplate):
cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir")
template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context)
asset_cache_dir = asset_cache_dir.substitute(template_context, eval_id="asset_cache_dir")
cache_dir = asset_cache_dir
cache_dir = Path(cache_dir)
if cache_dir.exists():
if clear_cache:
remove_dir(cache_dir, missing_ok=True, with_contents=True)
check_mkdir(cache_dir, **cache_mkdir_kwargs)
self._persist_cache = persist_cache
self._cache_dir = cache_dir
self._max_cache_count = max_cache_count
self._save_cache_kwargs = save_cache_kwargs
self._load_cache_kwargs = load_cache_kwargs
self._template_context = template_context
@property
def persist_cache(self) -> bool:
"""Whether to persist cache on disk."""
return self._persist_cache
@property
def cache_dir(self) -> tp.Path:
"""Cache directory."""
return self._cache_dir
@property
def max_cache_count(self) -> tp.Optional[int]:
"""Maximum number of assets to be cached.
Keeps only the most recent assets."""
return self._max_cache_count
@property
def save_cache_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.pickling.save`."""
return self._save_cache_kwargs
@property
def load_cache_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.pickling.load`."""
return self._load_cache_kwargs
@classmethod
def generate_cache_key(cls, **kwargs) -> str:
"""Generate a cache key based on the current VBT version, settings, and keyword arguments."""
from vectorbtpro._version import __version__
bytes_ = b""
bytes_ += dumps(kwargs)
bytes_ += dumps(cls.get_settings())
bytes_ += dumps(__version__)
return hashlib.md5(bytes_).hexdigest()
def load_asset(self, cache_key: str) -> tp.Optional[tp.MaybeKnowledgeAsset]:
"""Load the knowledge asset under a cache key."""
if cache_key in asset_cache:
return asset_cache[cache_key]
asset_cache_file = self.cache_dir / cache_key
if asset_cache_file.exists():
return load(asset_cache_file, **self.load_cache_kwargs)
def cleanup_cache_dir(self) -> None:
"""Keep only the most recent assets."""
if not self.max_cache_count:
return
files = [f for f in self.cache_dir.iterdir() if f.is_file()]
if len(files) <= self.max_cache_count:
return
files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
files_to_delete = files[self.max_cache_count:]
for file_path in files_to_delete:
file_path.unlink(missing_ok=True)
def save_asset(self, asset: tp.MaybeKnowledgeAsset, cache_key: str) -> tp.Optional[tp.Path]:
"""Save a knowledge asset under a cache key."""
asset_cache[cache_key] = asset
if self.persist_cache:
asset_cache_file = self.cache_dir / cache_key
path = save(asset, path=asset_cache_file, **self.save_cache_kwargs)
self.cleanup_cache_dir()
return path
KnowledgeAssetT = tp.TypeVar("KnowledgeAssetT", bound="KnowledgeAsset")
class MetaKnowledgeAsset(type(Configured), type(MutableSequence)):
"""Metaclass for `KnowledgeAsset`."""
pass
class KnowledgeAsset(RankContextable, Configured, MutableSequence, metaclass=MetaKnowledgeAsset):
"""Class for working with a knowledge asset.
This class behaves like a mutable sequence.
For defaults, see `vectorbtpro._settings.knowledge`."""
_settings_path: tp.SettingsPath = "knowledge"
@hybrid_method
def combine(
cls_or_self: tp.MaybeType[KnowledgeAssetT],
*objs: tp.MaybeTuple[KnowledgeAssetT],
**kwargs,
) -> KnowledgeAssetT:
"""Combine multiple `KnowledgeAsset` instances into one.
Usage:
```pycon
>>> asset1 = asset[[0, 1]]
>>> asset2 = asset[[2, 3]]
>>> asset1.combine(asset2).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}]
```
"""
if not isinstance(cls_or_self, type) and len(objs) == 0:
if isinstance(cls_or_self[0], list):
return cls_or_self.merge_lists(**kwargs)
if isinstance(cls_or_self[0], dict):
return cls_or_self.merge_dicts(**kwargs)
raise ValueError("Cannot determine type of data items. Use merge_lists or merge_dicts.")
elif not isinstance(cls_or_self, type) and len(objs) > 0:
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, KnowledgeAsset):
raise TypeError("Each object to be combined must be an instance of KnowledgeAsset")
new_data = []
new_single_item = True
for obj in objs:
new_data.extend(obj.data)
if not obj.single_item:
new_single_item = False
kwargs = cls_or_self.resolve_merge_kwargs(
*[obj.config for obj in objs],
single_item=new_single_item,
data=new_data,
**kwargs,
)
return cls(**kwargs)
@hybrid_method
def merge(
cls_or_self: tp.MaybeType[KnowledgeAssetT],
*objs: tp.MaybeTuple[KnowledgeAssetT],
flatten_kwargs: tp.KwargsLike = None,
**kwargs,
) -> KnowledgeAssetT:
"""Either merge multiple `KnowledgeAsset` instances into one if called as a class method or instance
method with at least one additional object, or merge data items of a single instance if called
as an instance method with no additional objects.
Usage:
```pycon
>>> asset1 = asset.select(["s"])
>>> asset2 = asset.select(["b", "d2"])
>>> asset1.merge(asset2).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}}]
```
"""
if not isinstance(cls_or_self, type) and len(objs) == 0:
if isinstance(cls_or_self[0], list):
return cls_or_self.merge_lists(**kwargs)
if isinstance(cls_or_self[0], dict):
return cls_or_self.merge_dicts(**kwargs)
raise ValueError("Cannot determine type of data items. Use merge_lists or merge_dicts.")
elif not isinstance(cls_or_self, type) and len(objs) > 0:
objs = (cls_or_self, *objs)
cls = type(cls_or_self)
else:
cls = cls_or_self
if len(objs) == 1:
objs = objs[0]
objs = list(objs)
for obj in objs:
if not checks.is_instance_of(obj, KnowledgeAsset):
raise TypeError("Each object to be merged must be an instance of KnowledgeAsset")
if flatten_kwargs is None:
flatten_kwargs = {}
if "annotate_all" not in flatten_kwargs:
flatten_kwargs["annotate_all"] = True
if "excl_types" not in flatten_kwargs:
flatten_kwargs["excl_types"] = (tuple, set, frozenset)
max_items = 1
new_single_item = True
for obj in objs:
obj_data = obj.data
if len(obj_data) > max_items:
max_items = len(obj_data)
if not obj.single_item:
new_single_item = False
flat_data = []
for obj in objs:
obj_data = obj.data
if len(obj_data) == 1:
obj_data = [obj_data] * max_items
flat_obj_data = list(map(lambda x: flatten_obj(x, **flatten_kwargs), obj_data))
flat_data.append(flat_obj_data)
new_data = []
for flat_dcts in zip(*flat_data):
merged_flat_dct = flat_merge_dicts(*flat_dcts)
new_data.append(unflatten_obj(merged_flat_dct))
kwargs = cls_or_self.resolve_merge_kwargs(
*[obj.config for obj in objs],
single_item=new_single_item,
data=new_data,
**kwargs,
)
return cls(**kwargs)
@classmethod
def from_json_file(
cls: tp.Type[KnowledgeAssetT],
path: tp.PathLike,
compression: tp.Union[None, bool, str] = None,
decompress_kwargs: tp.KwargsLike = None,
**kwargs,
) -> KnowledgeAssetT:
"""Build `KnowledgeAsset` from a JSON file."""
bytes_ = load_bytes(path, compression=compression, decompress_kwargs=decompress_kwargs)
json_str = bytes_.decode("utf-8")
return cls(data=json.loads(json_str), **kwargs)
@classmethod
def from_json_bytes(
cls: tp.Type[KnowledgeAssetT],
bytes_: bytes,
compression: tp.Union[None, bool, str] = None,
decompress_kwargs: tp.KwargsLike = None,
**kwargs,
) -> KnowledgeAssetT:
"""Build `KnowledgeAsset` from JSON bytes."""
if decompress_kwargs is None:
decompress_kwargs = {}
bytes_ = decompress(bytes_, compression=compression, **decompress_kwargs)
json_str = bytes_.decode("utf-8")
return cls(data=json.loads(json_str), **kwargs)
def __init__(self, data: tp.Optional[tp.List[tp.Any]] = None, single_item: bool = True, **kwargs) -> None:
if data is None:
data = []
if not isinstance(data, list):
data = [data]
else:
data = list(data)
if len(data) > 1:
single_item = False
Configured.__init__(
self,
data=data,
single_item=single_item,
**kwargs,
)
self._data = data
self._single_item = single_item
@property
def data(self) -> tp.List[tp.Any]:
"""Data."""
return self._data
@property
def single_item(self) -> bool:
"""Whether this instance holds a single item."""
return self._single_item
def modify_data(self, data: tp.List[tp.Any]) -> None:
"""Modify data in place."""
if len(data) > 1:
single_item = False
else:
single_item = self.single_item
self._data = data
self._single_item = single_item
self.update_config(data=data, single_item=single_item)
# ############# Item methods ############# #
def get_items(self, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]]) -> tp.Any:
"""Get one or more data items."""
if checks.is_complex_iterable(index):
if all(checks.is_bool(i) for i in index):
index = list(index)
if len(index) != len(self.data):
raise IndexError("Boolean index must have the same length as data")
return self.replace(data=[item for item, flag in zip(self.data, index) if flag])
if all(checks.is_int(i) for i in index):
return self.replace(data=[self.data[i] for i in index])
raise TypeError("Index must contain all integers or all booleans")
if isinstance(index, slice):
return self.replace(data=self.data[index])
return self.data[index]
def set_items(
self: KnowledgeAssetT,
index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]],
value: tp.Any,
inplace: bool = False,
) -> tp.Optional[KnowledgeAssetT]:
"""Set one or more data items.
Returns a new `KnowledgeAsset` instance if `inplace` is False."""
new_data = list(self.data)
if checks.is_complex_iterable(index):
index = list(index)
if all(checks.is_bool(i) for i in index):
if len(index) != len(new_data):
raise IndexError("Boolean index must have the same length as data")
if checks.is_complex_iterable(value):
value = list(value)
if len(value) == len(index):
for i, (b, v) in enumerate(zip(index, value)):
if b:
new_data[i] = v
else:
num_true = sum(index)
if len(value) != num_true:
raise ValueError(f"Attempting to assign {len(value)} values to {num_true} targets")
it = iter(value)
for i, b in enumerate(index):
if b:
new_data[i] = next(it)
else:
for i, b in enumerate(index):
if b:
new_data[i] = value
elif all(checks.is_int(i) for i in index):
if checks.is_complex_iterable(value):
value = list(value)
if len(value) != len(index):
raise ValueError(f"Attempting to assign {len(value)} values to {len(index)} targets")
for i, v in zip(index, value):
new_data[i] = v
else:
for i in index:
new_data[i] = value
else:
raise TypeError("Index must contain all integers or all booleans")
else:
new_data[index] = value
if inplace:
self.modify_data(new_data)
return None
return self.replace(data=new_data)
def delete_items(
self: KnowledgeAssetT,
index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]],
inplace: bool = False,
) -> tp.Optional[KnowledgeAssetT]:
"""Delete one or more data items.
Returns a new `KnowledgeAsset` instance if `inplace` is False."""
new_data = list(self.data)
if checks.is_complex_iterable(index):
if all(checks.is_bool(i) for i in index):
index = list(index)
if len(index) != len(new_data):
raise IndexError("Boolean index must have the same length as data")
new_data = [item for item, flag in zip(new_data, index) if not flag]
elif all(checks.is_int(i) for i in index):
indices_to_remove = set(index)
max_index = len(new_data) - 1
for i in indices_to_remove:
if not -len(new_data) <= i <= max_index:
raise IndexError(f"Index {i} out of range")
new_data = [item for i, item in enumerate(new_data) if i not in indices_to_remove]
else:
raise TypeError("Index must contain all integers or all booleans")
else:
del new_data[index]
if inplace:
self.modify_data(new_data)
return None
return self.replace(data=new_data)
def append_item(
self: KnowledgeAssetT,
d: tp.Any,
inplace: bool = False,
) -> tp.Optional[KnowledgeAssetT]:
"""Append a new data item.
Returns a new `KnowledgeAsset` instance if `inplace` is False."""
new_data = list(self.data)
new_data.append(d)
if inplace:
self.modify_data(new_data)
return None
return self.replace(data=new_data)
def extend_items(
self: KnowledgeAssetT,
data: tp.Iterable[tp.Any],
inplace: bool = False,
) -> tp.Optional[KnowledgeAssetT]:
"""Extend by new data items.
Returns a new `KnowledgeAsset` instance if `inplace` is False."""
new_data = list(self.data)
new_data.extend(data)
if inplace:
self.modify_data(new_data)
return None
return self.replace(data=new_data)
def remove_empty(self, inplace: bool = False) -> tp.Optional[KnowledgeAssetT]:
"""Remove empty data items."""
from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc
new_data = [d for d in self.data if not FindRemoveAssetFunc.is_empty_func(d)]
if inplace:
self.modify_data(new_data)
return None
return self.replace(data=new_data)
def unique(
self: KnowledgeAssetT,
*args,
keep: str = "first",
inplace: bool = False,
**kwargs,
) -> tp.Optional[KnowledgeAssetT]:
"""De-duplicate based on `KnowledgeAsset.get` called on `*args` and `**kwargs`.
Returns a new `KnowledgeAsset` instance if `inplace` is False.
Usage:
```pycon
>>> asset.unique("b").get()
[{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}]
```
"""
keys = self.get(*args, **kwargs)
if keep.lower() == "first":
seen = set()
new_data = []
for key, item in zip(keys, self.data):
if key not in seen:
seen.add(key)
new_data.append(item)
elif keep.lower() == "last":
seen = set()
new_data_reversed = []
for key, item in zip(reversed(keys), reversed(self.data)):
if key not in seen:
seen.add(key)
new_data_reversed.append(item)
new_data = list(reversed(new_data_reversed))
else:
raise ValueError(f"Invalid keep option: '{keep}'")
if inplace:
self.modify_data(new_data)
return None
return self.replace(data=new_data)
def sort(
self: KnowledgeAssetT,
*args,
keys: tp.Optional[tp.Iterable[tp.Key]] = None,
ascending: bool = True,
inplace: bool = False,
**kwargs,
) -> tp.Optional[KnowledgeAssetT]:
"""Sort based on `KnowledgeAsset.get` called on `*args` and `**kwargs`.
Returns a new `KnowledgeAsset` instance if `inplace` is False.
Usage:
```pycon
>>> asset.sort("d2.c").get()
[{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}]
```
"""
if keys is None:
keys = self.get(*args, **kwargs)
new_data = [x for _, x in sorted(zip(keys, self.data), key=lambda x: x[0], reverse=not ascending)]
if inplace:
self.modify_data(new_data)
return None
return self.replace(data=new_data)
def shuffle(
self: KnowledgeAssetT,
seed: tp.Optional[int] = None,
inplace: bool = False,
) -> tp.Optional[KnowledgeAssetT]:
"""Shuffle data items."""
import random
if seed is not None:
random.seed(seed)
new_data = list(self.data)
random.shuffle(new_data)
if inplace:
self.modify_data(new_data)
return None
return self.replace(data=new_data)
def sample(
self,
k: tp.Optional[int] = None,
seed: tp.Optional[int] = None,
wrap: bool = True,
) -> tp.Any:
"""Pick a random sample of data items."""
import random
if k is None:
k = 1
single_item = True
else:
single_item = False
if seed is not None:
random.seed(seed)
new_data = random.sample(self.data, min(len(self.data), k))
if wrap:
return self.replace(data=new_data, single_item=single_item)
if single_item:
return new_data[0]
return new_data
def print_sample(self, k: tp.Optional[int] = None, seed: tp.Optional[int] = None, **kwargs) -> None:
"""Print a random sample.
Keyword arguments are passed to `KnowledgeAsset.print`."""
self.sample(k=k, seed=seed).print(**kwargs)
# ############# Collection methods ############# #
def __len__(self) -> int:
return len(self.data)
# ############# Sequence methods ############# #
def __getitem__(self, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]]) -> tp.Any:
return self.get_items(index)
# ############# MutableSequence methods ############# #
def insert(self, index: int, value: tp.Any) -> None:
new_data = list(self.data)
new_data.insert(index, value)
self.modify_data(new_data)
def __setitem__(self, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]], value: tp.Any) -> None:
self.set_items(index, value, inplace=True)
def __delitem__(self, index: tp.Union[int, slice, tp.Iterable[tp.Union[bool, int]]]) -> None:
self.delete_items(index, inplace=True)
def __add__(self: KnowledgeAssetT, other: tp.Any) -> KnowledgeAssetT:
if not isinstance(other, KnowledgeAsset):
other = KnowledgeAsset(other)
mro_self = self.__class__.mro()
mro_other = other.__class__.mro()
common_bases = set(mro_self).intersection(mro_other)
for cls in mro_self:
if cls in common_bases:
new_type = cls
break
else:
new_type = KnowledgeAsset
return new_type.combine(self, other)
def __iadd__(self: KnowledgeAssetT, other: tp.Any) -> KnowledgeAssetT:
if isinstance(other, KnowledgeAsset):
other = other.data
self.extend_items(other, inplace=True)
return self
# ############# Apply methods ############# #
def apply(
self,
func: tp.MaybeList[tp.Union[tp.AssetFuncLike, tp.AssetPipeline]],
*args,
execute_kwargs: tp.KwargsLike = None,
wrap: tp.Optional[bool] = None,
single_item: tp.Optional[bool] = None,
return_iterator: bool = False,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Apply a function to each data item.
Function can be either a callable, a tuple of function and its arguments,
a `vectorbtpro.utils.execution.Task` instance, a subclass of
`vectorbtpro.utils.knowledge.base_asset_funcs.AssetFunc` or its prefix or full name.
Moreover, function can be a list of the above. In such a case, `BasicAssetPipeline` will be used.
If function is a valid expression, `ComplexAssetPipeline` will be used.
Uses `vectorbtpro.utils.execution.execute` for execution.
If `wrap` is True, returns a new `KnowledgeAsset` instance, otherwise raw output.
Usage:
```pycon
>>> asset.apply(["flatten", ("query", len)])
[5, 5, 5, 5, 6]
>>> asset.apply("query(flatten(d), len)")
[5, 5, 5, 5, 6]
```
"""
from vectorbtpro.utils.knowledge.asset_pipelines import AssetPipeline, BasicAssetPipeline, ComplexAssetPipeline
execute_kwargs = self.resolve_setting(execute_kwargs, "execute_kwargs", merge=True)
asset_func_meta = {}
if isinstance(func, list):
func, args, kwargs = (
BasicAssetPipeline(
func,
*args,
cond_kwargs=dict(asset_cls=type(self)),
asset_func_meta=asset_func_meta,
**kwargs,
),
(),
{},
)
elif isinstance(func, str) and not func.isidentifier():
if len(args) > 0:
raise ValueError("No more positional arguments can be applied to ComplexAssetPipeline")
func, args, kwargs = (
ComplexAssetPipeline(
func,
context=kwargs.get("template_context", None),
cond_kwargs=dict(asset_cls=type(self)),
asset_func_meta=asset_func_meta,
**kwargs,
),
(),
{},
)
elif not isinstance(func, AssetPipeline):
func, args, kwargs = AssetPipeline.resolve_task(
func,
*args,
cond_kwargs=dict(asset_cls=type(self)),
asset_func_meta=asset_func_meta,
**kwargs,
)
else:
if len(args) > 0:
raise ValueError("No more positional arguments can be applied to AssetPipeline")
if len(kwargs) > 0:
raise ValueError("No more keyword arguments can be applied to AssetPipeline")
prefix = get_caller_qualname().split(".")[-1]
if "_short_name" in asset_func_meta:
prefix += f"[{asset_func_meta['_short_name']}]"
elif isinstance(func, type):
prefix += f"[{func.__name__}]"
else:
prefix += f"[{type(func).__name__}]"
execute_kwargs = merge_dicts(
dict(
show_progress=False if self.single_item else None,
pbar_kwargs=dict(
bar_id=get_caller_qualname(),
prefix=prefix,
),
),
execute_kwargs,
)
def _get_task_generator():
for i, d in enumerate(self.data):
_kwargs = dict(kwargs)
if "template_context" in _kwargs:
_kwargs["template_context"] = flat_merge_dicts(
{"i": i},
_kwargs["template_context"],
)
if return_iterator:
yield Task(func, d, *args, **_kwargs).execute()
else:
yield Task(func, d, *args, **_kwargs)
tasks = _get_task_generator()
if return_iterator:
return tasks
new_data = execute(tasks, size=len(self.data), **execute_kwargs)
if new_data is NoResult:
new_data = []
if wrap is None and asset_func_meta.get("_wrap", None) is not None:
wrap = asset_func_meta["_wrap"]
if wrap is None:
wrap = True
if single_item is None:
single_item = self.single_item
if wrap:
return self.replace(data=new_data, single_item=single_item)
if single_item:
if len(new_data) == 1:
return new_data[0]
if len(new_data) == 0:
return None
return new_data
def get(
self: KnowledgeAssetT,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
keep_path: tp.Optional[bool] = None,
skip_missing: tp.Optional[bool] = None,
source: tp.Optional[tp.CustomTemplateLike] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Get data items or parts of them.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.GetAssetFunc`.
Use argument `path` to specify what part of the data item should be got. For example, "x.y[0].z"
to navigate nested dictionaries/lists. If `keep_path` is True, the data item will be represented
as a nested dictionary with path as keys. If multiple paths are provided, `keep_path` automatically
becomes True, and they will be merged into one nested dictionary. If `skip_missing` is True
and path is missing in the data item, will skip the data item.
Use argument `source` instead of `path` or in addition to `path` to also preprocess the source.
It can be a string or function (will become a template), or any custom template. In this template,
the index of the data item is represented by "i", the data item itself is represented by "d",
the data item under the path is represented by "x" while its fields are represented by their names.
Usage:
```pycon
>>> asset.get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}]
>>> asset.get("d2.l[0]")
[1, 3, 5, 7, 9]
>>> asset.get("d2.l", source=lambda x: sum(x))
[3, 7, 11, 15, 19]
>>> asset.get("d2.l[0]", keep_path=True)
[{'d2': {'l': {0: 1}}},
{'d2': {'l': {0: 3}}},
{'d2': {'l': {0: 5}}},
{'d2': {'l': {0: 7}}},
{'d2': {'l': {0: 9}}}]
>>> asset.get(["d2.l[0]", "d2.l[1]"])
[{'d2': {'l': {0: 1, 1: 2}}},
{'d2': {'l': {0: 3, 1: 4}}},
{'d2': {'l': {0: 5, 1: 6}}},
{'d2': {'l': {0: 7, 1: 8}}},
{'d2': {'l': {0: 9, 1: 10}}}]
>>> asset.get("xyz", skip_missing=True)
[123]
```
"""
if path is None and source is None:
if self.single_item:
if len(self.data) == 1:
return self.data[0]
if len(self.data) == 0:
return None
return self.data
return self.apply(
"get",
path=path,
keep_path=keep_path,
skip_missing=skip_missing,
source=source,
template_context=template_context,
**kwargs,
)
def select(self: KnowledgeAssetT, *args, **kwargs) -> KnowledgeAssetT:
"""Call `KnowledgeAsset.get` and return a new `KnowledgeAsset` instance."""
return self.get(*args, wrap=True, **kwargs)
def set(
self: KnowledgeAssetT,
value: tp.Any,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Set data items or parts of them.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.SetAssetFunc`.
Argument `value` can be any value, function (will become a template), or a template. In this template,
the index of the data item is represented by "i", the data item itself is represented by "d",
the data item under the path is represented by "x" while its fields are represented by their names.
Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z"
to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and
path is missing in the data item, will skip the data item.
Set `make_copy` to True to not modify original data.
Set `changed_only` to True to keep only the data items that have been changed.
Usage:
```pycon
>>> asset.set(lambda d: sum(d["d2"]["l"])).get()
[3, 7, 11, 15, 19]
>>> asset.set(lambda d: sum(d["d2"]["l"]), path="d2.sum").get()
>>> asset.set(lambda x: sum(x["l"]), path="d2.sum").get()
>>> asset.set(lambda l: sum(l), path="d2.sum").get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2], 'sum': 3}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4], 'sum': 7}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6], 'sum': 11}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8], 'sum': 15}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10], 'sum': 19}, 'xyz': 123}]
>>> asset.set(lambda l: sum(l), path="d2.l").get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': 3}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': 7}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': 11}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': 15}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': 19}, 'xyz': 123}]
```
"""
return self.apply(
"set",
value=value,
path=path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
template_context=template_context,
**kwargs,
)
def remove(
self: KnowledgeAssetT,
path: tp.MaybeList[tp.PathLikeKey],
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Remove data items or parts of them.
If `path` is an integer, removes the entire data item at that index.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.RemoveAssetFunc`.
Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z"
to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and
path is missing in the data item, will skip the data item.
Set `make_copy` to True to not modify original data.
Set `changed_only` to True to keep only the data items that have been changed.
Usage:
```pycon
>>> asset.remove("d2.l[0]").get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [10]}, 'xyz': 123}]
>>> asset.remove("xyz", skip_missing=True).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}}]
```
"""
return self.apply(
"remove",
path=path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
**kwargs,
)
def move(
self: KnowledgeAssetT,
path: tp.Union[tp.PathMoveDict, tp.MaybeList[tp.PathLikeKey]],
new_path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Move data items or parts of them.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.MoveAssetFunc`.
Use argument `path` to specify what part of the data item should be renamed. For example, "x.y[0].z"
to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and
path is missing in the data item, will skip the data item.
Use argument `new_path` to specify the last part of the data item (i.e., token) that should be renamed to.
Multiple tokens can be provided. If None, `path` must be a dictionary.
Set `make_copy` to True to not modify original data.
Set `changed_only` to True to keep only the data items that have been changed.
Usage:
```pycon
>>> asset.move("d2.l", "l").get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red'}, 'l': [1, 2]},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue'}, 'l': [3, 4]},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green'}, 'l': [5, 6]},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow'}, 'l': [7, 8]},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black'}, 'xyz': 123, 'l': [9, 10]}]
>>> asset.move({"d2.c": "c", "b": "d2.b"}).get()
>>> asset.move(["d2.c", "b"], ["c", "d2.b"]).get()
[{'s': 'ABC', 'd2': {'l': [1, 2], 'b': True}, 'c': 'red'},
{'s': 'BCD', 'd2': {'l': [3, 4], 'b': True}, 'c': 'blue'},
{'s': 'CDE', 'd2': {'l': [5, 6], 'b': False}, 'c': 'green'},
{'s': 'DEF', 'd2': {'l': [7, 8], 'b': False}, 'c': 'yellow'},
{'s': 'EFG', 'd2': {'l': [9, 10], 'b': False}, 'xyz': 123, 'c': 'black'}]
```
"""
return self.apply(
"move",
path=path,
new_path=new_path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
**kwargs,
)
def rename(
self: KnowledgeAssetT,
path: tp.Union[tp.PathRenameDict, tp.MaybeList[tp.PathLikeKey]],
new_token: tp.Optional[tp.MaybeList[tp.PathKeyToken]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Rename data items or parts of them.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.RenameAssetFunc`.
Same as `KnowledgeAsset.move` but must specify new token instead of new path.
Usage:
```pycon
>>> asset.rename("d2.l", "x").get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'x': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'x': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'x': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'x': [7, 8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'x': [9, 10]}, 'xyz': 123}]
>>> asset.rename("xyz", "zyx", skip_missing=True, changed_only=True).get()
[{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'zyx': 123}]
```
"""
return self.apply(
"rename",
path=path,
new_token=new_token,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
**kwargs,
)
def reorder(
self: KnowledgeAssetT,
new_order: tp.Union[str, tp.PathKeyTokens],
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Reorder data items or parts of them.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.ReorderAssetFunc`.
Can change order in dicts based on `vectorbtpro.utils.config.reorder_dict` and
sequences based on `vectorbtpro.utils.config.reorder_list`.
Argument `new_order` can be a sequence of tokens. To not reorder a subset of keys, they can
be replaced by an ellipsis (`...`). For example, `["a", ..., "z"]` puts the token "a" at the start
and the token "z" at the end while other tokens are left in the original order. If `new_order` is
a string, it can be "asc"/"ascending" or "desc"/"descending". Other than that, it can be a string or
function (will become a template), or any custom template. In this template, the data item is
the index of the data item is represented by "i", the data item itself is represented by "d",
the data item under the path is represented by "x" while its fields are represented by their names.
Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z"
to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and
path is missing in the data item, will skip the data item.
Set `make_copy` to True to not modify original data.
Set `changed_only` to True to keep only the data items that have been changed.
Usage:
```pycon
>>> asset.reorder(["xyz", ...], skip_missing=True).get()
>>> asset.reorder(lambda x: ["xyz", ...] if "xyz" in x else [...]).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}},
{'xyz': 123, 's': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}}]
>>> asset.reorder("descending", path="d2.l").get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [2, 1]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [4, 3]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [6, 5]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [8, 7]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [10, 9]}, 'xyz': 123}]
```
"""
return self.apply(
"reorder",
new_order=new_order,
path=path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
template_context=template_context,
**kwargs,
)
def query(
self: KnowledgeAssetT,
expression: tp.CustomTemplateLike,
query_engine: tp.Optional[str] = None,
template_context: tp.KwargsLike = None,
return_type: tp.Optional[str] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Query using an engine and return the queried data item(s).
Following engines are supported:
* "jmespath": Evaluation with `jmespath` package
* "jsonpath", "jsonpath-ng" or "jsonpath_ng": Evaluation with `jsonpath-ng` package
* "jsonpath.ext", "jsonpath-ng.ext" or "jsonpath_ng.ext": Evaluation with extended `jsonpath-ng` package
* None or "template": Evaluation of each data item as a template. The index of the data item is
represented by "i", the data item itself is represented by "d", the data item under the path
is represented by "x" while its fields are represented by their names.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.QueryAssetFunc`.
* "pandas": Same as above but variables being columns
If `return_type` is "item", returns the data item when matched. If `return_type` is "bool",
returns True when matched.
Templates can also use the functions defined in `vectorbtpro.utils.search_.search_config`.
They work on single values and sequences alike.
Keyword arguments are passed to the respective search/parse/evaluation function.
Usage:
```pycon
>>> asset.query("d['s'] == 'ABC'")
>>> asset.query("x['s'] == 'ABC'")
>>> asset.query("s == 'ABC'")
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}]
>>> asset.query("x['s'] == 'ABC'", return_type="bool")
[True, False, False, False, False]
>>> asset.query("find('BC', s)")
>>> asset.query(lambda s: "BC" in s)
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}]
>>> asset.query("[?contains(s, 'BC')].s", query_engine="jmespath")
['ABC', 'BCD']
>>> asset.query("[].d2.c", query_engine="jmespath")
['red', 'blue', 'green', 'yellow', 'black']
>>> asset.query("[?d2.c != `blue`].d2.l", query_engine="jmespath")
[[1, 2], [5, 6], [7, 8], [9, 10]]
>>> asset.query("$[*].d2.c", query_engine="jsonpath.ext")
['red', 'blue', 'green', 'yellow', 'black']
>>> asset.query("$[?(@.b == true)].s", query_engine="jsonpath.ext")
['ABC', 'BCD']
>>> asset.query("s[b]", query_engine="pandas")
['ABC', 'BCD']
```
"""
query_engine = self.resolve_setting(query_engine, "query_engine")
template_context = self.resolve_setting(template_context, "template_context", merge=True)
return_type = self.resolve_setting(return_type, "return_type")
if query_engine is None or query_engine.lower() == "template":
new_obj = self.apply(
"query",
expression=expression,
template_context=template_context,
return_type=return_type,
**kwargs,
)
elif query_engine.lower() == "jmespath":
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("jmespath")
import jmespath
new_obj = jmespath.search(expression, self.data, **kwargs)
elif query_engine.lower() in ("jsonpath", "jsonpath-ng", "jsonpath_ng"):
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("jsonpath_ng")
import jsonpath_ng
jsonpath_expr = jsonpath_ng.parse(expression)
new_obj = [match.value for match in jsonpath_expr.find(self.data, **kwargs)]
elif query_engine.lower() in ("jsonpath.ext", "jsonpath-ng.ext", "jsonpath_ng.ext"):
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("jsonpath_ng")
import jsonpath_ng.ext
jsonpath_expr = jsonpath_ng.ext.parse(expression)
new_obj = [match.value for match in jsonpath_expr.find(self.data, **kwargs)]
elif query_engine.lower() == "pandas":
if isinstance(expression, str):
expression = RepEval(expression)
elif checks.is_function(expression):
if checks.is_builtin_func(expression):
expression = RepFunc(lambda _expression=expression: _expression)
else:
expression = RepFunc(expression)
elif not isinstance(expression, CustomTemplate):
raise TypeError(f"Expression must be a template")
df = pd.DataFrame.from_records(self.data)
_template_context = flat_merge_dicts(
{
"d": df,
"x": df,
**df.to_dict(orient="series"),
},
template_context,
)
result = expression.substitute(_template_context, eval_id="expression", **kwargs)
if checks.is_function(result):
result = result(df)
if return_type.lower() == "item":
as_filter = True
elif return_type.lower() == "bool":
as_filter = False
else:
raise ValueError(f"Invalid return type: '{return_type}'")
if as_filter and isinstance(result, pd.Series) and result.dtype == "bool":
result = df[result]
if isinstance(result, pd.Series):
new_obj = result.tolist()
elif isinstance(result, pd.DataFrame):
new_obj = result.to_dict(orient="records")
else:
new_obj = result
else:
raise ValueError(f"Invalid query engine: '{query_engine}'")
return new_obj
def filter(self: KnowledgeAssetT, *args, **kwargs) -> KnowledgeAssetT:
"""Call `KnowledgeAsset.query` and return a new `KnowledgeAsset` instance."""
return self.query(*args, wrap=True, **kwargs)
def find(
self: KnowledgeAssetT,
target: tp.MaybeList[tp.Any],
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
per_path: tp.Optional[bool] = None,
find_all: tp.Optional[bool] = None,
keep_path: tp.Optional[bool] = None,
skip_missing: tp.Optional[bool] = None,
source: tp.Optional[tp.CustomTemplateLike] = None,
in_dumps: tp.Optional[bool] = None,
dump_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
return_type: tp.Optional[str] = None,
return_path: tp.Optional[bool] = None,
merge_matches: tp.Optional[bool] = None,
merge_fields: tp.Optional[bool] = None,
unique_matches: tp.Optional[bool] = None,
unique_fields: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Find occurrences and return a new `KnowledgeAsset` instance.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.FindAssetFunc`.
Uses `vectorbtpro.utils.search_.contains_in_obj` (keyword arguments are passed here)
to find any occurrences in each data item if `return_type` is "item" (returns the data item when matched),
`return_type` is "field" (returns the field), or `return_type` is "bool" (returns True when matched).
For all other return types, uses `vectorbtpro.utils.search_.find_in_obj` and `vectorbtpro.utils.search_.find`.
Target can be one or multiple data items. If there are multiple targets and `find_all` is True,
the match function will return True only if all targets have been found.
Use argument `path` to specify what part of the data item should be searched. For example, "x.y[0].z"
to navigate nested dictionaries/lists. If `keep_path` is True, the data item will be represented
as a nested dictionary with path as keys. If multiple paths are provided, `keep_path` automatically
becomes True, and they will be merged into one nested dictionary. If `skip_missing` is True
and path is missing in the data item, will skip the data item. If `per_path` is True, will consider
targets to be provided per path.
Use argument `source` instead of `path` or in addition to `path` to also preprocess the source.
It can be a string or function (will become a template), or any custom template. In this template,
the index of the data item is represented by "i", the data item itself is represented by "d",
the data item under the path is represented by "x" while its fields are represented by their names.
Set `in_dumps` to True to convert the entire data item to string and search in that string.
Will use `vectorbtpro.utils.formatting.dump` with `**dump_kwargs`.
Disable `merge_matches` and `merge_fields` to keep empty lists when searching for matches and
fields respectively. Disable `unique_matches` and `unique_fields` to keep duplicate matches
and fields respectively.
Usage:
```pycon
>>> asset.find("BC").get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}]
>>> asset.find("BC", return_type="bool").get()
[True, True, False, False, False]
>>> asset.find(vbt.Not("BC")).get()
[{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}]
>>> asset.find("bc", ignore_case=True).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}]
>>> asset.find("bl", path="d2.c").get()
[{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}]
>>> asset.find(5, path="d2.l[0]").get()
[{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}]
>>> asset.find(True, path="d2.l", source=lambda x: sum(x) >= 10).get()
[{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}]
>>> asset.find(["A", "B", "C"]).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}]
>>> asset.find(["A", "B", "C"], find_all=True).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}}]
>>> asset.find(r"[ABC]+", mode="regex").get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}}]
>>> asset.find("yenlow", mode="fuzzy").get()
[{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}]
>>> asset.find("yenlow", mode="fuzzy", return_type="match").get()
'yellow'
>>> asset.find("yenlow", mode="fuzzy", return_type="match", merge_matches=False).get()
[[], [], [], ['yellow'], []]
>>> asset.find("yenlow", mode="fuzzy", return_type="match", return_path=True).get()
[{}, {}, {}, {('d2', 'c'): ['yellow']}, {}]
>>> asset.find("xyz", in_dumps=True).get()
[{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}]
```
"""
found_asset = self.apply(
"find",
target=target,
path=path,
per_path=per_path,
find_all=find_all,
keep_path=keep_path,
skip_missing=skip_missing,
source=source,
in_dumps=in_dumps,
dump_kwargs=dump_kwargs,
template_context=template_context,
return_type=return_type,
return_path=return_path,
**kwargs,
)
return_type = self.resolve_setting(return_type, "return_type")
merge_matches = self.resolve_setting(merge_matches, "merge_matches")
merge_fields = self.resolve_setting(merge_fields, "merge_fields")
unique_matches = self.resolve_setting(unique_matches, "unique_matches")
unique_fields = self.resolve_setting(unique_fields, "unique_fields")
if (
((merge_matches and return_type.lower() == "match") or (merge_fields and return_type.lower() == "field"))
and isinstance(found_asset, KnowledgeAsset)
and len(found_asset) > 0
and isinstance(found_asset[0], list)
):
found_asset = found_asset.merge()
if (
((unique_matches and return_type.lower() == "match") or (unique_fields and return_type.lower() == "field"))
and isinstance(found_asset, KnowledgeAsset)
and len(found_asset) > 0
and isinstance(found_asset[0], str)
):
found_asset = found_asset.unique()
return found_asset
def find_code(
self,
target: tp.Optional[tp.MaybeIterable[tp.Any]] = None,
language: tp.Union[None, bool, tp.MaybeIterable[str]] = None,
in_blocks: tp.Optional[bool] = None,
escape_target: bool = True,
escape_language: bool = True,
return_type: tp.Optional[str] = "match",
flags: int = 0,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Find code using `KnowledgeAsset.find`.
For defaults, see `code` in `vectorbtpro._settings.knowledge`."""
language = self.resolve_setting(language, "language", sub_path="code")
in_blocks = self.resolve_setting(in_blocks, "in_blocks", sub_path="code")
if target is not None:
if not isinstance(target, (str, list)):
target = list(target)
if language is not None:
if not isinstance(language, (str, list)):
language = list(language)
if escape_language:
if isinstance(language, list):
language = list(map(re.escape, language))
else:
language = re.escape(language)
if isinstance(language, list):
language = rf"(?:{'|'.join(language)})"
opt_language = r"[\w+-]+"
opt_title = r"(?:\s+[^\n`]+)?"
if target is not None:
if not isinstance(target, list):
targets = [target]
single_target = True
else:
targets = target
single_target = False
new_target = []
for t in targets:
if escape_target:
t = re.escape(t)
if in_blocks:
if language is not None and not isinstance(language, bool):
new_t = rf"""
```{language}{opt_title}\n
(?:(?!```)[\s\S])*?
{t}
(?:(?!```)[\s\S])*?
```\s*$
"""
elif language is not None and isinstance(language, bool) and language:
new_t = rf"""
```{opt_language}{opt_title}\n
(?:(?!```)[\s\S])*?
{t}
(?:(?!```)[\s\S])*?
```\s*$
"""
else:
new_t = rf"""
```(?:{opt_language}{opt_title})?\n
(?:(?!```)[\s\S])*?
{t}
(?:(?!```)[\s\S])*?
```\s*$
"""
else:
new_t = rf"(? tp.MaybeKnowledgeAsset:
"""Find and replace occurrences and return a new `KnowledgeAsset` instance.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.FindReplaceAssetFunc`.
Uses `vectorbtpro.utils.search_.find_in_obj` (keyword arguments are passed here) to find
occurrences in each data item. Then, uses `vectorbtpro.utils.search_.replace_in_obj` to replace them.
Target can be one or multiple of data items, either as a list or a dictionary. If there are multiple
targets and `find_all` is True, the match function will return True only if all targets have been found.
Use argument `path` to specify what part of the data item should be searched. For example, "x.y[0].z"
to navigate nested dictionaries/lists. If `keep_path` is True, the data item will be represented
as a nested dictionary with path as keys. If multiple paths are provided, `keep_path` automatically
becomes True, and they will be merged into one nested dictionary. If `skip_missing` is True
and path is missing in the data item, will skip the data item. If `per_path` is True, will consider
targets and replacements to be provided per path.
Set `make_copy` to True to not modify original data.
Set `changed_only` to True to keep only the data items that have been changed.
Usage:
```pycon
>>> asset.find_replace("BC", "XY").get()
[{'s': 'AXY', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'XYD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}]
>>> asset.find_replace("BC", "XY", changed_only=True).get()
[{'s': 'AXY', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'XYD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}}]
>>> asset.find_replace(r"(D)E(F)", r"\1X\2", mode="regex", changed_only=True).get()
[{'s': 'DXF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}}]
>>> asset.find_replace(True, False, changed_only=True).get()
[{'s': 'ABC', 'b': False, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': False, 'd2': {'c': 'blue', 'l': [3, 4]}}]
>>> asset.find_replace(3, 30, path="d2.l", changed_only=True).get()
[{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [30, 4]}}]
>>> asset.find_replace({1: 10, 4: 40}, path="d2.l", changed_only=True).get()
>>> asset.find_replace({1: 10, 4: 40}, path=["d2.l[0]", "d2.l[1]"], changed_only=True).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [10, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 40]}}]
>>> asset.find_replace({1: 10, 4: 40}, find_all=True, changed_only=True).get()
[]
>>> asset.find_replace({1: 10, 2: 20}, find_all=True, changed_only=True).get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [10, 20]}}]
>>> asset.find_replace("a", "X", path=["s", "d2.c"], ignore_case=True, changed_only=True).get()
[{'s': 'XBC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'blXck', 'l': [9, 10]}, 'xyz': 123}]
>>> asset.find_replace(123, 456, path="xyz", skip_missing=True, changed_only=True).get()
[{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 456}]
```
"""
return self.apply(
"find_replace",
target=target,
replacement=replacement,
path=path,
per_path=per_path,
find_all=find_all,
keep_path=keep_path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
**kwargs,
)
def find_remove(
self: KnowledgeAssetT,
target: tp.Union[dict, tp.MaybeList[tp.Any]],
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
per_path: tp.Optional[bool] = None,
find_all: tp.Optional[bool] = None,
keep_path: tp.Optional[bool] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Find and remove occurrences and return a new `KnowledgeAsset` instance.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.FindRemoveAssetFunc`.
Similar to `KnowledgeAsset.find_replace`."""
return self.apply(
"find_remove",
target=target,
path=path,
per_path=per_path,
find_all=find_all,
keep_path=keep_path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
**kwargs,
)
def find_remove_empty(self: KnowledgeAssetT, **kwargs) -> tp.MaybeKnowledgeAsset:
"""Find and remove empty objects."""
from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc
return self.find_remove(FindRemoveAssetFunc.is_empty_func, **kwargs)
def flatten(
self: KnowledgeAssetT,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Flatten data items or parts of them.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.FlattenAssetFunc`.
Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z"
to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and
path is missing in the data item, will skip the data item.
Set `make_copy` to True to not modify original data.
Set `changed_only` to True to keep only the data items that have been changed.
Keyword arguments are passed to `vectorbtpro.utils.search_.flatten_obj`.
Usage:
```pycon
>>> asset.flatten().get()
[{'s': 'ABC',
'b': True,
('d2', 'c'): 'red',
('d2', 'l', 0): 1,
('d2', 'l', 1): 2},
...
{'s': 'EFG',
'b': False,
('d2', 'c'): 'black',
('d2', 'l', 0): 9,
('d2', 'l', 1): 10,
'xyz': 123}]
```
"""
return self.apply(
"flatten",
path=path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
**kwargs,
)
def unflatten(
self: KnowledgeAssetT,
path: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
skip_missing: tp.Optional[bool] = None,
make_copy: tp.Optional[bool] = None,
changed_only: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Unflatten data items or parts of them.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.UnflattenAssetFunc`.
Use argument `path` to specify what part of the data item should be set. For example, "x.y[0].z"
to navigate nested dictionaries/lists. Multiple paths can be provided. If `skip_missing` is True and
path is missing in the data item, will skip the data item.
Set `make_copy` to True to not modify original data.
Set `changed_only` to True to keep only the data items that have been changed.
Keyword arguments are passed to `vectorbtpro.utils.search_.unflatten_obj`.
Usage:
```pycon
>>> asset.flatten().unflatten().get()
[{'s': 'ABC', 'b': True, 'd2': {'c': 'red', 'l': [1, 2]}},
{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'CDE', 'b': False, 'd2': {'c': 'green', 'l': [5, 6]}},
{'s': 'DEF', 'b': False, 'd2': {'c': 'yellow', 'l': [7, 8]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}]
```
"""
return self.apply(
"unflatten",
path=path,
skip_missing=skip_missing,
make_copy=make_copy,
changed_only=changed_only,
**kwargs,
)
def dump(
self: KnowledgeAssetT,
source: tp.Optional[tp.CustomTemplateLike] = None,
dump_engine: tp.Optional[str] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Dump data items.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.DumpAssetFunc`.
Following engines are supported:
* "repr": Dumping with `repr`
* "prettify": Dumping with `vectorbtpro.utils.formatting.prettify`
* "nestedtext": Dumping with NestedText (https://pypi.org/project/nestedtext/)
* "yaml": Dumping with YAML
* "toml": Dumping with TOML (https://pypi.org/project/toml/)
* "json": Dumping with JSON
Use argument `source` to also preprocess the source. It can be a string or function
(will become a template), or any custom template. In this template, the index of the data item
is represented by "i", the data item itself is represented by "d" while its fields are
represented by their names.
Keyword arguments are passed to the respective engine.
Usage:
```pycon
>>> print(asset.dump(source="{i: d}", default_flow_style=True).join())
{0: {s: ABC, b: true, d2: {c: red, l: [1, 2]}}}
{1: {s: BCD, b: true, d2: {c: blue, l: [3, 4]}}}
{2: {s: CDE, b: false, d2: {c: green, l: [5, 6]}}}
{3: {s: DEF, b: false, d2: {c: yellow, l: [7, 8]}}}
{4: {s: EFG, b: false, d2: {c: black, l: [9, 10]}, xyz: 123}}
```
"""
return self.apply(
"dump",
source=source,
dump_engine=dump_engine,
template_context=template_context,
**kwargs,
)
def dump_all(
self,
source: tp.Optional[tp.CustomTemplateLike] = None,
dump_engine: tp.Optional[str] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> str:
"""Dump data list as a single data item.
See `KnowledgeAsset.dump` for arguments."""
from vectorbtpro.utils.knowledge.base_asset_funcs import DumpAssetFunc
return DumpAssetFunc.prepare_and_call(
self.data,
source=source,
dump_engine=dump_engine,
template_context=template_context,
**kwargs,
)
def to_documents(self, **kwargs) -> tp.MaybeKnowledgeAsset:
"""Convert to documents of type `vectorbtpro.utils.knowledge.chatting.TextDocument`.
Document-related keyword arguments may contain templates. In such templates,
the index of the data item is represented by "i", the data item itself is represented by "d",
the data item under the path is represented by "x" while its fields are represented by their names."""
return self.apply("to_docs", **kwargs)
def split_text(
self,
text_path: tp.Optional[tp.PathLikeKey] = None,
merge_chunks: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Split text.
Uses `KnowledgeAsset.apply` on `vectorbtpro.utils.knowledge.base_asset_funcs.SplitTextAssetFunc`.
Use argument `text_path` to specify a path to the content.
If `merge_chunks` is True, merges all chunks into a single list.
Uses `vectorbtpro.utils.knowledge.chatting.split_text` with `**split_text_kwargs` for text splitting."""
split_asset = self.apply(
"split_text",
text_path=text_path,
**kwargs,
)
merge_chunks = self.resolve_setting(merge_chunks, "merge_chunks")
if (
merge_chunks
and isinstance(split_asset, KnowledgeAsset)
and len(split_asset) > 0
and isinstance(split_asset[0], list)
):
split_asset = split_asset.merge()
return split_asset
# ############# Reduce methods ############# #
@classmethod
def get_keys_and_groups(
cls,
by: tp.List[tp.Any],
uniform_groups: bool = False,
) -> tp.Tuple[tp.List[tp.Any], tp.List[tp.List[int]]]:
"""get keys and groups."""
keys = []
groups = []
if uniform_groups:
for i, item in enumerate(by):
if len(keys) > 0 and (keys[-1] is item or keys[-1] == item):
groups[-1].append(i)
else:
keys.append(item)
groups.append([i])
else:
groups = []
representatives = []
for idx, item in enumerate(by):
found = False
for rep_idx, rep_obj in enumerate(representatives):
if item is rep_obj or item == rep_obj:
groups[rep_idx].append(idx)
found = True
break
if not found:
representatives.append(item)
keys.append(by[idx])
groups.append([idx])
return keys, groups
def reduce(
self: KnowledgeAssetT,
func: tp.CustomTemplateLike,
*args,
initializer: tp.Optional[tp.Any] = None,
by: tp.Optional[tp.PathLikeKey] = None,
template_context: tp.KwargsLike = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
wrap: tp.Optional[bool] = None,
return_iterator: bool = False,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Reduce data items.
Function can be a callable, a tuple of function and its arguments,
a `vectorbtpro.utils.execution.Task` instance, a subclass of
`vectorbtpro.utils.knowledge.base_asset_funcs.AssetFunc` or its prefix or full name.
It can also be an expression or a template. In this template, the index of the data item is
represented by "i", the data items themselves are represented by "d1" and "d2" or "x1" and "x2".
If an initializer is provided, the first set of values will be `d1=initializer` and
`d2=self.data[0]`. If not, it will be `d1=self.data[0]` and `d2=self.data[1]`.
If `by` is provided, see `KnowledgeAsset.groupby_reduce`.
If `wrap` is True, returns a new `KnowledgeAsset` instance, otherwise raw output.
Usage:
```pycon
>>> asset.reduce(lambda d1, d2: vbt.merge_dicts(d1, d2))
>>> asset.reduce(vbt.merge_dicts)
>>> asset.reduce("{**d1, **d2}")
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}
>>> asset.reduce("{**d1, **d2}", by="b")
[{'s': 'BCD', 'b': True, 'd2': {'c': 'blue', 'l': [3, 4]}},
{'s': 'EFG', 'b': False, 'd2': {'c': 'black', 'l': [9, 10]}, 'xyz': 123}]
```
"""
if by is not None:
return self.groupby_reduce(
func,
*args,
by=by,
initializer=initializer,
template_context=template_context,
show_progress=show_progress,
pbar_kwargs=pbar_kwargs,
wrap=wrap,
**kwargs,
)
show_progress = self.resolve_setting(show_progress, "show_progress")
pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True)
asset_func_meta = {}
if isinstance(func, str) and not func.isidentifier():
func = RepEval(func)
elif not isinstance(func, CustomTemplate):
from vectorbtpro.utils.knowledge.asset_pipelines import AssetPipeline
func, args, kwargs = AssetPipeline.resolve_task(
func,
*args,
cond_kwargs=dict(asset_cls=type(self)),
asset_func_meta=asset_func_meta,
**kwargs,
)
it = iter(self.data)
if initializer is None and asset_func_meta.get("_initializer", None) is not None:
initializer = asset_func_meta["_initializer"]
if initializer is None:
d1 = next(it)
total = len(self.data) - 1
if total == 0:
raise ValueError("Must provide initializer")
else:
d1 = initializer
total = len(self.data)
def _get_d1_generator(d1):
for i, d2 in enumerate(it):
if isinstance(func, CustomTemplate):
_template_context = flat_merge_dicts(
{
"i": i,
"d1": d1,
"d2": d2,
"x1": d1,
"x2": d2,
},
template_context,
)
_d1 = func.substitute(_template_context, eval_id="func", **kwargs)
if checks.is_function(_d1):
d1 = _d1(d1, d2, *args)
else:
d1 = _d1
else:
_kwargs = dict(kwargs)
if "template_context" in _kwargs:
_kwargs["template_context"] = flat_merge_dicts(
{"i": i},
_kwargs["template_context"],
)
d1 = func(d1, d2, *args, **_kwargs)
yield d1
d1s = _get_d1_generator(d1)
if return_iterator:
return d1s
if show_progress is None:
show_progress = total > 1
prefix = get_caller_qualname().split(".")[-1]
if "_short_name" in asset_func_meta:
prefix += f"[{asset_func_meta['_short_name']}]"
elif isinstance(func, type):
prefix += f"[{func.__name__}]"
else:
prefix += f"[{type(func).__name__}]"
pbar_kwargs = flat_merge_dicts(
dict(
bar_id=get_caller_qualname(),
prefix=prefix,
),
pbar_kwargs,
)
with ProgressBar(total=total, show_progress=show_progress, **pbar_kwargs) as pbar:
for d1 in d1s:
pbar.update()
if wrap is None and asset_func_meta.get("_wrap", None) is not None:
wrap = asset_func_meta["_wrap"]
if wrap is None:
wrap = False
if wrap:
if not isinstance(d1, list):
d1 = [d1]
return self.replace(data=d1, single_item=True)
return d1
def groupby_reduce(
self: KnowledgeAssetT,
func: tp.CustomTemplateLike,
*args,
by: tp.Optional[tp.PathLikeKey] = None,
uniform_groups: tp.Optional[bool] = None,
get_kwargs: tp.KwargsLike = None,
execute_kwargs: tp.KwargsLike = None,
return_group_keys: bool = False,
**kwargs,
) -> tp.Union[KnowledgeAssetT, dict, list]:
"""Group data items by keys and reduce.
If `by` is provided, uses it as `path` in `KnowledgeAsset.get`, groups by unique values,
and runs `KnowledgeAsset.reduce` on each group.
Set `uniform_groups` to True to only group unique values that are located adjacent to each other.
Variable arguments are passed to each call of `KnowledgeAsset.reduce`."""
uniform_groups = self.resolve_setting(uniform_groups, "uniform_groups")
execute_kwargs = self.resolve_setting(execute_kwargs, "execute_kwargs", merge=True)
if get_kwargs is None:
get_kwargs = {}
by = self.get(path=by, **get_kwargs)
keys, groups = self.get_keys_and_groups(by, uniform_groups=uniform_groups)
if len(groups) == 0:
raise ValueError("Groups are empty")
tasks = []
for i, group in enumerate(groups):
group_instance = self.get_items(group)
tasks.append(Task(group_instance.reduce, func, *args, **kwargs))
prefix = get_caller_qualname().split(".")[-1]
execute_kwargs = merge_dicts(
dict(
show_progress=False if len(groups) == 1 else None,
pbar_kwargs=dict(
bar_id=get_caller_qualname(),
prefix=prefix,
),
),
execute_kwargs,
)
results = execute(tasks, size=len(groups), **execute_kwargs)
if return_group_keys:
return dict(zip(keys, results))
if len(results) > 0 and isinstance(results[0], type(self)):
return type(self).combine(results)
return results
def merge_dicts(self: KnowledgeAssetT, **kwargs) -> tp.MaybeKnowledgeAsset:
"""Merge (dict) date items into a single dict.
Final keyword arguments are passed to `vectorbtpro.utils.config.merge_dicts`."""
return self.reduce("merge_dicts", **kwargs)
def merge_lists(self: KnowledgeAssetT, **kwargs) -> tp.MaybeKnowledgeAsset:
"""Merge (list) date items into a single list."""
return self.reduce("merge_lists", **kwargs)
def collect(
self: KnowledgeAssetT,
sort_keys: tp.Optional[bool] = None,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Collect values of each key in each data item."""
return self.reduce("collect", sort_keys=sort_keys, **kwargs)
@classmethod
def describe_lengths(self, lengths: list, **describe_kwargs) -> dict:
"""Describe values representing lengths."""
len_describe_dict = pd.Series(lengths).describe(**describe_kwargs).to_dict()
del len_describe_dict["count"]
del len_describe_dict["std"]
return {"len_" + k: int(v) if k != "mean" else v for k, v in len_describe_dict.items()}
def describe(
self: KnowledgeAssetT,
ignore_empty: tp.Optional[bool] = None,
describe_kwargs: tp.KwargsLike = None,
wrap: bool = False,
**kwargs,
) -> tp.Union[KnowledgeAssetT, dict]:
"""Collect and describe each key in each data item."""
ignore_empty = self.resolve_setting(ignore_empty, "ignore_empty")
describe_kwargs = self.resolve_setting(describe_kwargs, "describe_kwargs", merge=True)
collected = self.collect(**kwargs)
description = {}
for k, v in list(collected.items()):
all_types = []
valid_types = []
valid_x = None
new_v = []
for x in v:
if not ignore_empty or x:
new_v.append(x)
if x is not None:
valid_x = x
if type(x) not in valid_types:
valid_types.append(type(x))
if type(x) not in all_types:
all_types.append(type(x))
v = new_v
description[k] = {}
description[k]["types"] = list(map(lambda x: x.__name__, all_types))
describe_sr = pd.Series(v)
if describe_sr.dtype == object and len(valid_types) == 1 and checks.is_complex_collection(valid_x):
describe_dict = {"count": len(v)}
else:
describe_dict = describe_sr.describe(**describe_kwargs).to_dict()
if "count" in describe_dict:
describe_dict["count"] = int(describe_dict["count"])
if "unique" in describe_dict:
describe_dict["unique"] = int(describe_dict["unique"])
if pd.api.types.is_integer_dtype(describe_sr.dtype):
new_describe_dict = {}
for _k, _v in describe_dict.items():
if _k not in {"mean", "std"}:
new_describe_dict[_k] = int(_v)
else:
new_describe_dict[_k] = _v
describe_dict = new_describe_dict
if "unique" in describe_dict and describe_dict["unique"] == describe_dict["count"]:
del describe_dict["top"]
del describe_dict["freq"]
if "unique" in describe_dict and describe_dict["count"] == 1:
del describe_dict["unique"]
description[k].update(describe_dict)
if len(valid_types) == 1 and checks.is_collection(valid_x):
lengths = [len(_v) for _v in v if _v is not None]
description[k].update(self.describe_lengths(lengths, **describe_kwargs))
if wrap:
return self.replace(data=[description], single_item=True)
return description
def print_schema(self, **kwargs) -> None:
"""Print schema.
Keyword arguments are split between `KnowledgeAsset.describe` and
`vectorbtpro.utils.path_.dir_tree_from_paths`.
Usage:
```pycon
>>> asset.print_schema()
/
├── s [5/5, str]
├── b [2/5, bool]
├── d2 [5/5, dict]
│ ├── c [5/5, str]
│ └── l
│ ├── 0 [5/5, int]
│ └── 1 [5/5, int]
└── xyz [1/5, int]
2 directories, 6 files
```
"""
dir_tree_arg_names = set(get_func_arg_names(dir_tree_from_paths))
dir_tree_kwargs = {k: kwargs.pop(k) for k in list(kwargs.keys()) if k in dir_tree_arg_names}
orig_describe_dict = self.describe(wrap=False, **kwargs)
flat_describe_dict = self.flatten(
skip_missing=True,
make_copy=True,
changed_only=False,
).describe(wrap=False, **kwargs)
describe_dict = flat_merge_dicts(orig_describe_dict, flat_describe_dict)
paths = []
path_names = []
for k, v in describe_dict.items():
if k is None:
k = "."
if not isinstance(k, tuple):
k = (k,)
path = Path(*map(str, k))
path_name = path.name
path_name += " [" + str(v["count"]) + "/" + str(len(self.data))
path_name += ", " + ", ".join(v["types"]) + "]"
path_names.append(path_name)
paths.append(path)
if "root_name" not in dir_tree_kwargs:
dir_tree_kwargs["root_name"] = "/"
if "sort" not in dir_tree_kwargs:
dir_tree_kwargs["sort"] = False
if "path_names" not in dir_tree_kwargs:
dir_tree_kwargs["path_names"] = path_names
if "length_limit" not in dir_tree_kwargs:
dir_tree_kwargs["length_limit"] = None
print(dir_tree_from_paths(paths, **dir_tree_kwargs))
def join(self, separator: tp.Optional[str] = None) -> str:
"""Join the list of string data items."""
if len(self.data) == 0:
return ""
if len(self.data) == 1:
return self.data[0]
if separator is None:
use_empty_separator = True
use_comma_separator = True
for d in self.data:
if not d.endswith(("\n", "\t", " ")):
use_empty_separator = False
if not d.endswith(("}", "]")):
use_comma_separator = False
if not use_empty_separator and not use_comma_separator:
break
if use_empty_separator:
separator = ""
elif use_comma_separator:
separator = ", "
else:
separator = "\n\n"
joined = separator.join(self.data)
if joined.startswith("{") and joined.endswith("}"):
return "[" + joined + "]"
return joined
def embed(
self,
to_documents_kwargs: tp.KwargsLike = None,
wrap_documents: tp.Optional[bool] = None,
**kwargs,
) -> tp.Optional[tp.MaybeKnowledgeAsset]:
"""Embed documents.
First, converts to `vectorbtpro.utils.knowledge.chatting.TextDocument` format using
`KnowledgeAsset.to_documents` and `**to_documents_kwargs`. Then, uses
`vectorbtpro.utils.knowledge.chatting.embed_documents` with `**kwargs` for actual ranking."""
from vectorbtpro.utils.knowledge.chatting import StoreDocument, EmbeddedDocument, embed_documents
if self.data and not isinstance(self.data[0], StoreDocument):
if to_documents_kwargs is None:
to_documents_kwargs = {}
documents = self.to_documents(**to_documents_kwargs)
if wrap_documents is None:
wrap_documents = False
else:
documents = self.data
if wrap_documents is None:
wrap_documents = True
embedded_documents = embed_documents(documents, **kwargs)
if embedded_documents is None:
return None
if not wrap_documents:
def _unwrap(document):
if isinstance(document, EmbeddedDocument):
return document.replace(
document=_unwrap(document.document),
child_documents=[_unwrap(d) for d in document.child_documents],
)
if isinstance(document, StoreDocument):
return document.data
return document
embedded_documents = list(map(_unwrap, embedded_documents))
return self.replace(data=embedded_documents)
def rank(
self,
query: str,
to_documents_kwargs: tp.KwargsLike = None,
wrap_documents: tp.Optional[bool] = None,
cache_documents: bool = False,
cache_key: tp.Optional[str] = None,
asset_cache_manager: tp.Optional[tp.MaybeType[AssetCacheManager]] = None,
asset_cache_manager_kwargs: tp.KwargsLike = None,
silence_warnings: bool = False,
**kwargs,
) -> tp.MaybeKnowledgeAsset:
"""Rank documents by their similarity to a query.
First, converts to `vectorbtpro.utils.knowledge.chatting.TextDocument` format using
`KnowledgeAsset.to_documents` and `**to_documents_kwargs`. Then, uses
`vectorbtpro.utils.knowledge.chatting.rank_documents` with `**kwargs` for actual ranking.
If `cache_documents` is True and `cache_key` is not None, will use an asset cache manager
to store the generated text documents in a local and/or disk cache after conversion.
Running the same method again will use the cached documents."""
from vectorbtpro.utils.knowledge.chatting import StoreDocument, ScoredDocument, rank_documents
if cache_documents:
if asset_cache_manager is None:
asset_cache_manager = AssetCacheManager
if asset_cache_manager_kwargs is None:
asset_cache_manager_kwargs = {}
if isinstance(asset_cache_manager, type):
checks.assert_subclass_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager")
asset_cache_manager = asset_cache_manager(**asset_cache_manager_kwargs)
else:
checks.assert_instance_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager")
if asset_cache_manager_kwargs:
asset_cache_manager = asset_cache_manager.replace(**asset_cache_manager_kwargs)
documents = None
if cache_documents and cache_key is not None:
documents = asset_cache_manager.load_asset(cache_key)
if documents is not None:
if wrap_documents is None:
wrap_documents = False
else:
if not silence_warnings:
warn("Caching documents...")
if documents is None:
if self.data and not isinstance(self.data[0], StoreDocument):
if to_documents_kwargs is None:
to_documents_kwargs = {}
documents = self.to_documents(**to_documents_kwargs)
if cache_documents and cache_key is not None and isinstance(documents, KnowledgeAsset):
asset_cache_manager.save_asset(documents, cache_key)
if wrap_documents is None:
wrap_documents = False
else:
documents = self.data
if wrap_documents is None:
wrap_documents = True
ranked_documents = rank_documents(query=query, documents=documents, **kwargs)
if not wrap_documents:
def _unwrap(document):
if isinstance(document, ScoredDocument):
return document.replace(
document=_unwrap(document.document),
child_documents=[_unwrap(d) for d in document.child_documents],
)
if isinstance(document, StoreDocument):
return document.data
return document
ranked_documents = list(map(_unwrap, ranked_documents))
return self.replace(data=ranked_documents)
def to_context(
self,
*args,
dump_all: tp.Optional[bool] = None,
separator: tp.Optional[str] = None,
**kwargs,
) -> str:
"""Convert to a context.
If `dump_all` is True, calls `KnowledgeAsset.dump_all` with `*args` and `**kwargs`.
Otherwise, calls `KnowledgeAsset.dump`.
Finally, calls `KnowledgeAsset.join` with `separator`."""
from vectorbtpro.utils.knowledge.chatting import StoreDocument, EmbeddedDocument, ScoredDocument
if dump_all is None:
dump_all = (
len(self.data) > 1
and not isinstance(self.data[0], (StoreDocument, EmbeddedDocument, ScoredDocument))
and separator is None
)
if dump_all:
dumped = self.dump_all(*args, **kwargs)
else:
dumped = self.dump(*args, **kwargs)
if isinstance(dumped, str):
return dumped
if not isinstance(dumped, KnowledgeAsset):
dumped = self.replace(data=dumped)
return dumped.join(separator=separator)
def print(self, *args, **kwargs) -> None:
"""Convert to a context and print.
Uses `KnowledgeAsset.to_context`."""
print(self.to_context(*args, **kwargs))
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for chatting.
See `vectorbtpro.utils.knowledge` for the toy dataset."""
import hashlib
import inspect
import re
import sys
from pathlib import Path
from collections.abc import MutableMapping
import numpy as np
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.attr_ import DefineMixin, define
from vectorbtpro.utils.config import merge_dicts, flat_merge_dicts, Configured, HasSettings, ExtSettingsPath
from vectorbtpro.utils.decorators import memoized_method, hybrid_method
from vectorbtpro.utils.knowledge.formatting import ContentFormatter, HTMLFileFormatter, resolve_formatter
from vectorbtpro.utils.parsing import get_func_arg_names, get_func_kwargs, get_forward_args
from vectorbtpro.utils.template import CustomTemplate, SafeSub, RepFunc
from vectorbtpro.utils.warnings_ import warn
try:
if not tp.TYPE_CHECKING:
raise ImportError
from tiktoken import Encoding as EncodingT
except ImportError:
EncodingT = "Encoding"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from openai import OpenAI as OpenAIT, Stream as StreamT
from openai.types.chat.chat_completion import ChatCompletion as ChatCompletionT
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunkT
except ImportError:
OpenAIT = "OpenAI"
StreamT = "Stream"
ChatCompletionT = "ChatCompletion"
ChatCompletionChunkT = "ChatCompletionChunk"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from litellm import ModelResponse as ModelResponseT, CustomStreamWrapper as CustomStreamWrapperT
except ImportError:
ModelResponseT = "ModelResponse"
CustomStreamWrapperT = "CustomStreamWrapper"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from llama_index.core.embeddings import BaseEmbedding as BaseEmbeddingT
from llama_index.core.llms import LLM as LLMT, ChatMessage as ChatMessageT, ChatResponse as ChatResponseT
from llama_index.core.node_parser import NodeParser as NodeParserT
except ImportError:
BaseEmbeddingT = "BaseEmbedding"
LLMT = "LLM"
ChatMessageT = "ChatMessage"
ChatResponseT = "ChatResponse"
NodeParserT = "NodeParser"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from IPython.display import DisplayHandle as DisplayHandleT
except ImportError:
DisplayHandleT = "DisplayHandle"
try:
if not tp.TYPE_CHECKING:
raise ImportError
from lmdbm import Lmdb as LmdbT
except ImportError:
LmdbT = "Lmdb"
__all__ = [
"Tokenizer",
"TikTokenizer",
"tokenize",
"detokenize",
"Embeddings",
"OpenAIEmbeddings",
"LiteLLMEmbeddings",
"LlamaIndexEmbeddings",
"embed",
"Completions",
"OpenAICompletions",
"LiteLLMCompletions",
"LlamaIndexCompletions",
"complete",
"TextSplitter",
"TokenSplitter",
"SegmentSplitter",
"LlamaIndexSplitter",
"split_text",
"StoreObject",
"StoreDocument",
"TextDocument",
"StoreEmbedding",
"ObjectStore",
"DictStore",
"MemoryStore",
"FileStore",
"LMDBStore",
"EmbeddedDocument",
"ScoredDocument",
"DocumentRanker",
"embed_documents",
"rank_documents",
"Rankable",
"Contextable",
"RankContextable",
]
# ############# Tokenizers ############# #
class Tokenizer(Configured):
"""Abstract class for tokenizers.
For defaults, see `knowledge.chat.tokenizer_config` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = None
"""Short name of the class."""
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.tokenizer_config"]
def __init__(self, template_context: tp.KwargsLike = None, **kwargs) -> None:
Configured.__init__(self, template_context=template_context, **kwargs)
template_context = self.resolve_setting(template_context, "template_context", merge=True)
self._template_context = template_context
@property
def template_context(self) -> tp.Kwargs:
"""Context used to substitute templates."""
return self._template_context
def encode(self, text: str) -> tp.Tokens:
"""Encode text into a list of tokens."""
raise NotImplementedError
def decode(self, tokens: tp.Tokens) -> str:
"""Decode a list of tokens into text."""
raise NotImplementedError
@memoized_method
def encode_single(self, text: str) -> tp.Token:
"""Encode text into a single token."""
tokens = self.encode(text)
if len(tokens) > 1:
raise ValueError("Text contains multiple tokens")
return tokens[0]
@memoized_method
def decode_single(self, token: tp.Token) -> str:
"""Decode a single token into text."""
return self.decode([token])
def count_tokens(self, text: str) -> int:
"""Count tokens in a text."""
return len(self.encode(text))
def count_tokens_in_messages(self, messages: tp.ChatMessages) -> int:
"""Count tokens in messages."""
raise NotImplementedError
class TikTokenizer(Tokenizer):
"""Tokenizer class for tiktoken.
Encoding can be a model name, an encoding name, or an encoding object for tokenization.
For defaults, see `chat.tokenizer_configs.tiktoken` in `vectorbtpro._settings.knowledge`."""
_short_name = "tiktoken"
_settings_path: tp.SettingsPath = "knowledge.chat.tokenizer_configs.tiktoken"
def __init__(
self,
encoding: tp.Union[None, str, EncodingT] = None,
model: tp.Optional[str] = None,
tokens_per_message: tp.Optional[int] = None,
tokens_per_name: tp.Optional[int] = None,
**kwargs,
) -> None:
Tokenizer.__init__(
self,
encoding=encoding,
model=model,
tokens_per_message=tokens_per_message,
tokens_per_name=tokens_per_name,
**kwargs,
)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("tiktoken")
from tiktoken import Encoding, get_encoding, encoding_for_model
encoding = self.resolve_setting(encoding, "encoding")
model = self.resolve_setting(model, "model")
tokens_per_message = self.resolve_setting(tokens_per_message, "tokens_per_message")
tokens_per_name = self.resolve_setting(tokens_per_name, "tokens_per_name")
if isinstance(encoding, str):
if encoding.startswith("model_or_"):
try:
if model is None:
raise KeyError
encoding = encoding_for_model(model)
except KeyError:
encoding = encoding[len("model_or_") :]
encoding = get_encoding(encoding) if "k_base" in encoding else encoding_for_model(encoding)
elif isinstance(encoding, str):
encoding = get_encoding(encoding) if "k_base" in encoding else encoding_for_model(encoding)
checks.assert_instance_of(encoding, Encoding, arg_name="encoding")
self._encoding = encoding
self._tokens_per_message = tokens_per_message
self._tokens_per_name = tokens_per_name
@property
def encoding(self) -> EncodingT:
"""Encoding."""
return self._encoding
@property
def tokens_per_message(self) -> int:
"""Tokens per message."""
return self._tokens_per_message
@property
def tokens_per_name(self) -> int:
"""Tokens per name."""
return self._tokens_per_name
def encode(self, text: str) -> tp.Tokens:
return self.encoding.encode(text)
def decode(self, tokens: tp.Tokens) -> str:
return self.encoding.decode(tokens)
def count_tokens_in_messages(self, messages: tp.ChatMessages) -> int:
num_tokens = 0
for message in messages:
num_tokens += self.tokens_per_message
for key, value in message.items():
num_tokens += self.count_tokens(value)
if key == "name":
num_tokens += self.tokens_per_name
num_tokens += 3
return num_tokens
def resolve_tokenizer(tokenizer: tp.TokenizerLike = None) -> tp.MaybeType[Tokenizer]:
"""Resolve a subclass or an instance of `Tokenizer`.
The following values are supported:
* "tiktoken" (`TikTokenizer`)
* A subclass or an instance of `Tokenizer`
"""
if tokenizer is None:
from vectorbtpro._settings import settings
chat_cfg = settings["knowledge"]["chat"]
tokenizer = chat_cfg["tokenizer"]
if isinstance(tokenizer, str):
curr_module = sys.modules[__name__]
found_tokenizer = None
for name, cls in inspect.getmembers(curr_module, inspect.isclass):
if name.endswith("Tokenizer"):
_short_name = getattr(cls, "_short_name", None)
if _short_name is not None and _short_name.lower() == tokenizer.lower():
found_tokenizer = cls
break
if found_tokenizer is None:
raise ValueError(f"Invalid tokenizer: '{tokenizer}'")
tokenizer = found_tokenizer
if isinstance(tokenizer, type):
checks.assert_subclass_of(tokenizer, Tokenizer, arg_name="tokenizer")
else:
checks.assert_instance_of(tokenizer, Tokenizer, arg_name="tokenizer")
return tokenizer
def tokenize(text: str, tokenizer: tp.TokenizerLike = None, **kwargs) -> tp.Tokens:
"""Tokenize text.
Resolves `tokenizer` with `resolve_tokenizer`. Keyword arguments are passed to either
initialize a class or replace an instance of `Tokenizer`."""
tokenizer = resolve_tokenizer(tokenizer=tokenizer)
if isinstance(tokenizer, type):
tokenizer = tokenizer(**kwargs)
elif kwargs:
tokenizer = tokenizer.replace(**kwargs)
return tokenizer.encode(text)
def detokenize(tokens: tp.Tokens, tokenizer: tp.TokenizerLike = None, **kwargs) -> str:
"""Detokenize text.
Resolves `tokenizer` with `resolve_tokenizer`. Keyword arguments are passed to either
initialize a class or replace an instance of `Tokenizer`."""
tokenizer = resolve_tokenizer(tokenizer=tokenizer)
if isinstance(tokenizer, type):
tokenizer = tokenizer(**kwargs)
elif kwargs:
tokenizer = tokenizer.replace(**kwargs)
return tokenizer.decode(tokens)
# ############# Embeddings ############# #
class Embeddings(Configured):
"""Abstract class for embedding providers.
For defaults, see `knowledge.chat.embeddings_config` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = None
"""Short name of the class."""
_expected_keys_mode: tp.ExpectedKeysMode = "disable"
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.embeddings_config"]
def __init__(
self,
batch_size: tp.Optional[int] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Configured.__init__(
self,
batch_size=batch_size,
show_progress=show_progress,
pbar_kwargs=pbar_kwargs,
template_context=template_context,
**kwargs,
)
batch_size = self.resolve_setting(batch_size, "batch_size")
show_progress = self.resolve_setting(show_progress, "show_progress")
pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True)
template_context = self.resolve_setting(template_context, "template_context", merge=True)
self._batch_size = batch_size
self._show_progress = show_progress
self._pbar_kwargs = pbar_kwargs
self._template_context = template_context
@property
def batch_size(self) -> tp.Optional[int]:
"""Batch size.
Set to None to disable batching."""
return self._batch_size
@property
def show_progress(self) -> tp.Optional[bool]:
"""Whether to show progress bar."""
return self._show_progress
@property
def pbar_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`."""
return self._pbar_kwargs
@property
def template_context(self) -> tp.Kwargs:
"""Context used to substitute templates."""
return self._template_context
@property
def model(self) -> tp.Optional[str]:
"""Model."""
return None
def get_embedding(self, query: str) -> tp.List[float]:
"""Get embedding for a query."""
raise NotImplementedError
def get_embedding_batch(self, batch: tp.List[str]) -> tp.List[tp.List[float]]:
"""Get embeddings for one batch of queries."""
return [self.get_embedding(query) for query in batch]
def iter_embedding_batches(self, queries: tp.List[str]) -> tp.Iterator[tp.List[tp.List[float]]]:
"""Get iterator of embedding batches."""
from vectorbtpro.utils.pbar import ProgressBar
if self.batch_size is not None:
batches = [queries[i : i + self.batch_size] for i in range(0, len(queries), self.batch_size)]
else:
batches = [queries]
pbar_kwargs = merge_dicts(dict(prefix="get_embeddings"), self.pbar_kwargs)
with ProgressBar(total=len(queries), show_progress=self.show_progress, **pbar_kwargs) as pbar:
for batch in batches:
yield self.get_embedding_batch(batch)
pbar.update(len(batch))
def get_embeddings(self, queries: tp.List[str]) -> tp.List[tp.List[float]]:
"""Get embeddings for multiple queries."""
return [embedding for batch in self.iter_embedding_batches(queries) for embedding in batch]
class OpenAIEmbeddings(Embeddings):
"""Embeddings class for OpenAI.
For defaults, see `chat.embeddings_configs.openai` in `vectorbtpro._settings.knowledge`."""
_short_name = "openai"
_settings_path: tp.SettingsPath = "knowledge.chat.embeddings_configs.openai"
def __init__(
self,
model: tp.Optional[str] = None,
batch_size: tp.Optional[int] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Embeddings.__init__(
self,
model=model,
batch_size=batch_size,
show_progress=show_progress,
pbar_kwargs=pbar_kwargs,
template_context=template_context,
**kwargs,
)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("openai")
from openai import OpenAI
openai_config = merge_dicts(self.get_settings(inherit=False), kwargs)
def_model = openai_config.pop("model", None)
if model is None:
model = def_model
if model is None:
raise ValueError("Must provide a model")
init_kwargs = get_func_kwargs(type(self).__init__)
for k in list(openai_config.keys()):
if k in init_kwargs:
openai_config.pop(k)
client_arg_names = set(get_func_arg_names(OpenAI.__init__))
client_kwargs = {}
embeddings_kwargs = {}
for k, v in openai_config.items():
if k in client_arg_names:
client_kwargs[k] = v
else:
embeddings_kwargs[k] = v
client = OpenAI(**client_kwargs)
self._model = model
self._client = client
self._embeddings_kwargs = embeddings_kwargs
@property
def model(self) -> str:
return self._model
@property
def client(self) -> OpenAIT:
"""Client."""
return self._client
@property
def embeddings_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `openai.resources.embeddings.Embeddings.create`."""
return self._embeddings_kwargs
def get_embedding(self, query: str) -> tp.List[float]:
response = self.client.embeddings.create(input=query, model=self.model, **self.embeddings_kwargs)
return response.data[0].embedding
def get_embedding_batch(self, batch: tp.List[str]) -> tp.List[tp.List[float]]:
response = self.client.embeddings.create(input=batch, model=self.model, **self.embeddings_kwargs)
return [embedding.embedding for embedding in response.data]
class LiteLLMEmbeddings(Embeddings):
"""Embeddings class for LiteLLM.
For defaults, see `chat.embeddings_configs.litellm` in `vectorbtpro._settings.knowledge`."""
_short_name = "litellm"
_settings_path: tp.SettingsPath = "knowledge.chat.embeddings_configs.litellm"
def __init__(
self,
model: tp.Optional[str] = None,
batch_size: tp.Optional[int] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Embeddings.__init__(
self,
model=model,
batch_size=batch_size,
show_progress=show_progress,
pbar_kwargs=pbar_kwargs,
template_context=template_context,
**kwargs,
)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("litellm")
litellm_config = merge_dicts(self.get_settings(inherit=False), kwargs)
def_model = litellm_config.pop("model", None)
if model is None:
model = def_model
if model is None:
raise ValueError("Must provide a model")
init_kwargs = get_func_kwargs(type(self).__init__)
for k in list(litellm_config.keys()):
if k in init_kwargs:
litellm_config.pop(k)
self._model = model
self._embedding_kwargs = litellm_config
@property
def model(self) -> str:
return self._model
@property
def embedding_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `litellm.embedding`."""
return self._embedding_kwargs
def get_embedding(self, query: str) -> tp.List[float]:
from litellm import embedding
response = embedding(self.model, input=query, **self.embedding_kwargs)
return response.data[0]["embedding"]
def get_embedding_batch(self, batch: tp.List[str]) -> tp.List[tp.List[float]]:
from litellm import embedding
response = embedding(self.model, input=batch, **self.embedding_kwargs)
return [embedding["embedding"] for embedding in response.data]
class LlamaIndexEmbeddings(Embeddings):
"""Embeddings class for LlamaIndex.
For defaults, see `chat.embeddings_configs.llama_index` in `vectorbtpro._settings.knowledge`."""
_short_name = "llama_index"
_settings_path: tp.SettingsPath = "knowledge.chat.embeddings_configs.llama_index"
def __init__(
self,
embedding: tp.Union[None, str, tp.MaybeType[BaseEmbeddingT]] = None,
batch_size: tp.Optional[int] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Embeddings.__init__(
self,
embedding=embedding,
batch_size=batch_size,
show_progress=show_progress,
pbar_kwargs=pbar_kwargs,
template_context=template_context,
**kwargs,
)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("llama_index")
from llama_index.core.embeddings import BaseEmbedding
llama_index_config = merge_dicts(self.get_settings(inherit=False), kwargs)
def_embedding = llama_index_config.pop("embedding", None)
if embedding is None:
embedding = def_embedding
if embedding is None:
raise ValueError("Must provide an embedding name or path")
init_kwargs = get_func_kwargs(type(self).__init__)
for k in list(llama_index_config.keys()):
if k in init_kwargs:
llama_index_config.pop(k)
if isinstance(embedding, str):
import llama_index.embeddings
from vectorbtpro.utils.module_ import search_package
def _match_func(k, v):
if isinstance(v, type) and issubclass(v, BaseEmbedding):
if "." in embedding:
if k.endswith(embedding):
return True
else:
if k.split(".")[-1].lower() == embedding.lower():
return True
if k.split(".")[-1].replace("Embedding", "").lower() == embedding.lower().replace("_", ""):
return True
return False
found_embedding = search_package(
llama_index.embeddings,
_match_func,
path_attrs=True,
return_first=True,
)
if found_embedding is None:
raise ValueError(f"Embedding '{embedding}' not found")
embedding = found_embedding
if isinstance(embedding, type):
checks.assert_subclass_of(embedding, BaseEmbedding, arg_name="embedding")
embedding_name = embedding.__name__.replace("Embedding", "").lower()
module_name = embedding.__module__
else:
checks.assert_instance_of(embedding, BaseEmbedding, arg_name="embedding")
embedding_name = type(embedding).__name__.replace("Embedding", "").lower()
module_name = type(embedding).__module__
embedding_configs = llama_index_config.pop("embedding_configs", {})
if embedding_name in embedding_configs:
llama_index_config = merge_dicts(llama_index_config, embedding_configs[embedding_name])
elif module_name in embedding_configs:
llama_index_config = merge_dicts(llama_index_config, embedding_configs[module_name])
if isinstance(embedding, type):
embedding = embedding(**llama_index_config)
elif len(kwargs) > 0:
raise ValueError("Cannot apply config to already initialized embedding")
model_name = llama_index_config.get("model_name", None)
if model_name is None:
func_kwargs = get_func_kwargs(type(embedding).__init__)
model_name = func_kwargs.get("model_name", None)
self._model = model_name
self._embedding = embedding
@property
def model(self) -> tp.Optional[str]:
return self._model
@property
def embedding(self) -> BaseEmbeddingT:
"""Embedding."""
return self._embedding
def get_embedding(self, query: str) -> tp.List[float]:
return self.embedding.get_text_embedding(query)
def get_embedding_batch(self, batch: tp.List[str]) -> tp.List[tp.List[float]]:
return [embedding for embedding in self.embedding.get_text_embedding_batch(batch)]
def resolve_embeddings(embeddings: tp.EmbeddingsLike = None) -> tp.MaybeType[Embeddings]:
"""Resolve a subclass or an instance of `Embeddings`.
The following values are supported:
* "openai" (`OpenAIEmbeddings`)
* "litellm" (`LiteLLMEmbeddings`)
* "llama_index" (`LlamaIndexEmbeddings`)
* "auto": Any installed from above, in the same order
* A subclass or an instance of `Embeddings`
"""
if embeddings is None:
from vectorbtpro._settings import settings
chat_cfg = settings["knowledge"]["chat"]
embeddings = chat_cfg["embeddings"]
if isinstance(embeddings, str):
if embeddings.lower() == "auto":
from vectorbtpro.utils.module_ import check_installed
if check_installed("openai"):
embeddings = "openai"
elif check_installed("litellm"):
embeddings = "litellm"
elif check_installed("llama_index"):
embeddings = "llama_index"
else:
raise ValueError("No packages for embeddings installed")
curr_module = sys.modules[__name__]
found_embeddings = None
for name, cls in inspect.getmembers(curr_module, inspect.isclass):
if name.endswith("Embeddings"):
_short_name = getattr(cls, "_short_name", None)
if _short_name is not None and _short_name.lower() == embeddings.lower():
found_embeddings = cls
break
if found_embeddings is None:
raise ValueError(f"Invalid embeddings: '{embeddings}'")
embeddings = found_embeddings
if isinstance(embeddings, type):
checks.assert_subclass_of(embeddings, Embeddings, arg_name="embeddings")
else:
checks.assert_instance_of(embeddings, Embeddings, arg_name="embeddings")
return embeddings
def embed(query: tp.MaybeList[str], embeddings: tp.EmbeddingsLike = None, **kwargs) -> tp.MaybeList[tp.List[float]]:
"""Get embedding(s) for one or more queries.
Resolves `embeddings` with `resolve_embeddings`. Keyword arguments are passed to either
initialize a class or replace an instance of `Embeddings`."""
embeddings = resolve_embeddings(embeddings=embeddings)
if isinstance(embeddings, type):
embeddings = embeddings(**kwargs)
elif kwargs:
embeddings = embeddings.replace(**kwargs)
if isinstance(query, str):
return embeddings.get_embedding(query)
return embeddings.get_embeddings(query)
# ############# Completions ############# #
class Completions(Configured):
"""Abstract class for completion providers.
For argument descriptions, see their properties, like `Completions.chat_history`.
For defaults, see `knowledge.chat.completions_config` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = None
"""Short name of the class."""
_expected_keys_mode: tp.ExpectedKeysMode = "disable"
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.completions_config"]
def __init__(
self,
context: str = "",
chat_history: tp.Optional[tp.ChatHistory] = None,
stream: tp.Optional[bool] = None,
max_tokens: tp.Optional[int] = None,
tokenizer: tp.TokenizerLike = None,
tokenizer_kwargs: tp.KwargsLike = None,
system_prompt: tp.Optional[str] = None,
system_as_user: tp.Optional[bool] = None,
context_prompt: tp.Optional[str] = None,
formatter: tp.ContentFormatterLike = None,
formatter_kwargs: tp.KwargsLike = None,
minimal_format: tp.Optional[bool] = None,
silence_warnings: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Configured.__init__(
self,
context=context,
chat_history=chat_history,
stream=stream,
max_tokens=max_tokens,
tokenizer=tokenizer,
tokenizer_kwargs=tokenizer_kwargs,
system_prompt=system_prompt,
system_as_user=system_as_user,
context_prompt=context_prompt,
formatter=formatter,
formatter_kwargs=formatter_kwargs,
minimal_format=minimal_format,
silence_warnings=silence_warnings,
template_context=template_context,
**kwargs,
)
if chat_history is None:
chat_history = []
stream = self.resolve_setting(stream, "stream")
max_tokens_set = max_tokens is not None
max_tokens = self.resolve_setting(max_tokens, "max_tokens")
tokenizer = self.resolve_setting(tokenizer, "tokenizer", default=None)
tokenizer_kwargs = self.resolve_setting(tokenizer_kwargs, "tokenizer_kwargs", default=None, merge=True)
system_prompt = self.resolve_setting(system_prompt, "system_prompt")
system_as_user = self.resolve_setting(system_as_user, "system_as_user")
context_prompt = self.resolve_setting(context_prompt, "context_prompt")
formatter = self.resolve_setting(formatter, "formatter", default=None)
formatter_kwargs = self.resolve_setting(formatter_kwargs, "formatter_kwargs", default=None, merge=True)
minimal_format = self.resolve_setting(minimal_format, "minimal_format", default=None)
silence_warnings = self.resolve_setting(silence_warnings, "silence_warnings")
template_context = self.resolve_setting(template_context, "template_context", merge=True)
tokenizer = resolve_tokenizer(tokenizer)
formatter = resolve_formatter(formatter)
self._context = context
self._chat_history = chat_history
self._stream = stream
self._max_tokens_set = max_tokens_set
self._max_tokens = max_tokens
self._tokenizer = tokenizer
self._tokenizer_kwargs = tokenizer_kwargs
self._system_prompt = system_prompt
self._system_as_user = system_as_user
self._context_prompt = context_prompt
self._formatter = formatter
self._formatter_kwargs = formatter_kwargs
self._minimal_format = minimal_format
self._silence_warnings = silence_warnings
self._template_context = template_context
@property
def context(self) -> str:
"""Context.
Becomes a user message."""
return self._context
@property
def chat_history(self) -> tp.ChatHistory:
"""Chat history.
Must be list of dictionaries with proper roles.
After generating a response, the output will be appended to this sequence as an assistant message."""
return self._chat_history
@property
def stream(self) -> bool:
"""Whether to stream the response.
When streaming, appends chunks one by one and displays the intermediate result.
Otherwise, displays the entire message."""
return self._stream
@property
def max_tokens_set(self) -> tp.Optional[int]:
"""Whether the user provided `max_tokens`."""
return self._max_tokens_set
@property
def max_tokens(self) -> tp.Optional[int]:
"""Maximum number of tokens in messages."""
return self._max_tokens
@property
def tokenizer(self) -> tp.MaybeType[Tokenizer]:
"""A subclass or an instance of `Tokenizer`.
Resolved with `resolve_tokenizer`."""
return self._tokenizer
@property
def tokenizer_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `Completions.tokenizer`.
Used either to initialize a class or replace an instance of `Tokenizer`."""
return self._tokenizer_kwargs
@property
def system_prompt(self) -> str:
"""System prompt.
Precedes the context prompt."""
return self._system_prompt
@property
def system_as_user(self) -> bool:
"""Whether to use the user role for the system message.
Mainly for experimental models where the system role is not available."""
return self._system_as_user
@property
def context_prompt(self) -> str:
"""Context prompt.
A prompt template requiring the variable "context". The prompt can be either a custom template,
or string or function that will become one. Once the prompt is evaluated, it becomes a user message."""
return self._context_prompt
@property
def formatter(self) -> tp.MaybeType[ContentFormatter]:
"""A subclass or an instance of `vectorbtpro.utils.knowledge.formatting.ContentFormatter`.
Resolved with `vectorbtpro.utils.knowledge.formatting.resolve_formatter`."""
return self._formatter
@property
def formatter_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `Completions.formatter`.
Used either to initialize a class or replace an instance of
`vectorbtpro.utils.knowledge.formatting.ContentFormatter`."""
return self._formatter_kwargs
@property
def minimal_format(self) -> bool:
"""Whether input is minimally-formatted."""
return self._minimal_format
@property
def silence_warnings(self) -> bool:
"""Whether to silence warnings."""
return self._silence_warnings
@property
def template_context(self) -> tp.Kwargs:
"""Context used to substitute templates."""
return self._template_context
@property
def model(self) -> tp.Optional[str]:
"""Model."""
return None
def get_chat_response(self, messages: tp.ChatMessages, **kwargs) -> tp.Any:
"""Get chat response to messages."""
raise NotImplementedError
def get_message_content(self, response: tp.Any) -> tp.Optional[str]:
"""Get content from a chat response."""
raise NotImplementedError
def get_stream_response(self, messages: tp.ChatMessages, **kwargs) -> tp.Any:
"""Get streaming response to messages."""
raise NotImplementedError
def get_delta_content(self, response: tp.Any) -> tp.Optional[str]:
"""Get content from a streaming response chunk."""
raise NotImplementedError
def prepare_messages(self, message: str) -> tp.ChatMessages:
"""Prepare messages for a completion."""
context = self.context
chat_history = self.chat_history
max_tokens_set = self.max_tokens_set
max_tokens = self.max_tokens
tokenizer = self.tokenizer
tokenizer_kwargs = self.tokenizer_kwargs
system_prompt = self.system_prompt
system_as_user = self.system_as_user
context_prompt = self.context_prompt
template_context = self.template_context
silence_warnings = self.silence_warnings
if isinstance(tokenizer, type):
tokenizer_kwargs = dict(tokenizer_kwargs)
tokenizer_kwargs["template_context"] = merge_dicts(
template_context, tokenizer_kwargs.get("template_context", None)
)
if issubclass(tokenizer, TikTokenizer) and "model" not in tokenizer_kwargs:
tokenizer_kwargs["model"] = self.model
tokenizer = tokenizer(**tokenizer_kwargs)
elif tokenizer_kwargs:
tokenizer = tokenizer.replace(**tokenizer_kwargs)
if context:
if isinstance(context_prompt, str):
context_prompt = SafeSub(context_prompt)
elif checks.is_function(context_prompt):
context_prompt = RepFunc(context_prompt)
elif not isinstance(context_prompt, CustomTemplate):
raise TypeError(f"Context prompt must be a string, function, or template")
if max_tokens is not None:
empty_context_prompt = context_prompt.substitute(
flat_merge_dicts(dict(context=""), template_context),
eval_id="context_prompt",
)
empty_messages = [
dict(role="user" if system_as_user else "system", content=system_prompt),
dict(role="user", content=empty_context_prompt),
*chat_history,
dict(role="user", content=message),
]
num_tokens = tokenizer.count_tokens_in_messages(empty_messages)
max_context_tokens = max(0, max_tokens - num_tokens)
encoded_context = tokenizer.encode(context)
if len(encoded_context) > max_context_tokens:
context = tokenizer.decode(encoded_context[:max_context_tokens])
if not max_tokens_set and not silence_warnings:
warn(
f"Context is too long ({len(encoded_context)}). "
f"Truncating to {max_context_tokens} tokens."
)
template_context = flat_merge_dicts(dict(context=context), template_context)
context_prompt = context_prompt.substitute(template_context, eval_id="context_prompt")
return [
dict(role="user" if system_as_user else "system", content=system_prompt),
dict(role="user", content=context_prompt),
*chat_history,
dict(role="user", content=message),
]
else:
return [
dict(role="user" if system_as_user else "system", content=system_prompt),
*chat_history,
dict(role="user", content=message),
]
def get_completion(
self,
message: str,
return_response: bool = False,
) -> tp.ChatOutput:
"""Get completion for a message."""
chat_history = self.chat_history
stream = self.stream
formatter = self.formatter
formatter_kwargs = self.formatter_kwargs
template_context = self.template_context
messages = self.prepare_messages(message)
if self.stream:
response = self.get_stream_response(messages)
else:
response = self.get_chat_response(messages)
if isinstance(formatter, type):
formatter_kwargs = dict(formatter_kwargs)
if "minimal_format" not in formatter_kwargs:
formatter_kwargs["minimal_format"] = self.minimal_format
formatter_kwargs["template_context"] = merge_dicts(
template_context, formatter_kwargs.get("template_context", None)
)
if issubclass(formatter, HTMLFileFormatter):
if "page_title" not in formatter_kwargs:
formatter_kwargs["page_title"] = message
if "cache_dir" not in formatter_kwargs:
chat_dir = self.get_setting("chat_dir", default=None)
if isinstance(chat_dir, CustomTemplate):
cache_dir = self.get_setting("cache_dir", default=None)
if cache_dir is not None:
if isinstance(cache_dir, CustomTemplate):
cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir")
template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context)
release_dir = self.get_setting("release_dir", default=None)
if release_dir is not None:
if isinstance(release_dir, CustomTemplate):
release_dir = release_dir.substitute(template_context, eval_id="release_dir")
template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context)
chat_dir = chat_dir.substitute(template_context, eval_id="chat_dir")
chat_dir = Path(chat_dir) / "html"
formatter_kwargs["dir_path"] = chat_dir
formatter = formatter(**formatter_kwargs)
elif formatter_kwargs:
formatter = formatter.replace(**formatter_kwargs)
if stream:
with formatter:
for i, response_chunk in enumerate(response):
new_content = self.get_delta_content(response_chunk)
if new_content is not None:
formatter.append(new_content)
content = formatter.content
else:
content = self.get_message_content(response)
if content is None:
content = ""
formatter.append_once(content)
chat_history.append(dict(role="user", content=message))
chat_history.append(dict(role="assistant", content=content))
if isinstance(formatter, HTMLFileFormatter) and formatter.file_handle is not None:
file_path = Path(formatter.file_handle.name)
else:
file_path = None
if return_response:
return file_path, response
return file_path
class OpenAICompletions(Completions):
"""Completions class for OpenAI.
Keyword arguments are distributed between the client call and the completion call.
For defaults, see `chat.completions_configs.openai` in `vectorbtpro._settings.knowledge`."""
_short_name = "openai"
_settings_path: tp.SettingsPath = "knowledge.chat.completions_configs.openai"
def __init__(
self,
context: str = "",
chat_history: tp.Optional[tp.ChatHistory] = None,
stream: tp.Optional[bool] = None,
max_tokens: tp.Optional[int] = None,
tokenizer: tp.TokenizerLike = None,
tokenizer_kwargs: tp.KwargsLike = None,
system_prompt: tp.Optional[str] = None,
system_as_user: tp.Optional[bool] = None,
context_prompt: tp.Optional[str] = None,
formatter: tp.ContentFormatterLike = None,
formatter_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
model: tp.Optional[str] = None,
**kwargs,
) -> None:
Completions.__init__(
self,
context=context,
chat_history=chat_history,
stream=stream,
max_tokens=max_tokens,
tokenizer=tokenizer,
tokenizer_kwargs=tokenizer_kwargs,
system_prompt=system_prompt,
system_as_user=system_as_user,
context_prompt=context_prompt,
formatter=formatter,
formatter_kwargs=formatter_kwargs,
silence_warnings=silence_warnings,
template_context=template_context,
model=model,
**kwargs,
)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("openai")
from openai import OpenAI
openai_config = merge_dicts(self.get_settings(inherit=False), kwargs)
def_model = openai_config.pop("model", None)
if model is None:
model = def_model
if model is None:
raise ValueError("Must provide a model")
init_kwargs = get_func_kwargs(type(self).__init__)
for k in list(openai_config.keys()):
if k in init_kwargs:
openai_config.pop(k)
client_arg_names = set(get_func_arg_names(OpenAI.__init__))
client_kwargs = {}
completion_kwargs = {}
for k, v in openai_config.items():
if k in client_arg_names:
client_kwargs[k] = v
else:
completion_kwargs[k] = v
client = OpenAI(**client_kwargs)
self._model = model
self._client = client
self._completion_kwargs = completion_kwargs
@property
def model(self) -> str:
return self._model
@property
def client(self) -> OpenAIT:
"""Client."""
return self._client
@property
def completion_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `openai.resources.chat.completions_configs.Completions.create`."""
return self._completion_kwargs
def get_chat_response(self, messages: tp.ChatMessages) -> ChatCompletionT:
return self.client.chat.completions.create(
messages=messages,
model=self.model,
stream=False,
**self.completion_kwargs,
)
def get_message_content(self, response: ChatCompletionT) -> tp.Optional[str]:
return response.choices[0].message.content
def get_stream_response(self, messages: tp.ChatMessages) -> StreamT:
return self.client.chat.completions.create(
messages=messages,
model=self.model,
stream=True,
**self.completion_kwargs,
)
def get_delta_content(self, response_chunk: ChatCompletionChunkT) -> tp.Optional[str]:
return response_chunk.choices[0].delta.content
class LiteLLMCompletions(Completions):
"""Completions class for LiteLLM.
Keyword arguments are passed to the completion call.
For defaults, see `chat.completions_configs.litellm` in `vectorbtpro._settings.knowledge`."""
_short_name = "litellm"
_settings_path: tp.SettingsPath = "knowledge.chat.completions_configs.litellm"
def __init__(
self,
context: str = "",
chat_history: tp.Optional[tp.ChatHistory] = None,
stream: tp.Optional[bool] = None,
max_tokens: tp.Optional[int] = None,
tokenizer: tp.TokenizerLike = None,
tokenizer_kwargs: tp.KwargsLike = None,
system_prompt: tp.Optional[str] = None,
system_as_user: tp.Optional[bool] = None,
context_prompt: tp.Optional[str] = None,
formatter: tp.ContentFormatterLike = None,
formatter_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
model: tp.Optional[str] = None,
**kwargs,
) -> None:
Completions.__init__(
self,
context=context,
chat_history=chat_history,
stream=stream,
max_tokens=max_tokens,
tokenizer=tokenizer,
tokenizer_kwargs=tokenizer_kwargs,
system_prompt=system_prompt,
system_as_user=system_as_user,
context_prompt=context_prompt,
formatter=formatter,
formatter_kwargs=formatter_kwargs,
silence_warnings=silence_warnings,
template_context=template_context,
model=model,
**kwargs,
)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("litellm")
completion_kwargs = merge_dicts(self.get_settings(inherit=False), kwargs)
def_model = completion_kwargs.pop("model", None)
if model is None:
model = def_model
if model is None:
raise ValueError("Must provide a model")
self._model = model
self._completion_kwargs = completion_kwargs
@property
def model(self) -> str:
return self._model
@property
def completion_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `litellm.completion`."""
return self._completion_kwargs
def get_chat_response(self, messages: tp.ChatMessages) -> ModelResponseT:
from litellm import completion
return completion(
messages=messages,
model=self.model,
stream=False,
**self.completion_kwargs,
)
def get_message_content(self, response: ModelResponseT) -> tp.Optional[str]:
return response.choices[0].message.content
def get_stream_response(self, messages: tp.ChatMessages) -> CustomStreamWrapperT:
from litellm import completion
return completion(
messages=messages,
model=self.model,
stream=True,
**self.completion_kwargs,
)
def get_delta_content(self, response_chunk: ModelResponseT) -> tp.Optional[str]:
return response_chunk.choices[0].delta.content
class LlamaIndexCompletions(Completions):
"""Completions class for LlamaIndex.
LLM can be provided via `llm`, which can be either the name of the class (case doesn't matter),
the path or its suffix to the class (case matters), or a subclass or an instance of
`llama_index.core.llms.LLM`.
Keyword arguments are passed to the resolved LLM.
For defaults, see `chat.completions_configs.llama_index` in `vectorbtpro._settings.knowledge`."""
_short_name = "llama_index"
_settings_path: tp.SettingsPath = "knowledge.chat.completions_configs.llama_index"
def __init__(
self,
context: str = "",
chat_history: tp.Optional[tp.ChatHistory] = None,
stream: tp.Optional[bool] = None,
max_tokens: tp.Optional[int] = None,
tokenizer: tp.TokenizerLike = None,
tokenizer_kwargs: tp.KwargsLike = None,
system_prompt: tp.Optional[str] = None,
system_as_user: tp.Optional[bool] = None,
context_prompt: tp.Optional[str] = None,
formatter: tp.ContentFormatterLike = None,
formatter_kwargs: tp.KwargsLike = None,
silence_warnings: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
llm: tp.Union[None, str, tp.MaybeType[LLMT]] = None,
**kwargs,
) -> None:
Completions.__init__(
self,
context=context,
chat_history=chat_history,
stream=stream,
max_tokens=max_tokens,
tokenizer=tokenizer,
tokenizer_kwargs=tokenizer_kwargs,
system_prompt=system_prompt,
system_as_user=system_as_user,
context_prompt=context_prompt,
formatter=formatter,
formatter_kwargs=formatter_kwargs,
silence_warnings=silence_warnings,
template_context=template_context,
llm=llm,
**kwargs,
)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("llama_index")
from llama_index.core.llms import LLM
llama_index_config = merge_dicts(self.get_settings(inherit=False), kwargs)
def_llm = llama_index_config.pop("llm", None)
if llm is None:
llm = def_llm
if llm is None:
raise ValueError("Must provide an LLM name or path")
init_kwargs = get_func_kwargs(type(self).__init__)
for k in list(llama_index_config.keys()):
if k in init_kwargs:
llama_index_config.pop(k)
if isinstance(llm, str):
import llama_index.llms
from vectorbtpro.utils.module_ import search_package
def _match_func(k, v):
if isinstance(v, type) and issubclass(v, LLM):
if "." in llm:
if k.endswith(llm):
return True
else:
if k.split(".")[-1].lower() == llm.lower():
return True
if k.split(".")[-1].replace("LLM", "").lower() == llm.lower().replace("_", ""):
return True
return False
found_llm = search_package(
llama_index.llms,
_match_func,
path_attrs=True,
return_first=True,
)
if found_llm is None:
raise ValueError(f"LLM '{llm}' not found")
llm = found_llm
if isinstance(llm, type):
checks.assert_subclass_of(llm, LLM, arg_name="llm")
llm_name = llm.__name__.replace("LLM", "").lower()
module_name = llm.__module__
else:
checks.assert_instance_of(llm, LLM, arg_name="llm")
llm_name = type(llm).__name__.replace("LLM", "").lower()
module_name = type(llm).__module__
llm_configs = llama_index_config.pop("llm_configs", {})
if llm_name in llm_configs:
llama_index_config = merge_dicts(llama_index_config, llm_configs[llm_name])
elif module_name in llm_configs:
llama_index_config = merge_dicts(llama_index_config, llm_configs[module_name])
if isinstance(llm, type):
llm = llm(**llama_index_config)
elif len(kwargs) > 0:
raise ValueError("Cannot apply config to already initialized LLM")
model = llama_index_config.get("model", None)
if model is None:
func_kwargs = get_func_kwargs(type(llm).__init__)
model = func_kwargs.get("model", None)
self._model = model
self._llm = llm
@property
def model(self) -> tp.Optional[str]:
return self._model
@property
def llm(self) -> LLMT:
"""LLM."""
return self._llm
def get_chat_response(self, messages: tp.ChatMessages) -> ChatResponseT:
from llama_index.core.llms import ChatMessage
return self.llm.chat(list(map(lambda x: ChatMessage(**dict(x)), messages)))
def get_message_content(self, response: ChatResponseT) -> tp.Optional[str]:
return response.message.content
def get_stream_response(self, messages: tp.ChatMessages) -> tp.Iterator[ChatResponseT]:
from llama_index.core.llms import ChatMessage
return self.llm.stream_chat(list(map(lambda x: ChatMessage(**dict(x)), messages)))
def get_delta_content(self, response_chunk: ChatResponseT) -> tp.Optional[str]:
return response_chunk.delta
def resolve_completions(completions: tp.CompletionsLike = None) -> tp.MaybeType[Completions]:
"""Resolve a subclass or an instance of `Completions`.
The following values are supported:
* "openai" (`OpenAICompletions`)
* "litellm" (`LiteLLMCompletions`)
* "llama_index" (`LlamaIndexCompletions`)
* "auto": Any installed from above, in the same order
* A subclass or an instance of `Completions`
"""
if completions is None:
from vectorbtpro._settings import settings
chat_cfg = settings["knowledge"]["chat"]
completions = chat_cfg["completions"]
if isinstance(completions, str):
if completions.lower() == "auto":
from vectorbtpro.utils.module_ import check_installed
if check_installed("openai"):
completions = "openai"
elif check_installed("litellm"):
completions = "litellm"
elif check_installed("llama_index"):
completions = "llama_index"
else:
raise ValueError("No packages for completions installed")
curr_module = sys.modules[__name__]
found_completions = None
for name, cls in inspect.getmembers(curr_module, inspect.isclass):
if name.endswith("Completions"):
_short_name = getattr(cls, "_short_name", None)
if _short_name is not None and _short_name.lower() == completions.lower():
found_completions = cls
break
if found_completions is None:
raise ValueError(f"Invalid completions: '{completions}'")
completions = found_completions
if isinstance(completions, type):
checks.assert_subclass_of(completions, Completions, arg_name="completions")
else:
checks.assert_instance_of(completions, Completions, arg_name="completions")
return completions
def complete(message: str, completions: tp.CompletionsLike = None, **kwargs) -> tp.ChatOutput:
"""Get completion for a message.
Resolves `completions` with `resolve_completions`. Keyword arguments are passed to either
initialize a class or replace an instance of `Completions`."""
completions = resolve_completions(completions=completions)
if isinstance(completions, type):
completions = completions(**kwargs)
elif kwargs:
completions = completions.replace(**kwargs)
return completions.get_completion(message)
# ############# Splitting ############# #
class TextSplitter(Configured):
"""Abstract class for text splitters.
For defaults, see `knowledge.chat.text_splitter_config` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = None
"""Short name of the class."""
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.text_splitter_config"]
def __init__(
self,
chunk_template: tp.Optional[tp.CustomTemplateLike] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Configured.__init__(
self,
chunk_template=chunk_template,
template_context=template_context,
**kwargs,
)
chunk_template = self.resolve_setting(chunk_template, "chunk_template")
template_context = self.resolve_setting(template_context, "template_context", merge=True)
self._chunk_template = chunk_template
self._template_context = template_context
@property
def chunk_template(self) -> tp.Kwargs:
"""Chunk template.
Can use the following context: `chunk_idx`, `chunk_start`, `chunk_end`, `chunk_text`, and `text`."""
return self._chunk_template
@property
def template_context(self) -> tp.Kwargs:
"""Context used to substitute templates."""
return self._template_context
def split(self, text: str) -> tp.TSRangeChunks:
"""Split text and yield start character and end character position of each chunk."""
raise NotImplementedError
def split_text(self, text: str) -> tp.TSTextChunks:
"""Split text and return text chunks."""
for chunk_idx, (chunk_start, chunk_end) in enumerate(self.split(text)):
chunk_text = text[chunk_start:chunk_end]
chunk_template = self.chunk_template
if isinstance(chunk_template, str):
chunk_template = SafeSub(chunk_template)
elif checks.is_function(chunk_template):
chunk_template = RepFunc(chunk_template)
elif not isinstance(chunk_template, CustomTemplate):
raise TypeError(f"Chunk template must be a string, function, or template")
template_context = flat_merge_dicts(
dict(
chunk_idx=chunk_idx,
chunk_start=chunk_start,
chunk_end=chunk_end,
chunk_text=chunk_text,
text=text,
),
self.template_context,
)
yield chunk_template.substitute(template_context, eval_id="chunk_template")
class TokenSplitter(TextSplitter):
"""Splitter class for tokens.
For defaults, see `chat.text_splitter_configs.token` in `vectorbtpro._settings.knowledge`."""
_short_name = "token"
_settings_path: tp.SettingsPath = "knowledge.chat.text_splitter_configs.token"
def __init__(
self,
chunk_size: tp.Optional[int] = None,
chunk_overlap: tp.Union[None, int, float] = None,
tokenizer: tp.TokenizerLike = None,
tokenizer_kwargs: tp.KwargsLike = None,
**kwargs,
) -> None:
TextSplitter.__init__(
self,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
tokenizer=tokenizer,
tokenizer_kwargs=tokenizer_kwargs,
**kwargs,
)
chunk_size = self.resolve_setting(chunk_size, "chunk_size")
chunk_overlap = self.resolve_setting(chunk_overlap, "chunk_overlap")
tokenizer = self.resolve_setting(tokenizer, "tokenizer", default=None)
tokenizer_kwargs = self.resolve_setting(tokenizer_kwargs, "tokenizer_kwargs", default=None, merge=True)
tokenizer = resolve_tokenizer(tokenizer)
if isinstance(tokenizer, type):
tokenizer_kwargs = dict(tokenizer_kwargs)
tokenizer_kwargs["template_context"] = merge_dicts(
self.template_context, tokenizer_kwargs.get("template_context", None)
)
tokenizer = tokenizer(**tokenizer_kwargs)
elif tokenizer_kwargs:
tokenizer = tokenizer.replace(**tokenizer_kwargs)
if checks.is_float(chunk_overlap):
if 0 <= abs(chunk_overlap) <= 1:
chunk_overlap = chunk_overlap * chunk_size
elif not chunk_overlap.is_integer():
raise TypeError("Floating number for chunk_overlap must be between 0 and 1")
chunk_overlap = int(chunk_overlap)
if chunk_overlap >= chunk_size:
raise ValueError("Chunk overlap must be less than the chunk size")
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._tokenizer = tokenizer
@property
def chunk_size(self) -> int:
"""Maximum number of tokens per chunk."""
return self._chunk_size
@property
def chunk_overlap(self) -> int:
"""Number of overlapping tokens between chunks.
Can also be provided as a floating number relative to `SegmentSplitter.chunk_size`."""
return self._chunk_overlap
@property
def tokenizer(self) -> Tokenizer:
"""An instance of `Tokenizer`."""
return self._tokenizer
def split_into_tokens(self, text: str) -> tp.TSRangeChunks:
"""Split text into tokens."""
tokens = self.tokenizer.encode(text)
last_end = 0
for token in tokens:
_text = self.tokenizer.decode_single(token)
start = last_end
end = start + len(_text)
yield start, end
last_end = end
def split(self, text: str) -> tp.TSRangeChunks:
tokens = list(self.split_into_tokens(text))
total_tokens = len(tokens)
if not tokens:
return
token_count = 0
while token_count < total_tokens:
chunk_tokens = tokens[token_count:token_count + self.chunk_size]
chunk_start = chunk_tokens[0][0]
chunk_end = chunk_tokens[-1][1]
yield chunk_start, chunk_end
if token_count + self.chunk_size >= total_tokens:
break
token_count += self.chunk_size - self.chunk_overlap
class SegmentSplitter(TokenSplitter):
"""Splitter class for segments based on separators.
If a segment is too big, the next separator within the same layer is taken to split the segment
into smaller segments. If a segment is too big and there are no segments previously added
to the chunk, or, if the number of tokens is less than the minimal count, the next layer is taken.
To split into tokens, set any separator to None. To split into characters, use an empty string.
For defaults, see `chat.text_splitter_configs.segment` in `vectorbtpro._settings.knowledge`."""
_short_name = "segment"
_settings_path: tp.SettingsPath = "knowledge.chat.text_splitter_configs.segment"
def __init__(
self,
separators: tp.MaybeList[tp.MaybeList[tp.Optional[str]]] = None,
min_chunk_size: tp.Union[None, int, float] = None,
fixed_overlap: tp.Optional[bool] = None,
**kwargs,
) -> None:
TokenSplitter.__init__(
self,
separators=separators,
min_chunk_size=min_chunk_size,
fixed_overlap=fixed_overlap,
**kwargs,
)
separators = self.resolve_setting(separators, "separators")
min_chunk_size = self.resolve_setting(min_chunk_size, "min_chunk_size")
fixed_overlap = self.resolve_setting(fixed_overlap, "fixed_overlap")
if not isinstance(separators, list):
separators = [separators]
else:
separators = list(separators)
for layer in range(len(separators)):
if not isinstance(separators[layer], list):
separators[layer] = [separators[layer]]
else:
separators[layer] = list(separators[layer])
if checks.is_float(min_chunk_size):
if 0 <= abs(min_chunk_size) <= 1:
min_chunk_size = min_chunk_size * self.chunk_size
elif not min_chunk_size.is_integer():
raise TypeError("Floating number for min_chunk_size must be between 0 and 1")
min_chunk_size = int(min_chunk_size)
self._separators = separators
self._min_chunk_size = min_chunk_size
self._fixed_overlap = fixed_overlap
@property
def separators(self) -> tp.List[tp.List[tp.Optional[str]]]:
"""Nested list of separators grouped into layers."""
return self._separators
@property
def min_chunk_size(self) -> int:
"""Minimum number of tokens per chunk.
Can also be provided as a floating number relative to `SegmentSplitter.chunk_size`."""
return self._min_chunk_size
@property
def fixed_overlap(self) -> bool:
"""Whether overlap should be fixed."""
return self._fixed_overlap
def split_into_segments(self, text: str, separator: tp.Optional[str] = None) -> tp.TSSegmentChunks:
"""Split text into segments."""
if not separator:
if separator is None:
for start, end in self.split_into_tokens(text):
yield start, end, False
else:
for i in range(len(text)):
yield i, i + 1, False
else:
last_end = 0
for match in re.finditer(separator, text):
start, end = match.span()
if start > last_end:
_text = text[last_end:start]
yield last_end, start, False
_text = text[start:end]
yield start, end, True
last_end = end
if last_end < len(text):
_text = text[last_end:]
yield last_end, len(text), False
def split(self, text: str) -> tp.TSRangeChunks:
if not text:
yield 0, 0
return None
total_tokens = self.tokenizer.count_tokens(text)
if total_tokens <= self.chunk_size:
yield 0, len(text)
return None
layer = 0
chunk_start = 0
chunk_continue = 0
chunk_tokens = []
stable_token_count = 0
stable_char_count = 0
remaining_text = text
overlap_segments = []
token_offset_map = {}
while remaining_text:
if layer == 0:
if chunk_continue:
curr_start = chunk_continue
else:
curr_start = chunk_start
curr_text = remaining_text
curr_segments = list(overlap_segments)
curr_tokens = list(chunk_tokens)
curr_stable_token_count = stable_token_count
curr_stable_char_count = stable_char_count
sep_curr_segments = None
sep_curr_tokens = None
sep_curr_stable_token_count = None
sep_curr_stable_char_count = None
for separator in self.separators[layer]:
segments = self.split_into_segments(curr_text, separator=separator)
curr_text = ""
finished = False
for segment in segments:
segment_start = curr_start + segment[0]
segment_end = curr_start + segment[1]
segment_is_separator = segment[2]
if not curr_tokens:
segment_text = text[segment_start:segment_end]
new_curr_tokens = self.tokenizer.encode(segment_text)
new_curr_stable_token_count = 0
new_curr_stable_char_count = 0
elif not curr_stable_token_count:
chunk_text = text[chunk_start:segment_end]
new_curr_tokens = self.tokenizer.encode(chunk_text)
new_curr_stable_token_count = 0
new_curr_stable_char_count = 0
min_token_count = min(len(curr_tokens), len(new_curr_tokens))
for i in range(min_token_count):
if curr_tokens[i] == new_curr_tokens[i]:
new_curr_stable_token_count += 1
new_curr_stable_char_count += len(self.tokenizer.decode_single(curr_tokens[i]))
else:
break
else:
stable_tokens = curr_tokens[:curr_stable_token_count]
unstable_start = chunk_start + curr_stable_char_count
partial_text = text[unstable_start:segment_end]
partial_tokens = self.tokenizer.encode(partial_text)
new_curr_tokens = stable_tokens + partial_tokens
new_curr_stable_token_count = curr_stable_token_count
new_curr_stable_char_count = curr_stable_char_count
min_token_count = min(len(curr_tokens), len(new_curr_tokens))
for i in range(curr_stable_token_count, min_token_count):
if curr_tokens[i] == new_curr_tokens[i]:
new_curr_stable_token_count += 1
new_curr_stable_char_count += len(self.tokenizer.decode_single(curr_tokens[i]))
else:
break
if len(new_curr_tokens) > self.chunk_size:
if segment_is_separator:
if (
sep_curr_segments
and len(sep_curr_tokens) >= self.min_chunk_size
and not (self.chunk_overlap and len(sep_curr_tokens) <= self.chunk_overlap)
):
curr_segments = list(sep_curr_segments)
curr_tokens = list(sep_curr_tokens)
curr_stable_token_count = sep_curr_stable_token_count
curr_stable_char_count = sep_curr_stable_char_count
segment_start = curr_segments[-1][0]
segment_end = curr_segments[-1][1]
curr_text = text[segment_start:segment_end]
curr_start = segment_start
finished = False
break
else:
curr_segments.append((segment_start, segment_end, segment_is_separator))
token_offset_map[segment_start] = len(curr_tokens)
curr_tokens = new_curr_tokens
curr_stable_token_count = new_curr_stable_token_count
curr_stable_char_count = new_curr_stable_char_count
if segment_is_separator:
sep_curr_segments = list(curr_segments)
sep_curr_tokens = list(curr_tokens)
sep_curr_stable_token_count = curr_stable_token_count
sep_curr_stable_char_count = curr_stable_char_count
finished = True
if finished:
break
if (
curr_segments
and len(curr_tokens) >= self.min_chunk_size
and not (self.chunk_overlap and len(curr_tokens) <= self.chunk_overlap)
):
chunk_start = curr_segments[0][0]
chunk_end = curr_segments[-1][1]
yield chunk_start, chunk_end
if chunk_end == len(text):
break
if self.chunk_overlap:
fixed_overlap = True
if not self.fixed_overlap:
for segment in curr_segments:
if not segment[2]:
token_offset = token_offset_map[segment[0]]
if token_offset > curr_stable_token_count:
break
if len(curr_tokens) - token_offset <= self.chunk_overlap:
chunk_tokens = curr_tokens[token_offset:]
new_chunk_start = segment[0]
chunk_offset = new_chunk_start - chunk_start
chunk_start = new_chunk_start
chunk_continue = chunk_end
fixed_overlap = False
break
if fixed_overlap:
chunk_tokens = curr_tokens[-self.chunk_overlap:]
token_offset = len(curr_tokens) - len(chunk_tokens)
new_chunk_start = chunk_end - len(self.tokenizer.decode(chunk_tokens))
chunk_offset = new_chunk_start - chunk_start
chunk_start = new_chunk_start
chunk_continue = chunk_end
stable_token_count = max(0, curr_stable_token_count - token_offset)
stable_char_count = max(0, curr_stable_char_count - chunk_offset)
overlap_segments = [(chunk_start, chunk_end, False)]
token_offset_map[chunk_start] = 0
else:
chunk_tokens = []
chunk_start = chunk_end
chunk_continue = 0
stable_token_count = 0
stable_char_count = 0
overlap_segments = []
token_offset_map = {}
if chunk_continue:
remaining_text = text[chunk_continue:]
else:
remaining_text = text[chunk_start:]
layer = 0
else:
layer += 1
if layer == len(self.separators):
if curr_segments and curr_segments[-1][1] == len(text):
chunk_start = curr_segments[0][0]
chunk_end = curr_segments[-1][1]
yield chunk_start, chunk_end
break
remaining_tokens = self.tokenizer.encode(remaining_text)
if len(remaining_tokens) > self.chunk_size:
raise ValueError(
"Total number of tokens in the last chunk is greater than the chunk size. "
"Increase chunk_size or the separator granularity."
)
yield curr_start, len(text)
break
class LlamaIndexSplitter(TextSplitter):
"""Splitter class based on a node parser from LlamaIndex.
For defaults, see `chat.text_splitter_configs.llama_index` in `vectorbtpro._settings.knowledge`."""
_short_name = "llama_index"
_settings_path: tp.SettingsPath = "knowledge.chat.text_splitter_configs.llama_index"
def __init__(
self,
node_parser: tp.Union[None, str, NodeParserT] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
TextSplitter.__init__(self, template_context=template_context, **kwargs)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("llama_index")
from llama_index.core.node_parser import NodeParser
llama_index_config = merge_dicts(self.get_settings(inherit=False), kwargs)
def_node_parser = llama_index_config.pop("node_parser", None)
if node_parser is None:
node_parser = def_node_parser
init_kwargs = get_func_kwargs(type(self).__init__)
for k in list(llama_index_config.keys()):
if k in init_kwargs:
llama_index_config.pop(k)
if isinstance(node_parser, str):
import llama_index.core.node_parser
from vectorbtpro.utils.module_ import search_package
def _match_func(k, v):
if isinstance(v, type) and issubclass(v, NodeParser):
if "." in node_parser:
if k.endswith(node_parser):
return True
else:
if k.split(".")[-1].lower() == node_parser.lower():
return True
if k.split(".")[-1].replace("Splitter", "").replace(
"NodeParser", ""
).lower() == node_parser.lower().replace("_", ""):
return True
return False
found_node_parser = search_package(
llama_index.core.node_parser,
_match_func,
path_attrs=True,
return_first=True,
)
if found_node_parser is None:
raise ValueError(f"Node parser '{node_parser}' not found")
node_parser = found_node_parser
if isinstance(node_parser, type):
checks.assert_subclass_of(node_parser, NodeParser, arg_name="node_parser")
node_parser_name = node_parser.__name__.replace("Splitter", "").replace("NodeParser", "").lower()
module_name = node_parser.__module__
else:
checks.assert_instance_of(node_parser, NodeParser, arg_name="node_parser")
node_parser_name = type(node_parser).__name__.replace("Splitter", "").replace("NodeParser", "").lower()
module_name = type(node_parser).__module__
node_parser_configs = llama_index_config.pop("node_parser_configs", {})
if node_parser_name in node_parser_configs:
llama_index_config = merge_dicts(llama_index_config, node_parser_configs[node_parser_name])
elif module_name in node_parser_configs:
llama_index_config = merge_dicts(llama_index_config, node_parser_configs[module_name])
if isinstance(node_parser, type):
node_parser = node_parser(**llama_index_config)
elif len(kwargs) > 0:
raise ValueError("Cannot apply config to already initialized node parser")
model_name = llama_index_config.get("model_name", None)
if model_name is None:
func_kwargs = get_func_kwargs(type(node_parser).__init__)
model_name = func_kwargs.get("model_name", None)
self._model = model_name
self._node_parser = node_parser
@property
def node_parser(self) -> NodeParserT:
"""An instance of `llama_index.core.node_parser.interface.NodeParser`."""
return self._node_parser
def split(self, text: str) -> tp.TSRangeChunks:
for text_chunk in self.split_text(text):
start = text.find(text_chunk)
if start == -1:
end = -1
else:
end = start + len(text_chunk)
yield start, end
def split_text(self, text: str) -> tp.TSTextChunks:
from llama_index.core.schema import Document
nodes = self.node_parser.get_nodes_from_documents([Document(text=text)])
for node in nodes:
yield node.text
def resolve_text_splitter(text_splitter: tp.TextSplitterLike = None) -> tp.MaybeType[TextSplitter]:
"""Resolve a subclass or an instance of `TextSplitter`.
The following values are supported:
* "token" (`TokenSplitter`)
* "segment" (`SegmentSplitter`)
* "llama_index" (`LlamaIndexSplitter`)
* A subclass or an instance of `TextSplitter`
"""
if text_splitter is None:
from vectorbtpro._settings import settings
chat_cfg = settings["knowledge"]["chat"]
text_splitter = chat_cfg["text_splitter"]
if isinstance(text_splitter, str):
curr_module = sys.modules[__name__]
found_text_splitter = None
for name, cls in inspect.getmembers(curr_module, inspect.isclass):
if name.endswith("Splitter"):
_short_name = getattr(cls, "_short_name", None)
if _short_name is not None and _short_name.lower() == text_splitter.lower():
found_text_splitter = cls
break
if found_text_splitter is None:
raise ValueError(f"Invalid text splitter: '{text_splitter}'")
text_splitter = found_text_splitter
if isinstance(text_splitter, type):
checks.assert_subclass_of(text_splitter, TextSplitter, arg_name="text_splitter")
else:
checks.assert_instance_of(text_splitter, TextSplitter, arg_name="text_splitter")
return text_splitter
def split_text(text: str, text_splitter: tp.TextSplitterLike = None, **kwargs) -> tp.List[str]:
"""Split text.
Resolves `text_splitter` with `resolve_text_splitter`. Keyword arguments are passed to either
initialize a class or replace an instance of `TextSplitter`."""
text_splitter = resolve_text_splitter(text_splitter=text_splitter)
if isinstance(text_splitter, type):
text_splitter = text_splitter(**kwargs)
elif kwargs:
text_splitter = text_splitter.replace(**kwargs)
return list(text_splitter.split_text(text))
# ############# Storing ############# #
StoreObjectT = tp.TypeVar("StoreObjectT", bound="StoreObject")
@define
class StoreObject(DefineMixin):
"""Class for objects to be managed by a store."""
id_: str = define.field()
"""Object identifier."""
@property
def hash_key(self) -> tuple:
return (self.id_,)
StoreDocumentT = tp.TypeVar("StoreDocumentT", bound="StoreDocument")
@define
class StoreDocument(StoreObject, DefineMixin):
"""Abstract class for documents to be stored."""
data: tp.Any = define.field()
"""Data."""
template_context: tp.KwargsLike = define.field(factory=dict)
"""Context used to substitute templates."""
@classmethod
def id_from_data(cls, data: tp.Any) -> str:
"""Generate a unique identifier from data."""
from vectorbtpro.utils.pickling import dumps
return hashlib.md5(dumps(data)).hexdigest()
@classmethod
def from_data(
cls: tp.Type[StoreDocumentT],
data: tp.Any,
id_: tp.Optional[str] = None,
**kwargs,
) -> StoreDocumentT:
"""Create an instance of `StoreDocument` from data."""
if id_ is None:
id_ = cls.id_from_data(data)
return cls(id_, data, **kwargs)
def __attrs_post_init__(self):
if self.id_ is None:
new_id = self.id_from_data(self.data)
object.__setattr__(self, "id_", new_id)
def get_content(self, for_embed: bool = False) -> tp.Optional[str]:
"""Get content.
Returns None if there's no content."""
raise NotImplementedError
def split(self: StoreDocumentT) -> tp.List[StoreDocumentT]:
"""Split document into multiple documents."""
raise NotImplementedError
def __str__(self) -> str:
return self.get_content()
TextDocumentT = tp.TypeVar("TextDocumentT", bound="TextDocument")
def def_metadata_template(metadata_content: str) -> str:
"""Default metadata template"""
if metadata_content.endswith("\n"):
return "---\n{metadata_content}---\n\n".format(metadata_content=metadata_content)
return "---\n{metadata_content}\n---\n\n".format(metadata_content=metadata_content)
@define
class TextDocument(StoreDocument, DefineMixin):
"""Class for text documents."""
text_path: tp.Optional[tp.PathLikeKey] = define.field(default=None)
"""Path to the text field."""
split_text_kwargs: tp.KwargsLike = define.field(factory=dict)
"""Keyword arguments passed to `split_text`."""
excl_metadata: tp.Union[bool, tp.MaybeList[tp.PathLikeKey]] = define.field(default=False)
"""Whether to exclude metadata and which fields to exclude.
If False, metadata becomes everything except text."""
excl_embed_metadata: tp.Union[None, bool, tp.MaybeList[tp.PathLikeKey]] = define.field(default=None)
"""Whether to exclude metadata and which fields to exclude for embeddings.
If None, becomes `TextDocument.excl_metadata`."""
skip_missing: bool = define.field(default=True)
"""Set missing text or metadata to None rather than raise an error."""
dump_kwargs: tp.KwargsLike = define.field(factory=dict)
"""Keyword arguments passed to `vectorbtpro.utils.formatting.dump`."""
metadata_template: tp.CustomTemplateLike = define.field(
default=RepFunc(def_metadata_template, eval_id="metadata_template")
)
"""Metadata template.
Must be suitable for formatting via the `format()` method."""
content_template: tp.CustomTemplateLike = define.field(
default=SafeSub("${metadata_content}${text}", eval_id="content_template")
)
"""Content template.
Must be suitable for formatting via the `format()` method."""
def get_text(self) -> tp.Optional[str]:
"""Get text.
Returns None if no text."""
from vectorbtpro.utils.search_ import get_pathlike_key
if self.data is None:
return None
if isinstance(self.data, str):
return self.data
if self.text_path is not None:
try:
text = get_pathlike_key(self.data, self.text_path, keep_path=False)
except (KeyError, IndexError, AttributeError) as e:
if not self.skip_missing:
raise e
return None
if text is None:
return None
if not isinstance(text, str):
raise TypeError(f"Text field must be a string, not {type(text)}")
return text
raise TypeError(f"If text path is not provided, data item must be a string, not {type(self.data)}")
def get_metadata(self, for_embed: bool = False) -> tp.Optional[tp.Any]:
"""Get metadata.
Returns None if no metadata."""
from vectorbtpro.utils.search_ import remove_pathlike_key
if self.data is None or isinstance(self.data, str) or self.text_path is None:
return None
prev_keys = []
data = self.data
try:
data = remove_pathlike_key(data, self.text_path, make_copy=True, prev_keys=prev_keys)
except (KeyError, IndexError, AttributeError) as e:
if not self.skip_missing:
raise e
excl_metadata = self.excl_metadata
if for_embed:
excl_embed_metadata = self.excl_embed_metadata
if excl_embed_metadata is None:
excl_embed_metadata = excl_metadata
excl_metadata = excl_embed_metadata
if isinstance(excl_metadata, bool):
if excl_metadata:
return None
return data
if not excl_metadata:
return data
if not isinstance(excl_metadata, list):
excl_metadata = [excl_metadata]
for p in excl_metadata:
try:
data = remove_pathlike_key(data, p, make_copy=True, prev_keys=prev_keys)
except (KeyError, IndexError, AttributeError) as e:
continue
return data
def get_metadata_content(self, for_embed: bool = False) -> tp.Optional[str]:
"""Get metadata content.
Returns None if no metadata."""
from vectorbtpro.utils.formatting import dump
metadata = self.get_metadata(for_embed=for_embed)
if metadata is None:
return None
return dump(metadata, **self.dump_kwargs)
def get_content(self, for_embed: bool = False) -> tp.Optional[str]:
text = self.get_text()
metadata_content = self.get_metadata_content(for_embed=for_embed)
if text is None and metadata_content is None:
return None
if text is None:
text = ""
if metadata_content is None:
metadata_content = ""
if metadata_content:
metadata_template = self.metadata_template
if isinstance(metadata_template, str):
metadata_template = SafeSub(metadata_template)
elif checks.is_function(metadata_template):
metadata_template = RepFunc(metadata_template)
elif not isinstance(metadata_template, CustomTemplate):
raise TypeError(f"Metadata template must be a string, function, or template")
template_context = flat_merge_dicts(
dict(metadata_content=metadata_content),
self.template_context,
)
metadata_content = metadata_template.substitute(template_context, eval_id="metadata_template")
content_template = self.content_template
if isinstance(content_template, str):
content_template = SafeSub(content_template)
elif checks.is_function(content_template):
content_template = RepFunc(content_template)
elif not isinstance(content_template, CustomTemplate):
raise TypeError(f"Content template must be a string, function, or template")
template_context = flat_merge_dicts(
dict(metadata_content=metadata_content, text=text),
self.template_context,
)
return content_template.substitute(template_context, eval_id="content_template")
def split(self: TextDocumentT) -> tp.List[TextDocumentT]:
from vectorbtpro.utils.search_ import set_pathlike_key
text = self.get_text()
if text is None:
return [self]
text_chunks = split_text(text, **self.split_text_kwargs)
document_chunks = []
for text_chunk in text_chunks:
if not isinstance(self.data, str) and self.text_path is not None:
data_chunk = set_pathlike_key(
self.data,
self.text_path,
text_chunk,
make_copy=True,
)
else:
data_chunk = text_chunk
document_chunks.append(self.replace(data=data_chunk, id_=None))
return document_chunks
@define
class StoreEmbedding(StoreObject, DefineMixin):
"""Class for embeddings to be stored."""
parent_id: tp.Optional[str] = define.field(default=None)
"""Parent object identifier."""
child_ids: tp.List[str] = define.field(factory=list)
"""Child object identifiers."""
embedding: tp.Optional[tp.List[int]] = define.field(default=None, repr=lambda x: f"List[{len(x)}]" if x else None)
"""Embedding."""
class MetaObjectStore(type(Configured), type(MutableMapping)):
"""Metaclass for `ObjectStore`."""
pass
class ObjectStore(Configured, MutableMapping, metaclass=MetaObjectStore):
"""Abstract class for managing an object store.
For defaults, see `knowledge.chat.obj_store_config` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = None
"""Short name of the class."""
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.obj_store_config"]
def __init__(
self,
store_id: tp.Optional[str] = None,
purge_on_open: tp.Optional[bool] = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Configured.__init__(
self,
store_id=store_id,
purge_on_open=purge_on_open,
template_context=template_context,
**kwargs,
)
store_id = self.resolve_setting(store_id, "store_id")
purge_on_open = self.resolve_setting(purge_on_open, "purge_on_open")
template_context = self.resolve_setting(template_context, "template_context", merge=True)
self._store_id = store_id
self._purge_on_open = purge_on_open
self._template_context = template_context
self._opened = False
self._enter_calls = 0
@property
def store_id(self) -> str:
"""Store id."""
return self._store_id
@property
def purge_on_open(self) -> bool:
"""Whether to purge on open."""
return self._purge_on_open
@property
def template_context(self) -> tp.Kwargs:
"""Context used to substitute templates."""
return self._template_context
@property
def opened(self) -> bool:
"""Whether the store has been opened."""
return self._opened
@property
def enter_calls(self) -> int:
"""Number of enter calls."""
return self._enter_calls
@property
def mirror_store_id(self) -> tp.Optional[str]:
"""Mirror store id."""
return None
def open(self) -> None:
"""Open the store."""
if self.opened:
self.close()
if self.purge_on_open:
self.purge()
self._opened = True
def check_opened(self) -> None:
"""Check the store is opened."""
if not self.opened:
raise Exception(f"{type(self)} must be opened first")
def commit(self) -> None:
"""Commit changes."""
pass
def close(self) -> None:
"""Close the store."""
self.commit()
self._opened = False
def purge(self) -> None:
"""Purge the store."""
self.close()
def __getitem__(self, id_: str) -> StoreObjectT:
raise NotImplementedError
def __setitem__(self, id_: str, obj: StoreObjectT) -> None:
raise NotImplementedError
def __delitem__(self, id_: str) -> None:
raise NotImplementedError
def __iter__(self) -> tp.Iterator[str]:
raise NotImplementedError
def __len__(self) -> int:
raise NotImplementedError
def __enter__(self) -> tp.Self:
if not self.opened:
self.open()
self._enter_calls += 1
return self
def __exit__(self, *args) -> None:
if self.enter_calls == 1:
self.close()
self._close_on_exit = False
self._enter_calls -= 1
if self.enter_calls < 0:
self._enter_calls = 0
class DictStore(ObjectStore):
"""Store class based on a dictionary.
For defaults, see `chat.obj_store_configs.memory` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "dict"
_settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.dict"
def __init__(self, **kwargs) -> None:
ObjectStore.__init__(self, **kwargs)
self._store = {}
@property
def store(self) -> tp.Dict[str, StoreObjectT]:
"""Store dictionary."""
return self._store
def purge(self) -> None:
ObjectStore.purge(self)
self.store.clear()
def __getitem__(self, id_: str) -> StoreObjectT:
return self.store[id_]
def __setitem__(self, id_: str, obj: StoreObjectT) -> None:
self.store[id_] = obj
def __delitem__(self, id_: str) -> None:
del self.store[id_]
def __iter__(self) -> tp.Iterator[str]:
return iter(self.store)
def __len__(self) -> int:
return len(self.store)
memory_store: tp.Dict[str, tp.Dict[str, StoreObjectT]] = {}
"""Object store by store id for `MemoryStore`."""
class MemoryStore(DictStore):
"""Store class based in memory.
Commits changes to `memory_store`.
For defaults, see `chat.obj_store_configs.memory` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "memory"
_settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.memory"
def __init__(self, **kwargs) -> None:
DictStore.__init__(self, **kwargs)
@property
def store(self) -> tp.Dict[str, StoreObjectT]:
"""Store dictionary."""
return self._store
def store_exists(self) -> bool:
"""Whether store exists."""
return self.store_id in memory_store
def open(self) -> None:
DictStore.open(self)
if self.store_exists():
self._store = dict(memory_store[self.store_id])
def commit(self) -> None:
DictStore.commit(self)
memory_store[self.store_id] = dict(self.store)
def purge(self) -> None:
DictStore.purge(self)
if self.store_exists():
del memory_store[self.store_id]
class FileStore(DictStore):
"""Store class based on files.
Either commits changes to a single file (with index id being the file name), or commits the initial
changes to the base file and any other change to patch file(s) (with index id being the directory name).
For defaults, see `chat.obj_store_configs.file` in `vectorbtpro._settings.knowledge`."""
_short_name = "file"
_settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.file"
def __init__(
self,
dir_path: tp.Optional[tp.PathLike] = None,
compression: tp.Union[None, bool, str] = None,
save_kwargs: tp.KwargsLike = None,
load_kwargs: tp.KwargsLike = None,
use_patching: tp.Optional[bool] = None,
consolidate: tp.Optional[bool] = None,
**kwargs,
) -> None:
DictStore.__init__(
self,
dir_path=dir_path,
compression=compression,
save_kwargs=save_kwargs,
load_kwargs=load_kwargs,
use_patching=use_patching,
consolidate=consolidate,
**kwargs,
)
dir_path = self.resolve_setting(dir_path, "dir_path")
template_context = self.template_context
if isinstance(dir_path, CustomTemplate):
cache_dir = self.get_setting("cache_dir", default=None)
if cache_dir is not None:
if isinstance(cache_dir, CustomTemplate):
cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir")
template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context)
release_dir = self.get_setting("release_dir", default=None)
if release_dir is not None:
if isinstance(release_dir, CustomTemplate):
release_dir = release_dir.substitute(template_context, eval_id="release_dir")
template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context)
dir_path = dir_path.substitute(template_context, eval_id="dir_path")
compression = self.resolve_setting(compression, "compression")
save_kwargs = self.resolve_setting(save_kwargs, "save_kwargs", merge=True)
load_kwargs = self.resolve_setting(load_kwargs, "load_kwargs", merge=True)
use_patching = self.resolve_setting(use_patching, "use_patching")
consolidate = self.resolve_setting(consolidate, "consolidate")
self._dir_path = dir_path
self._compression = compression
self._save_kwargs = save_kwargs
self._load_kwargs = load_kwargs
self._use_patching = use_patching
self._consolidate = consolidate
self._store_changes = {}
self._new_keys = set()
@property
def dir_path(self) -> tp.Optional[tp.Path]:
"""Path to the directory."""
return self._dir_path
@property
def compression(self) -> tp.CompressionLike:
"""Compression."""
return self._compression
@property
def save_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.pickling.save`."""
return self._save_kwargs
@property
def load_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.pickling.load`."""
return self._load_kwargs
@property
def use_patching(self) -> bool:
"""Whether to use directory with patch files or create a single file."""
return self._use_patching
@property
def consolidate(self) -> bool:
"""Whether to consolidate patch files."""
return self._consolidate
@property
def store_changes(self) -> tp.Dict[str, StoreObjectT]:
"""Store with new or modified objects only."""
return self._store_changes
@property
def new_keys(self) -> tp.Set[str]:
"""Keys that haven't been added to the store."""
return self._new_keys
def reset_state(self) -> None:
"""Reset state."""
self._consolidate = False
self._store_changes = {}
self._new_keys = set()
@property
def store_path(self) -> tp.Path:
"""Path to the directory with patch files or a single file."""
dir_path = self.dir_path
if dir_path is None:
dir_path = "."
dir_path = Path(dir_path)
return dir_path / self.store_id
@property
def mirror_store_id(self) -> str:
return str(self.store_path.resolve())
def get_next_patch_path(self) -> tp.Path:
"""Get path to the next patch file to be saved."""
indices = []
for file in self.store_path.glob("patch_*"):
indices.append(int(file.stem.split("_")[1]))
next_index = max(indices) + 1 if indices else 0
return self.store_path / f"patch_{next_index}"
def open(self) -> None:
DictStore.open(self)
if self.store_path.exists():
from vectorbtpro.utils.pickling import load
if self.store_path.is_dir():
store = {}
store.update(
load(
path=self.store_path / "base",
compression=self.compression,
**self.load_kwargs,
)
)
patch_paths = sorted(self.store_path.glob("patch_*"), key=lambda f: int(f.stem.split("_")[1]))
for patch_path in patch_paths:
store.update(
load(
path=patch_path,
compression=self.compression,
**self.load_kwargs,
)
)
else:
store = load(
path=self.store_path,
compression=self.compression,
**self.load_kwargs,
)
self._store = store
self.reset_state()
def commit(self) -> tp.Optional[tp.Path]:
DictStore.commit(self)
from vectorbtpro.utils.pickling import save
file_path = None
if self.use_patching:
base_path = self.store_path / "base"
if self.consolidate:
self.purge()
file_path = save(
self.store,
path=base_path,
compression=self.compression,
**self.save_kwargs,
)
elif self.store_changes:
if self.store_path.exists() and self.store_path.is_file():
self.purge()
if not base_path.exists():
file_path = save(
self.store_changes,
path=base_path,
compression=self.compression,
**self.save_kwargs,
)
else:
file_path = save(
self.store_changes,
path=self.get_next_patch_path(),
compression=self.compression,
**self.save_kwargs,
)
else:
if self.consolidate or self.store_changes:
if self.store_path.exists() and self.store_path.is_dir():
self.purge()
file_path = save(
self.store,
path=self.store_path,
compression=self.compression,
**self.save_kwargs,
)
self.reset_state()
return file_path
def close(self) -> None:
DictStore.close(self)
self.reset_state()
def purge(self) -> None:
DictStore.purge(self)
from vectorbtpro.utils.path_ import remove_file, remove_dir
if self.store_path.exists():
if self.store_path.is_dir():
remove_dir(self.store_path, with_contents=True)
else:
remove_file(self.store_path)
self.reset_state()
def __setitem__(self, id_: str, obj: StoreObjectT) -> None:
if obj.id_ not in self:
self.new_keys.add(obj.id_)
self.store_changes[obj.id_] = obj
DictStore.__setitem__(self, id_, obj)
def __delitem__(self, id_: str) -> None:
if id_ in self.new_keys:
del self.store_changes[id_]
self.new_keys.remove(id_)
else:
if id_ in self.store_changes:
del self.store_changes[id_]
DictStore.__delitem__(self, id_)
class LMDBStore(ObjectStore):
"""Store class based on LMDB (Lightning Memory-Mapped Database).
Uses [lmdbm](https://pypi.org/project/lmdbm/) package.
For defaults, see `chat.obj_store_configs.lmdb` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "lmdb"
_expected_keys_mode: tp.ExpectedKeysMode = "disable"
_settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.lmdb"
def __init__(
self,
dir_path: tp.Optional[tp.PathLike] = None,
mkdir_kwargs: tp.KwargsLike = None,
dumps_kwargs: tp.KwargsLike = None,
loads_kwargs: tp.KwargsLike = None,
**kwargs,
) -> None:
ObjectStore.__init__(
self,
dir_path=dir_path,
mkdir_kwargs=mkdir_kwargs,
dumps_kwargs=dumps_kwargs,
loads_kwargs=loads_kwargs,
**kwargs,
)
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("lmdbm")
dir_path = self.resolve_setting(dir_path, "dir_path")
template_context = self.template_context
if isinstance(dir_path, CustomTemplate):
cache_dir = self.get_setting("cache_dir", default=None)
if cache_dir is not None:
if isinstance(cache_dir, CustomTemplate):
cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir")
template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context)
release_dir = self.get_setting("release_dir", default=None)
if release_dir is not None:
if isinstance(release_dir, CustomTemplate):
release_dir = release_dir.substitute(template_context, eval_id="release_dir")
template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context)
dir_path = dir_path.substitute(template_context, eval_id="dir_path")
mkdir_kwargs = self.resolve_setting(mkdir_kwargs, "mkdir_kwargs", merge=True)
dumps_kwargs = self.resolve_setting(dumps_kwargs, "dumps_kwargs", merge=True)
loads_kwargs = self.resolve_setting(loads_kwargs, "loads_kwargs", merge=True)
open_kwargs = merge_dicts(self.get_settings(inherit=False), kwargs)
for arg_name in get_func_arg_names(ObjectStore.__init__) + get_func_arg_names(type(self).__init__):
if arg_name in open_kwargs:
del open_kwargs[arg_name]
if "mirror" in open_kwargs:
del open_kwargs["mirror"]
self._dir_path = dir_path
self._mkdir_kwargs = mkdir_kwargs
self._dumps_kwargs = dumps_kwargs
self._loads_kwargs = loads_kwargs
self._open_kwargs = open_kwargs
self._db = None
@property
def dir_path(self) -> tp.Optional[tp.Path]:
"""Path to the directory."""
return self._dir_path
@property
def mkdir_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.path_.check_mkdir`."""
return self._mkdir_kwargs
@property
def dumps_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.pickling.dumps`."""
return self._dumps_kwargs
@property
def loads_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.pickling.loads`."""
return self._loads_kwargs
@property
def open_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `lmdbm.lmdbm.Lmdb.open`."""
return self._open_kwargs
@property
def db_path(self) -> tp.Path:
"""Path to the database."""
dir_path = self.dir_path
if dir_path is None:
dir_path = "."
dir_path = Path(dir_path)
return dir_path / self.store_id
@property
def mirror_store_id(self) -> str:
return str(self.db_path.resolve())
@property
def db(self) -> tp.Optional[LmdbT]:
"""Database."""
return self._db
def open(self) -> None:
ObjectStore.open(self)
from lmdbm import Lmdb
from vectorbtpro.utils.path_ import check_mkdir
check_mkdir(self.db_path.parent, **self.mkdir_kwargs)
self._db = Lmdb.open(str(self.db_path.resolve()), **self.open_kwargs)
def close(self) -> None:
ObjectStore.close(self)
if self.db:
self.db.close()
self._db = None
def purge(self) -> None:
ObjectStore.purge(self)
from vectorbtpro.utils.path_ import remove_dir
remove_dir(self.db_path, missing_ok=True, with_contents=True)
def encode(self, obj: StoreObjectT) -> bytes:
"""Encode an object."""
from vectorbtpro.utils.pickling import dumps
return dumps(obj, **self.dumps_kwargs)
def decode(self, bytes_: bytes) -> StoreObjectT:
"""Decode an object."""
from vectorbtpro.utils.pickling import loads
return loads(bytes_, **self.loads_kwargs)
def __getitem__(self, id_: str) -> StoreObjectT:
self.check_opened()
return self.decode(self.db[id_])
def __setitem__(self, id_: str, obj: StoreObjectT) -> None:
self.check_opened()
self.db[id_] = self.encode(obj)
def __delitem__(self, id_: str) -> None:
self.check_opened()
del self.db[id_]
def __iter__(self) -> tp.Iterator[str]:
self.check_opened()
return iter(self.db)
def __len__(self) -> int:
self.check_opened()
return len(self.db)
class CachedStore(DictStore):
"""Store class that acts as a (temporary) cache to another store.
For defaults, see `chat.obj_store_configs.cached` in `vectorbtpro._settings.knowledge`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "cached"
_settings_path: tp.SettingsPath = "knowledge.chat.obj_store_configs.cached"
def __init__(
self,
obj_store: ObjectStore,
lazy_open: tp.Optional[bool] = None,
mirror: tp.Optional[bool] = None,
**kwargs,
) -> None:
DictStore.__init__(
self,
obj_store=obj_store,
lazy_open=lazy_open,
mirror=mirror,
**kwargs,
)
lazy_open = self.resolve_setting(lazy_open, "lazy_open")
mirror = obj_store.resolve_setting(mirror, "mirror", default=None)
mirror = self.resolve_setting(mirror, "mirror")
if mirror and obj_store.mirror_store_id is None:
mirror = False
self._obj_store = obj_store
self._lazy_open = lazy_open
self._mirror = mirror
self._force_open = False
@property
def obj_store(self) -> ObjectStore:
"""Object store."""
return self._obj_store
@property
def lazy_open(self) -> bool:
"""Whether to open the store lazily."""
return self._lazy_open
@property
def mirror(self) -> bool:
"""Whether to mirror the store in `memory_store`."""
return self._mirror
@property
def force_open(self) -> bool:
"""Whether to open the store forcefully."""
return self._force_open
def open(self) -> None:
DictStore.open(self)
if self.mirror and self.obj_store.mirror_store_id in memory_store:
self.store.update(memory_store[self.obj_store.mirror_store_id])
elif not self.lazy_open or self.force_open:
self.obj_store.open()
def check_opened(self) -> None:
if self.lazy_open and not self.obj_store.opened:
self._force_open = True
self.obj_store.open()
DictStore.check_opened(self)
def commit(self) -> None:
DictStore.commit(self)
self.check_opened()
self.obj_store.commit()
if self.mirror:
memory_store[self.obj_store.mirror_store_id] = dict(self.store)
def close(self) -> None:
DictStore.close(self)
self.obj_store.close()
self._force_open = False
def purge(self) -> None:
DictStore.purge(self)
self.obj_store.purge()
if self.mirror and self.obj_store.mirror_store_id in memory_store:
del memory_store[self.obj_store.mirror_store_id]
def __getitem__(self, id_: str) -> StoreObjectT:
if id_ in self.store:
return self.store[id_]
self.check_opened()
obj = self.obj_store[id_]
self.store[id_] = obj
return obj
def __setitem__(self, id_: str, obj: StoreObjectT) -> None:
self.check_opened()
self.store[id_] = obj
self.obj_store[id_] = obj
def __delitem__(self, id_: str) -> None:
self.check_opened()
if id_ in self.store:
del self.store[id_]
del self.obj_store[id_]
def __iter__(self) -> tp.Iterator[str]:
self.check_opened()
return iter(self.obj_store)
def __len__(self) -> int:
self.check_opened()
return len(self.obj_store)
def resolve_obj_store(obj_store: tp.ObjectStoreLike = None) -> tp.MaybeType[ObjectStore]:
"""Resolve a subclass or an instance of `ObjectStore`.
The following values are supported:
* "dict" (`DictStore`)
* "memory" (`MemoryStore`)
* "file" (`FileStore`)
* "lmdb" (`LMDBStore`)
* "cached" (`CachedStore`)
* A subclass or an instance of `ObjectStore`
"""
if obj_store is None:
from vectorbtpro._settings import settings
chat_cfg = settings["knowledge"]["chat"]
obj_store = chat_cfg["obj_store"]
if isinstance(obj_store, str):
curr_module = sys.modules[__name__]
found_obj_store = None
for name, cls in inspect.getmembers(curr_module, inspect.isclass):
if name.endswith("Store"):
_short_name = getattr(cls, "_short_name", None)
if _short_name is not None and _short_name.lower() == obj_store.lower():
found_obj_store = cls
break
if found_obj_store is None:
raise ValueError(f"Invalid object store: '{obj_store}'")
obj_store = found_obj_store
if isinstance(obj_store, type):
checks.assert_subclass_of(obj_store, ObjectStore, arg_name="obj_store")
else:
checks.assert_instance_of(obj_store, ObjectStore, arg_name="obj_store")
return obj_store
# ############# Ranking ############# #
@define
class EmbeddedDocument(DefineMixin):
"""Abstract class for embedded documents."""
document: StoreDocument = define.field()
"""Document."""
embedding: tp.Optional[tp.List[float]] = define.field(default=None)
"""Embedding."""
child_documents: tp.List["EmbeddedDocument"] = define.field(factory=list)
"""Embedded child documents."""
@define
class ScoredDocument(DefineMixin):
"""Abstract class for scored documents."""
document: StoreDocument = define.field()
"""Document."""
score: float = define.field(default=float("nan"))
"""Score."""
child_documents: tp.List["ScoredDocument"] = define.field(factory=list)
"""Scored child documents."""
class DocumentRanker(Configured):
"""Class for embedding, scoring, and ranking documents.
For defaults, see `knowledge.chat.doc_ranker_config` in `vectorbtpro._settings.knowledge`."""
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat", "knowledge.chat.doc_ranker_config"]
def __init__(
self,
dataset_id: tp.Optional[str] = None,
embeddings: tp.EmbeddingsLike = None,
embeddings_kwargs: tp.KwargsLike = None,
doc_store: tp.TokenizerLike = None,
doc_store_kwargs: tp.KwargsLike = None,
cache_doc_store: tp.Optional[bool] = None,
emb_store: tp.TokenizerLike = None,
emb_store_kwargs: tp.KwargsLike = None,
cache_emb_store: tp.Optional[bool] = None,
score_func: tp.Union[None, str, tp.Callable] = None,
score_agg_func: tp.Union[None, str, tp.Callable] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
Configured.__init__(
self,
dataset_id=dataset_id,
embeddings=embeddings,
embeddings_kwargs=embeddings_kwargs,
doc_store=doc_store,
doc_store_kwargs=doc_store_kwargs,
cache_doc_store=cache_doc_store,
emb_store=emb_store,
emb_store_kwargs=emb_store_kwargs,
cache_emb_store=cache_emb_store,
score_func=score_func,
score_agg_func=score_agg_func,
show_progress=show_progress,
pbar_kwargs=pbar_kwargs,
template_context=template_context,
**kwargs,
)
dataset_id = self.resolve_setting(dataset_id, "dataset_id")
embeddings = self.resolve_setting(embeddings, "embeddings", default=None)
embeddings_kwargs = self.resolve_setting(embeddings_kwargs, "embeddings_kwargs", default=None, merge=True)
doc_store = self.resolve_setting(doc_store, "doc_store", default=None)
doc_store_kwargs = self.resolve_setting(doc_store_kwargs, "doc_store_kwargs", default=None, merge=True)
cache_doc_store = self.resolve_setting(cache_doc_store, "cache_doc_store")
emb_store = self.resolve_setting(emb_store, "emb_store", default=None)
emb_store_kwargs = self.resolve_setting(emb_store_kwargs, "emb_store_kwargs", default=None, merge=True)
cache_emb_store = self.resolve_setting(cache_emb_store, "cache_emb_store")
score_func = self.resolve_setting(score_func, "score_func")
score_agg_func = self.resolve_setting(score_agg_func, "score_agg_func")
show_progress = self.resolve_setting(show_progress, "show_progress")
pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True)
template_context = self.resolve_setting(template_context, "template_context", merge=True)
obj_store = self.get_setting("obj_store", default=None)
obj_store_kwargs = self.get_setting("obj_store_kwargs", default=None, merge=True)
if doc_store is None:
doc_store = obj_store
doc_store_kwargs = merge_dicts(obj_store_kwargs, doc_store_kwargs)
if emb_store is None:
emb_store = obj_store
emb_store_kwargs = merge_dicts(obj_store_kwargs, emb_store_kwargs)
embeddings = resolve_embeddings(embeddings)
if isinstance(embeddings, type):
embeddings_kwargs = dict(embeddings_kwargs)
embeddings_kwargs["template_context"] = merge_dicts(
template_context, embeddings_kwargs.get("template_context", None)
)
embeddings = embeddings(**embeddings_kwargs)
elif embeddings_kwargs:
embeddings = embeddings.replace(**embeddings_kwargs)
if isinstance(self._settings_path, list):
if not isinstance(self._settings_path[-1], str):
raise TypeError("_settings_path[-1] for DocumentRanker and its subclasses must be a string")
target_settings_path = self._settings_path[-1]
elif isinstance(self._settings_path, str):
target_settings_path = self._settings_path
else:
raise TypeError("_settings_path for DocumentRanker and its subclasses must be a list or string")
doc_store = resolve_obj_store(doc_store)
if not isinstance(doc_store._settings_path, str):
raise TypeError("_settings_path for ObjectStore and its subclasses must be a string")
doc_store_cls = doc_store if isinstance(doc_store, type) else type(doc_store)
doc_store_settings_path = doc_store._settings_path
doc_store_settings_path = doc_store_settings_path.replace("knowledge.chat", target_settings_path)
doc_store_settings_path = doc_store_settings_path.replace("obj_store", "doc_store")
with ExtSettingsPath([(doc_store_cls, doc_store_settings_path)]):
if isinstance(doc_store, type):
doc_store_kwargs = dict(doc_store_kwargs)
if dataset_id is not None and "store_id" not in doc_store_kwargs:
doc_store_kwargs["store_id"] = dataset_id
doc_store_kwargs["template_context"] = merge_dicts(
template_context, doc_store_kwargs.get("template_context", None)
)
doc_store = doc_store(**doc_store_kwargs)
elif doc_store_kwargs:
doc_store = doc_store.replace(**doc_store_kwargs)
if cache_doc_store and not isinstance(doc_store, CachedStore):
doc_store = CachedStore(doc_store)
emb_store = resolve_obj_store(emb_store)
if not isinstance(emb_store._settings_path, str):
raise TypeError("_settings_path for ObjectStore and its subclasses must be a string")
emb_store_cls = emb_store if isinstance(emb_store, type) else type(emb_store)
emb_store_settings_path = emb_store._settings_path
emb_store_settings_path = emb_store_settings_path.replace("knowledge.chat", target_settings_path)
emb_store_settings_path = emb_store_settings_path.replace("obj_store", "emb_store")
with ExtSettingsPath([(emb_store_cls, emb_store_settings_path)]):
if isinstance(emb_store, type):
emb_store_kwargs = dict(emb_store_kwargs)
if dataset_id is not None and "store_id" not in emb_store_kwargs:
emb_store_kwargs["store_id"] = dataset_id
emb_store_kwargs["template_context"] = merge_dicts(
template_context, emb_store_kwargs.get("template_context", None)
)
emb_store = emb_store(**emb_store_kwargs)
elif emb_store_kwargs:
emb_store = emb_store.replace(**emb_store_kwargs)
if cache_emb_store and not isinstance(emb_store, CachedStore):
emb_store = CachedStore(emb_store)
if isinstance(score_agg_func, str):
score_agg_func = getattr(np, score_agg_func)
self._embeddings = embeddings
self._doc_store = doc_store
self._emb_store = emb_store
self._score_func = score_func
self._score_agg_func = score_agg_func
self._show_progress = show_progress
self._pbar_kwargs = pbar_kwargs
self._template_context = template_context
@property
def embeddings(self) -> Embeddings:
"""An instance of `Embeddings`."""
return self._embeddings
@property
def doc_store(self) -> ObjectStore:
"""An instance of `ObjectStore` for documents."""
return self._doc_store
@property
def emb_store(self) -> ObjectStore:
"""An instance of `ObjectStore` for embeddings."""
return self._emb_store
@property
def score_func(self) -> tp.Union[str, tp.Callable]:
"""Score function.
See `DocumentRanker.compute_score`."""
return self._score_func
@property
def score_agg_func(self) -> tp.Callable:
"""Score aggregation function."""
return self._score_agg_func
@property
def show_progress(self) -> tp.Optional[bool]:
"""Whether to show progress bar."""
return self._show_progress
@property
def pbar_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `vectorbtpro.utils.pbar.ProgressBar`."""
return self._pbar_kwargs
@property
def template_context(self) -> tp.Kwargs:
"""Context used to substitute templates."""
return self._template_context
def embed_documents(
self,
documents: tp.Iterable[StoreDocument],
refresh: bool = False,
refresh_documents: tp.Optional[bool] = None,
refresh_embeddings: tp.Optional[bool] = None,
return_embeddings: bool = False,
return_documents: bool = False,
) -> tp.Optional[tp.EmbeddedDocuments]:
"""Embed documents.
Enable `refresh` or its sub-arguments to refresh documents and/or embeddings in their
particular stores. Without refreshing, will rely on the persisted objects.
If `return_embeddings` and `return_documents` are both False, returns nothing.
If `return_embeddings` and `return_documents` are both True, for each document,
returns the document and either an embedding or a list of document chunks and their embeddings.
If `return_documents` is False, returns only embeddings."""
if refresh_documents is None:
refresh_documents = refresh
if refresh_embeddings is None:
refresh_embeddings = refresh
with self.doc_store, self.emb_store:
documents = list(documents)
documents_to_split = []
document_splits = {}
for document in documents:
if refresh_documents or refresh_embeddings or document.id_ not in self.emb_store:
documents_to_split.append(document)
if documents_to_split:
from vectorbtpro.utils.pbar import ProgressBar
pbar_kwargs = merge_dicts(dict(prefix="split_documents"), self.pbar_kwargs)
with ProgressBar(
total=len(documents_to_split),
show_progress=self.show_progress,
**pbar_kwargs,
) as pbar:
for document in documents_to_split:
document_splits[document.id_] = document.split()
pbar.update()
obj_contents = {}
for document in documents:
if refresh_documents or document.id_ not in self.doc_store:
self.doc_store[document.id_] = document
if document.id_ in document_splits:
document_chunks = document_splits[document.id_]
obj = StoreEmbedding(document.id_)
for document_chunk in document_chunks:
if document_chunk.id_ != document.id_:
if refresh_documents or document_chunk.id_ not in self.doc_store:
self.doc_store[document_chunk.id_] = document_chunk
if refresh_embeddings or document_chunk.id_ not in self.emb_store:
child_obj = StoreEmbedding(document_chunk.id_, parent_id=document.id_)
self.emb_store[child_obj.id_] = child_obj
else:
child_obj = self.emb_store[document_chunk.id_]
obj.child_ids.append(child_obj.id_)
if not child_obj.embedding:
content = document_chunk.get_content(for_embed=True)
if content:
obj_contents[child_obj.id_] = content
if refresh_documents or refresh_embeddings or document.id_ not in self.emb_store:
self.emb_store[obj.id_] = obj
else:
obj = self.emb_store[document.id_]
if not obj.child_ids and not obj.embedding:
content = document.get_content(for_embed=True)
if content:
obj_contents[obj.id_] = content
if obj_contents:
total = 0
for batch in self.embeddings.iter_embedding_batches(list(obj_contents.values())):
batch_keys = list(obj_contents.keys())[total : total + len(batch)]
obj_embeddings = dict(zip(batch_keys, batch))
for obj_id, embedding in obj_embeddings.items():
obj = self.emb_store[obj_id]
new_obj = obj.replace(embedding=embedding)
self.emb_store[new_obj.id_] = new_obj
total += len(batch)
if return_embeddings or return_documents:
embeddings = []
for document in documents:
obj = self.emb_store[document.id_]
if obj.embedding:
if return_documents:
embeddings.append(EmbeddedDocument(document, embedding=obj.embedding))
else:
embeddings.append(obj.embedding)
elif obj.child_ids:
child_embeddings = []
for child_id in obj.child_ids:
child_obj = self.emb_store[child_id]
if child_obj.embedding:
if return_documents:
child_document = self.doc_store[child_id]
child_embeddings.append(
EmbeddedDocument(child_document, embedding=child_obj.embedding)
)
else:
child_embeddings.append(child_obj.embedding)
else:
if return_documents:
child_document = self.doc_store[child_id]
child_embeddings.append(EmbeddedDocument(child_document))
else:
child_embeddings.append(None)
if return_documents:
embeddings.append(EmbeddedDocument(document, child_documents=child_embeddings))
else:
embeddings.append(child_embeddings)
else:
if return_documents:
embeddings.append(EmbeddedDocument(document))
else:
embeddings.append(None)
return embeddings
def compute_score(
self,
emb1: tp.Union[tp.MaybeIterable[tp.List[float]], np.ndarray],
emb2: tp.Union[tp.MaybeIterable[tp.List[float]], np.ndarray],
) -> tp.Union[float, np.ndarray]:
"""Compute scores between embeddings, which can be either single or multiple.
Supported distance functions are 'cosine', 'euclidean', and 'dot'. A metric can also be
a callable that should take two and return one 2-dim NumPy array."""
emb1 = np.asarray(emb1)
emb2 = np.asarray(emb2)
emb1_single = emb1.ndim == 1
emb2_single = emb2.ndim == 1
if emb1_single:
emb1 = emb1.reshape(1, -1)
if emb2_single:
emb2 = emb2.reshape(1, -1)
if isinstance(self.score_func, str):
if self.score_func.lower() == "cosine":
emb1_norm = emb1 / np.linalg.norm(emb1, axis=1, keepdims=True)
emb2_norm = emb2 / np.linalg.norm(emb2, axis=1, keepdims=True)
emb1_norm = np.nan_to_num(emb1_norm)
emb2_norm = np.nan_to_num(emb2_norm)
score_matrix = np.dot(emb1_norm, emb2_norm.T)
elif self.score_func.lower() == "euclidean":
diff = emb1[:, np.newaxis, :] - emb2[np.newaxis, :, :]
distances = np.linalg.norm(diff, axis=2)
score_matrix = np.divide(1, distances, where=distances != 0, out=np.full_like(distances, np.inf))
elif self.score_func.lower() == "dot":
score_matrix = np.dot(emb1, emb2.T)
else:
raise ValueError(f"Invalid distance function: '{self.score_func}'")
else:
score_matrix = self.score_func(emb1, emb2)
if emb1_single and emb2_single:
return float(score_matrix[0, 0])
if emb1_single or emb2_single:
return score_matrix.flatten()
return score_matrix
def score_documents(
self,
query: str,
documents: tp.Optional[tp.Iterable[StoreDocument]] = None,
refresh: bool = False,
refresh_documents: tp.Optional[bool] = None,
refresh_embeddings: tp.Optional[bool] = None,
return_chunks: bool = False,
return_documents: bool = False,
) -> tp.ScoredDocuments:
"""Score documents by relevance to a query."""
with self.doc_store, self.emb_store:
if documents is None:
if self.doc_store is None:
raise ValueError("Must provide at least documents or doc_store")
documents = self.doc_store.values()
documents_provided = False
else:
documents_provided = True
documents = list(documents)
if not documents:
return []
self.embed_documents(
documents,
refresh=refresh,
refresh_documents=refresh_documents,
refresh_embeddings=refresh_embeddings,
)
if return_chunks:
document_chunks = []
for document in documents:
obj = self.emb_store[document.id_]
if obj.child_ids:
for child_id in obj.child_ids:
document_chunk = self.doc_store[child_id]
document_chunks.append(document_chunk)
elif not obj.parent_id or obj.parent_id not in self.doc_store:
document_chunk = self.doc_store[obj.id_]
document_chunks.append(document_chunk)
documents = document_chunks
elif not documents_provided:
document_parents = []
for document in documents:
obj = self.emb_store[document.id_]
if not obj.parent_id or obj.parent_id not in self.doc_store:
document_parent = self.doc_store[obj.id_]
document_parents.append(document_parent)
documents = document_parents
obj_embeddings = {}
for document in documents:
obj = self.emb_store[document.id_]
if obj.embedding:
obj_embeddings[obj.id_] = obj.embedding
elif obj.child_ids:
for child_id in obj.child_ids:
child_obj = self.emb_store[child_id]
if child_obj.embedding:
obj_embeddings[child_id] = child_obj.embedding
if obj_embeddings:
query_embedding = self.embeddings.get_embedding(query)
scores = self.compute_score(query_embedding, list(obj_embeddings.values()))
obj_scores = dict(zip(obj_embeddings.keys(), scores))
else:
obj_scores = {}
scores = []
for document in documents:
obj = self.emb_store[document.id_]
child_scores = []
if obj.child_ids:
for child_id in obj.child_ids:
if child_id in obj_scores:
child_score = obj_scores[child_id]
if return_documents:
child_document = self.doc_store[child_id]
child_scores.append(ScoredDocument(child_document, score=child_score))
else:
child_scores.append(child_score)
if child_scores:
if return_documents:
doc_score = self.score_agg_func([document.score for document in child_scores])
else:
doc_score = self.score_agg_func(child_scores)
else:
doc_score = float("nan")
else:
if obj.id_ in obj_scores:
doc_score = obj_scores[obj.id_]
else:
doc_score = float("nan")
if return_documents:
scores.append(ScoredDocument(document, score=doc_score, child_documents=child_scores))
else:
scores.append(doc_score)
return scores
@classmethod
def resolve_top_k(cls, scores: tp.Iterable[float], top_k: tp.TopKLike = None) -> tp.Optional[int]:
"""Resolve `top_k` based on _sorted_ scores.
Supported values are integers (top number), floats (top %), strings (supported methods are
'elbow' and 'kmeans'), as well as callables that should take a 1-dim NumPy array and return
an integer or a float. Filters out NaN before computation (requires them to be at the tail)."""
if top_k is None:
return None
scores = np.asarray(scores)
scores = scores[~np.isnan(scores)]
if isinstance(top_k, str):
if top_k.lower() == "elbow":
if scores.size == 0:
return 0
diffs = np.diff(scores)
top_k = np.argmax(-diffs) + 1
elif top_k.lower() == "kmeans":
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=2, random_state=0).fit(scores.reshape(-1, 1))
high_score_cluster = np.argmax(kmeans.cluster_centers_)
top_k_indices = np.where(kmeans.labels_ == high_score_cluster)[0]
top_k = max(top_k_indices) + 1
else:
raise ValueError(f"Invalid top_k method: '{top_k}'")
elif callable(top_k):
top_k = top_k(scores)
if checks.is_float(top_k):
top_k = int(top_k * len(scores))
return top_k
@classmethod
def top_k_from_cutoff(cls, scores: tp.Iterable[float], cutoff: tp.Optional[float] = None) -> tp.Optional[int]:
"""Get `top_k` from `cutoff` based on _sorted_ scores."""
if cutoff is None:
return None
scores = np.asarray(scores)
scores = scores[~np.isnan(scores)]
return len(scores[scores >= cutoff])
def rank_documents(
self,
query: str,
documents: tp.Optional[tp.Iterable[StoreDocument]] = None,
top_k: tp.TopKLike = None,
min_top_k: tp.TopKLike = None,
max_top_k: tp.TopKLike = None,
cutoff: tp.Optional[float] = None,
refresh: bool = False,
refresh_documents: tp.Optional[bool] = None,
refresh_embeddings: tp.Optional[bool] = None,
return_chunks: bool = False,
return_scores: bool = False,
) -> tp.RankedDocuments:
"""Sort documents by relevance to a query.
Top-k, minimum top-k, and maximum top-k are resolved with `DocumentRanker.resolve_top_k`.
Score cutoff is converted into top-k with `DocumentRanker.top_k_from_cutoff`.
Minimum and maximum top-k are used to override non-integer top-k and cutoff; it has no effect on
the integer top-k, which can be outside the top-k bounds and won't be overridden."""
scored_documents = self.score_documents(
query,
documents=documents,
refresh=refresh,
refresh_documents=refresh_documents,
refresh_embeddings=refresh_embeddings,
return_chunks=return_chunks,
return_documents=True,
)
scored_documents = sorted(scored_documents, key=lambda x: (not np.isnan(x.score), x.score), reverse=True)
scores = [document.score for document in scored_documents]
int_top_k = top_k is not None and checks.is_int(top_k)
top_k = self.resolve_top_k(scores, top_k=top_k)
min_top_k = self.resolve_top_k(scores, top_k=min_top_k)
max_top_k = self.resolve_top_k(scores, top_k=max_top_k)
cutoff = self.top_k_from_cutoff(scores, cutoff=cutoff)
if not int_top_k and min_top_k is not None and min_top_k > top_k:
top_k = min_top_k
if not int_top_k and max_top_k is not None and max_top_k < top_k:
top_k = max_top_k
if cutoff is not None and min_top_k is not None and min_top_k > cutoff:
cutoff = min_top_k
if cutoff is not None and max_top_k is not None and max_top_k < cutoff:
cutoff = max_top_k
if top_k is None:
top_k = len(scores)
if cutoff is None:
cutoff = len(scores)
top_k = min(top_k, cutoff)
if top_k == 0:
raise ValueError("No documents selected after ranking. Change top_k or cutoff.")
scored_documents = scored_documents[:top_k]
if return_scores:
return scored_documents
return [document.document for document in scored_documents]
def embed_documents(
documents: tp.Iterable[StoreDocument],
refresh: bool = False,
refresh_documents: tp.Optional[bool] = None,
refresh_embeddings: tp.Optional[bool] = None,
return_embeddings: bool = False,
return_documents: bool = False,
doc_ranker: tp.Optional[tp.MaybeType[DocumentRanker]] = None,
**kwargs,
) -> tp.Optional[tp.EmbeddedDocuments]:
"""Embed documents.
Keyword arguments are passed to either initialize a class or replace an
instance of `DocumentRanker`."""
if doc_ranker is None:
doc_ranker = DocumentRanker
if isinstance(doc_ranker, type):
checks.assert_subclass_of(doc_ranker, DocumentRanker, "doc_ranker")
doc_ranker = doc_ranker(**kwargs)
else:
checks.assert_instance_of(doc_ranker, DocumentRanker, "doc_ranker")
if kwargs:
doc_ranker = doc_ranker.replace(**kwargs)
return doc_ranker.embed_documents(
documents,
refresh=refresh,
refresh_documents=refresh_documents,
refresh_embeddings=refresh_embeddings,
return_embeddings=return_embeddings,
return_documents=return_documents,
)
def rank_documents(
query: str,
documents: tp.Optional[tp.Iterable[StoreDocument]] = None,
top_k: tp.TopKLike = None,
min_top_k: tp.TopKLike = None,
max_top_k: tp.TopKLike = None,
cutoff: tp.Optional[float] = None,
refresh: bool = False,
refresh_documents: tp.Optional[bool] = None,
refresh_embeddings: tp.Optional[bool] = None,
return_chunks: bool = False,
return_scores: bool = False,
doc_ranker: tp.Optional[tp.MaybeType[DocumentRanker]] = None,
**kwargs,
) -> tp.RankedDocuments:
"""Rank documents by their relevance to a query.
Keyword arguments are passed to either initialize a class or replace an
instance of `DocumentRanker`."""
if doc_ranker is None:
doc_ranker = DocumentRanker
if isinstance(doc_ranker, type):
checks.assert_subclass_of(doc_ranker, DocumentRanker, "doc_ranker")
doc_ranker = doc_ranker(**kwargs)
else:
checks.assert_instance_of(doc_ranker, DocumentRanker, "doc_ranker")
if kwargs:
doc_ranker = doc_ranker.replace(**kwargs)
return doc_ranker.rank_documents(
query,
documents=documents,
top_k=top_k,
min_top_k=min_top_k,
max_top_k=max_top_k,
cutoff=cutoff,
refresh=refresh,
refresh_documents=refresh_documents,
refresh_embeddings=refresh_embeddings,
return_chunks=return_chunks,
return_scores=return_scores,
)
RankableT = tp.TypeVar("RankableT", bound="Rankable")
class Rankable(HasSettings):
"""Abstract class that can be ranked."""
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat"]
def embed(
self: RankableT,
refresh: bool = False,
refresh_documents: tp.Optional[bool] = None,
refresh_embeddings: tp.Optional[bool] = None,
return_embeddings: bool = False,
return_documents: bool = False,
**kwargs,
) -> tp.Optional[RankableT]:
"""Embed documents."""
raise NotImplementedError
def rank(
self: RankableT,
query: str,
top_k: tp.TopKLike = None,
min_top_k: tp.TopKLike = None,
max_top_k: tp.TopKLike = None,
cutoff: tp.Optional[float] = None,
refresh: bool = False,
refresh_documents: tp.Optional[bool] = None,
refresh_embeddings: tp.Optional[bool] = None,
return_chunks: bool = False,
return_scores: bool = False,
**kwargs,
) -> RankableT:
"""Rank documents by their relevance to a query."""
raise NotImplementedError
# ############# Contexting ############# #
class Contextable(HasSettings):
"""Abstract class that can be converted into a context."""
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.chat"]
def to_context(self, *args, **kwargs) -> str:
"""Convert to a context."""
raise NotImplementedError
def count_tokens(
self,
to_context_kwargs: tp.KwargsLike = None,
tokenizer: tp.TokenizerLike = None,
tokenizer_kwargs: tp.KwargsLike = None,
) -> int:
"""Count the number of tokens in the context."""
to_context_kwargs = self.resolve_setting(to_context_kwargs, "to_context_kwargs", merge=True)
tokenizer = self.resolve_setting(tokenizer, "tokenizer", default=None)
tokenizer_kwargs = self.resolve_setting(tokenizer_kwargs, "tokenizer_kwargs", default=None, merge=True)
context = self.to_context(**to_context_kwargs)
tokenizer = resolve_tokenizer(tokenizer)
if isinstance(tokenizer, type):
tokenizer = tokenizer(**tokenizer_kwargs)
elif tokenizer_kwargs:
tokenizer = tokenizer.replace(**tokenizer_kwargs)
return len(tokenizer.encode(context))
def create_chat(
self,
to_context_kwargs: tp.KwargsLike = None,
completions: tp.CompletionsLike = None,
**kwargs,
) -> tp.Completions:
"""Create a chat by returning an instance of `Completions`.
Uses `Contextable.to_context` to turn this instance to a context.
Usage:
```pycon
>>> chat = asset.create_chat()
>>> chat.get_completion("What's the value under 'xyz'?")
The value under 'xyz' is 123.
>>> chat.get_completion("Are you sure?")
Yes, I am sure. The value under 'xyz' is 123 for the entry where `s` is "EFG".
```"""
to_context_kwargs = self.resolve_setting(to_context_kwargs, "to_context_kwargs", merge=True)
context = self.to_context(**to_context_kwargs)
completions = resolve_completions(completions=completions)
if isinstance(completions, type):
completions = completions(context=context, **kwargs)
else:
completions = completions.replace(context=context, **kwargs)
return completions
@hybrid_method
def chat(
cls_or_self,
message: str,
chat_history: tp.Optional[tp.ChatHistory] = None,
*,
return_chat: bool = False,
**kwargs,
) -> tp.MaybeChatOutput:
"""Chat with an LLM while using the instance as a context.
Uses `Contextable.create_chat` and then `Completions.get_completion`.
!!! note
Context is recalculated each time this method is invoked. For multiple turns,
it's more efficient to use `Contextable.create_chat`.
Usage:
```pycon
>>> asset.chat("What's the value under 'xyz'?")
The value under 'xyz' is 123.
>>> chat_history = []
>>> asset.chat("What's the value under 'xyz'?", chat_history=chat_history)
The value under 'xyz' is 123.
>>> asset.chat("Are you sure?", chat_history=chat_history)
Yes, I am sure. The value under 'xyz' is 123 for the entry where `s` is "EFG".
```
"""
if isinstance(cls_or_self, type):
args, kwargs = get_forward_args(super().chat, locals())
return super().chat(*args, **kwargs)
completions = cls_or_self.create_chat(chat_history=chat_history, **kwargs)
if return_chat:
return completions.get_completion(message), completions
return completions.get_completion(message)
class RankContextable(Rankable, Contextable):
"""Abstract class that combines both `Rankable` and `Contextable` to rank a context."""
@hybrid_method
def chat(
cls_or_self,
message: str,
chat_history: tp.Optional[tp.ChatHistory] = None,
*,
incl_past_queries: tp.Optional[bool] = None,
rank: tp.Optional[bool] = None,
top_k: tp.TopKLike = None,
min_top_k: tp.TopKLike = None,
max_top_k: tp.TopKLike = None,
cutoff: tp.Optional[float] = None,
return_chunks: tp.Optional[bool] = None,
rank_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeChatOutput:
"""See `Contextable.chat`.
If `rank` is True, or `rank` is None and any of `top_k`, `min_top_k`, `max_top_k`, `cutoff`, or
`return_chunks` is set, will rank the documents with `Rankable.rank` first."""
if isinstance(cls_or_self, type):
args, kwargs = get_forward_args(super().chat, locals())
return super().chat(*args, **kwargs)
incl_past_queries = cls_or_self.resolve_setting(incl_past_queries, "incl_past_queries")
rank = cls_or_self.resolve_setting(rank, "rank")
rank_kwargs = cls_or_self.resolve_setting(rank_kwargs, "rank_kwargs", merge=True)
def_top_k = rank_kwargs.pop("top_k")
if top_k is None:
top_k = def_top_k
def_min_top_k = rank_kwargs.pop("min_top_k")
if min_top_k is None:
min_top_k = def_min_top_k
def_max_top_k = rank_kwargs.pop("max_top_k")
if max_top_k is None:
max_top_k = def_max_top_k
def_cutoff = rank_kwargs.pop("cutoff")
if cutoff is None:
cutoff = def_cutoff
def_return_chunks = rank_kwargs.pop("return_chunks")
if return_chunks is None:
return_chunks = def_return_chunks
if rank or (rank is None and (top_k or min_top_k or max_top_k or cutoff or return_chunks)):
if incl_past_queries and chat_history is not None:
queries = []
for message_dct in chat_history:
if "role" in message_dct and message_dct["role"] == "user":
queries.append(message_dct["content"])
queries.append(message)
if len(queries) > 1:
query = "\n\n".join(queries)
else:
query = queries[0]
else:
query = message
_cls_or_self = cls_or_self.rank(
query,
top_k=top_k,
min_top_k=min_top_k,
max_top_k=max_top_k,
cutoff=cutoff,
return_chunks=return_chunks,
**rank_kwargs,
)
else:
_cls_or_self = cls_or_self
return Contextable.chat.__func__(_cls_or_self, message, chat_history, **kwargs)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom asset function classes."""
from vectorbtpro import _typing as tp
from vectorbtpro.utils.config import flat_merge_dicts
from vectorbtpro.utils.knowledge.base_asset_funcs import AssetFunc, RemoveAssetFunc
from vectorbtpro.utils.knowledge.formatting import to_markdown, to_html, format_html
__all__ = []
class ToMarkdownAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.custom_assets.VBTAsset.to_markdown`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "to_markdown"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
root_metadata_key: tp.Optional[tp.Key] = None,
minimize_metadata: tp.Optional[bool] = None,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**to_markdown_kwargs,
) -> tp.ArgsKwargs:
from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc
if asset_cls is None:
from vectorbtpro.utils.knowledge.custom_assets import VBTAsset
asset_cls = VBTAsset
root_metadata_key = asset_cls.resolve_setting(root_metadata_key, "root_metadata_key")
minimize_metadata = asset_cls.resolve_setting(minimize_metadata, "minimize_metadata")
minimize_keys = asset_cls.resolve_setting(minimize_keys, "minimize_keys")
clean_metadata = asset_cls.resolve_setting(clean_metadata, "clean_metadata")
clean_metadata_kwargs = asset_cls.resolve_setting(clean_metadata_kwargs, "clean_metadata_kwargs", merge=True)
dump_metadata_kwargs = asset_cls.resolve_setting(dump_metadata_kwargs, "dump_metadata_kwargs", merge=True)
to_markdown_kwargs = asset_cls.resolve_setting(to_markdown_kwargs, "to_markdown_kwargs", merge=True)
clean_metadata_kwargs = flat_merge_dicts(dict(target=FindRemoveAssetFunc.is_empty_func), clean_metadata_kwargs)
_, clean_metadata_kwargs = FindRemoveAssetFunc.prepare(**clean_metadata_kwargs)
_, dump_metadata_kwargs = DumpAssetFunc.prepare(**dump_metadata_kwargs)
return (), {
**dict(
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
root_metadata_key=root_metadata_key,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
),
**to_markdown_kwargs,
}
@classmethod
def get_markdown_metadata(
cls,
d: dict,
root_metadata_key: tp.Optional[tp.Key] = None,
allow_empty: tp.Optional[bool] = None,
minimize_metadata: bool = False,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: bool = True,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
**to_markdown_kwargs,
) -> str:
"""Get metadata in Markdown format."""
from vectorbtpro.utils.formatting import get_dump_language
from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc
if allow_empty is None:
allow_empty = root_metadata_key is not None
if clean_metadata_kwargs is None:
clean_metadata_kwargs = {}
if dump_metadata_kwargs is None:
dump_metadata_kwargs = {}
metadata = dict(d)
if "content" in metadata:
del metadata["content"]
if metadata and minimize_metadata and minimize_keys:
metadata = RemoveAssetFunc.call(metadata, minimize_keys, skip_missing=True)
if metadata and clean_metadata:
metadata = FindRemoveAssetFunc.call(metadata, **clean_metadata_kwargs)
if not metadata and not allow_empty:
return ""
if root_metadata_key is not None:
if not metadata:
metadata = None
metadata = {root_metadata_key: metadata}
text = DumpAssetFunc.call(metadata, **dump_metadata_kwargs).strip()
dump_engine = dump_metadata_kwargs.get("dump_engine", "nestedtext")
dump_language = get_dump_language(dump_engine)
text = f"```{dump_language}\n{text}\n```"
return to_markdown(text, **to_markdown_kwargs)
@classmethod
def get_markdown_content(cls, d: dict, **kwargs) -> str:
"""Get content in Markdown format."""
if d["content"] is None:
return ""
return to_markdown(d["content"], **kwargs)
@classmethod
def call(
cls,
d: tp.Any,
root_metadata_key: tp.Optional[tp.Key] = None,
minimize_metadata: bool = False,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: bool = True,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
**to_markdown_kwargs,
) -> tp.Any:
if not isinstance(d, (str, dict)):
raise TypeError("Data item must be a string or dict")
if isinstance(d, str):
d = dict(content=d)
markdown_metadata = cls.get_markdown_metadata(
d,
root_metadata_key=root_metadata_key,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
**to_markdown_kwargs,
)
markdown_content = cls.get_markdown_content(d, **to_markdown_kwargs)
if markdown_metadata and markdown_content:
markdown_content = markdown_metadata + "\n\n" + markdown_content
elif markdown_metadata:
markdown_content = markdown_metadata
return markdown_content
class ToHTMLAssetFunc(ToMarkdownAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.custom_assets.VBTAsset.to_html`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "to_html"
@classmethod
def prepare(
cls,
root_metadata_key: tp.Optional[tp.Key] = None,
minimize_metadata: tp.Optional[bool] = None,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
format_html_kwargs: tp.KwargsLike = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**to_html_kwargs,
) -> tp.ArgsKwargs:
from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc
if asset_cls is None:
from vectorbtpro.utils.knowledge.custom_assets import VBTAsset
asset_cls = VBTAsset
root_metadata_key = asset_cls.resolve_setting(root_metadata_key, "root_metadata_key")
minimize_metadata = asset_cls.resolve_setting(minimize_metadata, "minimize_metadata")
minimize_keys = asset_cls.resolve_setting(minimize_keys, "minimize_keys")
clean_metadata = asset_cls.resolve_setting(clean_metadata, "clean_metadata")
clean_metadata_kwargs = asset_cls.resolve_setting(clean_metadata_kwargs, "clean_metadata_kwargs", merge=True)
dump_metadata_kwargs = asset_cls.resolve_setting(dump_metadata_kwargs, "dump_metadata_kwargs", merge=True)
to_markdown_kwargs = asset_cls.resolve_setting(to_markdown_kwargs, "to_markdown_kwargs", merge=True)
format_html_kwargs = asset_cls.resolve_setting(format_html_kwargs, "format_html_kwargs", merge=True)
to_html_kwargs = asset_cls.resolve_setting(to_html_kwargs, "to_html_kwargs", merge=True)
clean_metadata_kwargs = flat_merge_dicts(dict(target=FindRemoveAssetFunc.is_empty_func), clean_metadata_kwargs)
_, clean_metadata_kwargs = FindRemoveAssetFunc.prepare(**clean_metadata_kwargs)
_, dump_metadata_kwargs = DumpAssetFunc.prepare(**dump_metadata_kwargs)
return (), {
**dict(
root_metadata_key=root_metadata_key,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
to_markdown_kwargs=to_markdown_kwargs,
format_html_kwargs=format_html_kwargs,
),
**to_html_kwargs,
}
@classmethod
def get_html_metadata(
cls,
d: dict,
root_metadata_key: tp.Optional[tp.Key] = None,
allow_empty: tp.Optional[bool] = None,
minimize_metadata: bool = False,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: bool = True,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
**to_html_kwargs,
) -> str:
"""Get metadata in HTML format."""
if to_markdown_kwargs is None:
to_markdown_kwargs = {}
metadata = cls.get_markdown_metadata(
d,
root_metadata_key=root_metadata_key,
allow_empty=allow_empty,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
**to_markdown_kwargs,
)
if not metadata:
return ""
return to_html(metadata, **to_html_kwargs)
@classmethod
def get_html_content(cls, d: dict, to_markdown_kwargs: tp.KwargsLike = None, **kwargs) -> str:
"""Get content in HTML format."""
if to_markdown_kwargs is None:
to_markdown_kwargs = {}
content = cls.get_markdown_content(d, **to_markdown_kwargs)
return to_html(content, **kwargs)
@classmethod
def call(
cls,
d: tp.Any,
root_metadata_key: tp.Optional[tp.Key] = None,
minimize_metadata: bool = False,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: bool = True,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
format_html_kwargs: tp.KwargsLike = None,
**to_html_kwargs,
) -> tp.Any:
if not isinstance(d, (str, dict, list)):
raise TypeError("Data item must be a string, dict, or list of such")
if isinstance(d, str):
d = dict(content=d)
if isinstance(d, list):
html_metadata = []
for _d in d:
if not isinstance(_d, (str, dict)):
raise TypeError("Data item must be a string, dict, or list of such")
if isinstance(_d, str):
_d = dict(content=_d)
html_metadata.append(
cls.get_html_metadata(
_d,
root_metadata_key=root_metadata_key,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
to_markdown_kwargs=to_markdown_kwargs,
**to_html_kwargs,
)
)
html = format_html(
title="/",
html_metadata="\n".join(html_metadata),
**format_html_kwargs,
)
else:
html_metadata = cls.get_html_metadata(
d,
root_metadata_key=root_metadata_key,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
to_markdown_kwargs=to_markdown_kwargs,
**to_html_kwargs,
)
html_content = cls.get_html_content(
d,
to_markdown_kwargs=to_markdown_kwargs,
**to_html_kwargs,
)
html = format_html(
title=d["link"] if "link" in d else "",
html_metadata=html_metadata,
html_content=html_content,
**format_html_kwargs,
)
return html
class AggMessageAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.custom_assets.MessagesAsset.aggregate_messages`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "agg_message"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
minimize_metadata: tp.Optional[bool] = None,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc
if asset_cls is None:
from vectorbtpro.utils.knowledge.custom_assets import MessagesAsset
asset_cls = MessagesAsset
minimize_metadata = asset_cls.resolve_setting(minimize_metadata, "minimize_metadata")
minimize_keys = asset_cls.resolve_setting(minimize_keys, "minimize_keys")
clean_metadata = asset_cls.resolve_setting(clean_metadata, "clean_metadata")
clean_metadata_kwargs = asset_cls.resolve_setting(clean_metadata_kwargs, "clean_metadata_kwargs", merge=True)
dump_metadata_kwargs = asset_cls.resolve_setting(dump_metadata_kwargs, "dump_metadata_kwargs", merge=True)
clean_metadata_kwargs = flat_merge_dicts(dict(target=FindRemoveAssetFunc.is_empty_func), clean_metadata_kwargs)
_, clean_metadata_kwargs = FindRemoveAssetFunc.prepare(**clean_metadata_kwargs)
_, dump_metadata_kwargs = DumpAssetFunc.prepare(**dump_metadata_kwargs)
return (), {
**dict(
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
minimize_metadata: bool = False,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: bool = True,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
) -> tp.Any:
if not isinstance(d, dict):
raise TypeError("Data item must be a dict")
if "attachments" not in d:
return dict(d)
if clean_metadata_kwargs is None:
clean_metadata_kwargs = {}
if dump_metadata_kwargs is None:
dump_metadata_kwargs = {}
if to_markdown_kwargs is None:
to_markdown_kwargs = {}
new_d = dict(d)
new_d["content"] = new_d["content"].strip()
attachments = new_d.pop("attachments", [])
for attachment in attachments:
content = attachment["content"].strip()
if new_d["content"]:
new_d["content"] += "\n\n"
metadata = ToMarkdownAssetFunc.get_markdown_metadata(
attachment,
root_metadata_key="attachment",
allow_empty=not content,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
**to_markdown_kwargs,
)
new_d["content"] += metadata
if content:
new_d["content"] += "\n\n" + content
return new_d
class AggBlockAssetFunc(AssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.custom_assets.MessagesAsset.aggregate_blocks`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "agg_block"
_wrap: tp.ClassVar[tp.Optional[str]] = True
@classmethod
def prepare(
cls,
aggregate_fields: tp.Union[None, bool, tp.MaybeIterable[str]] = None,
parent_links_only: tp.Optional[bool] = None,
minimize_metadata: tp.Optional[bool] = None,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
link_map: tp.Optional[tp.Dict[str, dict]] = None,
asset_cls: tp.Optional[tp.Type[tp.KnowledgeAsset]] = None,
**kwargs,
) -> tp.ArgsKwargs:
from vectorbtpro.utils.knowledge.base_asset_funcs import FindRemoveAssetFunc, DumpAssetFunc
if asset_cls is None:
from vectorbtpro.utils.knowledge.custom_assets import MessagesAsset
asset_cls = MessagesAsset
aggregate_fields = asset_cls.resolve_setting(aggregate_fields, "aggregate_fields")
parent_links_only = asset_cls.resolve_setting(parent_links_only, "parent_links_only")
minimize_metadata = asset_cls.resolve_setting(minimize_metadata, "minimize_metadata")
minimize_keys = asset_cls.resolve_setting(minimize_keys, "minimize_keys")
clean_metadata = asset_cls.resolve_setting(clean_metadata, "clean_metadata")
clean_metadata_kwargs = asset_cls.resolve_setting(clean_metadata_kwargs, "clean_metadata_kwargs", merge=True)
dump_metadata_kwargs = asset_cls.resolve_setting(dump_metadata_kwargs, "dump_metadata_kwargs", merge=True)
clean_metadata_kwargs = flat_merge_dicts(dict(target=FindRemoveAssetFunc.is_empty_func), clean_metadata_kwargs)
_, clean_metadata_kwargs = FindRemoveAssetFunc.prepare(**clean_metadata_kwargs)
_, dump_metadata_kwargs = DumpAssetFunc.prepare(**dump_metadata_kwargs)
return (), {
**dict(
aggregate_fields=aggregate_fields,
parent_links_only=parent_links_only,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
link_map=link_map,
),
**kwargs,
}
@classmethod
def call(
cls,
d: tp.Any,
aggregate_fields: tp.Union[bool, tp.MaybeIterable[str]] = False,
parent_links_only: bool = True,
minimize_metadata: bool = False,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: bool = True,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
link_map: tp.Optional[tp.Dict[str, dict]] = None,
) -> tp.Any:
if not isinstance(d, dict):
raise TypeError("Data item must be a dict")
if isinstance(aggregate_fields, bool):
if aggregate_fields:
aggregate_fields = {"mentions", "attachments", "reactions"}
else:
aggregate_fields = set()
elif isinstance(aggregate_fields, str):
aggregate_fields = {aggregate_fields}
elif not isinstance(aggregate_fields, set):
aggregate_fields = set(aggregate_fields)
if clean_metadata_kwargs is None:
clean_metadata_kwargs = {}
if dump_metadata_kwargs is None:
dump_metadata_kwargs = {}
if to_markdown_kwargs is None:
to_markdown_kwargs = {}
new_d = {}
metadata_keys = []
for k, v in d.items():
if k == "link":
new_d[k] = d["block"][0]
if k == "block":
continue
if k == "timestamp":
new_d[k] = v[0]
if k in {"thread", "channel", "author"}:
new_d[k] = v[0]
continue
if k == "reference" and link_map is not None:
found_missing = False
new_v = []
for _v in v:
if _v:
if _v in link_map:
_v = link_map[_v]["block"]
else:
found_missing = True
break
if _v not in new_v:
new_v.append(_v)
if found_missing or len(new_v) > 1:
new_d[k] = "?"
else:
new_d[k] = new_v[0]
if k == "replies" and link_map is not None:
new_v = []
for _v in v:
for __v in _v:
if __v and __v in link_map:
__v = link_map[__v]["block"]
if __v not in new_v:
new_v.append(__v)
else:
new_v.append("?")
new_d[k] = new_v
if k == "content":
new_d[k] = []
continue
if k in aggregate_fields and isinstance(v[0], list):
new_v = []
for _v in new_v:
for __v in _v:
if __v not in new_v:
new_v.append(__v)
new_d[k] = new_v
continue
if k == "reactions" and k in aggregate_fields:
new_d[k] = sum(v)
continue
if parent_links_only:
if k in ("link", "block", "thread", "reference", "replies"):
continue
metadata_keys.append(k)
if len(metadata_keys) > 0:
for i in range(len(d[metadata_keys[0]])):
content = d["content"][i].strip()
metadata = {}
for k in metadata_keys:
metadata[k] = d[k][i]
if len(new_d["content"]) > 0:
new_d["content"].append("\n\n")
metadata = ToMarkdownAssetFunc.get_markdown_metadata(
metadata,
root_metadata_key="message",
allow_empty=not content,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
**to_markdown_kwargs,
)
new_d["content"].append(metadata)
if content:
new_d["content"].append("\n\n" + content)
new_d["content"] = "".join(new_d["content"])
return new_d
class AggThreadAssetFunc(AggBlockAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.custom_assets.MessagesAsset.aggregate_threads`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "agg_thread"
@classmethod
def call(
cls,
d: tp.Any,
aggregate_fields: tp.Union[bool, tp.MaybeIterable[str]] = False,
parent_links_only: bool = True,
minimize_metadata: bool = False,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: bool = True,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
link_map: tp.Optional[tp.Dict[str, dict]] = None,
) -> tp.Any:
if not isinstance(d, dict):
raise TypeError("Data item must be a dict")
if isinstance(aggregate_fields, bool):
if aggregate_fields:
aggregate_fields = {"mentions", "attachments", "reactions"}
else:
aggregate_fields = set()
elif isinstance(aggregate_fields, str):
aggregate_fields = {aggregate_fields}
elif not isinstance(aggregate_fields, set):
aggregate_fields = set(aggregate_fields)
if clean_metadata_kwargs is None:
clean_metadata_kwargs = {}
if dump_metadata_kwargs is None:
dump_metadata_kwargs = {}
if to_markdown_kwargs is None:
to_markdown_kwargs = {}
new_d = {}
metadata_keys = []
for k, v in d.items():
if k == "link":
new_d[k] = d["thread"][0]
if k == "thread":
continue
if k == "timestamp":
new_d[k] = v[0]
if k == "channel":
new_d[k] = v[0]
continue
if k == "content":
new_d[k] = []
continue
if k in aggregate_fields and isinstance(v[0], list):
new_v = []
for _v in new_v:
for __v in _v:
if __v not in new_v:
new_v.append(__v)
new_d[k] = new_v
continue
if k == "reactions" and k in aggregate_fields:
new_d[k] = sum(v)
continue
if parent_links_only:
if k in ("link", "block", "thread", "reference", "replies"):
continue
metadata_keys.append(k)
if len(metadata_keys) > 0:
for i in range(len(d[metadata_keys[0]])):
content = d["content"][i].strip()
metadata = {}
for k in metadata_keys:
metadata[k] = d[k][i]
if len(new_d["content"]) > 0:
new_d["content"].append("\n\n")
metadata = ToMarkdownAssetFunc.get_markdown_metadata(
metadata,
root_metadata_key="message",
allow_empty=not content,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
**to_markdown_kwargs,
)
new_d["content"].append(metadata)
if content:
new_d["content"].append("\n\n" + content)
new_d["content"] = "".join(new_d["content"])
return new_d
class AggChannelAssetFunc(AggThreadAssetFunc):
"""Asset function class for `vectorbtpro.utils.knowledge.custom_assets.MessagesAsset.aggregate_channels`."""
_short_name: tp.ClassVar[tp.Optional[str]] = "agg_channel"
@classmethod
def get_channel_link(cls, link: str) -> str:
"""Get channel link from a message link."""
if link.startswith("$discord/"):
link = link[len("$discord/") :]
link_parts = link.split("/")
channel_id = link_parts[0]
return "$discord/" + channel_id
if link.startswith("https://discord.com/channels/"):
link = link[len("https://discord.com/channels/") :]
link_parts = link.split("/")
guild_id = link_parts[0]
channel_id = link_parts[1]
return f"https://discord.com/channels/{guild_id}/{channel_id}"
raise ValueError(f"Invalid link: '{link}'")
@classmethod
def call(
cls,
d: tp.Any,
aggregate_fields: tp.Union[bool, tp.MaybeIterable[str]] = False,
parent_links_only: bool = True,
minimize_metadata: bool = False,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: bool = True,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
link_map: tp.Optional[tp.Dict[str, dict]] = None,
) -> tp.Any:
if not isinstance(d, dict):
raise TypeError("Data item must be a dict")
if isinstance(aggregate_fields, bool):
if aggregate_fields:
aggregate_fields = {"mentions", "attachments", "reactions"}
else:
aggregate_fields = set()
elif isinstance(aggregate_fields, str):
aggregate_fields = {aggregate_fields}
elif not isinstance(aggregate_fields, set):
aggregate_fields = set(aggregate_fields)
if clean_metadata_kwargs is None:
clean_metadata_kwargs = {}
if dump_metadata_kwargs is None:
dump_metadata_kwargs = {}
if to_markdown_kwargs is None:
to_markdown_kwargs = {}
new_d = {}
metadata_keys = []
for k, v in d.items():
if k == "link":
new_d[k] = cls.get_channel_link(v[0])
if k == "timestamp":
new_d[k] = v[0]
if k == "channel":
new_d[k] = v[0]
continue
if k == "content":
new_d[k] = []
continue
if k in aggregate_fields and isinstance(v[0], list):
new_v = []
for _v in new_v:
for __v in _v:
if __v not in new_v:
new_v.append(__v)
new_d[k] = new_v
continue
if k == "reactions" and k in aggregate_fields:
new_d[k] = sum(v)
continue
if parent_links_only:
if k in ("link", "block", "thread", "reference", "replies"):
continue
metadata_keys.append(k)
if len(metadata_keys) > 0:
for i in range(len(d[metadata_keys[0]])):
content = d["content"][i].strip()
metadata = {}
for k in metadata_keys:
metadata[k] = d[k][i]
if len(new_d["content"]) > 0:
new_d["content"].append("\n\n")
metadata = ToMarkdownAssetFunc.get_markdown_metadata(
metadata,
root_metadata_key="message",
allow_empty=not content,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
**to_markdown_kwargs,
)
new_d["content"].append(metadata)
if content:
new_d["content"].append("\n\n" + content)
new_d["content"] = "".join(new_d["content"])
return new_d
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Custom asset classes."""
import inspect
import io
import os
import pkgutil
import re
import base64
from collections import defaultdict, deque
from pathlib import Path
from types import ModuleType
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import merge_dicts, flat_merge_dicts, reorder_list, HybridConfig, SpecSettingsPath
from vectorbtpro.utils.decorators import hybrid_method
from vectorbtpro.utils.knowledge.base_assets import AssetCacheManager, KnowledgeAsset
from vectorbtpro.utils.knowledge.formatting import FormatHTML
from vectorbtpro.utils.module_ import prepare_refname, get_caller_qualname
from vectorbtpro.utils.parsing import get_func_arg_names
from vectorbtpro.utils.path_ import check_mkdir, remove_dir, get_common_prefix, dir_tree_from_paths
from vectorbtpro.utils.pbar import ProgressBar
from vectorbtpro.utils.pickling import suggest_compression
from vectorbtpro.utils.search_ import find, replace
from vectorbtpro.utils.template import CustomTemplate
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"VBTAsset",
"PagesAsset",
"MessagesAsset",
"find_api",
"find_docs",
"find_messages",
"find_examples",
"find_assets",
"chat_about",
"search",
"chat",
]
__pdoc__ = {}
class_abbr_config = HybridConfig(
dict(
Accessor={"acc"},
Array={"arr"},
ArrayWrapper={"wrapper"},
Benchmark={"bm"},
Cacheable={"ca"},
Chunkable={"ch"},
Drawdowns={"dd"},
Jitable={"jit"},
Figure={"fig"},
MappedArray={"ma"},
NumPy={"np"},
Numba={"nb"},
Optimizer={"opt"},
Pandas={"pd"},
Portfolio={"pf"},
ProgressBar={"pbar"},
Registry={"reg"},
Returns_={"ret"},
Returns={"rets"},
QuantStats={"qs"},
Signals_={"sig"},
)
)
"""_"""
__pdoc__[
"class_abbr_config"
] = f"""Config for class name (part) abbreviations.
```python
{class_abbr_config.prettify()}
```
"""
class NoItemFoundError(Exception):
"""Exception raised when no data item was found."""
class MultipleItemsFoundError(Exception):
"""Exception raised when multiple data items were found."""
VBTAssetT = tp.TypeVar("VBTAssetT", bound="VBTAsset")
class VBTAsset(KnowledgeAsset):
"""Class for working with VBT content.
For defaults, see `assets.vbt` in `vectorbtpro._settings.knowledge`."""
_settings_path: tp.SettingsPath = "knowledge.assets.vbt"
def __init__(self, *args, release_name: tp.Optional[str] = None, **kwargs) -> None:
KnowledgeAsset.__init__(self, *args, release_name=release_name, **kwargs)
self._release_name = release_name
@property
def release_name(self) -> tp.Optional[str]:
"""Release name."""
return self._release_name
@classmethod
def pull(
cls: tp.Type[VBTAssetT],
release_name: tp.Optional[str] = None,
asset_name: tp.Optional[str] = None,
repo_owner: tp.Optional[str] = None,
repo_name: tp.Optional[str] = None,
token: tp.Optional[str] = None,
token_required: tp.Optional[bool] = None,
use_pygithub: tp.Optional[bool] = None,
chunk_size: tp.Optional[int] = None,
cache: tp.Optional[bool] = None,
cache_dir: tp.Optional[tp.PathLike] = None,
cache_mkdir_kwargs: tp.KwargsLike = None,
clear_cache: tp.Optional[bool] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> VBTAssetT:
"""Build `VBTAsset` from a JSON asset of a release.
Examples of a release name include None or 'current' for the current release, 'latest' for the
latest release, and any other tag name such as 'v2024.12.15'.
An example of an asset file name is 'messages.json.zip'. You can find all asset file names
at https://github.com/polakowo/vectorbt.pro/releases/latest
Token must be a valid GitHub token. It doesn't have to be provided if the asset has already been downloaded.
If `use_pygithub` is True, uses https://github.com/PyGithub/PyGithub (otherwise requests)
Argument `chunk_size` denotes the number of bytes in each chunk when downloading an asset file.
If `cache` is True, uses the cache directory (`assets_dir` in settings). Otherwise, builds the asset
instance in memory. If `clear_cache` is True, deletes any existing directory before creating a new one."""
import requests
release_name = cls.resolve_setting(release_name, "release_name")
asset_name = cls.resolve_setting(asset_name, "asset_name")
repo_owner = cls.resolve_setting(repo_owner, "repo_owner")
repo_name = cls.resolve_setting(repo_name, "repo_name")
token = cls.resolve_setting(token, "token")
token_required = cls.resolve_setting(token_required, "token_required")
use_pygithub = cls.resolve_setting(use_pygithub, "use_pygithub")
chunk_size = cls.resolve_setting(chunk_size, "chunk_size")
cache = cls.resolve_setting(cache, "cache")
assets_dir = cls.resolve_setting(cache_dir, "assets_dir")
cache_mkdir_kwargs = cls.resolve_setting(cache_mkdir_kwargs, "cache_mkdir_kwargs", merge=True)
clear_cache = cls.resolve_setting(clear_cache, "clear_cache")
show_progress = cls.resolve_setting(show_progress, "show_progress")
pbar_kwargs = cls.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True)
template_context = cls.resolve_setting(template_context, "template_context", merge=True)
if release_name is None or release_name.lower() == "current":
from vectorbtpro._version import __release__
release_name = __release__
if release_name.lower() == "latest":
if token is None:
token = os.environ.get("GITHUB_TOKEN", None)
if token is None and token_required:
raise ValueError("GitHub token is required")
if use_pygithub is None:
from vectorbtpro.utils.module_ import check_installed
use_pygithub = check_installed("github")
if use_pygithub:
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("github")
from github import Github, Auth
from github.GithubException import UnknownObjectException
if token is not None:
g = Github(auth=Auth.Token(token))
else:
g = Github()
try:
repo = g.get_repo(f"{repo_owner}/{repo_name}")
except UnknownObjectException:
raise Exception(f"Repository '{repo_owner}/{repo_name}' not found or access denied")
try:
release = repo.get_latest_release()
except UnknownObjectException:
raise Exception("Latest release not found")
release_name = release.title
else:
headers = {"Accept": "application/vnd.github+json"}
if token is not None:
headers["Authorization"] = f"token {token}"
release_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
response = requests.get(release_url, headers=headers)
response.raise_for_status()
release_info = response.json()
release_name = release_info.get("name")
template_context = flat_merge_dicts(dict(release_name=release_name), template_context)
if isinstance(assets_dir, CustomTemplate):
cache_dir = cls.get_setting("cache_dir")
if isinstance(cache_dir, CustomTemplate):
cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir")
template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context)
release_dir = cls.get_setting("release_dir")
if isinstance(release_dir, CustomTemplate):
release_dir = release_dir.substitute(template_context, eval_id="release_dir")
template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context)
assets_dir = assets_dir.substitute(template_context, eval_id="assets_dir")
if cache:
if assets_dir.exists():
if clear_cache:
remove_dir(assets_dir, missing_ok=True, with_contents=True)
else:
cache_file = None
for file in assets_dir.iterdir():
if file.is_file() and file.name == asset_name:
cache_file = file
break
if cache_file is not None:
return cls.from_json_file(cache_file, release_name=release_name, **kwargs)
if token is None:
token = os.environ.get("GITHUB_TOKEN", None)
if token is None and token_required:
raise ValueError("GitHub token is required")
if use_pygithub is None:
from vectorbtpro.utils.module_ import check_installed
use_pygithub = check_installed("github")
if use_pygithub:
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("github")
from github import Github, Auth
from github.GithubException import UnknownObjectException
if token is not None:
g = Github(auth=Auth.Token(token))
else:
g = Github()
try:
repo = g.get_repo(f"{repo_owner}/{repo_name}")
except UnknownObjectException:
raise Exception(f"Repository '{repo_owner}/{repo_name}' not found or access denied")
releases = repo.get_releases()
found_release = None
for release in releases:
if release.title == release_name:
found_release = release
if found_release is None:
raise Exception(f"Release '{release_name}' not found")
release = found_release
assets = release.get_assets()
if asset_name is not None:
asset = next((a for a in assets if a.name == asset_name), None)
if asset is None:
raise Exception(f"Asset '{asset_name}' not found in release {release_name}")
else:
assets_list = list(assets)
if len(assets_list) == 1:
asset = assets_list[0]
else:
raise Exception("Please specify asset_name")
asset_url = asset.url
else:
headers = {"Accept": "application/vnd.github+json"}
if token is not None:
headers["Authorization"] = f"token {token}"
releases_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases"
response = requests.get(releases_url, headers=headers)
response.raise_for_status()
releases = response.json()
release_info = None
for release in releases:
if release.get("name") == release_name:
release_info = release
if release_info is None:
raise ValueError(f"Release '{release_name}' not found")
assets = release_info.get("assets", [])
if asset_name is not None:
asset = next((a for a in assets if a["name"] == asset_name), None)
if asset is None:
raise Exception(f"Asset '{asset_name}' not found in release {release_name}")
else:
if len(assets) == 1:
asset = assets[0]
else:
raise Exception("Please specify asset_name")
asset_url = asset["url"]
asset_headers = {"Accept": "application/octet-stream"}
if token is not None:
asset_headers["Authorization"] = f"token {token}"
asset_response = requests.get(asset_url, headers=asset_headers, stream=True)
asset_response.raise_for_status()
file_size = int(asset_response.headers.get("Content-Length", 0))
if file_size == 0:
file_size = asset.get("size", 0)
if show_progress is None:
show_progress = True
pbar_kwargs = flat_merge_dicts(
dict(
bar_id=get_caller_qualname(),
unit="iB",
unit_scale=True,
prefix=f"Downloading {asset_name}",
),
pbar_kwargs,
)
if cache:
check_mkdir(assets_dir, **cache_mkdir_kwargs)
cache_file = assets_dir / asset_name
with open(cache_file, "wb") as f:
with ProgressBar(total=file_size, show_progress=show_progress, **pbar_kwargs) as pbar:
for chunk in asset_response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
return cls.from_json_file(cache_file, release_name=release_name, **kwargs)
else:
with io.BytesIO() as bytes_io:
with ProgressBar(total=file_size, show_progress=show_progress, **pbar_kwargs) as pbar:
for chunk in asset_response.iter_content(chunk_size=chunk_size):
if chunk:
bytes_io.write(chunk)
pbar.update(len(chunk))
bytes_ = bytes_io.getvalue()
compression = suggest_compression(asset_name)
if compression is not None and "compression" not in kwargs:
kwargs["compression"] = compression
return cls.from_json_bytes(bytes_, release_name=release_name, **kwargs)
def find_link(
self: VBTAssetT,
link: tp.MaybeList[str],
mode: str = "end",
per_path: bool = False,
single_item: bool = True,
consolidate: bool = True,
allow_empty: bool = False,
**kwargs,
) -> tp.MaybeVBTAsset:
"""Find item(s) corresponding to link(s)."""
def _extend_link(link):
from urllib.parse import urlparse
if not urlparse(link).fragment:
if link.endswith("/"):
return [link, link[:-1]]
return [link, link + "/"]
return [link]
links = link
if mode.lower() in ("exact", "end"):
if isinstance(link, str):
links = _extend_link(link)
elif isinstance(link, list):
from itertools import chain
links = list(chain(*map(_extend_link, link)))
else:
raise TypeError("Link must be either string or list")
found = self.find(links, path="link", mode=mode, per_path=per_path, single_item=single_item, **kwargs)
if isinstance(found, (type(self), list)):
if len(found) == 0:
if allow_empty:
return found
raise NoItemFoundError(f"No item matching '{link}'")
if single_item and len(found) > 1:
if consolidate:
top_parents = self.get_top_parent_links(list(found))
if len(top_parents) == 1:
for i, d in enumerate(found):
if d["link"] == top_parents[0]:
if isinstance(found, type(self)):
return found.replace(data=[d], single_item=True)
return d
links_block = "\n".join([d["link"] for d in found])
raise MultipleItemsFoundError(f"Multiple items matching '{link}':\n\n{links_block}")
return found
@classmethod
def minimize_link(cls, link: str, rules: tp.Optional[tp.Dict[str, str]] = None) -> str:
"""Minimize a single link."""
rules = cls.resolve_setting(rules, "minimize_link_rules", merge=True)
for k, v in rules.items():
link = replace(k, v, link, mode="regex")
return link
def minimize_links(self, rules: tp.Optional[tp.Dict[str, str]] = None) -> tp.MaybeVBTAsset:
"""Minimize links."""
rules = self.resolve_setting(rules, "minimize_link_rules", merge=True)
return self.find_replace(rules, mode="regex")
def minimize(
self,
keys: tp.Optional[tp.List[str]] = None,
links: tp.Optional[bool] = None,
) -> tp.MaybeVBTAsset:
"""Minimize by keeping the most useful information.
If `minimize_links` is True, replaces redundant URL prefixes by templates that can
be easily substituted later."""
keys = self.resolve_setting(keys, "minimize_keys")
links = self.resolve_setting(links, "minimize_links")
new_instance = self.find_remove_empty()
if links:
return new_instance.minimize_links()
if keys:
new_instance = new_instance.remove(keys, skip_missing=True)
return new_instance
def select_previous(self: VBTAssetT, link: str, **kwargs) -> VBTAssetT:
"""Select the previous data item."""
d = self.find_link(link, wrap=False, **kwargs)
d_index = self.index(d)
new_data = []
if d_index > 0:
new_data.append(self.data[d_index - 1])
return self.replace(data=new_data, single_item=True)
def select_next(self: VBTAssetT, link: str, **kwargs) -> VBTAssetT:
"""Select the next data item."""
d = self.find_link(link, wrap=False, **kwargs)
d_index = self.index(d)
new_data = []
if d_index < len(self.data) - 1:
new_data.append(self.data[d_index + 1])
return self.replace(data=new_data, single_item=True)
def to_markdown(
self,
root_metadata_key: tp.Optional[tp.Key] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeVBTAsset:
"""Convert to Markdown.
Uses `VBTAsset.apply` on `vectorbtpro.utils.knowledge.custom_asset_funcs.ToMarkdownAssetFunc`.
Use `root_metadata_key` to provide the root key for the metadata markdown.
If `clean_metadata` is True, removes empty fields from the metadata. Arguments in
`clean_metadata_kwargs` are passed to `vectorbtpro.utils.knowledge.base_asset_funcs.FindRemoveAssetFunc`,
while `dump_metadata_kwargs` are passed to `vectorbtpro.utils.knowledge.base_asset_funcs.DumpAssetFunc`.
Last keyword arguments in `kwargs` are passed to `vectorbtpro.utils.knowledge.formatting.to_markdown`."""
return self.apply(
"to_markdown",
root_metadata_key=root_metadata_key,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
**kwargs,
)
@classmethod
def links_to_paths(
cls,
urls: tp.Iterable[str],
extension: tp.Optional[str] = None,
allow_fragments: bool = True,
) -> tp.List[Path]:
"""Convert links to corresponding paths."""
from urllib.parse import urlparse
url_paths = []
for url in urls:
parsed = urlparse(url, allow_fragments=allow_fragments)
path_parts = [parsed.netloc]
url_path = parsed.path.strip("/")
if url_path:
parts = url_path.split("/")
if parsed.fragment:
path_parts.extend(parts)
if extension is not None:
file_name = parsed.fragment + "." + extension
else:
file_name = parsed.fragment
path_parts.append(file_name)
else:
if len(parts) > 1:
path_parts.extend(parts[:-1])
last_part = parts[-1]
if extension is not None:
file_name = last_part + "." + extension
else:
file_name = last_part
path_parts.append(file_name)
else:
if parsed.fragment:
if extension is not None:
file_name = parsed.fragment + "." + extension
else:
file_name = parsed.fragment
path_parts.append(file_name)
else:
if extension is not None:
path_parts.append("index." + extension)
else:
path_parts.append("index")
url_paths.append(Path(os.path.join(*path_parts)))
return url_paths
def save_to_markdown(
self,
cache: tp.Optional[bool] = None,
cache_dir: tp.Optional[tp.PathLike] = None,
cache_mkdir_kwargs: tp.KwargsLike = None,
clear_cache: tp.Optional[bool] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> Path:
"""Save to Markdown files.
If `cache` is True, uses the cache directory (`markdown_dir` in settings). Otherwise,
creates a temporary directory. If `clear_cache` is True, deletes any existing directory
before creating a new one. Returns the path of the directory where Markdown files are stored.
Keyword arguments are passed to `vectorbtpro.utils.knowledge.custom_asset_funcs.ToMarkdownAssetFunc`.
Last keyword arguments in `kwargs` are passed to `vectorbtpro.utils.knowledge.formatting.to_markdown`."""
import tempfile
from vectorbtpro.utils.knowledge.custom_asset_funcs import ToMarkdownAssetFunc
cache = self.resolve_setting(cache, "cache")
markdown_dir = self.resolve_setting(cache_dir, "markdown_dir")
cache_mkdir_kwargs = self.resolve_setting(cache_mkdir_kwargs, "cache_mkdir_kwargs", merge=True)
clear_cache = self.resolve_setting(clear_cache, "clear_cache")
show_progress = self.resolve_setting(show_progress, "show_progress")
pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True)
template_context = self.resolve_setting(template_context, "template_context", merge=True)
if cache:
if self.release_name:
template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context)
if isinstance(markdown_dir, CustomTemplate):
cache_dir = self.get_setting("cache_dir")
if isinstance(cache_dir, CustomTemplate):
cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir")
template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context)
release_dir = self.get_setting("release_dir")
if isinstance(release_dir, CustomTemplate):
release_dir = release_dir.substitute(template_context, eval_id="release_dir")
template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context)
markdown_dir = markdown_dir.substitute(template_context, eval_id="markdown_dir")
if markdown_dir.exists():
if clear_cache:
remove_dir(markdown_dir, missing_ok=True, with_contents=True)
check_mkdir(markdown_dir, **cache_mkdir_kwargs)
else:
markdown_dir = Path(tempfile.mkdtemp(prefix=get_caller_qualname() + "_"))
link_map = {d["link"]: dict(d) for d in self.data}
url_paths = self.links_to_paths(link_map.keys(), extension="md")
url_file_map = dict(zip(link_map.keys(), [markdown_dir / p for p in url_paths]))
_, kwargs = ToMarkdownAssetFunc.prepare(**kwargs)
if show_progress is None:
show_progress = not self.single_item
prefix = get_caller_qualname().split(".")[-1]
pbar_kwargs = flat_merge_dicts(
dict(
bar_id=get_caller_qualname(),
prefix=prefix,
),
pbar_kwargs,
)
with ProgressBar(total=len(self.data), show_progress=show_progress, **pbar_kwargs) as pbar:
for d in self.data:
if not url_file_map[d["link"]].exists():
markdown_content = ToMarkdownAssetFunc.call(d, **kwargs)
check_mkdir(url_file_map[d["link"]].parent, mkdir=True)
with open(url_file_map[d["link"]], "w", encoding="utf-8") as f:
f.write(markdown_content)
pbar.update()
return markdown_dir
def to_html(
self: VBTAssetT,
root_metadata_key: tp.Optional[tp.Key] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
format_html_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeVBTAsset:
"""Convert to HTML.
Uses `VBTAsset.apply` on `vectorbtpro.utils.knowledge.custom_asset_funcs.ToHTMLAssetFunc`.
Arguments in `format_html_kwargs` are passed to `vectorbtpro.utils.knowledge.formatting.format_html`.
Last keyword arguments in `kwargs` are passed to `vectorbtpro.utils.knowledge.formatting.to_html`.
For other arguments, see `VBTAsset.to_markdown`."""
return self.apply(
"to_html",
root_metadata_key=root_metadata_key,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
to_markdown_kwargs=to_markdown_kwargs,
format_html_kwargs=format_html_kwargs,
**kwargs,
)
@classmethod
def get_top_parent_links(cls, data: tp.List) -> tp.List[str]:
"""Get links of top parents in data."""
link_map = {d["link"]: dict(d) for d in data}
top_parents = []
for d in data:
if d.get("parent", None) is None or d["parent"] not in link_map:
top_parents.append(d["link"])
return top_parents
@property
def top_parent_links(self) -> tp.List[str]:
"""Get links of top parents."""
return self.get_top_parent_links(self.data)
@classmethod
def replace_urls_in_html(cls, html: str, url_map: dict) -> str:
"""Replace URLs in attributes based on a provided mapping."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("bs4")
from bs4 import BeautifulSoup
from urllib.parse import urlparse, urlunparse
soup = BeautifulSoup(html, "html.parser")
for a_tag in soup.find_all("a", href=True):
original_href = a_tag["href"]
if original_href in url_map:
a_tag["href"] = url_map[original_href]
else:
try:
parsed_href = urlparse(original_href)
base_url = urlunparse(parsed_href._replace(fragment=""))
if base_url in url_map:
new_base_url = url_map[base_url]
new_parsed = urlparse(new_base_url)
new_parsed = new_parsed._replace(fragment=parsed_href.fragment)
new_href = urlunparse(new_parsed)
a_tag["href"] = new_href
except ValueError:
pass
return str(soup)
def save_to_html(
self,
cache: tp.Optional[bool] = None,
cache_dir: tp.Optional[tp.PathLike] = None,
cache_mkdir_kwargs: tp.KwargsLike = None,
clear_cache: tp.Optional[bool] = None,
show_progress: tp.Optional[bool] = None,
pbar_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
return_url_map: bool = False,
**kwargs,
) -> tp.Union[Path, tp.Tuple[Path, dict]]:
"""Save to HTML files.
Opens the web browser. Also, returns the path of the directory where HTML files are stored,
and if `return_url_map` is True also returns the link->file map.
In addition, if there are multiple top-level parents, creates an index page.
If `cache` is True, uses the cache directory (`html_dir` in settings). Otherwise, creates
a temporary directory. If `clear_cache` is True, deletes any existing directory before creating a new one.
Keyword arguments are passed to `vectorbtpro.utils.knowledge.custom_asset_funcs.ToHTMLAssetFunc`."""
import tempfile
from vectorbtpro.utils.knowledge.custom_asset_funcs import ToHTMLAssetFunc
cache = self.resolve_setting(cache, "cache")
html_dir = self.resolve_setting(cache_dir, "html_dir")
cache_mkdir_kwargs = self.resolve_setting(cache_mkdir_kwargs, "cache_mkdir_kwargs", merge=True)
clear_cache = self.resolve_setting(clear_cache, "clear_cache")
show_progress = self.resolve_setting(show_progress, "show_progress")
pbar_kwargs = self.resolve_setting(pbar_kwargs, "pbar_kwargs", merge=True)
if cache:
if self.release_name:
template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context)
if isinstance(html_dir, CustomTemplate):
cache_dir = self.get_setting("cache_dir")
if isinstance(cache_dir, CustomTemplate):
cache_dir = cache_dir.substitute(template_context, eval_id="cache_dir")
template_context = flat_merge_dicts(dict(cache_dir=cache_dir), template_context)
release_dir = self.get_setting("release_dir")
if isinstance(release_dir, CustomTemplate):
release_dir = release_dir.substitute(template_context, eval_id="release_dir")
template_context = flat_merge_dicts(dict(release_dir=release_dir), template_context)
html_dir = html_dir.substitute(template_context, eval_id="html_dir")
if html_dir.exists():
if clear_cache:
remove_dir(html_dir, missing_ok=True, with_contents=True)
check_mkdir(html_dir, **cache_mkdir_kwargs)
else:
html_dir = Path(tempfile.mkdtemp(prefix=get_caller_qualname() + "_"))
link_map = {d["link"]: dict(d) for d in self.data}
top_parents = self.top_parent_links
if len(top_parents) > 1:
link_map["/"] = {}
url_paths = self.links_to_paths(link_map.keys(), extension="html")
url_file_map = dict(zip(link_map.keys(), [html_dir / p for p in url_paths]))
url_map = {k: "file://" + str(v.resolve()) for k, v in url_file_map.items()}
_, kwargs = ToHTMLAssetFunc.prepare(**kwargs)
if len(top_parents) > 1:
entry_link = "/"
if not url_file_map[entry_link].exists():
html = ToHTMLAssetFunc.call([link_map[link] for link in top_parents], **kwargs)
html = self.replace_urls_in_html(html, url_map)
check_mkdir(url_file_map[entry_link].parent, mkdir=True)
with open(url_file_map[entry_link], "w", encoding="utf-8") as f:
f.write(html)
if show_progress is None:
show_progress = not self.single_item
prefix = get_caller_qualname().split(".")[-1]
pbar_kwargs = flat_merge_dicts(
dict(
bar_id=get_caller_qualname(),
prefix=prefix,
),
pbar_kwargs,
)
with ProgressBar(total=len(self.data), show_progress=show_progress, **pbar_kwargs) as pbar:
for d in self.data:
if not url_file_map[d["link"]].exists():
html = ToHTMLAssetFunc.call(d, **kwargs)
html = self.replace_urls_in_html(html, url_map)
check_mkdir(url_file_map[d["link"]].parent, mkdir=True)
with open(url_file_map[d["link"]], "w", encoding="utf-8") as f:
f.write(html)
pbar.update()
if return_url_map:
return html_dir, url_map
return html_dir
def browse(
self,
entry_link: tp.Optional[str] = None,
find_kwargs: tp.KwargsLike = None,
open_browser: tp.Optional[bool] = None,
**kwargs,
) -> Path:
"""Browse one or more HTML pages.
Opens the web browser. Also, returns the path of the directory where HTML files are stored.
Use `entry_link` to specify the link of the page that should be displayed first.
If `entry_link` is None and there are multiple top-level parents, displays them as an index.
If it's not None, it will be matched using `VBTAsset.find_link` and `find_kwargs`.
Keyword arguments are passed to `PagesAsset.save_to_html`."""
open_browser = self.resolve_setting(open_browser, "open_browser")
if entry_link is None:
if len(self.data) == 1:
entry_link = self.data[0]["link"]
else:
top_parents = self.top_parent_links
if len(top_parents) == 1:
entry_link = top_parents[0]
else:
entry_link = "/"
else:
if find_kwargs is None:
find_kwargs = {}
d = self.find_link(entry_link, wrap=False, **find_kwargs)
entry_link = d["link"]
html_dir, url_map = self.save_to_html(return_url_map=True, **kwargs)
if open_browser:
import webbrowser
webbrowser.open(url_map[entry_link])
return html_dir
def display(
self,
link: tp.Optional[str] = None,
find_kwargs: tp.KwargsLike = None,
open_browser: tp.Optional[bool] = None,
html_template: tp.Optional[str] = None,
style_extras: tp.Optional[tp.MaybeList[str]] = None,
head_extras: tp.Optional[tp.MaybeList[str]] = None,
body_extras: tp.Optional[tp.MaybeList[str]] = None,
invert_colors: tp.Optional[bool] = None,
title: str = "",
template_context: tp.KwargsLike = None,
**kwargs,
) -> Path:
"""Display as an HTML page.
If there are multiple HTML pages, shows them as iframes inside a parent HTML page with pagination.
For this, uses `vectorbtpro.utils.knowledge.formatting.FormatHTML`.
Opens the web browser. Also, returns the path of the temporary HTML file.
!!! note
The file __won't__ be deleted automatically.
"""
import tempfile
open_browser = self.resolve_setting(open_browser, "open_browser", sub_path="display")
if link is not None:
if find_kwargs is None:
find_kwargs = {}
instance = self.find_link(link, **find_kwargs)
else:
instance = self
html = instance.to_html(wrap=False, single_item=True, **kwargs)
if len(instance) > 1:
from vectorbtpro.utils.config import ExtSettingsPath
encoded_pages = map(lambda x: base64.b64encode(x.encode("utf-8")).decode("ascii"), html)
pages = "[\n" + ",\n".join(f' "{page}"' for page in encoded_pages) + "\n]"
ext_settings_paths = []
for cls_ in type(self).__mro__[::-1]:
if issubclass(cls_, VBTAsset):
if not isinstance(cls_._settings_path, str):
raise TypeError("_settings_path for VBTAsset and its subclasses must be a string")
ext_settings_paths.append((FormatHTML, cls_._settings_path + ".display"))
with ExtSettingsPath(ext_settings_paths):
html = FormatHTML(
html_template=html_template,
style_extras=style_extras,
head_extras=head_extras,
body_extras=body_extras,
invert_colors=invert_colors,
auto_scroll=False,
).format_html(title=title, pages=pages)
with tempfile.NamedTemporaryFile(
"w",
encoding="utf-8",
prefix=get_caller_qualname() + "_",
suffix=".html",
delete=False,
) as f:
f.write(html)
file_path = Path(f.name)
if open_browser:
import webbrowser
webbrowser.open("file://" + str(file_path.resolve()))
return file_path
@classmethod
def prepare_mention_target(
cls,
target: str,
as_code: bool = False,
as_regex: bool = True,
allow_prefix: bool = False,
allow_suffix: bool = False,
) -> str:
"""Prepare a mention target."""
if as_regex:
escaped_target = re.escape(target)
new_target = ""
if not allow_prefix and re.match(r"\w", target[0]):
new_target += r"(? tp.List[str]:
"""Split a class name constituent parts."""
return re.findall(r"[A-Z]+(?=[A-Z][a-z]|$)|[A-Z][a-z]+", name)
@classmethod
def get_class_abbrs(cls, name: str) -> tp.List[str]:
"""Convert a class name to snake case and its abbreviated versions."""
from itertools import product
parts = cls.split_class_name(name)
replacement_lists = []
for i, part in enumerate(parts):
replacements = [part.lower()]
if i == 0 and f"{part}_" in class_abbr_config:
replacements.extend(class_abbr_config[f"{part}_"])
if part in class_abbr_config:
replacements.extend(class_abbr_config[part])
replacement_lists.append(replacements)
all_combinations = list(product(*replacement_lists))
snake_case_names = ["_".join(combo) for combo in all_combinations]
return snake_case_names
@classmethod
def generate_refname_targets(
cls,
refname: str,
resolve: bool = True,
incl_shortcuts: tp.Optional[bool] = None,
incl_shortcut_access: tp.Optional[bool] = None,
incl_shortcut_call: tp.Optional[bool] = None,
incl_instances: tp.Optional[bool] = None,
as_code: tp.Optional[bool] = None,
as_regex: tp.Optional[bool] = None,
allow_prefix: tp.Optional[bool] = None,
allow_suffix: tp.Optional[bool] = None,
) -> tp.List[str]:
"""Generate reference name targets.
If `incl_shortcuts` is True, includes shortcuts found in `import vectorbtpro as vbt`.
In addition, if `incl_shortcut_access` is True and the object is a class or module, includes a version
with attribute access, and if `incl_shortcut_call` is True and the object is callable, includes a version
that is being called.
If `incl_instances` is True, includes typical short names of classes, which
include the snake-cased class name and mapped name parts found in `class_abbr_config`.
Prepares each mention target with `VBTAsset.prepare_mention_target`."""
from vectorbtpro.utils.module_ import annotate_refname_parts
import vectorbtpro as vbt
incl_shortcuts = cls.resolve_setting(incl_shortcuts, "incl_shortcuts")
incl_shortcut_access = cls.resolve_setting(incl_shortcut_access, "incl_shortcut_access")
incl_shortcut_call = cls.resolve_setting(incl_shortcut_call, "incl_shortcut_call")
incl_instances = cls.resolve_setting(incl_instances, "incl_instances")
as_code = cls.resolve_setting(as_code, "as_code")
as_regex = cls.resolve_setting(as_regex, "as_regex")
allow_prefix = cls.resolve_setting(allow_prefix, "allow_prefix")
allow_suffix = cls.resolve_setting(allow_suffix, "allow_suffix")
def _prepare_target(
target,
_as_code=as_code,
_as_regex=as_regex,
_allow_prefix=allow_prefix,
_allow_suffix=allow_suffix,
):
return cls.prepare_mention_target(
target,
as_code=_as_code,
as_regex=_as_regex,
allow_prefix=_allow_prefix,
allow_suffix=_allow_suffix,
)
targets = set()
new_target = _prepare_target(refname)
targets.add(new_target)
refname_parts = refname.split(".")
if resolve:
annotated_parts = annotate_refname_parts(refname)
if len(annotated_parts) >= 2 and isinstance(annotated_parts[-2]["obj"], type):
cls_refname = ".".join(refname_parts[:-1])
cls_aliases = {annotated_parts[-2]["name"]}
attr_aliases = set()
for k, v in vbt.__dict__.items():
v_refname = prepare_refname(v, raise_error=False)
if v_refname is not None:
if v_refname == cls_refname:
cls_aliases.add(k)
elif v_refname == refname:
attr_aliases.add(k)
if incl_shortcuts:
new_target = _prepare_target("vbt." + k)
targets.add(new_target)
if incl_shortcuts:
for cls_alias in cls_aliases:
new_target = _prepare_target(cls_alias + "." + annotated_parts[-1]["name"])
targets.add(new_target)
for attr_alias in attr_aliases:
if incl_shortcut_call and callable(annotated_parts[-1]["obj"]):
new_target = _prepare_target(attr_alias + "(")
targets.add(new_target)
if incl_instances:
for cls_alias in cls_aliases:
for class_abbr in cls.get_class_abbrs(cls_alias):
new_target = _prepare_target(class_abbr + "." + annotated_parts[-1]["name"])
targets.add(new_target)
else:
if len(refname_parts) >= 2:
module_name = ".".join(refname_parts[:-1])
attr_name = refname_parts[-1]
new_target = _prepare_target("from {} import {}".format(module_name, attr_name))
targets.add(new_target)
aliases = {annotated_parts[-1]["name"]}
for k, v in vbt.__dict__.items():
v_refname = prepare_refname(v, raise_error=False)
if v_refname is not None:
if v_refname == refname:
aliases.add(k)
if incl_shortcuts:
new_target = _prepare_target("vbt." + k)
targets.add(new_target)
if incl_shortcuts:
for alias in aliases:
if incl_shortcut_access and isinstance(annotated_parts[-1]["obj"], (type, ModuleType)):
new_target = _prepare_target(alias + ".")
targets.add(new_target)
if incl_shortcut_call and callable(annotated_parts[-1]["obj"]):
new_target = _prepare_target(alias + "(")
targets.add(new_target)
if incl_instances and isinstance(annotated_parts[-1]["obj"], type):
for alias in aliases:
for class_abbr in cls.get_class_abbrs(alias):
new_target = _prepare_target(class_abbr + " =")
targets.add(new_target)
new_target = _prepare_target(class_abbr + ".")
targets.add(new_target)
return sorted(targets)
def generate_mention_targets(
self,
obj: tp.MaybeList,
*,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
incl_base_attr: tp.Optional[bool] = None,
incl_shortcuts: tp.Optional[bool] = None,
incl_shortcut_access: tp.Optional[bool] = None,
incl_shortcut_call: tp.Optional[bool] = None,
incl_instances: tp.Optional[bool] = None,
as_code: tp.Optional[bool] = None,
as_regex: tp.Optional[bool] = None,
allow_prefix: tp.Optional[bool] = None,
allow_suffix: tp.Optional[bool] = None,
) -> tp.List[str]:
"""Generate mention targets.
Prepares the object reference with `vectorbtpro.utils.module_.prepare_refname`.
If an attribute is provided, checks whether the attribute is defined by the object itself
or by one of its base classes. If the latter and `incl_base_attr` is True, generates
reference name targets for both the object attribute and the base class attribute.
Generates reference name targets with `VBTAsset.generate_refname_targets`."""
from vectorbtpro.utils.module_ import prepare_refname
incl_base_attr = self.resolve_setting(incl_base_attr, "incl_base_attr")
targets = []
if not isinstance(obj, list):
objs = [obj]
else:
objs = obj
for obj in objs:
obj_refname = prepare_refname(obj, module=module, resolve=resolve)
if attr is not None:
checks.assert_instance_of(attr, str, arg_name="attr")
if isinstance(obj, tuple):
attr_obj = (*obj, attr)
else:
attr_obj = (obj, attr)
base_attr_refname = prepare_refname(attr_obj, module=module, resolve=resolve)
obj_refname += "." + attr
if base_attr_refname == obj_refname:
obj_refname = base_attr_refname
base_attr_refname = None
else:
base_attr_refname = None
targets.extend(
self.generate_refname_targets(
obj_refname,
resolve=resolve,
incl_shortcuts=incl_shortcuts,
incl_shortcut_access=incl_shortcut_access,
incl_shortcut_call=incl_shortcut_call,
incl_instances=incl_instances,
as_code=as_code,
as_regex=as_regex,
allow_prefix=allow_prefix,
allow_suffix=allow_suffix,
)
)
if incl_base_attr and base_attr_refname is not None:
targets.extend(
self.generate_refname_targets(
base_attr_refname,
resolve=resolve,
incl_shortcuts=incl_shortcuts,
incl_shortcut_access=incl_shortcut_access,
incl_shortcut_call=incl_shortcut_call,
incl_instances=incl_instances,
as_code=as_code,
as_regex=as_regex,
allow_prefix=allow_prefix,
allow_suffix=allow_suffix,
)
)
seen = set()
targets = [x for x in targets if not (x in seen or seen.add(x))]
return targets
@classmethod
def merge_mention_targets(cls, targets: tp.List[str], as_regex: bool = True) -> str:
"""Merge mention targets into a single regular expression."""
if as_regex:
prefixed_targets = []
non_prefixed_targets = []
common_prefix = r"(? tp.MaybeVBTAsset:
"""Find mentions of a VBT object.
Generates mention targets with `VBTAsset.generate_mention_targets`.
Provide custom mentions in `incl_custom`. If regular expressions are provided,
set `is_custom_regex` to True.
If `as_code` is True, uses `VBTAsset.find_code`, otherwise, uses `VBTAsset.find`.
If `as_regex` is True, search is refined by using regular expressions. For instance,
`vbt.PF` may match `vbt.PFO` if RegEx is not used.
If `merge_targets`, uses `VBTAsset.merge_mention_targets` to reduce the number of targets.
Sets `as_regex` to True if False (but after the targets were generated)."""
as_code = self.resolve_setting(as_code, "as_code")
as_regex = self.resolve_setting(as_regex, "as_regex")
allow_prefix = self.resolve_setting(allow_prefix, "allow_prefix")
allow_suffix = self.resolve_setting(allow_suffix, "allow_suffix")
merge_targets = self.resolve_setting(merge_targets, "merge_targets")
mention_targets = self.generate_mention_targets(
obj,
attr=attr,
module=module,
resolve=resolve,
incl_shortcuts=incl_shortcuts,
incl_shortcut_access=incl_shortcut_access,
incl_shortcut_call=incl_shortcut_call,
incl_instances=incl_instances,
as_code=as_code,
as_regex=as_regex,
allow_prefix=allow_prefix,
allow_suffix=allow_suffix,
)
if incl_custom:
def _prepare_target(
target,
_as_code=as_code,
_as_regex=as_regex,
_allow_prefix=allow_prefix,
_allow_suffix=allow_suffix,
):
return self.prepare_mention_target(
target,
as_code=_as_code,
as_regex=_as_regex,
allow_prefix=_allow_prefix,
allow_suffix=_allow_suffix,
)
if isinstance(incl_custom, str):
incl_custom = [incl_custom]
for custom in incl_custom:
new_target = _prepare_target(custom, _as_regex=is_custom_regex)
if new_target not in mention_targets:
mention_targets.append(new_target)
if merge_targets:
mention_targets = self.merge_mention_targets(mention_targets, as_regex=as_regex)
as_regex = True
if as_code:
mentions_asset = self.find_code(
mention_targets,
escape_target=not as_regex,
path=path,
per_path=per_path,
return_type=return_type,
**kwargs,
)
elif as_regex:
mentions_asset = self.find(
mention_targets,
mode="regex",
path=path,
per_path=per_path,
return_type=return_type,
**kwargs,
)
else:
mentions_asset = self.find(
mention_targets,
path=path,
per_path=per_path,
return_type=return_type,
**kwargs,
)
return mentions_asset
@classmethod
def resolve_spec_settings_path(cls) -> dict:
"""Resolve specialized settings paths."""
spec_settings_path = {}
for cls_ in cls.__mro__[::-1]:
if issubclass(cls_, VBTAsset):
if not isinstance(cls_._settings_path, str):
raise TypeError("_settings_path for VBTAsset and its subclasses must be a string")
if "knowledge" not in spec_settings_path:
spec_settings_path["knowledge"] = []
spec_settings_path["knowledge"].append(cls_._settings_path)
if "knowledge.chat" not in spec_settings_path:
spec_settings_path["knowledge.chat"] = []
spec_settings_path["knowledge.chat"].append(cls_._settings_path + ".chat")
return spec_settings_path
def embed(self, *args, template_context: tp.KwargsLike = None, **kwargs) -> tp.Optional[tp.MaybeVBTAsset]:
template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context)
spec_settings_path = self.resolve_spec_settings_path()
if spec_settings_path:
with SpecSettingsPath(spec_settings_path):
return KnowledgeAsset.embed(self, *args, template_context=template_context, **kwargs)
return KnowledgeAsset.embed(self, *args, template_context=template_context, **kwargs)
def rank(self, *args, template_context: tp.KwargsLike = None, **kwargs) -> tp.MaybeVBTAsset:
template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context)
spec_settings_path = self.resolve_spec_settings_path()
if spec_settings_path:
with SpecSettingsPath(spec_settings_path):
return KnowledgeAsset.rank(self, *args, template_context=template_context, **kwargs)
return KnowledgeAsset.rank(self, *args, template_context=template_context, **kwargs)
def create_chat(self, *args, template_context: tp.KwargsLike = None, **kwargs) -> tp.Completions:
template_context = flat_merge_dicts(dict(release_name=self.release_name), template_context)
spec_settings_path = self.resolve_spec_settings_path()
if spec_settings_path:
with SpecSettingsPath(spec_settings_path):
return KnowledgeAsset.create_chat(self, *args, template_context=template_context, **kwargs)
return KnowledgeAsset.create_chat(self, *args, template_context=template_context, **kwargs)
@hybrid_method
def chat(cls_or_self, *args, template_context: tp.KwargsLike = None, **kwargs) -> tp.MaybeChatOutput:
if not isinstance(cls_or_self, type):
template_context = flat_merge_dicts(dict(release_name=cls_or_self.release_name), template_context)
spec_settings_path = cls_or_self.resolve_spec_settings_path()
if spec_settings_path:
with SpecSettingsPath(spec_settings_path):
return KnowledgeAsset.chat.__func__(cls_or_self, *args, template_context=template_context, **kwargs)
return KnowledgeAsset.chat.__func__(cls_or_self, *args, template_context=template_context, **kwargs)
PagesAssetT = tp.TypeVar("PagesAssetT", bound="PagesAsset")
class PagesAsset(VBTAsset):
"""Class for working with website pages.
Has the following fields:
* link: URL of the page (without fragment), such as "https://vectorbt.pro/features/data/", or
URL of the heading (with fragment), such as "https://vectorbt.pro/features/data/#trading-view"
* parent: URL of the parent page or heading. For example, a heading 1 is a parent of a heading 2.
* children: List of URLs of the child pages and/or headings. For example, a heading 2 is a child of a heading 1.
* name: Name of the page or heading. Within the API, the name of the object that the heading represents,
such as "Portfolio.from_signals".
* type: Type of the page or heading, such as "page", "heading 1", "heading 2", etc.
* icon: Icon, such as "material-brain"
* tags: List of tags, such as ["portfolio", "records"]
* content: String content of the page or heading. Can be None in pages that solely redirect.
* obj_type: Within the API, the type of the object that the heading represents, such as "property"
* github_link: Within the API, the URL to the source code of the object that the heading represents
For defaults, see `assets.pages` in `vectorbtpro._settings.knowledge`."""
_settings_path: tp.SettingsPath = "knowledge.assets.pages"
def descend_links(self: PagesAssetT, links: tp.List[str], **kwargs) -> PagesAssetT:
"""Descend links by removing redundant ones.
Only headings are descended."""
redundant_links = set()
new_data = {}
for link in links:
if link in redundant_links:
continue
descendant_headings = self.select_descendant_headings(link, incl_link=True)
for d in descendant_headings:
if d["link"] != link:
redundant_links.add(d["link"])
new_data[d["link"]] = d
for link in links:
if link in redundant_links and link in new_data:
del new_data[link]
return self.replace(data=list(new_data.values()), **kwargs)
def aggregate_links(
self: PagesAssetT,
links: tp.List[str],
aggregate_kwargs: tp.KwargsLike = None,
**kwargs,
) -> PagesAssetT:
"""Aggregate links by removing redundant ones.
Only headings are aggregated."""
if aggregate_kwargs is None:
aggregate_kwargs = {}
redundant_links = set()
new_data = {}
for link in links:
if link in redundant_links:
continue
descendant_headings = self.select_descendant_headings(link, incl_link=True)
for d in descendant_headings:
if d["link"] != link:
redundant_links.add(d["link"])
descendant_headings = descendant_headings.aggregate(**aggregate_kwargs)
new_data[link] = descendant_headings[0]
for link in links:
if link in redundant_links and link in new_data:
del new_data[link]
return self.replace(data=list(new_data.values()), **kwargs)
def find_page(
self: PagesAssetT,
link: tp.MaybeList[str],
aggregate: bool = False,
aggregate_kwargs: tp.KwargsLike = None,
incl_descendants: bool = False,
single_item: bool = True,
**kwargs,
) -> tp.MaybePagesAsset:
"""Find the page(s) corresponding to link(s).
Keyword arguments are passed to `VBTAsset.find_link`."""
found = self.find_link(link, single_item=single_item, **kwargs)
if not isinstance(found, (type(self), list)):
return found
if aggregate:
return self.aggregate_links(
[d["link"] for d in found],
aggregate_kwargs=aggregate_kwargs,
single_item=single_item,
)
if incl_descendants:
return self.descend_links(
[d["link"] for d in found],
single_item=single_item,
)
return found
def find_refname(
self,
refname: tp.MaybeList[str],
**kwargs,
) -> tp.MaybePagesAsset:
"""Find the page corresponding to a reference."""
if isinstance(refname, list):
link = list(map(lambda x: f"#({re.escape(x)})$", refname))
else:
link = f"#({re.escape(refname)})$"
return self.find_page(link, mode="regex", **kwargs)
def find_obj(
self,
obj: tp.Any,
*,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
**kwargs,
) -> tp.MaybePagesAsset:
"""Find the page corresponding a single (internal) object or reference name.
Prepares the reference with `vectorbtpro.utils.module_.prepare_refname`."""
if attr is not None:
checks.assert_instance_of(attr, str, arg_name="attr")
if isinstance(obj, tuple):
obj = (*obj, attr)
else:
obj = (obj, attr)
refname = prepare_refname(obj, module=module, resolve=resolve)
return self.find_refname(refname, **kwargs)
@classmethod
def parse_content_links(cls, content: str) -> tp.List[str]:
"""Parse all links from a content."""
link_pattern = r'(? tp.Optional[str]:
"""Parse the reference name from a link."""
if "/api/" not in link:
return None
if "#" in link:
refname = link.split("#")[1]
if refname.startswith("vectorbtpro"):
return refname
return None
return "vectorbtpro." + ".".join(link.split("/api/")[1].strip("/").split("/"))
@classmethod
def is_link_module(cls, link: str) -> bool:
"""Return whether a link is a module."""
if "/api/" not in link:
return False
if "#" not in link:
return True
refname = link.split("#")[1]
if "/".join(refname.split(".")) in link:
return True
return False
def find_obj_api(
self,
obj: tp.MaybeList,
*,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
use_parent: tp.Optional[bool] = None,
use_base_parents: tp.Optional[bool] = None,
use_ref_parents: tp.Optional[bool] = None,
incl_bases: tp.Union[None, bool, int] = None,
incl_ancestors: tp.Union[None, bool, int] = None,
incl_base_ancestors: tp.Union[None, bool, int] = None,
incl_refs: tp.Union[None, bool, int] = None,
incl_descendants: tp.Optional[bool] = None,
incl_ancestor_descendants: tp.Optional[bool] = None,
incl_ref_descendants: tp.Optional[bool] = None,
aggregate: tp.Optional[bool] = None,
aggregate_ancestors: tp.Optional[bool] = None,
aggregate_refs: tp.Optional[bool] = None,
aggregate_kwargs: tp.KwargsLike = None,
topo_sort: tp.Optional[bool] = None,
return_refname_graph: bool = False,
) -> tp.Union[PagesAssetT, tp.Tuple[PagesAssetT, dict]]:
"""Find API pages and headings relevant to object(s).
Prepares the object reference with `vectorbtpro.utils.module_.prepare_refname`.
If `incl_bases` is True, extends the asset with the base classes/attributes if the object is
a class/attribute. For instance, `vectorbtpro.portfolio.base.Portfolio` has
`vectorbtpro.generic.analyzable.Analyzable` as one of its base classes. It can also be an
integer indicating the maximum inheritance level. If `obj` is a module, then bases are sub-modules.
If `incl_ancestors` is True, extends the asset with the ancestors of the object.
For instance, `vectorbtpro.portfolio.base.Portfolio` has `vectorbtpro.portfolio.base` as its ancestor.
It can also be an integer indicating the maximum inheritance level. Provide `incl_base_ancestors`
to override `incl_ancestors` for base classes/attributes.
If `incl_refs` is True, extends the asset with the references found in the content of the object.
It can also be an integer indicating the maximum reference level. Defaults to False for modules
and classes, and True otherwise. If resolution of reference names is disabled, defaults to False.
If `incl_descendants` is True, extends the asset page or heading with any descendant headings.
Provide `incl_ancestor_descendants` and `incl_ref_descendants` to override `incl_descendants`
for ancestors and references respectively.
If `aggregate` is True, aggregates any descendant headings into pages for this object
and all base classes/attributes. Provide `aggregate_ancestors` and `aggregate_refs` to
override `aggregate` for ancestors and references respectively.
If `topo_sort` is True, creates a topological graph from all reference names and sorts pages
and headings based on this graph. Use `return_refname_graph` to True to also return the graph."""
from vectorbtpro.utils.module_ import prepare_refname, annotate_refname_parts
incl_bases = self.resolve_setting(incl_bases, "incl_bases")
incl_ancestors = self.resolve_setting(incl_ancestors, "incl_ancestors")
incl_base_ancestors = self.resolve_setting(incl_base_ancestors, "incl_base_ancestors")
incl_refs = self.resolve_setting(incl_refs, "incl_refs")
incl_descendants = self.resolve_setting(incl_descendants, "incl_descendants")
incl_ancestor_descendants = self.resolve_setting(incl_ancestor_descendants, "incl_ancestor_descendants")
incl_ref_descendants = self.resolve_setting(incl_ref_descendants, "incl_ref_descendants")
aggregate = self.resolve_setting(aggregate, "aggregate")
aggregate_ancestors = self.resolve_setting(aggregate_ancestors, "aggregate_ancestors")
aggregate_refs = self.resolve_setting(aggregate_refs, "aggregate_refs")
topo_sort = self.resolve_setting(topo_sort, "topo_sort")
base_refnames = []
base_refnames_set = set()
if not isinstance(obj, list):
objs = [obj]
else:
objs = obj
for obj in objs:
if attr is not None:
checks.assert_instance_of(attr, str, arg_name="attr")
if isinstance(obj, tuple):
obj = (*obj, attr)
else:
obj = (obj, attr)
obj_refname = prepare_refname(obj, module=module, resolve=resolve)
refname_graph = defaultdict(list)
if resolve:
annotated_parts = annotate_refname_parts(obj_refname)
if isinstance(annotated_parts[-1]["obj"], ModuleType):
_module = annotated_parts[-1]["obj"]
_cls = None
_attr = None
elif isinstance(annotated_parts[-1]["obj"], type):
_module = None
_cls = annotated_parts[-1]["obj"]
_attr = None
elif len(annotated_parts) >= 2 and isinstance(annotated_parts[-2]["obj"], type):
_module = None
_cls = annotated_parts[-2]["obj"]
_attr = annotated_parts[-1]["name"]
else:
_module = None
_cls = None
_attr = None
if use_parent is None:
use_parent = _cls is not None and _attr is None
if not aggregate and not incl_descendants:
use_parent = False
use_base_parents = False
if incl_refs is None:
incl_refs = _module is None and _cls is None
if _cls is not None and incl_bases:
level_classes = defaultdict(set)
visited = set()
queue = deque([(_cls, 0)])
while queue:
current_cls, current_level = queue.popleft()
if current_cls in visited:
continue
visited.add(current_cls)
level_classes[current_level].add(current_cls)
for base in current_cls.__bases__:
queue.append((base, current_level + 1))
mro = inspect.getmro(_cls)
classes = []
levels = list(level_classes.keys())
if not isinstance(incl_bases, bool):
if isinstance(incl_bases, int):
levels = levels[: incl_bases + 1]
else:
raise TypeError(f"Invalid incl_bases: {incl_bases}")
for level in levels:
classes.extend([_cls for _cls in mro if _cls in level_classes[level]])
for c in classes:
if c.__module__.split(".")[0] != "vectorbtpro":
continue
if _attr is not None:
if not hasattr(c, _attr):
continue
refname = prepare_refname((c, _attr))
else:
refname = prepare_refname(c)
if (use_parent and refname == obj_refname) or use_base_parents:
refname = ".".join(refname.split(".")[:-1])
if refname not in base_refnames_set:
base_refnames.append(refname)
base_refnames_set.add(refname)
for b in c.__bases__:
if b.__module__.split(".")[0] == "vectorbtpro":
if _attr is not None:
if not hasattr(b, _attr):
continue
b_refname = prepare_refname((b, _attr))
else:
b_refname = prepare_refname(b)
if use_base_parents:
b_refname = ".".join(b_refname.split(".")[:-1])
if refname != b_refname:
refname_graph[refname].append(b_refname)
elif _module is not None and hasattr(_module, "__path__") and incl_bases:
base_refnames.append(_module.__name__)
base_refnames_set.add(_module.__name__)
refname_level = {}
refname_level[_module.__name__] = 0
for _, refname, _ in pkgutil.walk_packages(_module.__path__, prefix=f"{_module.__name__}."):
if refname not in base_refnames_set:
parent_refname = ".".join(refname.split(".")[:-1])
if not isinstance(incl_bases, bool):
if isinstance(incl_bases, int):
if refname_level[parent_refname] + 1 > incl_bases:
continue
else:
raise TypeError(f"Invalid incl_bases: {incl_bases}")
base_refnames.append(refname)
base_refnames_set.add(refname)
refname_level[refname] = refname_level[parent_refname] + 1
if parent_refname != refname:
refname_graph[parent_refname].append(refname)
else:
base_refnames.append(obj_refname)
base_refnames_set.add(obj_refname)
else:
if incl_refs is None:
incl_refs = False
base_refnames.append(obj_refname)
base_refnames_set.add(obj_refname)
api_asset = self.find_refname(
base_refnames,
single_item=False,
incl_descendants=incl_descendants,
aggregate=aggregate,
aggregate_kwargs=aggregate_kwargs,
allow_empty=True,
wrap=True,
)
if len(api_asset) == 0:
return api_asset
if not topo_sort:
refname_indices = {refname: [] for refname in base_refnames}
remaining_indices = []
for i, d in enumerate(api_asset):
refname = self.parse_link_refname(d["link"])
if refname is not None:
while refname not in refname_indices:
if not refname:
break
refname = ".".join(refname.split(".")[:-1])
if refname:
refname_indices[refname].append(i)
else:
remaining_indices.append(i)
get_indices = [i for v in refname_indices.values() for i in v] + remaining_indices
api_asset = api_asset.get_items(get_indices)
if incl_ancestors or incl_refs:
refnames_aggregated = {}
for d in api_asset:
refname = self.parse_link_refname(d["link"])
if refname is not None:
refnames_aggregated[refname] = aggregate
to_ref_api_asset = api_asset
if incl_ancestors:
anc_refnames = []
anc_refnames_set = set(refnames_aggregated.keys())
for d in api_asset:
child_refname = refname = self.parse_link_refname(d["link"])
if refname is not None:
if incl_base_ancestors or refname == obj_refname:
refname = ".".join(refname.split(".")[:-1])
anc_level = 1
while refname:
if isinstance(incl_base_ancestors, bool) or refname == obj_refname:
if not isinstance(incl_ancestors, bool):
if isinstance(incl_ancestors, int):
if anc_level > incl_ancestors:
break
else:
raise TypeError(f"Invalid incl_ancestors: {incl_ancestors}")
else:
if not isinstance(incl_base_ancestors, bool):
if isinstance(incl_base_ancestors, int):
if anc_level > incl_base_ancestors:
break
else:
raise TypeError(f"Invalid incl_base_ancestors: {incl_base_ancestors}")
if refname not in anc_refnames_set:
anc_refnames.append(refname)
anc_refnames_set.add(refname)
if refname != child_refname:
refname_graph[refname].append(child_refname)
child_refname = refname
refname = ".".join(refname.split(".")[:-1])
anc_level += 1
anc_api_asset = self.find_refname(
anc_refnames,
single_item=False,
incl_descendants=incl_ancestor_descendants,
aggregate=aggregate_ancestors,
aggregate_kwargs=aggregate_kwargs,
allow_empty=True,
wrap=True,
)
if aggregate_ancestors or incl_ancestor_descendants:
obj_index = None
for i, d in enumerate(api_asset):
d_refname = self.parse_link_refname(d["link"])
if d_refname == obj_refname:
obj_index = i
break
if obj_index is not None:
del api_asset[obj_index]
for d in anc_api_asset:
refname = self.parse_link_refname(d["link"])
if refname is not None:
refnames_aggregated[refname] = aggregate_ancestors
api_asset = anc_api_asset + api_asset
if incl_refs:
if not aggregate and not incl_descendants:
use_ref_parents = False
main_ref_api_asset = None
ref_api_asset = to_ref_api_asset
while incl_refs:
content_refnames = []
content_refnames_set = set()
for d in ref_api_asset:
d_refname = self.parse_link_refname(d["link"])
if d_refname is not None:
for link in self.parse_content_links(d["content"]):
if "/api/" in link:
refname = self.parse_link_refname(link)
if refname is not None:
if use_ref_parents and not self.is_link_module(link):
refname = ".".join(refname.split(".")[:-1])
if refname not in content_refnames_set:
content_refnames.append(refname)
content_refnames_set.add(refname)
if d_refname != refname and refname not in refname_graph:
refname_graph[d_refname].append(refname)
ref_refnames = []
ref_refnames_set = set(refnames_aggregated.keys()) | content_refnames_set
for refname in content_refnames:
if refname in refnames_aggregated and (refnames_aggregated[refname] or not aggregate_refs):
continue
_refname = refname
while _refname:
_refname = ".".join(_refname.split(".")[:-1])
if _refname in ref_refnames_set and refnames_aggregated.get(_refname, aggregate_refs):
break
if not _refname:
ref_refnames.append(refname)
if len(ref_refnames) == 0:
break
ref_api_asset = self.find_refname(
ref_refnames,
single_item=False,
incl_descendants=incl_ref_descendants,
aggregate=aggregate_refs,
aggregate_kwargs=aggregate_kwargs,
allow_empty=True,
wrap=True,
)
for d in ref_api_asset:
refname = self.parse_link_refname(d["link"])
if refname is not None:
refnames_aggregated[refname] = aggregate_refs
if main_ref_api_asset is None:
main_ref_api_asset = ref_api_asset
else:
main_ref_api_asset += ref_api_asset
incl_refs -= 1
if main_ref_api_asset is not None:
api_asset += main_ref_api_asset
aggregated_refnames_set = set()
for refname, aggregated in refnames_aggregated.items():
if aggregated:
aggregated_refnames_set.add(refname)
delete_indices = []
for i, d in enumerate(api_asset):
refname = self.parse_link_refname(d["link"])
if refname is not None:
if not refnames_aggregated[refname] and refname in aggregated_refnames_set:
delete_indices.append(i)
continue
while refname:
refname = ".".join(refname.split(".")[:-1])
if refname in aggregated_refnames_set:
break
if refname:
delete_indices.append(i)
if len(delete_indices) > 0:
api_asset.delete_items(delete_indices, inplace=True)
if topo_sort:
from graphlib import TopologicalSorter
refname_topo_graph = defaultdict(set)
refname_topo_sorter = TopologicalSorter(refname_topo_graph)
for parent_node, child_nodes in refname_graph.items():
for child_node in child_nodes:
refname_topo_sorter.add(child_node, parent_node)
refname_topo_order = refname_topo_sorter.static_order()
refname_indices = {refname: [] for refname in refname_topo_order}
remaining_indices = []
for i, d in enumerate(api_asset):
refname = self.parse_link_refname(d["link"])
if refname is not None:
while refname not in refname_indices:
if not refname:
break
refname = ".".join(refname.split(".")[:-1])
if refname:
refname_indices[refname].append(i)
else:
remaining_indices.append(i)
else:
remaining_indices.append(i)
get_indices = [i for v in refname_indices.values() for i in v] + remaining_indices
api_asset = api_asset.get_items(get_indices)
if return_refname_graph:
return api_asset, refname_graph
return api_asset
def find_obj_docs(
self,
obj: tp.MaybeList,
*,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
incl_pages: tp.Optional[tp.MaybeIterable[str]] = None,
excl_pages: tp.Optional[tp.MaybeIterable[str]] = None,
page_find_mode: tp.Optional[str] = None,
up_aggregate: tp.Optional[bool] = None,
up_aggregate_th: tp.Union[None, int, float] = None,
up_aggregate_pages: tp.Optional[bool] = None,
aggregate: tp.Optional[bool] = None,
aggregate_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybePagesAsset:
"""Find documentation relevant to object(s).
If a link matches one of the links or link parts in `incl_pages`, it will be included,
otherwise, it will be excluded if `incl_pages` is not empty. If a link matches one of the links
or link parts in `excl_pages`, it will be excluded, otherwise, it will be included. Matching is
done using `vectorbtpro.utils.search_.find` with `page_find_mode` used as `mode`.
For example, using `excl_pages=["release-notes"]` won't search in release notes.
If `up_aggregate` is True, will aggregate each set of headings into their parent if their number
is greater than some threshold `up_aggregate_th`, which depends on the total number of headings
in the parent. It can be an integer for absolute number or float for relative number.
For example, `up_aggregate_th=2/3` means this method must find 2 headings out of 3 in order to
replace it by the full parent heading/page. If `up_aggregate_pages` is True, does the same
to pages. For example, if 2 tutorial pages out of 3 are matched, the whole tutorial series is used.
If `aggregate` is True, aggregates any descendant headings into pages for this object
and all base classes/attributes using `PagesAsset.aggregate_links`.
Uses `PagesAsset.find_obj_mentions`."""
incl_pages = self.resolve_setting(incl_pages, "incl_pages")
excl_pages = self.resolve_setting(excl_pages, "excl_pages")
page_find_mode = self.resolve_setting(page_find_mode, "page_find_mode")
up_aggregate = self.resolve_setting(up_aggregate, "up_aggregate")
up_aggregate_th = self.resolve_setting(up_aggregate_th, "up_aggregate_th")
up_aggregate_pages = self.resolve_setting(up_aggregate_pages, "up_aggregate_pages")
aggregate = self.resolve_setting(aggregate, "aggregate")
if incl_pages is None:
incl_pages = ()
elif isinstance(incl_pages, str):
incl_pages = (incl_pages,)
if excl_pages is None:
excl_pages = ()
elif isinstance(excl_pages, str):
excl_pages = (excl_pages,)
def _filter_func(x):
if "link" not in x:
return False
if "/api/" in x["link"]:
return False
if excl_pages:
for page in excl_pages:
if find(page, x["link"], mode=page_find_mode):
return False
if incl_pages:
for page in incl_pages:
if find(page, x["link"], mode=page_find_mode):
return True
return False
return True
docs_asset = self.filter(_filter_func)
mentions_asset = docs_asset.find_obj_mentions(
obj,
attr=attr,
module=module,
resolve=resolve,
**kwargs,
)
if (
isinstance(mentions_asset, PagesAsset)
and len(mentions_asset) > 0
and isinstance(mentions_asset[0], dict)
and "link" in mentions_asset[0]
):
if up_aggregate:
link_map = {d["link"]: dict(d) for d in docs_asset.data}
new_links = {d["link"] for d in mentions_asset}
while True:
parent_map = defaultdict(list)
without_parent = set()
for link in new_links:
if link_map[link]["parent"] is not None:
parent_map[link_map[link]["parent"]].append(link)
else:
without_parent.add(link)
_new_links = set()
for parent, children in parent_map.items():
headings = set()
non_headings = set()
for child in children:
if link_map[child]["type"].startswith("heading"):
headings.add(child)
else:
non_headings.add(child)
if up_aggregate_pages:
_children = children
else:
_children = headings
if checks.is_float(up_aggregate_th) and 0 <= abs(up_aggregate_th) <= 1:
_up_aggregate_th = int(up_aggregate_th * len(link_map[parent]["children"]))
elif checks.is_number(up_aggregate_th):
if checks.is_float(up_aggregate_th) and not up_aggregate_th.is_integer():
raise TypeError(f"Up-aggregation threshold ({up_aggregate_th}) must be between 0 and 1")
_up_aggregate_th = int(up_aggregate_th)
else:
raise TypeError(f"Up-aggregation threshold must be a number")
if 0 < len(_children) >= _up_aggregate_th:
_new_links.add(parent)
else:
_new_links |= headings
_new_links |= non_headings
if _new_links == new_links:
break
new_links = _new_links | without_parent
return docs_asset.find_page(
list(new_links),
single_item=False,
aggregate=aggregate,
aggregate_kwargs=aggregate_kwargs,
)
if aggregate:
return docs_asset.aggregate_links(
[d["link"] for d in mentions_asset],
aggregate_kwargs=aggregate_kwargs,
)
return mentions_asset
def browse(
self,
entry_link: tp.Optional[str] = None,
descendants_only: bool = False,
aggregate: bool = False,
aggregate_kwargs: tp.KwargsLike = None,
**kwargs,
) -> Path:
new_instance = self
if entry_link is not None and entry_link != "/" and descendants_only:
new_instance = new_instance.select_descendants(entry_link, incl_link=True)
if aggregate:
if aggregate_kwargs is None:
aggregate_kwargs = {}
new_instance = new_instance.aggregate(**aggregate_kwargs)
return VBTAsset.browse(new_instance, entry_link=entry_link, **kwargs)
def display(
self,
link: tp.Optional[str] = None,
aggregate: bool = False,
aggregate_kwargs: tp.KwargsLike = None,
**kwargs,
) -> Path:
new_instance = self
if link is not None:
new_instance = new_instance.find_page(
link,
aggregate=aggregate,
aggregate_kwargs=aggregate_kwargs,
)
elif aggregate:
if aggregate_kwargs is None:
aggregate_kwargs = {}
new_instance = new_instance.aggregate(**aggregate_kwargs)
return VBTAsset.display(new_instance, **kwargs)
def aggregate(
self: PagesAssetT,
append_obj_type: tp.Optional[bool] = None,
append_github_link: tp.Optional[bool] = None,
) -> PagesAssetT:
"""Aggregate pages.
Content of each heading will be converted into markdown and concatenated into the content
of the parent heading or page. Only regular pages and headings without parents will be left.
If `append_obj_type` is True, will also append object type to the heading name.
If `append_github_link` is True, will also append GitHub link to the heading name."""
append_obj_type = self.resolve_setting(append_obj_type, "append_obj_type")
append_github_link = self.resolve_setting(append_github_link, "append_github_link")
link_map = {d["link"]: dict(d) for d in self.data}
top_parents = self.top_parent_links
aggregated_links = set()
def _aggregate_content(link):
node = link_map[link]
content = node["content"]
if content is None:
content = ""
if node["type"].startswith("heading"):
level = int(node["type"].split(" ")[1])
heading_markdown = "#" * level + " " + node["name"]
if append_obj_type and node.get("obj_type", None) is not None:
heading_markdown += f" | {node['obj_type']}"
if append_github_link and node.get("github_link", None) is not None:
heading_markdown += f" | [source]({node['github_link']})"
if content == "":
content = heading_markdown
else:
content = f"{heading_markdown}\n\n{content}"
children = list(node["children"])
for child in list(children):
if child in link_map:
child_node = link_map[child]
child_content = _aggregate_content(child)
if child_node["type"].startswith("heading"):
if child_content.startswith("# "):
content = child_content
else:
content += f"\n\n{child_content}"
children.remove(child)
aggregated_links.add(child)
if content != "":
node["content"] = content
node["children"] = children
return content
for top_parent in top_parents:
_aggregate_content(top_parent)
new_data = [link_map[link] for link in link_map if link not in aggregated_links]
return self.replace(data=new_data)
def select_parent(self: PagesAssetT, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT:
"""Select the parent page of a link."""
d = self.find_page(link, wrap=False, **kwargs)
link_map = {d["link"]: dict(d) for d in self.data}
new_data = []
if incl_link:
new_data.append(d)
if d.get("parent", None):
if d["parent"] in link_map:
new_data.append(link_map[d["parent"]])
return self.replace(data=new_data, single_item=True)
def select_children(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT:
"""Select the child pages of a link."""
d = self.find_page(link, wrap=False, **kwargs)
link_map = {d["link"]: dict(d) for d in self.data}
new_data = []
if incl_link:
new_data.append(d)
if d.get("children", []):
for child in d["children"]:
if child in link_map:
new_data.append(link_map[child])
return self.replace(data=new_data, single_item=False)
def select_siblings(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT:
"""Select the sibling pages of a link."""
d = self.find_page(link, wrap=False, **kwargs)
link_map = {d["link"]: dict(d) for d in self.data}
new_data = []
if incl_link:
new_data.append(d)
if d.get("parent", None):
if d["parent"] in link_map:
parent_d = link_map[d["parent"]]
if parent_d.get("children", []):
for child in parent_d["children"]:
if incl_link or child != d["link"]:
if child in link_map:
new_data.append(link_map[child])
return self.replace(data=new_data, single_item=False)
def select_descendants(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT:
"""Select all descendant pages of a link."""
d = self.find_page(link, wrap=False, **kwargs)
link_map = {d["link"]: dict(d) for d in self.data}
new_data = []
if incl_link:
new_data.append(d)
descendants = set()
stack = [d]
while stack:
d = stack.pop()
children = d.get("children", [])
for child in children:
if child in link_map and child not in descendants:
descendants.add(child)
new_data.append(link_map[child])
stack.append(link_map[child])
return self.replace(data=new_data, single_item=False)
def select_branch(self, link: str, **kwargs) -> PagesAssetT:
"""Select all descendant pages of a link including the link."""
return self.select_descendants(link, incl_link=True, **kwargs)
def select_ancestors(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT:
"""Select all ancestor pages of a link."""
d = self.find_page(link, wrap=False, **kwargs)
link_map = {d["link"]: dict(d) for d in self.data}
new_data = []
if incl_link:
new_data.append(d)
ancestors = set()
parent = d.get("parent", None)
while parent and parent in link_map:
if parent in ancestors:
break
ancestors.add(parent)
new_data.append(link_map[parent])
parent = link_map[parent].get("parent", None)
return self.replace(data=new_data, single_item=False)
def select_parent_page(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT:
"""Select parent page."""
d = self.find_page(link, wrap=False, **kwargs)
link_map = {d["link"]: dict(d) for d in self.data}
new_data = []
if incl_link:
new_data.append(d)
ancestors = set()
parent = d.get("parent", None)
while parent and parent in link_map:
if parent in ancestors:
break
ancestors.add(parent)
new_data.append(link_map[parent])
if link_map[parent]["type"] == "page":
break
parent = link_map[parent].get("parent", None)
return self.replace(data=new_data, single_item=False)
def select_descendant_headings(self, link: str, incl_link: bool = False, **kwargs) -> PagesAssetT:
"""Select descendant headings."""
d = self.find_page(link, wrap=False, **kwargs)
link_map = {d["link"]: dict(d) for d in self.data}
new_data = []
if incl_link:
new_data.append(d)
descendants = set()
stack = [d]
while stack:
d = stack.pop()
children = d.get("children", [])
for child in children:
if child in link_map and child not in descendants:
if link_map[child]["type"].startswith("heading"):
descendants.add(child)
new_data.append(link_map[child])
stack.append(link_map[child])
return self.replace(data=new_data, single_item=False)
def print_site_schema(
self,
append_type: bool = False,
append_obj_type: bool = False,
structure_fragments: bool = True,
split_fragments: bool = True,
**dir_tree_kwargs,
) -> None:
"""Print site schema.
If `structure_fragments` is True, builds a hierarchy of fragments. Otherwise,
displays them on the same level.
If `split_fragments` is True, displays fragments as continuation of their parents.
Otherwise, displays them in full length.
Keyword arguments are split between `KnowledgeAsset.describe` and
`vectorbtpro.utils.path_.dir_tree_from_paths`."""
link_map = {d["link"]: dict(d) for d in self.data}
links = []
for link, d in link_map.items():
if not structure_fragments:
links.append(link)
continue
x = d
link_base = None
link_fragments = []
while x["type"].startswith("heading") and "#" in x["link"]:
link_parts = x["link"].split("#")
if link_base is None:
link_base = link_parts[0]
link_fragments.append("#" + link_parts[1])
if not x.get("parent", None) or x["parent"] not in link_map:
if x["type"].startswith("heading"):
level = int(x["type"].split()[1])
for i in range(level - 1):
link_fragments.append("?")
break
x = link_map[x["parent"]]
if link_base is None:
links.append(link)
else:
if split_fragments and len(link_fragments) > 1:
link_fragments = link_fragments[::-1]
new_link_fragments = [link_fragments[0]]
for i in range(1, len(link_fragments)):
link_fragment1 = link_fragments[i - 1]
link_fragment2 = link_fragments[i]
if link_fragment2.startswith(link_fragment1 + "."):
new_link_fragments.append("." + link_fragment2[len(link_fragment1 + ".") :])
else:
new_link_fragments.append(link_fragment2)
link_fragments = new_link_fragments
links.append(link_base + "/".join(link_fragments))
paths = self.links_to_paths(links, allow_fragments=not structure_fragments)
path_names = []
for i, d in enumerate(link_map.values()):
path_name = paths[i].name
brackets = []
if append_type:
brackets.append(d["type"])
if append_obj_type and d["obj_type"]:
brackets.append(d["obj_type"])
if brackets:
path_name += f" [{', '.join(brackets)}]"
path_names.append(path_name)
if "root_name" not in dir_tree_kwargs:
root_name = get_common_prefix(link_map.keys())
if not root_name:
root_name = "/"
dir_tree_kwargs["root_name"] = root_name
if "sort" not in dir_tree_kwargs:
dir_tree_kwargs["sort"] = False
if "path_names" not in dir_tree_kwargs:
dir_tree_kwargs["path_names"] = path_names
if "length_limit" not in dir_tree_kwargs:
dir_tree_kwargs["length_limit"] = None
print(dir_tree_from_paths(paths, **dir_tree_kwargs))
MessagesAssetT = tp.TypeVar("MessagesAssetT", bound="MessagesAsset")
class MessagesAsset(VBTAsset):
"""Class for working with Discord messages.
Each message has the following fields:
link: URL of the message, such as "https://discord.com/channels/918629562441695344/919715148896301067/923327319882485851"
block: URL of the first message in the block. A block is a bunch of messages of the same author
that either reference a message of another author, or don't reference any message at all.
thread: URL of the first message in the thread. A thread is a bunch of blocks that reference each other
in a chain, such as questions, answers, follow-up questions, etc.
reference: URL of the message that the message references. Can be None.
replies: List of URLs of the messages that reference the message
channel: Channel of the message, such as "support"
timestamp: Timestamp of the message, such as "2024-01-01 00:00:00"
author: Author of the message, such as "@polakowo"
content: String content of the message
mentions: List of Discord usernames that this message mentions, such as ["@polakowo"]
attachments: List of attachments. Each attachment has two fields: "file_name", such as "some_image.png",
and "content" containing the string content extracted from the file.
reactions: Total number of reactions that this message has received
For defaults, see `assets.messages` in `vectorbtpro._settings.knowledge`."""
_settings_path: tp.SettingsPath = "knowledge.assets.messages"
def latest_first(self, **kwargs) -> tp.MaybeMessagesAsset:
"""Sort by timestamp in descending order."""
return self.sort(keys=self.get("timestamp"), ascending=False, **kwargs)
def aggregate_messages(
self: MessagesAssetT,
minimize_metadata: tp.Optional[bool] = None,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeMessagesAsset:
"""Aggregate attachments by message.
For keyword arguments, see `MessagesAsset.to_markdown`.
Uses `MessagesAsset.apply` on `vectorbtpro.utils.knowledge.custom_asset_funcs.AggMessageAssetFunc`."""
return self.apply(
"agg_message",
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
to_markdown_kwargs=to_markdown_kwargs,
**kwargs,
)
def aggregate_blocks(
self: MessagesAssetT,
collect_kwargs: tp.KwargsLike = None,
aggregate_fields: tp.Union[None, bool, tp.MaybeIterable[str]] = None,
parent_links_only: tp.Optional[bool] = None,
minimize_metadata: tp.Optional[bool] = None,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeMessagesAsset:
"""Aggregate messages by block.
First, uses `MessagesAsset.reduce` on `vectorbtpro.utils.knowledge.base_asset_funcs.CollectAssetFunc`
to collect data items by the field "block". Keyword arguments in `collect_kwargs` are passed here.
Argument `uniform_groups` is True by default. Then, uses `MessagesAsset.apply` on
`vectorbtpro.utils.knowledge.custom_asset_funcs.AggBlockAssetFunc` to aggregate each collected data item.
Use `aggregate_fields` to provide a set of fields to be aggregated rather than used in child metadata.
It can be True to aggregate all lists and False to aggregate none.
If `parent_links_only` is True, doesn't include links in the metadata of each message.
For other keyword arguments, see `MessagesAsset.to_markdown`."""
if collect_kwargs is None:
collect_kwargs = {}
if "uniform_groups" not in collect_kwargs:
collect_kwargs["uniform_groups"] = True
instance = self.collect(by="block", wrap=True, **collect_kwargs)
return instance.apply(
"agg_block",
aggregate_fields=aggregate_fields,
parent_links_only=parent_links_only,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
to_markdown_kwargs=to_markdown_kwargs,
link_map={d["link"]: dict(d) for d in self.data},
**kwargs,
)
def aggregate_threads(
self: MessagesAssetT,
collect_kwargs: tp.KwargsLike = None,
aggregate_fields: tp.Union[None, bool, tp.MaybeIterable[str]] = None,
parent_links_only: tp.Optional[bool] = None,
minimize_metadata: tp.Optional[bool] = None,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeMessagesAsset:
"""Aggregate messages by thread.
Same as `MessagesAsset.aggregate_blocks` but for threads.
Uses `vectorbtpro.utils.knowledge.custom_asset_funcs.AggThreadAssetFunc`."""
if collect_kwargs is None:
collect_kwargs = {}
if "uniform_groups" not in collect_kwargs:
collect_kwargs["uniform_groups"] = True
instance = self.collect(by="thread", wrap=True, **collect_kwargs)
return instance.apply(
"agg_thread",
aggregate_fields=aggregate_fields,
parent_links_only=parent_links_only,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
to_markdown_kwargs=to_markdown_kwargs,
link_map={d["link"]: dict(d) for d in self.data},
**kwargs,
)
def aggregate_channels(
self: MessagesAssetT,
collect_kwargs: tp.KwargsLike = None,
aggregate_fields: tp.Union[None, bool, tp.MaybeIterable[str]] = None,
parent_links_only: tp.Optional[bool] = None,
minimize_metadata: tp.Optional[bool] = None,
minimize_keys: tp.Optional[tp.MaybeList[tp.PathLikeKey]] = None,
clean_metadata: tp.Optional[bool] = None,
clean_metadata_kwargs: tp.KwargsLike = None,
dump_metadata_kwargs: tp.KwargsLike = None,
to_markdown_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeMessagesAsset:
"""Aggregate messages by channel.
Same as `MessagesAsset.aggregate_threads` but for channels.
Uses `vectorbtpro.utils.knowledge.custom_asset_funcs.AggChannelAssetFunc`."""
if collect_kwargs is None:
collect_kwargs = {}
if "uniform_groups" not in collect_kwargs:
collect_kwargs["uniform_groups"] = True
instance = self.collect(by="channel", wrap=True, **collect_kwargs)
return instance.apply(
"agg_channel",
aggregate_fields=aggregate_fields,
parent_links_only=parent_links_only,
minimize_metadata=minimize_metadata,
minimize_keys=minimize_keys,
clean_metadata=clean_metadata,
clean_metadata_kwargs=clean_metadata_kwargs,
dump_metadata_kwargs=dump_metadata_kwargs,
to_markdown_kwargs=to_markdown_kwargs,
link_map={d["link"]: dict(d) for d in self.data},
**kwargs,
)
@property
def lowest_aggregate_by(self) -> tp.Optional[str]:
"""Get the lowest level that aggregates all messages."""
try:
if self.get("attachments"):
return "message"
except KeyError:
pass
try:
if len(set(self.get("block"))) == 1:
return "block"
except KeyError:
pass
try:
if len(set(self.get("thread"))) == 1:
return "thread"
except KeyError:
pass
try:
if len(set(self.get("channel"))) == 1:
return "channel"
except KeyError:
pass
@property
def highest_aggregate_by(self) -> tp.Optional[str]:
"""Get the highest level that aggregates all messages."""
try:
if len(set(self.get("channel"))) == 1:
return "channel"
except KeyError:
pass
try:
if len(set(self.get("thread"))) == 1:
return "thread"
except KeyError:
pass
try:
if len(set(self.get("block"))) == 1:
return "block"
except KeyError:
pass
try:
if self.get("attachments"):
return "message"
except KeyError:
pass
def aggregate(self, by: str = "lowest", **kwargs) -> tp.MaybeMessagesAsset:
"""Aggregate by "message" (attachments), "block", "thread", or "channel".
If `by` is None, uses `MessagesAsset.lowest_aggregate_by`."""
if by.lower() == "lowest":
by = self.lowest_aggregate_by
elif by.lower() == "highest":
by = self.highest_aggregate_by
if by is None:
raise ValueError("Must provide by")
if not by.lower().endswith("s"):
by += "s"
return getattr(self, "aggregate_" + by.lower())(**kwargs)
def select_reference(self: MessagesAssetT, link: str, **kwargs) -> MessagesAssetT:
"""Select the reference message."""
d = self.find_link(link, wrap=False, **kwargs)
reference = d.get("reference", None)
new_data = []
if reference:
for d2 in self.data:
if d2["reference"] == reference:
new_data.append(d2)
break
return self.replace(data=new_data, single_item=True)
def select_replies(self: MessagesAssetT, link: str, **kwargs) -> MessagesAssetT:
"""Select the reply messages."""
d = self.find_link(link, wrap=False, **kwargs)
replies = d.get("replies", [])
new_data = []
if replies:
reply_data = {reply: None for reply in replies}
replies_found = 0
for d2 in self.data:
if d2["link"] in reply_data:
reply_data[d2["link"]] = d2
replies_found += 1
if replies_found == len(replies):
break
new_data = list(reply_data.values())
return self.replace(data=new_data, single_item=True)
def select_block(self: MessagesAssetT, link: str, incl_link: bool = True, **kwargs) -> MessagesAssetT:
"""Select the messages that belong to the block of a link."""
d = self.find_link(link, wrap=False, **kwargs)
new_data = []
for d2 in self.data:
if d2["block"] == d["block"] and (incl_link or d2["link"] != d["link"]):
new_data.append(d2)
return self.replace(data=new_data, single_item=False)
def select_thread(self: MessagesAssetT, link: str, incl_link: bool = True, **kwargs) -> MessagesAssetT:
"""Select the messages that belong to the thread of a link."""
d = self.find_link(link, wrap=False, **kwargs)
new_data = []
for d2 in self.data:
if d2["thread"] == d["thread"] and (incl_link or d2["link"] != d["link"]):
new_data.append(d2)
return self.replace(data=new_data, single_item=False)
def select_channel(self: MessagesAssetT, link: str, incl_link: bool = True, **kwargs) -> MessagesAssetT:
"""Select the messages that belong to the channel of a link."""
d = self.find_link(link, wrap=False, **kwargs)
new_data = []
for d2 in self.data:
if d2["channel"] == d["channel"] and (incl_link or d2["link"] != d["link"]):
new_data.append(d2)
return self.replace(data=new_data, single_item=False)
def find_obj_messages(
self,
obj: tp.MaybeList,
*,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
**kwargs,
) -> tp.MaybeMessagesAsset:
"""Find messages relevant to object(s).
Uses `MessagesAsset.find_obj_mentions`."""
return self.find_obj_mentions(obj, attr=attr, module=module, resolve=resolve, **kwargs)
def is_obj_or_query_ref(obj_or_query: tp.MaybeList) -> bool:
"""Return whether `obj_or_query` is a reference to an object."""
if isinstance(obj_or_query, str):
return all(segment.isidentifier() for segment in obj_or_query.split("."))
return True
def find_api(
obj_or_query: tp.Optional[tp.MaybeList] = None,
*,
as_query: tp.Optional[bool] = None,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
pages_asset: tp.Optional[tp.MaybeType[PagesAssetT]] = None,
pull_kwargs: tp.KwargsLike = None,
aggregate: bool = False,
aggregate_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybePagesAsset:
"""Find API pages and headings relevant to object(s) or a query.
If `obj_or_query` is None, returns all API pages. If it's a reference to an object, uses
`PagesAsset.find_obj_api`. Otherwise, uses `PagesAsset.rank`."""
if pages_asset is None:
pages_asset = PagesAsset
if isinstance(pages_asset, type):
checks.assert_subclass_of(pages_asset, PagesAsset, arg_name="pages_asset")
if pull_kwargs is None:
pull_kwargs = {}
pages_asset = pages_asset.pull(**pull_kwargs)
else:
checks.assert_instance_of(pages_asset, PagesAsset, arg_name="pages_asset")
if aggregate:
if aggregate_kwargs is None:
aggregate_kwargs = {}
pages_asset = pages_asset.aggregate(**aggregate_kwargs)
if as_query is None:
as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query)
if obj_or_query is not None and not as_query:
return pages_asset.find_obj_api(obj_or_query, attr=attr, module=module, resolve=resolve, **kwargs)
pages_asset = pages_asset.filter(lambda x: "link" in x and "/api/" in x["link"])
if obj_or_query is None:
return pages_asset
return pages_asset.rank(obj_or_query, **kwargs)
def find_docs(
obj_or_query: tp.Optional[tp.MaybeList] = None,
*,
as_query: tp.Optional[bool] = None,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
pages_asset: tp.Optional[tp.MaybeType[PagesAssetT]] = None,
pull_kwargs: tp.KwargsLike = None,
aggregate: bool = False,
aggregate_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybePagesAsset:
"""Find documentation pages and headings relevant to object(s) or a query.
If `obj_or_query` is None, returns all documentation pages. If it's a reference to an object, uses
`PagesAsset.find_obj_docs`. Otherwise, uses `PagesAsset.rank`."""
if pages_asset is None:
pages_asset = PagesAsset
if isinstance(pages_asset, type):
checks.assert_subclass_of(pages_asset, PagesAsset, arg_name="pages_asset")
if pull_kwargs is None:
pull_kwargs = {}
pages_asset = pages_asset.pull(**pull_kwargs)
else:
checks.assert_instance_of(pages_asset, PagesAsset, arg_name="pages_asset")
if aggregate:
if aggregate_kwargs is None:
aggregate_kwargs = {}
pages_asset = pages_asset.aggregate(**aggregate_kwargs)
if as_query is None:
as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query)
if obj_or_query is not None and not as_query:
return pages_asset.find_obj_docs(obj_or_query, attr=attr, module=module, resolve=resolve, **kwargs)
pages_asset = pages_asset.filter(lambda x: "link" in x and "/api/" not in x["link"])
if obj_or_query is None:
return pages_asset
return pages_asset.rank(obj_or_query, **kwargs)
def find_messages(
obj_or_query: tp.Optional[tp.MaybeList] = None,
*,
as_query: tp.Optional[bool] = None,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
messages_asset: tp.Optional[tp.MaybeType[MessagesAssetT]] = None,
pull_kwargs: tp.KwargsLike = None,
aggregate: tp.Union[bool, str] = "messages",
aggregate_kwargs: tp.KwargsLike = None,
latest_first: bool = False,
shuffle: bool = False,
**kwargs,
) -> tp.MaybeMessagesAsset:
"""Find messages relevant to object(s) or a query.
If `obj_or_query` is None, returns all messages. If it's a reference to an object, uses
`MessagesAsset.find_obj_messages`. Otherwise, uses `MessagesAsset.rank`."""
if messages_asset is None:
messages_asset = MessagesAsset
if isinstance(messages_asset, type):
checks.assert_subclass_of(messages_asset, MessagesAsset, arg_name="messages_asset")
if pull_kwargs is None:
pull_kwargs = {}
messages_asset = messages_asset.pull(**pull_kwargs)
else:
checks.assert_instance_of(messages_asset, MessagesAsset, arg_name="messages_asset")
if aggregate:
if aggregate_kwargs is None:
aggregate_kwargs = {}
if isinstance(aggregate, str) and "by" not in aggregate_kwargs:
aggregate_kwargs["by"] = aggregate
messages_asset = messages_asset.aggregate(**aggregate_kwargs)
if latest_first:
messages_asset = messages_asset.latest_first()
elif shuffle:
messages_asset = messages_asset.shuffle()
if as_query is None:
as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query)
if obj_or_query is not None and not as_query:
return messages_asset.find_obj_messages(obj_or_query, attr=attr, module=module, resolve=resolve, **kwargs)
if obj_or_query is None:
return messages_asset
return messages_asset.rank(obj_or_query, **kwargs)
def find_examples(
obj_or_query: tp.Optional[tp.MaybeList] = None,
*,
as_query: tp.Optional[bool] = None,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
as_code: bool = True,
return_type: tp.Optional[str] = "field",
pages_asset: tp.Optional[tp.MaybeType[PagesAssetT]] = None,
messages_asset: tp.Optional[tp.MaybeType[MessagesAssetT]] = None,
pull_kwargs: tp.KwargsLike = None,
aggregate_pages: bool = False,
aggregate_pages_kwargs: tp.KwargsLike = None,
aggregate_messages: tp.Union[bool, str] = "messages",
aggregate_messages_kwargs: tp.KwargsLike = None,
latest_messages_first: bool = False,
shuffle_messages: bool = False,
find_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeVBTAsset:
"""Find (code) examples relevant to object(s) or a query.
If `obj_or_query` is None, returns all examples with `VBTAsset.find_code` or `VBTAsset.find`.
If it's a reference to an object, uses `VBTAsset.find_obj_mentions`. Otherwise, uses `VBTAsset.find_code`
or `VBTAsset.find` and then `VBTAsset.rank`. Keyword arguments are distributed among these methods
automatically, unless some keys cannot be found in both signatures. In such a case, the key will be
used for ranking. If this is not wanted, specify `find_kwargs`.
By default, extracts code with text. Use `return_type="match"` to extract code without text,
or, for instance, `return_type="item"` to also get links."""
if pages_asset is None:
pages_asset = PagesAsset
if isinstance(pages_asset, type):
checks.assert_subclass_of(pages_asset, PagesAsset, arg_name="pages_asset")
if pull_kwargs is None:
pull_kwargs = {}
pages_asset = pages_asset.pull(**pull_kwargs)
else:
checks.assert_instance_of(pages_asset, PagesAsset, arg_name="pages_asset")
if aggregate_pages:
if aggregate_pages_kwargs is None:
aggregate_pages_kwargs = {}
pages_asset = pages_asset.aggregate(**aggregate_pages_kwargs)
if messages_asset is None:
messages_asset = MessagesAsset
if isinstance(messages_asset, type):
checks.assert_subclass_of(messages_asset, MessagesAsset, arg_name="messages_asset")
if pull_kwargs is None:
pull_kwargs = {}
messages_asset = messages_asset.pull(**pull_kwargs)
else:
checks.assert_instance_of(messages_asset, MessagesAsset, arg_name="messages_asset")
if aggregate_messages:
if aggregate_messages_kwargs is None:
aggregate_messages_kwargs = {}
if isinstance(aggregate_messages, str) and "by" not in aggregate_messages_kwargs:
aggregate_messages_kwargs["by"] = aggregate_messages
messages_asset = messages_asset.aggregate(**aggregate_messages_kwargs)
if latest_messages_first:
messages_asset = messages_asset.latest_first()
elif shuffle_messages:
messages_asset = messages_asset.shuffle()
combined_asset = pages_asset + messages_asset
if as_query is None:
as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query)
if find_kwargs is None:
find_kwargs = {}
else:
find_kwargs = dict(find_kwargs)
find_kwargs["return_type"] = return_type
if obj_or_query is not None and not as_query:
return combined_asset.find_obj_mentions(
obj_or_query,
attr=attr,
module=module,
resolve=resolve,
as_code=as_code,
**find_kwargs,
**kwargs,
)
if as_code:
method = combined_asset.find_code
else:
method = combined_asset.find
if obj_or_query is None:
return method(**find_kwargs, **kwargs)
find_code_arg_names = set(get_func_arg_names(method))
rank_kwargs = {}
for k, v in kwargs.items():
if k in find_code_arg_names:
if k not in find_kwargs:
find_kwargs[k] = v
else:
rank_kwargs[k] = v
return method(**find_kwargs).rank(obj_or_query, **rank_kwargs)
def find_assets(
obj_or_query: tp.Optional[tp.MaybeList] = None,
*,
as_query: tp.Optional[bool] = None,
attr: tp.Optional[str] = None,
module: tp.Union[None, str, ModuleType] = None,
resolve: bool = True,
asset_names: tp.Optional[tp.MaybeIterable[str]] = None,
pages_asset: tp.Optional[tp.MaybeType[PagesAssetT]] = None,
messages_asset: tp.Optional[tp.MaybeType[MessagesAssetT]] = None,
pull_kwargs: tp.KwargsLike = None,
aggregate_pages: bool = False,
aggregate_pages_kwargs: tp.KwargsLike = None,
aggregate_messages: tp.Union[bool, str] = "messages",
aggregate_messages_kwargs: tp.KwargsLike = None,
latest_messages_first: bool = False,
shuffle_messages: bool = False,
api_kwargs: tp.KwargsLike = None,
docs_kwargs: tp.KwargsLike = None,
messages_kwargs: tp.KwargsLike = None,
examples_kwargs: tp.KwargsLike = None,
minimize: tp.Optional[bool] = None,
minimize_pages: tp.Optional[bool] = None,
minimize_messages: tp.Optional[bool] = None,
minimize_kwargs: tp.KwargsLike = None,
minimize_pages_kwargs: tp.KwargsLike = None,
minimize_messages_kwargs: tp.KwargsLike = None,
combine: bool = True,
combine_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeDict[tp.VBTAsset]:
"""Find all assets relevant to object(s) or a query.
Argument `asset_names` can be a list of asset names in any order. It defaults to "api", "docs",
and "messages", It can also include ellipsis (`...`). For example, `["messages", ...]` puts
"messages" at the beginning and all other assets in their usual order at the end.
The following asset names are supported:
* "api": `find_api` with `api_kwargs`
* "docs": `find_docs` with `docs_kwargs`
* "messages": `find_messages` with `messages_kwargs`
* "examples": `find_examples` with `examples_kwargs`
* "all": All of the above
!!! note
Examples usually overlap with other assets, thus they are excluded by default.
Set `combine` to True to combine all assets into a single asset. Uses
`vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.combine` with `combine_kwargs`.
If `obj_or_query` is a query, will rank the combined asset. Otherwise, will rank each individual asset.
Set `minimize` to True (or `minimize_pages` for pages and `minimize_messages` for messages)
in order to minimize to remove fields that aren't relevant for chatting.
It defaults to True if `combine` is True, otherwise, it defaults to False. Uses `VBTAsset.minimize`
with `minimize_kwargs`, `PagesAsset.minimize` with `minimize_pages_kwargs`, and `MessagesAsset.minimize`
with `minimize_messages_kwargs`. Arguments `minimize_pages_kwargs` and `minimize_messages_kwargs`
are merged over `minimize_kwargs`.
Keyword arguments are passed to all functions (except for `find_api` when `obj_or_query` is an object
since it doesn't share common arguments with other three functions), unless `combine` and `as_query`
are both True; in this case they are passed to `VBTAsset.rank`. Use specialized arguments like
`api_kwargs` to provide keyword arguments to the respective function."""
if pages_asset is None:
pages_asset = PagesAsset
if isinstance(pages_asset, type):
checks.assert_subclass_of(pages_asset, PagesAsset, arg_name="pages_asset")
if pull_kwargs is None:
pull_kwargs = {}
pages_asset = pages_asset.pull(**pull_kwargs)
else:
checks.assert_instance_of(pages_asset, PagesAsset, arg_name="pages_asset")
if aggregate_pages:
if aggregate_pages_kwargs is None:
aggregate_pages_kwargs = {}
pages_asset = pages_asset.aggregate(**aggregate_pages_kwargs)
if messages_asset is None:
messages_asset = MessagesAsset
if isinstance(messages_asset, type):
checks.assert_subclass_of(messages_asset, MessagesAsset, arg_name="messages_asset")
if pull_kwargs is None:
pull_kwargs = {}
messages_asset = messages_asset.pull(**pull_kwargs)
else:
checks.assert_instance_of(messages_asset, MessagesAsset, arg_name="messages_asset")
if aggregate_messages:
if aggregate_messages_kwargs is None:
aggregate_messages_kwargs = {}
if isinstance(aggregate_messages, str) and "by" not in aggregate_messages_kwargs:
aggregate_messages_kwargs["by"] = aggregate_messages
messages_asset = messages_asset.aggregate(**aggregate_messages_kwargs)
if latest_messages_first:
messages_asset = messages_asset.latest_first()
elif shuffle_messages:
messages_asset = messages_asset.shuffle()
if as_query is None:
as_query = obj_or_query is not None and not is_obj_or_query_ref(obj_or_query)
if combine and as_query and obj_or_query is not None:
if api_kwargs is None:
api_kwargs = {}
if docs_kwargs is None:
docs_kwargs = {}
if messages_kwargs is None:
messages_kwargs = {}
if examples_kwargs is None:
examples_kwargs = {}
else:
if as_query:
api_kwargs = merge_dicts(kwargs, api_kwargs)
else:
if api_kwargs is None:
api_kwargs = {}
docs_kwargs = merge_dicts(kwargs, docs_kwargs)
messages_kwargs = merge_dicts(kwargs, messages_kwargs)
examples_kwargs = merge_dicts(kwargs, examples_kwargs)
asset_dict = {}
all_asset_names = ["api", "docs", "messages", "examples"]
if asset_names is not None:
if isinstance(asset_names, str) and asset_names.lower() == "all":
asset_names = all_asset_names
else:
if isinstance(asset_names, (str, type(Ellipsis))):
asset_names = [asset_names]
asset_keys = []
for asset_name in asset_names:
if asset_name is not Ellipsis:
asset_key = all_asset_names.index(asset_name.lower())
if asset_key == -1:
raise ValueError(f"Invalid asset name: '{asset_name}'")
asset_keys.append(asset_key)
else:
asset_keys.append(Ellipsis)
new_asset_names = reorder_list(all_asset_names, asset_keys, skip_missing=True)
if "examples" not in asset_names and "examples" in new_asset_names:
new_asset_names.remove("examples")
asset_names = new_asset_names
else:
asset_names = ["api", "docs", "messages"]
for asset_name in asset_names:
if asset_name == "api":
asset = find_api(
None if combine and as_query else obj_or_query,
as_query=as_query,
attr=attr,
module=module,
resolve=resolve,
pages_asset=pages_asset,
aggregate=False,
**api_kwargs,
)
if len(asset) > 0:
asset_dict[asset_name] = asset
elif asset_name == "docs":
asset = find_docs(
None if combine and as_query else obj_or_query,
as_query=as_query,
attr=attr,
module=module,
resolve=resolve,
pages_asset=pages_asset,
aggregate=False,
**docs_kwargs,
)
if len(asset) > 0:
asset_dict[asset_name] = asset
elif asset_name == "messages":
asset = find_messages(
None if combine and as_query else obj_or_query,
as_query=as_query,
attr=attr,
module=module,
resolve=resolve,
messages_asset=messages_asset,
aggregate=False,
latest_first=False,
**messages_kwargs,
)
if len(asset) > 0:
asset_dict[asset_name] = asset
elif asset_name == "examples":
if examples_kwargs is None:
examples_kwargs = {}
asset = find_examples(
None if combine and as_query else obj_or_query,
as_query=as_query,
attr=attr,
module=module,
resolve=resolve,
pages_asset=pages_asset,
messages_asset=messages_asset,
aggregate_messages=False,
aggregate_pages=False,
latest_messages_first=False,
**examples_kwargs,
)
if len(asset) > 0:
asset_dict[asset_name] = asset
if minimize is None:
minimize = combine and not as_query
if minimize:
if minimize_kwargs is None:
minimize_kwargs = {}
for k, v in asset_dict.items():
if (
isinstance(v, VBTAsset)
and not isinstance(v, (PagesAsset, MessagesAsset))
and len(v) > 0
and not isinstance(v[0], str)
):
asset_dict[k] = v.minimize(**minimize_kwargs)
if minimize_pages is None:
minimize_pages = minimize
if minimize_pages:
minimize_pages_kwargs = merge_dicts(minimize_kwargs, minimize_pages_kwargs)
for k, v in asset_dict.items():
if isinstance(v, PagesAsset) and len(v) > 0 and not isinstance(v[0], str):
asset_dict[k] = v.minimize(**minimize_pages_kwargs)
if minimize_messages is None:
minimize_messages = minimize
if minimize_messages:
minimize_messages_kwargs = merge_dicts(minimize_kwargs, minimize_messages_kwargs)
for k, v in asset_dict.items():
if isinstance(v, MessagesAsset) and len(v) > 0 and not isinstance(v[0], str):
asset_dict[k] = v.minimize(**minimize_messages_kwargs)
if combine:
if len(asset_dict) >= 2:
if combine_kwargs is None:
combine_kwargs = {}
combined_asset = VBTAsset.combine(*asset_dict.values(), **combine_kwargs)
elif len(asset_dict) == 1:
combined_asset = list(asset_dict.values())[0]
else:
combined_asset = VBTAsset()
if combined_asset and as_query and obj_or_query is not None:
combined_asset = combined_asset.rank(obj_or_query, **kwargs)
return combined_asset
return asset_dict
def chat_about(
obj: tp.MaybeList,
message: str,
chat_history: tp.ChatHistory = None,
*,
asset_names: tp.Optional[tp.MaybeIterable[str]] = "examples",
latest_messages_first: bool = True,
shuffle_messages: tp.Optional[bool] = None,
shuffle: tp.Optional[bool] = None,
find_assets_kwargs: tp.KwargsLike = None,
**kwargs,
) -> tp.MaybeChatOutput:
"""Chat about object(s).
By default, uses examples only.
Uses `find_assets` with `combine=True` and `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.chat`.
Keyword arguments are distributed among these two methods automatically, unless some keys cannot be
found in both signatures. In such a case, the key will be used for chatting. If this is not wanted,
specify the `find_assets`-related arguments explicitly with `find_assets_kwargs`.
If `shuffle` is True, shuffles the combined asset. By default, shuffles only messages (`shuffle=False`
and `shuffle_messages=True`). If `shuffle` is False, shuffles neither messages nor combined asset."""
if shuffle is not None:
if shuffle_messages is None:
shuffle_messages = False
else:
shuffle = False
if shuffle_messages is None:
shuffle_messages = True
find_arg_names = set(get_func_arg_names(find_assets))
if find_assets_kwargs is None:
find_assets_kwargs = {}
else:
find_assets_kwargs = dict(find_assets_kwargs)
chat_kwargs = {}
for k, v in kwargs.items():
if k in find_arg_names:
if k not in find_assets_kwargs:
find_assets_kwargs[k] = v
else:
chat_kwargs[k] = v
asset = find_assets(
obj,
as_query=False,
asset_names=asset_names,
combine=True,
latest_messages_first=latest_messages_first,
shuffle_messages=shuffle_messages,
**find_assets_kwargs,
)
if shuffle:
asset = asset.shuffle()
return asset.chat(message, chat_history, **chat_kwargs)
def search(
query: str,
cache_documents: bool = True,
cache_key: tp.Optional[str] = None,
asset_cache_manager: tp.Optional[tp.MaybeType[AssetCacheManager]] = None,
asset_cache_manager_kwargs: tp.KwargsLike = None,
aggregate_messages: tp.Union[bool, str] = "threads",
aggregate_messages_kwargs: tp.KwargsLike = None,
find_assets_kwargs: tp.KwargsLike = None,
display: tp.Union[bool, int] = 20,
display_kwargs: tp.KwargsLike = None,
silence_warnings: bool = False,
**kwargs,
) -> tp.Union[tp.MaybeVBTAsset, tp.Path]:
"""Search for a query.
By default, uses API, documentation, and messages.
Uses `find_assets` with `combine=True` and `vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.rank`.
Keyword arguments are distributed among these two methods automatically, unless some keys cannot be
found in both signatures. In such a case, the key will be used for ranking. If this is not wanted,
specify the `find_assets`-related arguments explicitly with `find_assets_kwargs`.
If `display` is True, displays the top results as static HTML pages with `VBTAsset.display`.
Pass an integer to display n top results. Will return the path to the temporary file.
Metadata when aggregating messages will be minimized by default.
If `cache_documents` is True, will use an asset cache manager to store the generated text documents
in a local and/or disk cache after conversion. Running the same method again will use the cached documents."""
find_arg_names = set(get_func_arg_names(find_assets))
if find_assets_kwargs is None:
find_assets_kwargs = {}
else:
find_assets_kwargs = dict(find_assets_kwargs)
rank_kwargs = {}
for k, v in kwargs.items():
if k in find_arg_names:
if k not in find_assets_kwargs:
find_assets_kwargs[k] = v
else:
rank_kwargs[k] = v
find_assets_kwargs["aggregate_messages"] = aggregate_messages
if aggregate_messages_kwargs is None:
aggregate_messages_kwargs = {}
else:
aggregate_messages_kwargs = dict(aggregate_messages_kwargs)
if "minimize_metadata" not in aggregate_messages_kwargs:
aggregate_messages_kwargs["minimize_metadata"] = True
find_assets_kwargs["aggregate_messages_kwargs"] = aggregate_messages_kwargs
if cache_documents:
if asset_cache_manager is None:
asset_cache_manager = AssetCacheManager
if asset_cache_manager_kwargs is None:
asset_cache_manager_kwargs = {}
if isinstance(asset_cache_manager, type):
checks.assert_subclass_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager")
asset_cache_manager = asset_cache_manager(**asset_cache_manager_kwargs)
else:
checks.assert_instance_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager")
if asset_cache_manager_kwargs:
asset_cache_manager = asset_cache_manager.replace(**asset_cache_manager_kwargs)
asset_cache_manager_kwargs = {}
if cache_key is None:
cache_key = asset_cache_manager.generate_cache_key(**find_assets_kwargs)
asset = asset_cache_manager.load_asset(cache_key)
if asset is None:
if not silence_warnings:
warn("Caching documents...")
silence_warnings = True
else:
asset = None
if asset is None:
asset = find_assets(None, as_query=True, **find_assets_kwargs)
found_asset = asset.rank(
query,
cache_documents=cache_documents,
cache_key=cache_key,
asset_cache_manager=asset_cache_manager,
asset_cache_manager_kwargs=asset_cache_manager_kwargs,
silence_warnings=silence_warnings,
**rank_kwargs,
)
if display:
if display_kwargs is None:
display_kwargs = {}
else:
display_kwargs = dict(display_kwargs)
if "title" not in display_kwargs:
display_kwargs["title"] = query
if isinstance(display, bool):
display_asset = found_asset
else:
display_asset = found_asset[:display]
return display_asset.display(**display_kwargs)
return found_asset
def chat(
query: str,
chat_history: tp.ChatHistory = None,
*,
cache_documents: bool = True,
cache_key: tp.Optional[str] = None,
asset_cache_manager: tp.Optional[tp.MaybeType[AssetCacheManager]] = None,
asset_cache_manager_kwargs: tp.KwargsLike = None,
aggregate_messages: tp.Union[bool, str] = "threads",
aggregate_messages_kwargs: tp.KwargsLike = None,
find_assets_kwargs: tp.KwargsLike = None,
rank: tp.Optional[bool] = True,
top_k: tp.TopKLike = "elbow",
min_top_k: tp.TopKLike = 20,
max_top_k: tp.TopKLike = 100,
cutoff: tp.Optional[float] = None,
return_chunks: tp.Optional[bool] = True,
rank_kwargs: tp.KwargsLike = None,
wrap_documents: tp.Optional[bool] = True,
silence_warnings: bool = False,
**kwargs,
) -> tp.MaybeChatOutput:
"""Chat about a query.
By default, uses API, documentation, and messages.
Uses `find_assets` with `obj_or_query=None`, `as_query=True`, and `combine=True`, and
`vectorbtpro.utils.knowledge.base_assets.KnowledgeAsset.chat`. Keyword arguments are distributed
among these two methods automatically, unless some keys cannot be found in both signatures.
In such a case, the key will be used for chatting. If this is not wanted, specify the `find_assets`-
related arguments explicitly with `find_assets_kwargs`.
Metadata when aggregating messages will be minimized by default.
If `cache_documents` is True, will use an asset cache manager to store the generated text documents
in a local and/or disk cache after conversion. Running the same method again will use the cached documents."""
find_arg_names = set(get_func_arg_names(find_assets))
if find_assets_kwargs is None:
find_assets_kwargs = {}
else:
find_assets_kwargs = dict(find_assets_kwargs)
chat_kwargs = {}
for k, v in kwargs.items():
if k in find_arg_names:
if k not in find_assets_kwargs:
find_assets_kwargs[k] = v
else:
chat_kwargs[k] = v
find_assets_kwargs["aggregate_messages"] = aggregate_messages
if aggregate_messages_kwargs is None:
aggregate_messages_kwargs = {}
else:
aggregate_messages_kwargs = dict(aggregate_messages_kwargs)
if "minimize_metadata" not in aggregate_messages_kwargs:
aggregate_messages_kwargs["minimize_metadata"] = True
find_assets_kwargs["aggregate_messages_kwargs"] = aggregate_messages_kwargs
if cache_documents:
if asset_cache_manager is None:
asset_cache_manager = AssetCacheManager
if asset_cache_manager_kwargs is None:
asset_cache_manager_kwargs = {}
if isinstance(asset_cache_manager, type):
checks.assert_subclass_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager")
asset_cache_manager = asset_cache_manager(**asset_cache_manager_kwargs)
else:
checks.assert_instance_of(asset_cache_manager, AssetCacheManager, "asset_cache_manager")
if asset_cache_manager_kwargs:
asset_cache_manager = asset_cache_manager.replace(**asset_cache_manager_kwargs)
asset_cache_manager_kwargs = {}
if cache_key is None:
cache_key = asset_cache_manager.generate_cache_key(**find_assets_kwargs)
asset = asset_cache_manager.load_asset(cache_key)
if asset is None:
if not silence_warnings:
warn("Caching documents...")
silence_warnings = True
else:
asset = None
if asset is None:
asset = find_assets(None, as_query=True, **find_assets_kwargs)
if rank_kwargs is None:
rank_kwargs = {}
else:
rank_kwargs = dict(rank_kwargs)
rank_kwargs["cache_documents"] = cache_documents
rank_kwargs["cache_key"] = cache_key
rank_kwargs["asset_cache_manager"] = asset_cache_manager
rank_kwargs["asset_cache_manager_kwargs"] = asset_cache_manager_kwargs
rank_kwargs["silence_warnings"] = silence_warnings
if "wrap_documents" not in rank_kwargs:
rank_kwargs["wrap_documents"] = wrap_documents
return asset.chat(
query,
chat_history,
rank=rank,
top_k=top_k,
min_top_k=min_top_k,
max_top_k=max_top_k,
cutoff=cutoff,
return_chunks=return_chunks,
rank_kwargs=rank_kwargs,
**chat_kwargs,
)
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Classes for content formatting.
See `vectorbtpro.utils.knowledge` for the toy dataset."""
import re
import inspect
import time
import sys
from pathlib import Path
from vectorbtpro import _typing as tp
from vectorbtpro.utils import checks
from vectorbtpro.utils.config import Configured, flat_merge_dicts
from vectorbtpro.utils.module_ import get_caller_qualname
from vectorbtpro.utils.path_ import check_mkdir
from vectorbtpro.utils.template import CustomTemplate, SafeSub, RepFunc
try:
if not tp.TYPE_CHECKING:
raise ImportError
from IPython.display import DisplayHandle as DisplayHandleT
except ImportError:
DisplayHandleT = "DisplayHandle"
__all__ = [
"ContentFormatter",
"PlainFormatter",
"IPythonFormatter",
"IPythonMarkdownFormatter",
"IPythonHTMLFormatter",
"HTMLFileFormatter",
]
class ToMarkdown(Configured):
"""Class to convert text to Markdown."""
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.formatting"]
def __init__(
self,
remove_code_title: tp.Optional[bool] = None,
even_indentation: tp.Optional[bool] = None,
newline_before_list: tp.Optional[bool] = None,
**kwargs,
) -> None:
Configured.__init__(
self,
remove_code_title=remove_code_title,
even_indentation=even_indentation,
newline_before_list=newline_before_list,
**kwargs,
)
remove_code_title = self.resolve_setting(remove_code_title, "remove_code_title")
even_indentation = self.resolve_setting(even_indentation, "even_indentation")
newline_before_list = self.resolve_setting(newline_before_list, "newline_before_list")
self._remove_code_title = remove_code_title
self._even_indentation = even_indentation
self._newline_before_list = newline_before_list
@property
def remove_code_title(self) -> bool:
"""Whether to remove `title` attribute from a code block and puts it above it."""
return self._remove_code_title
@property
def newline_before_list(self) -> bool:
"""Whether to add a new line before a list."""
return self._newline_before_list
@property
def even_indentation(self) -> bool:
"""Whether to make leading spaces even.
For example, 3 leading spaces become 4."""
return self._even_indentation
def to_markdown(self, text: str) -> str:
"""Convert text to Markdown."""
markdown = text
if self.remove_code_title:
def _replace_code_block(match):
language = match.group(1)
title = match.group(2)
code = match.group(3)
if title:
title_md = f"**{title}**\n\n"
else:
title_md = ""
code_md = f"```{language}\n{code}\n```"
return title_md + code_md
code_block_pattern = re.compile(r'```(\w+)\s+title="([^"]*)"\s*\n(.*?)\n```', re.DOTALL)
markdown = code_block_pattern.sub(_replace_code_block, markdown)
if self.even_indentation:
leading_spaces_pattern = re.compile(r"^( +)(?=\S|$|\n)")
fixed_lines = []
for line in markdown.splitlines(keepends=True):
match = leading_spaces_pattern.match(line)
if match and len(match.group(0)) % 2 != 0:
line = " " + line
fixed_lines.append(line)
markdown = "".join(fixed_lines)
if self.newline_before_list:
markdown = re.sub(r"(?<=[^\n])\n(?=[ \t]*(?:[*+-]\s|\d+\.\s))", "\n\n", markdown)
return markdown
def to_markdown(text: str, **kwargs) -> str:
"""Convert text to Markdown using `ToMarkdown`."""
return ToMarkdown(**kwargs).to_markdown(text)
class ToHTML(Configured):
"""Class to convert Markdown to HTML."""
_expected_keys_mode: tp.ExpectedKeysMode = "disable"
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.formatting"]
def __init__(
self,
resolve_extensions: tp.Optional[bool] = None,
make_links: tp.Optional[bool] = None,
**markdown_kwargs,
) -> None:
Configured.__init__(
self,
resolve_extensions=resolve_extensions,
make_links=make_links,
**markdown_kwargs,
)
resolve_extensions = self.resolve_setting(resolve_extensions, "resolve_extensions")
make_links = self.resolve_setting(make_links, "make_links")
markdown_kwargs = self.resolve_setting(markdown_kwargs, "markdown_kwargs", merge=True)
self._resolve_extensions = resolve_extensions
self._make_links = make_links
self._markdown_kwargs = markdown_kwargs
@property
def resolve_extensions(self) -> bool:
"""Whether to resolve Markdown extensions.
Uses `pymdownx` extensions over native extensions if installed."""
return self._resolve_extensions
@property
def make_links(self) -> bool:
"""Whether to detect raw URLs in HTML text (`p` and `span` elements only) and convert them to links."""
return self._make_links
@property
def markdown_kwargs(self) -> tp.Kwargs:
"""Keyword arguments passed to `markdown.markdown`."""
return self._markdown_kwargs
def to_html(self, markdown: str) -> str:
"""Convert Markdown to HTML."""
from vectorbtpro.utils.module_ import assert_can_import
assert_can_import("markdown")
import markdown as md
markdown_kwargs = dict(self.markdown_kwargs)
extensions = markdown_kwargs.pop("extensions", [])
if self.resolve_extensions:
from vectorbtpro.utils.module_ import check_installed
filtered_extensions = [
ext for ext in extensions if "." not in ext or check_installed(ext.partition(".")[0])
]
ext_set = set(filtered_extensions)
remove_fenced_code = "fenced_code" in ext_set and "pymdownx.superfences" in ext_set
remove_codehilite = "codehilite" in ext_set and "pymdownx.highlight" in ext_set
if remove_fenced_code or remove_codehilite:
filtered_extensions = [
ext
for ext in filtered_extensions
if not (
(ext == "fenced_code" and remove_fenced_code) or (ext == "codehilite" and remove_codehilite)
)
]
extensions = filtered_extensions
html = md.markdown(markdown, extensions=extensions, **markdown_kwargs)
if self.make_links:
tag_pattern = re.compile(r"<(p|span)(\s[^>]*)?>(.*?)\1>", re.DOTALL | re.IGNORECASE)
url_pattern = re.compile(r'(https?://[^\s<>"\'`]+?)(?=[.,;:!?)\]]*(?:\s|$))', re.IGNORECASE)
def _replace_urls(match, _url_pattern=url_pattern):
tag = match.group(1)
attributes = match.group(2) if match.group(2) else ""
content = match.group(3)
parts = re.split(r"(]*>.*?)", content, flags=re.DOTALL | re.IGNORECASE)
for i, part in enumerate(parts):
if not re.match(r"]*>.*?", part, re.DOTALL | re.IGNORECASE):
part = _url_pattern.sub(r'\1', part)
parts[i] = part
new_content = "".join(parts)
return f"<{tag}{attributes}>{new_content}{tag}>"
html = tag_pattern.sub(_replace_urls, html)
return html.strip()
def to_html(text: str, **kwargs) -> str:
"""Convert Markdown to HTML using `ToHTML`."""
return ToHTML(**kwargs).to_html(text)
class FormatHTML(Configured):
"""Class to format HTML.
If `use_pygments` is True, uses Pygments package for code highlighting. Arguments in
`pygments_kwargs` are then passed to `pygments.formatters.HtmlFormatter`.
Use `style_extras` to inject additional CSS rules outside the predefined ones.
Use `head_extras` to inject additional HTML elements into the `` section, such as meta tags,
links to external stylesheets, or scripts. Use `body_extras` to inject JavaScript files or inline
scripts at the end of the ``. All of these arguments can be lists.
HTML template is a template that can use all the arguments except those related to pygments.
It can be either a custom template, or string or function that will become one."""
_settings_path: tp.SettingsPath = ["knowledge", "knowledge.formatting"]
def __init__(
self,
html_template: tp.Optional[str] = None,
style_extras: tp.Optional[tp.MaybeList[str]] = None,
head_extras: tp.Optional[tp.MaybeList[str]] = None,
body_extras: tp.Optional[tp.MaybeList[str]] = None,
invert_colors: tp.Optional[bool] = None,
invert_colors_style: tp.Optional[str] = None,
auto_scroll: tp.Optional[bool] = None,
auto_scroll_body: tp.Optional[str] = None,
show_spinner: tp.Optional[bool] = None,
spinner_style: tp.Optional[str] = None,
spinner_body: tp.Optional[str] = None,
use_pygments: tp.Optional[bool] = None,
pygments_kwargs: tp.KwargsLike = None,
template_context: tp.KwargsLike = None,
**kwargs,
) -> None:
from vectorbtpro.utils.module_ import check_installed, assert_can_import
Configured.__init__(
self,
html_template=html_template,
style_extras=style_extras,
head_extras=head_extras,
body_extras=body_extras,
invert_colors=invert_colors,
invert_colors_style=invert_colors_style,
auto_scroll=auto_scroll,
auto_scroll_body=auto_scroll_body,
show_spinner=show_spinner,
spinner_style=spinner_style,
spinner_body=spinner_body,
use_pygments=use_pygments,
pygments_kwargs=pygments_kwargs,
template_context=template_context,
**kwargs,
)
html_template = self.resolve_setting(html_template, "html_template")
invert_colors = self.resolve_setting(invert_colors, "invert_colors")
invert_colors_style = self.resolve_setting(invert_colors_style, "invert_colors_style")
auto_scroll = self.resolve_setting(auto_scroll, "auto_scroll")
auto_scroll_body = self.resolve_setting(auto_scroll_body, "auto_scroll_body")
show_spinner = self.resolve_setting(show_spinner, "show_spinner")
spinner_style = self.resolve_setting(spinner_style, "spinner_style")
spinner_body = self.resolve_setting(spinner_body, "spinner_body")
use_pygments = self.resolve_setting(use_pygments, "use_pygments")
pygments_kwargs = self.resolve_setting(pygments_kwargs, "pygments_kwargs", merge=True)
template_context = self.resolve_setting(template_context, "template_context", merge=True)
def _prepare_extras(extras):
if extras is None:
extras = []
if isinstance(extras, str):
extras = [extras]
if not isinstance(extras, list):
extras = list(extras)
return "\n".join(extras)
if isinstance(html_template, str):
html_template = SafeSub(html_template)
elif checks.is_function(html_template):
html_template = RepFunc(html_template)
elif not isinstance(html_template, CustomTemplate):
raise TypeError(f"HTML template must be a string, function, or template")
style_extras = _prepare_extras(self.get_setting("style_extras")) + _prepare_extras(style_extras)
head_extras = _prepare_extras(self.get_setting("head_extras")) + _prepare_extras(head_extras)
body_extras = _prepare_extras(self.get_setting("body_extras")) + _prepare_extras(body_extras)
if invert_colors:
style_extras = "\n".join([style_extras, invert_colors_style])
if auto_scroll:
body_extras = "\n".join([body_extras, auto_scroll_body])
if show_spinner:
style_extras = "\n".join([style_extras, spinner_style])
body_extras = "\n".join([body_extras, spinner_body])
if use_pygments is None:
use_pygments = check_installed("pygments")
if use_pygments:
assert_can_import("pygments")
from pygments.formatters import HtmlFormatter
formatter = HtmlFormatter(**pygments_kwargs)
highlight_css = formatter.get_style_defs(".highlight")
if style_extras == "":
style_extras = highlight_css
else:
style_extras = highlight_css + "\n" + style_extras
self._html_template = html_template
self._style_extras = style_extras
self._head_extras = head_extras
self._body_extras = body_extras
self._template_context = template_context
@property
def html_template(self) -> CustomTemplate:
"""HTML template."""
return self._html_template
@property
def style_extras(self) -> str:
"""Extras for `
$head_extras
$html_metadata
$html_content
$body_extras
""",
root_style_extras=[],
style_extras=[],
head_extras=[],
body_extras=[
r"""""",
r"""""",
r"""""",
r"""""",
r"""""",
],
invert_colors=False,
invert_colors_style=""":root {
filter: invert(100%);
}""",
auto_scroll=False,
auto_scroll_body="""""",
show_spinner=False,
spinner_style=""".loader {
width: 300px;
height: 5px;
margin: 0 auto;
display: block;
position: relative;
overflow: hidden;
}
.loader::after {
content: '';
width: 300px;
height: 5px;
background: blue;
position: absolute;
top: 0;
left: 0;
box-sizing: border-box;
animation: animloader 1s ease-in-out infinite;
}
@keyframes animloader {
0%, 5% {
left: 0;
transform: translateX(-100%);
}
95%, 100% {
left: 100%;
transform: translateX(0%);
}
}
""",
spinner_body="""""",
output_to=None,
flush_output=True,
buffer_output=True,
close_output=None,
update_interval=None,
minimal_format=False,
formatter="ipython_auto",
formatter_config=flex_cfg(),
formatter_configs=flex_cfg(
plain=flex_cfg(),
ipython=flex_cfg(),
ipython_markdown=flex_cfg(),
ipython_html=flex_cfg(),
html=flex_cfg(
dir_path=RepEval("Path(cache_dir) / 'html'"),
mkdir_kwargs=flex_cfg(),
temp_files=False,
refresh_page=True,
file_prefix_len=20,
file_suffix_len=6,
auto_scroll=True,
show_spinner=True,
),
),
),
chat=flex_cfg(
chat_dir=RepEval("Path(cache_dir) / 'chat'"),
stream=True,
to_context_kwargs=flex_cfg(),
incl_past_queries=True,
rank=None,
rank_kwargs=flex_cfg(
top_k=None,
min_top_k=None,
max_top_k=None,
cutoff=None,
return_chunks=False,
),
max_tokens=120_000,
system_prompt=r"You are a helpful assistant. Given the context information and not prior knowledge, answer the query.",
system_as_user=True,
context_prompt=r"""Context information is below.
---------------------
$context
---------------------""",
minimal_format=True,
tokenizer="tiktoken",
tokenizer_config=flex_cfg(),
tokenizer_configs=flex_cfg(
tiktoken=flex_cfg(
encoding="model_or_o200k_base",
model=None,
tokens_per_message=3,
tokens_per_name=1,
),
),
embeddings="auto",
embeddings_config=flex_cfg(
batch_size=512,
),
embeddings_configs=flex_cfg(
openai=flex_cfg(
model="text-embedding-3-large",
dimensions=256,
),
litellm=flex_cfg(
model="text-embedding-3-large",
dimensions=256,
),
llama_index=flex_cfg(
embedding="openai",
embedding_configs=flex_cfg(
openai=flex_cfg(
model="text-embedding-3-large",
dimensions=256,
)
),
),
),
completions="auto",
completions_config=flex_cfg(),
completions_configs=flex_cfg(
openai=flex_cfg(
model="gpt-4o",
),
litellm=flex_cfg(
model="gpt-4o",
),
llama_index=flex_cfg(
llm="openai",
llm_configs=flex_cfg(
openai=flex_cfg(
model="gpt-4o",
)
),
),
),
text_splitter="segment",
text_splitter_config=flex_cfg(
chunk_template=r"""... (previous text omitted)
$chunk_text""",
),
text_splitter_configs=flex_cfg(
token=flex_cfg(
chunk_size=800,
chunk_overlap=400,
tokenizer="tiktoken",
tokenizer_kwargs=flex_cfg(
encoding="cl100k_base",
),
),
segment=flex_cfg(
separators=[[r"\n\s*\n", r"(?<=[^\s.?!])[.?!]+(?:\s+|$)"], r"\s+", None],
min_chunk_size=0.8,
fixed_overlap=False,
),
llama_index=flex_cfg(
node_parser="sentence",
node_parser_configs=flex_cfg(),
),
),
obj_store="memory",
obj_store_config=flex_cfg(
store_id="default",
purge_on_open=False,
),
obj_store_configs=flex_cfg(
memory=flex_cfg(),
file=flex_cfg(
dir_path=RepEval("Path(cache_dir) / 'file_store'"),
compression=None,
save_kwargs=flex_cfg(
mkdir_kwargs=flex_cfg(),
),
load_kwargs=flex_cfg(),
use_patching=True,
consolidate=False,
mirror=True,
),
lmdb=flex_cfg(
dir_path=RepEval("Path(cache_dir) / 'lmdb_store'"),
mkdir_kwargs=flex_cfg(),
dumps_kwargs=flex_cfg(),
loads_kwargs=flex_cfg(),
mirror=True,
flag="c",
),
cached=flex_cfg(
lazy_open=True,
mirror=False,
),
),
doc_ranker_config=flex_cfg(
dataset_id=None,
cache_doc_store=True,
cache_emb_store=True,
doc_store_configs=flex_cfg(
memory=flex_cfg(
store_id="doc_default",
),
file=flex_cfg(
dir_path=RepEval("Path(cache_dir) / 'doc_file_store'"),
),
lmdb=flex_cfg(
dir_path=RepEval("Path(cache_dir) / 'doc_lmdb_store'"),
),
),
emb_store_configs=flex_cfg(
memory=flex_cfg(
store_id="emb_default",
),
file=flex_cfg(
dir_path=RepEval("Path(cache_dir) / 'emb_file_store'"),
),
lmdb=flex_cfg(
dir_path=RepEval("Path(cache_dir) / 'emb_lmdb_store'"),
),
),
score_func="cosine",
score_agg_func="mean",
),
),
assets=flex_cfg(
vbt=flex_cfg(
cache_dir="./knowledge/vbt/",
release_dir=RepEval("(Path(cache_dir) / release_name) if release_name else cache_dir"),
assets_dir=RepEval("Path(release_dir) / 'assets'"),
markdown_dir=RepEval("Path(release_dir) / 'markdown'"),
html_dir=RepEval("Path(release_dir) / 'html'"),
release_name=None,
asset_name=None,
repo_owner="polakowo",
repo_name="vectorbt.pro",
token=None,
token_required=False,
use_pygithub=None,
chunk_size=8192,
document_cls=None,
document_kwargs=flex_cfg(
text_path="content",
excl_metadata=RepEval("asset_cls.get_setting('minimize_keys')"),
excl_embed_metadata=True,
split_text_kwargs=flex_cfg(),
),
minimize_metadata=False,
minimize_keys=[
"parent",
"children",
"type",
"icon",
"tags",
"block",
"thread",
"replies",
"mentions",
"reactions",
],
minimize_links=False,
minimize_link_rules=flex_cfg(
{
r"(https://vectorbt\.pro/pvt_[a-zA-Z0-9]+)": "$pvt_site",
r"(https://vectorbt\.pro)": "$pub_site",
r"(https://discord\.com/channels/[0-9]+)": "$discord",
r"(https://github\.com/polakowo/vectorbt\.pro)": "$github",
}
),
root_metadata_key=None,
aggregate_fields=False,
parent_links_only=True,
clean_metadata=True,
clean_metadata_kwargs=flex_cfg(),
dump_metadata_kwargs=flex_cfg(),
incl_base_attr=True,
incl_shortcuts=True,
incl_shortcut_access=True,
incl_shortcut_call=True,
incl_instances=True,
incl_custom=None,
is_custom_regex=False,
as_code=False,
as_regex=True,
allow_prefix=False,
allow_suffix=False,
merge_targets=True,
display=flex_cfg(
html_template=r"""
$title
$head_extras
$body_extras
""",
style_extras=[],
head_extras=[],
body_extras=[],
),
chat=flex_cfg(
chat_dir=RepEval("Path(release_dir) / 'chat'"),
system_prompt=r"""You are a helpful assistant with access to VectorBT PRO (also called VBT or vectorbtpro) documentation and relevant Discord history. Use only this provided context to generate clear, accurate answers. Do not reference the open‑source vectorbt, as VectorBT PRO is a proprietary successor with significant differences.\n\nWhen coding in Python, use:\n```python\nimport vectorbtpro as vbt\n```\n\nIf metadata includes links, reference them to support your answer. Do not include external or fabricated links, and exclude any information not present in the given context.\n\nFor each query, follow this structure:\n1. Optionally restate the question in your own words.\n2. Answer using only the available context.\n3. Include any relevant links.""",
doc_ranker_config=flex_cfg(
doc_store="lmdb",
doc_store_configs=flex_cfg(
file=flex_cfg(
dir_path=RepEval("Path(release_dir) / 'doc_file_store'"),
),
lmdb=flex_cfg(
dir_path=RepEval("Path(release_dir) / 'doc_lmdb_store'"),
),
),
emb_store="lmdb",
emb_store_configs=flex_cfg(
file=flex_cfg(
dir_path=RepEval("Path(release_dir) / 'emb_file_store'"),
),
lmdb=flex_cfg(
dir_path=RepEval("Path(release_dir) / 'emb_lmdb_store'"),
),
),
),
),
),
pages=flex_cfg(
assets_dir=RepEval("Path(release_dir) / 'pages' / 'assets'"),
markdown_dir=RepEval("Path(release_dir) / 'pages' / 'markdown'"),
html_dir=RepEval("Path(release_dir) / 'pages' / 'html'"),
asset_name="pages.json.zip",
token_required=True,
append_obj_type=True,
append_github_link=True,
use_parent=None,
use_base_parents=False,
use_ref_parents=False,
incl_bases=True,
incl_ancestors=True,
incl_base_ancestors=False,
incl_refs=None,
incl_descendants=True,
incl_ancestor_descendants=False,
incl_ref_descendants=False,
aggregate=True,
aggregate_ancestors=False,
aggregate_refs=False,
topo_sort=True,
incl_pages=None,
excl_pages=None,
page_find_mode="substring",
up_aggregate=True,
up_aggregate_th=2 / 3,
up_aggregate_pages=True,
),
messages=flex_cfg(
assets_dir=RepEval("Path(release_dir) / 'messages' / 'assets'"),
markdown_dir=RepEval("Path(release_dir) / 'messages' / 'markdown'"),
html_dir=RepEval("Path(release_dir) / 'messages' / 'html'"),
asset_name="messages.json.zip",
token_required=True,
),
),
)
"""_"""
__pdoc__["knowledge"] = Sub(
"""Sub-config with settings applied across `vectorbtpro.utils.knowledge`.
```python
${config_doc}
```"""
)
_settings["knowledge"] = knowledge
# ############# Settings config ############# #
class SettingsConfig(Config):
"""Extends `vectorbtpro.utils.config.Config` for global settings."""
def __init__(
self,
*args,
**kwargs,
) -> None:
options_ = kwargs.pop("options_", None)
if options_ is None:
options_ = {}
copy_kwargs = options_.pop("copy_kwargs", None)
if copy_kwargs is None:
copy_kwargs = {}
copy_kwargs["copy_mode"] = "deep"
options_["copy_kwargs"] = copy_kwargs
options_["frozen_keys"] = True
options_["as_attrs"] = True
Config.__init__(self, *args, options_=options_, **kwargs)
def register_template(self, theme: str) -> None:
"""Register template of a theme."""
if check_installed("plotly"):
import plotly.io as pio
import plotly.graph_objects as go
template_path = self["plotting"]["themes"][theme]["path"]
if template_path is None:
raise ValueError(f"Must provide template path for the theme '{theme}'")
if template_path.startswith("__name__/"):
template_path = template_path.replace("__name__/", "")
template = Config(json.loads(pkgutil.get_data(__name__, template_path)))
else:
with open(template_path, "r") as f:
template = Config(json.load(f))
pio.templates["vbt_" + theme] = go.layout.Template(template)
def register_templates(self) -> None:
"""Register templates of all themes."""
for theme in self["plotting"]["themes"]:
self.register_template(theme)
def set_theme(self, theme: str) -> None:
"""Set default theme."""
self.register_template(theme)
self["plotting"]["color_schema"].update(self["plotting"]["themes"][theme]["color_schema"])
self["plotting"]["layout"]["template"] = "vbt_" + theme
def reset_theme(self) -> None:
"""Reset to default theme."""
self.set_theme(self["plotting"]["default_theme"])
def substitute_sub_config_docs(self, __pdoc__: dict, prettify_kwargs: tp.KwargsLike = None) -> None:
"""Substitute templates in sub-config docs."""
if prettify_kwargs is None:
prettify_kwargs = {}
for k, v in __pdoc__.items():
if k in self:
config_doc = self[k].prettify(**prettify_kwargs.get(k, {}))
__pdoc__[k] = substitute_templates(
v,
context=dict(config_doc=config_doc),
eval_id="__pdoc__",
)
def get(self, key: tp.PathLikeKey, default: tp.Any = MISSING) -> tp.Any:
"""Get setting(s) under a path.
See `vectorbtpro.utils.search_.get_pathlike_key` for path format."""
from vectorbtpro.utils.search_ import get_pathlike_key
try:
return get_pathlike_key(self, key)
except (KeyError, IndexError, AttributeError) as e:
if default is MISSING:
raise e
return default
def set(self, key: tp.PathLikeKey, value: tp.Any, default_config_type: tp.Type[Config] = flex_cfg) -> None:
"""Set setting(s) under a path.
See `vectorbtpro.utils.search_.get_pathlike_key` for path format."""
from vectorbtpro.utils.search_ import resolve_pathlike_key
tokens = resolve_pathlike_key(key)
obj = self
for i, token in enumerate(tokens):
if isinstance(obj, Config):
if token not in obj:
obj[token] = default_config_type()
if i < len(tokens) - 1:
if isinstance(obj, (set, frozenset)):
obj = list(obj)[token]
elif hasattr(obj, "__getitem__"):
obj = obj[token]
elif isinstance(token, str) and hasattr(obj, token):
obj = getattr(obj, token)
else:
raise TypeError(f"Cannot navigate object of type {type(obj).__name__}")
else:
if hasattr(obj, "__setitem__"):
obj[token] = value
elif hasattr(obj, "__dict__"):
setattr(obj, token, value)
else:
raise TypeError(f"Cannot modify object of type {type(obj).__name__}")
settings = SettingsConfig(_settings)
"""Global settings config.
Combines all sub-configs defined in this module."""
settings_name = os.environ.get("VBT_SETTINGS_NAME", "vbt")
if "VBT_SETTINGS_PATH" in os.environ:
if len(os.environ["VBT_SETTINGS_PATH"]) > 0:
settings.load_update(os.environ["VBT_SETTINGS_PATH"])
elif settings.file_exists(settings_name):
settings.load_update(settings_name)
settings.reset_theme()
settings.register_templates()
settings.make_checkpoint()
settings.substitute_sub_config_docs(__pdoc__)
if settings["numba"]["disable"]:
nb_config.DISABLE_JIT = True
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""General types used across vectorbtpro."""
from datetime import datetime, timedelta, tzinfo, date, time
from enum import EnumMeta
from pathlib import Path
import sys
if sys.version_info < (3, 9):
import typing
typing.__all__.append("TextIO")
from typing import *
import numpy as np
import pandas as pd
from mypy_extensions import VarArg
from pandas import Series, DataFrame as Frame, Index
from pandas.core.groupby import GroupBy as PandasGroupBy
from pandas.core.indexing import _IndexSlice as IndexSlice
from pandas.core.resample import Resampler as PandasResampler
from pandas.tseries.offsets import BaseOffset
try:
if not TYPE_CHECKING:
raise ImportError
from plotly.graph_objects import Figure, FigureWidget
from plotly.basedatatypes import BaseFigure, BaseTraceType
except ImportError:
Figure = Any
FigureWidget = Any
BaseFigure = Any
BaseTraceType = Any
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
try:
from typing import Self
except ImportError:
from typing_extensions import Self
if TYPE_CHECKING:
from vectorbtpro.utils.parsing import Regex
from vectorbtpro.utils.execution import Task, ExecutionEngine
from vectorbtpro.utils.chunking import Sizer, NotChunked, ChunkTaker, ChunkMeta, ChunkMetaGenerator
from vectorbtpro.utils.jitting import Jitter
from vectorbtpro.utils.template import CustomTemplate
from vectorbtpro.utils.datetime_ import DTC, DTCNT
from vectorbtpro.utils.selection import PosSel, LabelSel
from vectorbtpro.utils.merging import MergeFunc
from vectorbtpro.utils.knowledge.base_asset_funcs import AssetFunc
from vectorbtpro.utils.knowledge.asset_pipelines import AssetPipeline
from vectorbtpro.utils.knowledge.base_assets import KnowledgeAsset
from vectorbtpro.utils.knowledge.chatting import (
Tokenizer,
Embeddings,
Completions,
TextSplitter,
StoreDocument,
ObjectStore,
EmbeddedDocument,
ScoredDocument,
)
from vectorbtpro.utils.knowledge.custom_assets import VBTAsset, PagesAsset, MessagesAsset
from vectorbtpro.utils.knowledge.formatting import ContentFormatter
from vectorbtpro.base.indexing import hslice
from vectorbtpro.base.grouping.base import Grouper
from vectorbtpro.base.resampling.base import Resampler
from vectorbtpro.generic.splitting.base import FixRange, RelRange
else:
Regex = "Regex"
Task = "Task"
ExecutionEngine = "ExecutionEngine"
Sizer = "Sizer"
NotChunked = "NotChunked"
ChunkTaker = "ChunkTaker"
ChunkMeta = "ChunkMeta"
ChunkMetaGenerator = "ChunkMetaGenerator"
TraceUpdater = "TraceUpdater"
Jitter = "Jitter"
CustomTemplate = "CustomTemplate"
DTC = "DTC"
DTCNT = "DTCNT"
PosSel = "PosSel"
LabelSel = "LabelSel"
MergeFunc = "MergeFunc"
AssetFunc = "AssetFunc"
AssetPipeline = "AssetPipeline"
KnowledgeAsset = "KnowledgeAsset"
Tokenizer = "Tokenizer"
Embeddings = "Embeddings"
Completions = "Completions"
TextSplitter = "TextSplitter"
StoreDocument = "StoreDocument"
ObjectStore = "ObjectStore"
EmbeddedDocument = "EmbeddedDocument"
ScoredDocument = "ScoredDocument"
VBTAsset = "VBTAsset"
PagesAsset = "PagesAsset"
MessagesAsset = "MessagesAsset"
ContentFormatter = "ContentFormatter"
hslice = "hslice"
Grouper = "Grouper"
Resampler = "Resampler"
FixRange = "FixRange"
RelRange = "RelRange"
__all__ = []
# Generic types
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
MaybeType = Union[T, Type[T]]
# Scalars
Scalar = Union[str, float, int, complex, bool, object, np.generic]
Number = Union[int, float, complex, np.number, np.bool_]
Int = Union[int, np.integer]
Float = Union[float, np.floating]
IntFloat = Union[Int, Float]
IntStr = Union[Int, str]
# Basic sequences
MaybeTuple = Union[T, Tuple[T, ...]]
MaybeList = Union[T, List[T]]
MaybeSet = Union[T, Set[T]]
TupleList = Union[List[T], Tuple[T, ...]]
MaybeTupleList = Union[T, List[T], Tuple[T, ...]]
MaybeIterable = Union[T, Iterable[T]]
MaybeSequence = Union[T, Sequence[T]]
MaybeDict = Union[Dict[Hashable, T], T]
MappingSequence = Union[Mapping[Hashable, T], Sequence[T]]
MaybeMappingSequence = Union[T, Mapping[Hashable, T], Sequence[T]]
SetLike = Union[None, Set[T]]
Items = Iterator[Tuple[Hashable, Any]]
# Arrays
class SupportsArrayT(Protocol):
def __array__(self) -> np.ndarray: ...
DTypeLike = Any
PandasDTypeLike = Any
TypeLike = MaybeIterable[Union[Type, str, Regex]]
Shape = Tuple[int, ...]
ShapeLike = Union[int, Shape]
Array = np.ndarray # ready to be used for n-dim data
Array1d = np.ndarray
Array2d = np.ndarray
Array3d = np.ndarray
Record = np.void
RecordArray = np.ndarray
RecordArray2d = np.ndarray
RecArray = np.recarray
MaybeArray = Union[Scalar, Array]
MaybeIndexArray = Union[int, slice, Array1d, Array2d]
SeriesFrame = Union[Series, Frame]
MaybeSeries = Union[Scalar, Series]
MaybeSeriesFrame = Union[T, Series, Frame]
PandasArray = Union[Index, Series, Frame]
AnyArray = Union[Array, PandasArray]
AnyArray1d = Union[Array1d, Index, Series]
AnyArray2d = Union[Array2d, Frame]
ArrayLike = Union[Scalar, Sequence[Scalar], Sequence[Sequence[Any]], SupportsArrayT, Array]
IndexLike = Union[range, Sequence[Scalar], SupportsArrayT]
FlexArray1d = Array1d
FlexArray2d = Array2d
FlexArray1dLike = Union[Scalar, Array1d, Array2d]
FlexArray2dLike = Union[Scalar, Array1d, Array2d]
# Templates
CustomTemplateLike = Union[str, Callable, CustomTemplate]
# Labels
Label = Hashable
Labels = Sequence[Label]
Level = Union[str, int]
LevelSequence = Sequence[Level]
MaybeLevelSequence = Union[Level, LevelSequence]
# Datetime
Datetime = Union[pd.Timestamp, np.datetime64, datetime]
DatetimeLike = Union[str, int, float, Datetime]
Timedelta = Union[pd.Timedelta, np.timedelta64, timedelta]
TimedeltaLike = Union[str, int, float, Timedelta]
Frequency = Union[BaseOffset, Timedelta]
FrequencyLike = Union[BaseOffset, TimedeltaLike]
TimezoneLike = Union[None, str, int, float, timedelta, tzinfo]
TimeLike = Union[str, time]
PandasFrequency = Union[BaseOffset, pd.Timedelta]
PandasDatetimeIndex = Union[pd.DatetimeIndex, pd.PeriodIndex]
AnyPandasFrequency = Union[None, int, float, PandasFrequency]
DTCLike = Union[None, str, int, time, date, Datetime, DTC, DTCNT]
# Indexing
Slice = Union[slice, hslice]
PandasIndexingFunc = Callable[[SeriesFrame], MaybeSeriesFrame]
# Grouping
PandasGroupByLike = Union[PandasGroupBy, PandasResampler, FrequencyLike]
GroupByLike = Union[None, bool, MaybeLevelSequence, IndexLike, CustomTemplate]
AnyGroupByLike = Union[Grouper, PandasGroupByLike, GroupByLike]
AnyRuleLike = Union[Resampler, PandasResampler, FrequencyLike, IndexLike]
GroupIdxs = Array1d
GroupLens = Array1d
GroupMap = Tuple[GroupIdxs, GroupLens]
# Wrapping
NameIndex = Union[None, Any, Index]
# Search
PathKeyToken = Hashable
PathKeyTokens = Sequence[Hashable]
PathKey = Tuple[PathKeyToken, ...]
MaybePathKey = Union[None, PathKeyToken, PathKey]
PathLikeKey = Union[MaybePathKey, Path]
PathLikeKeys = Sequence[PathLikeKey]
PathMoveDict = Dict[PathLikeKey, PathLikeKey]
PathRenameDict = Dict[PathLikeKey, PathKeyToken]
PathDict = Dict[PathLikeKey, Any]
# Config
DictLike = Union[None, dict]
DictLikeSequence = MaybeSequence[DictLike]
Args = Tuple[Any, ...]
ArgsLike = Union[None, Args]
Kwargs = Dict[str, Any]
KwargsLike = Union[None, Kwargs]
KwargsLikeSequence = MaybeSequence[KwargsLike]
ArgsKwargs = Tuple[Args, Kwargs]
PathLike = Union[str, Path]
_SettingsPath = Union[None, MaybeList[PathLikeKey], Dict[Hashable, PathLikeKey]]
SettingsPath = ClassVar[_SettingsPath]
ExtSettingsPaths = List[Tuple[type, _SettingsPath]]
SpecSettingsPaths = Dict[PathLikeKey, MaybeList[PathLikeKey]]
WriteableAttrs = ClassVar[Optional[Set[str]]]
ExpectedKeysMode = ClassVar[str]
ExpectedKeys = ClassVar[Optional[Set[str]]]
# Data
Column = Key = Feature = Symbol = Hashable
Columns = Keys = Features = Symbols = Sequence[Hashable]
MaybeColumns = MaybeKeys = MaybeFeatures = MaybeSymbols = Union[Hashable, Sequence[Hashable]]
KeyData = FeatureData = SymbolData = Union[None, SeriesFrame, Tuple[SeriesFrame, Kwargs]]
# Plotting
TraceName = Union[str, None]
TraceNames = MaybeSequence[TraceName]
# Generic
MapFunc = Callable[[Scalar, VarArg()], Scalar]
MapMetaFunc = Callable[[int, int, Scalar, VarArg()], Scalar]
ApplyFunc = Callable[[Array1d, VarArg()], MaybeArray]
ApplyMetaFunc = Callable[[int, VarArg()], MaybeArray]
ReduceFunc = Callable[[Array1d, VarArg()], Scalar]
ReduceMetaFunc = Callable[[int, VarArg()], Scalar]
ReduceToArrayFunc = Callable[[Array1d, VarArg()], Array1d]
ReduceToArrayMetaFunc = Callable[[int, VarArg()], Array1d]
ReduceGroupedFunc = Callable[[Array2d, VarArg()], Scalar]
ReduceGroupedMetaFunc = Callable[[GroupIdxs, int, VarArg()], Scalar]
ReduceGroupedToArrayFunc = Callable[[Array2d, VarArg()], Array1d]
ReduceGroupedToArrayMetaFunc = Callable[[GroupIdxs, int, VarArg()], Array1d]
RangeReduceMetaFunc = Callable[[int, int, int, VarArg()], Scalar]
ProximityReduceMetaFunc = Callable[[int, int, int, int, VarArg()], Scalar]
GroupByReduceMetaFunc = Callable[[GroupIdxs, int, int, VarArg()], Scalar]
GroupSqueezeMetaFunc = Callable[[int, GroupIdxs, int, VarArg()], Scalar]
GroupByTransformFunc = Callable[[Array2d, VarArg()], MaybeArray]
GroupByTransformMetaFunc = Callable[[GroupIdxs, int, VarArg()], MaybeArray]
# Signals
PlaceFunc = Callable[[NamedTuple, VarArg()], int]
RankFunc = Callable[[NamedTuple, VarArg()], int]
# Records
RecordsMapFunc = Callable[[np.void, VarArg()], Scalar]
RecordsMapMetaFunc = Callable[[int, VarArg()], Scalar]
MappedReduceMetaFunc = Callable[[GroupIdxs, int, VarArg()], Scalar]
MappedReduceToArrayMetaFunc = Callable[[GroupIdxs, int, VarArg()], Array1d]
# Indicators
ParamValue = Any
ParamValues = Sequence[ParamValue]
MaybeParamValues = MaybeSequence[ParamValue]
MaybeParams = Sequence[MaybeParamValues]
Params = Sequence[ParamValues]
ParamsOrLens = Sequence[Union[ParamValues, int]]
ParamsOrDict = Union[Params, Dict[Hashable, ParamValues]]
ParamGrid = Union[ParamsOrLens, Dict[Hashable, ParamsOrLens]]
ParamComb = Sequence[ParamValue]
ParamCombOrDict = Union[ParamComb, Dict[Hashable, ParamValue]]
# Mappings
MappingLike = Union[str, Mapping, NamedTuple, EnumMeta, IndexLike]
RecordsLike = Union[SeriesFrame, RecordArray, Sequence[MappingLike]]
# Annotations
Annotation = object
Annotations = Dict[str, Annotation]
# Parsing
AnnArgs = Dict[str, Kwargs]
FlatAnnArgs = Dict[str, Kwargs]
AnnArgQuery = Union[int, str, Regex]
# Execution
FuncArgs = Tuple[Callable, Args, Kwargs]
FuncsArgs = Iterable[FuncArgs]
TaskLike = Union[FuncArgs, Task]
TasksLike = Iterable[TaskLike]
ExecutionEngineLike = Union[str, type, ExecutionEngine, Callable]
ExecResult = Any
ExecResults = List[Any]
# JIT
JittedOption = Union[None, bool, str, Callable, Kwargs]
JitterLike = Union[str, Jitter, Type[Jitter]]
TaskId = Union[Hashable, Callable]
# Merging
MergeFuncLike = MaybeSequence[Union[None, str, Callable, MergeFunc]]
MergeResult = Any
MergeableResults = Union[ExecResults, MergeResult]
# Chunking
SizeFunc = Callable[[AnnArgs], int]
SizeLike = Union[int, str, Sizer, SizeFunc]
ChunkMetaFunc = Callable[[AnnArgs], Iterable[ChunkMeta]]
ChunkMetaLike = Union[Iterable[ChunkMeta], ChunkMetaGenerator, ChunkMetaFunc]
TakeSpec = Union[None, Type[NotChunked], Type[ChunkTaker], NotChunked, ChunkTaker]
ArgTakeSpec = Mapping[AnnArgQuery, TakeSpec]
ArgTakeSpecFunc = Callable[[AnnArgs, ChunkMeta], Tuple[Args, Kwargs]]
ArgTakeSpecLike = Union[Sequence[TakeSpec], ArgTakeSpec, ArgTakeSpecFunc, CustomTemplate]
MappingTakeSpec = Mapping[Hashable, TakeSpec]
SequenceTakeSpec = Sequence[TakeSpec]
ContainerTakeSpec = Union[MappingTakeSpec, SequenceTakeSpec]
ChunkedOption = Union[None, bool, str, Callable, Kwargs]
# Decorators
ClassWrapper = Callable[[Type[T]], Type[T]]
FlexClassWrapper = Union[Callable[[Type[T]], Type[T]], Type[T]]
# Splitting
FixRangeLike = Union[Slice, Sequence[int], Sequence[bool], Callable, CustomTemplate, FixRange]
RelRangeLike = Union[int, float, Callable, CustomTemplate, RelRange]
RangeLike = Union[FixRangeLike, RelRangeLike]
ReadyRangeLike = Union[slice, Array1d]
FixSplit = Sequence[FixRangeLike]
SplitLike = Union[str, int, float, MaybeSequence[RangeLike]]
Splits = Sequence[SplitLike]
SplitsArray = Array2d
SplitsMask = Array3d
BoundsArray = Array3d
# Staticization
StaticizedOption = Union[None, bool, Kwargs, TaskId]
# Selection
Selection = Union[PosSel, LabelSel, MaybeIterable[Union[PosSel, LabelSel, Hashable]]]
# Knowledge
AssetFuncLike = Union[str, Type[AssetFunc], FuncArgs, Task, Callable]
MaybeAsset = Union[None, T, dict, list, Iterator[T]]
MaybeKnowledgeAsset = MaybeAsset[KnowledgeAsset]
MaybeVBTAsset = MaybeAsset[VBTAsset]
MaybePagesAsset = MaybeAsset[PagesAsset]
MaybeMessagesAsset = MaybeAsset[MessagesAsset]
ContentFormatterLike = Union[None, str, MaybeType[ContentFormatter]]
TokenizerLike = Union[None, str, MaybeType[Tokenizer]]
Token = int
Tokens = List[Token]
EmbeddingsLike = Union[None, str, MaybeType[Embeddings]]
CompletionsLike = Union[None, str, MaybeType[Completions]]
ChatMessage = dict
ChatMessages = List[ChatMessage]
ChatHistory = MutableSequence[ChatMessage]
ChatOutput = Union[Optional[Path], Tuple[Optional[Path], Any]]
MaybeChatOutput = Union[ChatOutput, Tuple[ChatOutput, Completions]]
TextSplitterLike = Union[None, str, MaybeType[TextSplitter]]
TSRange = Tuple[int, int]
TSRangeChunks = Iterator[TSRange]
TSSegment = Tuple[int, int, bool]
TSSegmentChunks = Iterator[TSRange]
TSTextChunks = Iterator[str]
ObjectStoreLike = Union[None, str, MaybeType[ObjectStore]]
EmbeddedDocuments = List[EmbeddedDocument]
ScoredDocuments = List[Union[float, ScoredDocument]]
RankedDocuments = List[Union[StoreDocument, ScoredDocument]]
TopKLike = Union[None, int, float, str, Callable]
# Chaining
PipeFunc = Union[str, Callable, Tuple[Union[str, Callable], str]]
PipeTask = Union[PipeFunc, Tuple[PipeFunc, Args, Kwargs], Task]
PipeTasks = Iterable[PipeTask]
# Pickling
CompressionLike = Union[None, bool, str]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
try:
try:
import tomllib
except ModuleNotFoundError:
import tomli as tomllib
from pathlib import Path
with open(Path(__file__).resolve().parent.parent / "pyproject.toml", "rb") as f:
pyproject = tomllib.load(f)
__version__ = pyproject["project"]["version"]
except Exception as e:
import importlib.metadata
__version__ = importlib.metadata.version(__package__ or __name__)
__release__ = "v" + __version__
__all__ = [
"__version__",
"__release__",
]
# ==================================== VBTPROXYZ ====================================
# Copyright (c) 2021-2025 Oleg Polakow. All rights reserved.
#
# This file is part of the proprietary VectorBT® PRO package and is licensed under
# the VectorBT® PRO License available at https://vectorbt.pro/terms/software-license/
#
# Unauthorized publishing, distribution, sublicensing, or sale of this software
# or its parts is strictly prohibited.
# ===================================================================================
"""Root Pandas accessors of vectorbtpro.
An accessor adds additional "namespace" to pandas objects.
The `vectorbtpro.accessors` registers a custom `vbt` accessor on top of each `pd.Index`, `pd.Series`,
and `pd.DataFrame` object. It is the main entry point for all other accessors:
```plaintext
vbt.base.accessors.BaseIDX/SR/DFAccessor -> pd.Index/Series/DataFrame.vbt.*
vbt.generic.accessors.GenericSR/DFAccessor -> pd.Series/DataFrame.vbt.*
vbt.signals.accessors.SignalsSR/DFAccessor -> pd.Series/DataFrame.vbt.signals.*
vbt.returns.accessors.ReturnsSR/DFAccessor -> pd.Series/DataFrame.vbt.returns.*
vbt.ohlcv.accessors.OHLCVDFAccessor -> pd.DataFrame.vbt.ohlcv.*
vbt.px.accessors.PXSR/DFAccessor -> pd.Series/DataFrame.vbt.px.*
```
Additionally, some accessors subclass other accessors building the following inheritance hiearchy:
```plaintext
vbt.base.accessors.BaseIDXAccessor
vbt.base.accessors.BaseSR/DFAccessor
-> vbt.generic.accessors.GenericSR/DFAccessor
-> vbt.signals.accessors.SignalsSR/DFAccessor
-> vbt.returns.accessors.ReturnsSR/DFAccessor
-> vbt.ohlcv.accessors.OHLCVDFAccessor
-> vbt.px.accessors.PXSR/DFAccessor
```
So, for example, the method `pd.Series.vbt.to_2d_array` is also available as
`pd.Series.vbt.returns.to_2d_array`.
Class methods of any accessor can be conveniently accessed using `pd_acc`, `sr_acc`, and `df_acc` shortcuts:
```pycon
>>> from vectorbtpro import *
>>> vbt.pd_acc.signals.generate
>
```
!!! note
Accessors in vectorbt are not cached, so querying `df.vbt` twice will also call `Vbt_DFAccessor` twice.
You can change this in global settings."""
import pandas as pd
from pandas.core.accessor import DirNamesMixin
from vectorbtpro import _typing as tp
from vectorbtpro.base.accessors import BaseIDXAccessor
from vectorbtpro.base.wrapping import ArrayWrapper
from vectorbtpro.generic.accessors import GenericAccessor, GenericSRAccessor, GenericDFAccessor
from vectorbtpro.utils.base import Base
from vectorbtpro.utils.warnings_ import warn
__all__ = [
"Vbt_Accessor",
"Vbt_SRAccessor",
"Vbt_DFAccessor",
"idx_acc",
"pd_acc",
"sr_acc",
"df_acc",
]
__pdoc__ = {}
ParentAccessorT = tp.TypeVar("ParentAccessorT", bound=object)
AccessorT = tp.TypeVar("AccessorT", bound=object)
class Accessor(Base):
"""Accessor."""
def __init__(self, name: str, accessor: tp.Type[AccessorT]) -> None:
self._name = name
self._accessor = accessor
def __get__(self, obj: ParentAccessorT, cls: DirNamesMixin) -> AccessorT:
if obj is None:
return self._accessor
if isinstance(obj, (pd.Index, pd.Series, pd.DataFrame)):
accessor_obj = self._accessor(obj)
elif issubclass(self._accessor, type(obj)):
accessor_obj = obj.replace(cls_=self._accessor)
else:
accessor_obj = self._accessor(obj.wrapper, obj=obj._obj)
return accessor_obj
class CachedAccessor(Base):
"""Cached accessor."""
def __init__(self, name: str, accessor: tp.Type[AccessorT]) -> None:
self._name = name
self._accessor = accessor
def __get__(self, obj: ParentAccessorT, cls: DirNamesMixin) -> AccessorT:
if obj is None:
return self._accessor
if isinstance(obj, (pd.Index, pd.Series, pd.DataFrame)):
accessor_obj = self._accessor(obj)
elif issubclass(self._accessor, type(obj)):
accessor_obj = obj.replace(cls_=self._accessor)
else:
accessor_obj = self._accessor(obj.wrapper, obj=obj._obj)
object.__setattr__(obj, self._name, accessor_obj)
return accessor_obj
def register_accessor(name: str, cls: tp.Type[DirNamesMixin]) -> tp.Callable:
"""Register a custom accessor.
`cls` must subclass `pandas.core.accessor.DirNamesMixin`."""
def decorator(accessor: tp.Type[AccessorT]) -> tp.Type[AccessorT]:
from vectorbtpro._settings import settings
caching_cfg = settings["caching"]
if hasattr(cls, name):
warn(
f"registration of accessor {repr(accessor)} under name "
f"{repr(name)} for type {repr(cls)} is overriding a preexisting "
"attribute with the same name."
)
if caching_cfg["use_cached_accessors"]:
setattr(cls, name, CachedAccessor(name, accessor))
else:
setattr(cls, name, Accessor(name, accessor))
cls._accessors.add(name)
return accessor
return decorator
def register_index_accessor(name: str) -> tp.Callable:
"""Decorator to register a custom `pd.Index` accessor."""
return register_accessor(name, pd.Index)
def register_series_accessor(name: str) -> tp.Callable:
"""Decorator to register a custom `pd.Series` accessor."""
return register_accessor(name, pd.Series)
def register_dataframe_accessor(name: str) -> tp.Callable:
"""Decorator to register a custom `pd.DataFrame` accessor."""
return register_accessor(name, pd.DataFrame)
@register_index_accessor("vbt")
class Vbt_IDXAccessor(DirNamesMixin, BaseIDXAccessor):
"""The main vectorbt accessor for `pd.Index`."""
def __init__(self, obj: tp.Index, **kwargs) -> None:
self._obj = obj
DirNamesMixin.__init__(self)
BaseIDXAccessor.__init__(self, obj, **kwargs)
idx_acc = Vbt_IDXAccessor
"""Shortcut for `Vbt_IDXAccessor`."""
__pdoc__["idx_acc"] = False
class Vbt_Accessor(DirNamesMixin, GenericAccessor):
"""The main vectorbt accessor for `pd.Series` and `pd.DataFrame`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
**kwargs,
) -> None:
DirNamesMixin.__init__(self)
GenericAccessor.__init__(self, wrapper, obj=obj, **kwargs)
pd_acc = Vbt_Accessor
"""Shortcut for `Vbt_Accessor`."""
__pdoc__["pd_acc"] = False
@register_series_accessor("vbt")
class Vbt_SRAccessor(DirNamesMixin, GenericSRAccessor):
"""The main vectorbt accessor for `pd.Series`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
**kwargs,
) -> None:
DirNamesMixin.__init__(self)
GenericSRAccessor.__init__(self, wrapper, obj=obj, **kwargs)
sr_acc = Vbt_SRAccessor
"""Shortcut for `Vbt_SRAccessor`."""
__pdoc__["sr_acc"] = False
@register_dataframe_accessor("vbt")
class Vbt_DFAccessor(DirNamesMixin, GenericDFAccessor):
"""The main vectorbt accessor for `pd.DataFrame`."""
def __init__(
self,
wrapper: tp.Union[ArrayWrapper, tp.ArrayLike],
obj: tp.Optional[tp.ArrayLike] = None,
**kwargs,
) -> None:
DirNamesMixin.__init__(self)
GenericDFAccessor.__init__(self, wrapper, obj=obj, **kwargs)
df_acc = Vbt_DFAccessor
"""Shortcut for `Vbt_DFAccessor`."""
__pdoc__["df_acc"] = False
def register_vbt_accessor(name: str, parent: tp.Type[DirNamesMixin] = Vbt_Accessor) -> tp.Callable:
"""Decorator to register an accessor on top of a parent accessor."""
return register_accessor(name, parent)
def register_idx_vbt_accessor(name: str, parent: tp.Type[DirNamesMixin] = Vbt_IDXAccessor) -> tp.Callable:
"""Decorator to register a `pd.Index` accessor on top of a parent accessor."""
return register_accessor(name, parent)
def register_sr_vbt_accessor(name: str, parent: tp.Type[DirNamesMixin] = Vbt_SRAccessor) -> tp.Callable:
"""Decorator to register a `pd.Series` accessor on top of a parent accessor."""
return register_accessor(name, parent)
def register_df_vbt_accessor(name: str, parent: tp.Type[DirNamesMixin] = Vbt_DFAccessor) -> tp.Callable:
"""Decorator to register a `pd.DataFrame` accessor on top of a parent accessor."""
return register_accessor(name, parent)