Skip to content

Instantly share code, notes, and snippets.

@gerarldlee
Forked from wrighter/download_bars.py
Created May 22, 2024 07:28
Show Gist options
  • Select an option

  • Save gerarldlee/d412e7293bb7631c8b9bd999b9c08abf to your computer and use it in GitHub Desktop.

Select an option

Save gerarldlee/d412e7293bb7631c8b9bd999b9c08abf to your computer and use it in GitHub Desktop.

Revisions

  1. @wrighter wrighter revised this gist Jun 7, 2021. 1 changed file with 26 additions and 5 deletions.
    31 changes: 26 additions & 5 deletions download_bars.py
    Original file line number Diff line number Diff line change
    @@ -183,15 +183,22 @@ def start(self):
    @iswrapper
    def error(self, req_id: TickerId, error_code: int, error: str):
    super().error(req_id, error_code, error)
    print("Error. Id:", req_id, "Code:", error_code, "Msg:", error)
    if req_id < 0:
    logging.debug("Error. Id: %s Code %s Msg: %s", req_id, error_code, error)
    else:
    logging.error("Error. Id: %s Code %s Msg: %s", req_id, error_code, error)
    # we will always exit on error since data will need to be validated
    self.done = True


    def make_contract(symbol: str, sec_type: str, currency: str, exchange: str) -> Contract:
    def make_contract(symbol: str, sec_type: str, currency: str, exchange: str, localsymbol: str) -> Contract:
    contract = Contract()
    contract.symbol = symbol
    contract.secType = sec_type
    contract.currency = currency
    contract.exchange = exchange
    if localsymbol:
    contract.localSymbol = localsymbol
    return contract


    @@ -272,6 +279,9 @@ def __call__(
    argp.add_argument(
    "-d", "--debug", action="store_true", help="turn on debug logging"
    )
    argp.add_argument(
    "--logfile", help="log to file"
    )
    argp.add_argument(
    "-p", "--port", type=int, default=7496, help="local port for TWS connection"
    )
    @@ -292,6 +302,9 @@ def __call__(
    argp.add_argument(
    "--exchange", type=str, default="SMART", help="exchange for symbols"
    )
    argp.add_argument(
    "--localsymbol", type=str, default="", help="local symbol (for futures)"
    )
    argp.add_argument(
    "--security-type", type=str, default="STK", help="security type for symbols"
    )
    @@ -309,10 +322,18 @@ def __call__(
    )
    args = argp.parse_args()

    logargs = dict(format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
    datefmt='%H:%M:%S')
    if args.debug:
    logging.basicConfig(level=logging.DEBUG)
    logargs['level'] = logging.DEBUG
    else:
    logging.basicConfig(level=logging.INFO)
    logargs['level'] = logging.INFO

    if args.logfile:
    logargs['filemode'] = 'a'
    logargs['filename'] = args.logfile

    logging.basicConfig(**logargs)

    try:
    validate_duration(args.duration)
    @@ -326,7 +347,7 @@ def __call__(
    logging.debug(f"args={args}")
    contracts = []
    for s in args.symbol:
    contract = make_contract(s, args.security_type, args.currency, args.exchange)
    contract = make_contract(s, args.security_type, args.currency, args.exchange, args.localsymbol)
    contracts.append(contract)
    os.makedirs(make_download_path(args, contract), exist_ok=True)
    app = DownloadApp(contracts, args)
  2. @wrighter wrighter created this gist Feb 9, 2020.
    339 changes: 339 additions & 0 deletions download_bars.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,339 @@
    #!/usr/bin/env python

    import os
    import sys
    import argparse
    import logging
    from datetime import datetime, timedelta

    from typing import List, Optional
    from collections import defaultdict
    from dateutil.parser import parse

    import numpy as np
    import pandas as pd

    from ibapi import wrapper
    from ibapi.common import TickerId, BarData
    from ibapi.client import EClient
    from ibapi.contract import Contract
    from ibapi.utils import iswrapper

    ContractList = List[Contract]
    BarDataList = List[BarData]
    OptionalDate = Optional[datetime]


    def make_download_path(args: argparse.Namespace, contract: Contract) -> str:
    """Make path for saving csv files.
    Files to be stored in base_directory/<security_type>/<size>/<symbol>/
    """
    path = os.path.sep.join(
    [
    args.base_directory,
    args.security_type,
    args.size.replace(" ", "_"),
    contract.symbol,
    ]
    )
    return path


    class DownloadApp(EClient, wrapper.EWrapper):
    def __init__(self, contracts: ContractList, args: argparse.Namespace):
    EClient.__init__(self, wrapper=self)
    wrapper.EWrapper.__init__(self)
    self.request_id = 0
    self.started = False
    self.next_valid_order_id = None
    self.contracts = contracts
    self.requests = {}
    self.bar_data = defaultdict(list)
    self.pending_ends = set()
    self.args = args
    self.current = self.args.end_date
    self.duration = self.args.duration
    self.useRTH = 0

    def next_request_id(self, contract: Contract) -> int:
    self.request_id += 1
    self.requests[self.request_id] = contract
    return self.request_id

    def historicalDataRequest(self, contract: Contract) -> None:
    cid = self.next_request_id(contract)
    self.pending_ends.add(cid)

    self.reqHistoricalData(
    cid, # tickerId, used to identify incoming data
    contract,
    self.current.strftime("%Y%m%d 00:00:00"), # always go to midnight
    self.duration, # amount of time to go back
    self.args.size, # bar size
    self.args.data_type, # historical data type
    self.useRTH, # useRTH (regular trading hours)
    1, # format the date in yyyyMMdd HH:mm:ss
    False, # keep up to date after snapshot
    [], # chart options
    )

    def save_data(self, contract: Contract, bars: BarDataList) -> None:
    data = [
    [b.date, b.open, b.high, b.low, b.close, b.volume, b.barCount, b.average]
    for b in bars
    ]
    df = pd.DataFrame(
    data,
    columns=[
    "date",
    "open",
    "high",
    "low",
    "close",
    "volume",
    "barCount",
    "average",
    ],
    )
    if self.daily_files():
    path = "%s.csv" % make_download_path(self.args, contract)
    else:
    # since we fetched data until midnight, store data in
    # date file to which it belongs
    last = (self.current - timedelta(days=1)).strftime("%Y%m%d")
    path = os.path.sep.join(
    [make_download_path(self.args, contract), "%s.csv" % last,]
    )
    df.to_csv(path, index=False)

    def daily_files(self):
    return SIZES.index(self.args.size.split()[1]) >= 5

    @iswrapper
    def headTimestamp(self, reqId: int, headTimestamp: str) -> None:
    contract = self.requests.get(reqId)
    ts = datetime.strptime(headTimestamp, "%Y%m%d %H:%M:%S")
    logging.info("Head Timestamp for %s is %s", contract, ts)
    if ts > self.args.start_date or self.args.max_days:
    logging.warning("Overriding start date, setting to %s", ts)
    self.args.start_date = ts # TODO make this per contract
    if ts > self.args.end_date:
    logging.warning("Data for %s is not available before %s", contract, ts)
    self.done = True
    return
    # if we are getting daily data or longer, we'll grab the entire amount at once
    if self.daily_files():
    days = (self.args.end_date - self.args.start_date).days
    if days < 365:
    self.duration = "%d D" % days
    else:
    self.duration = "%d Y" % np.ceil(days / 365)
    # when getting daily data, look at regular trading hours only
    # to get accurate daily closing prices
    self.useRTH = 1
    # round up current time to midnight for even days
    self.current = self.current.replace(
    hour=0, minute=0, second=0, microsecond=0
    )

    self.historicalDataRequest(contract)

    @iswrapper
    def historicalData(self, reqId: int, bar) -> None:
    self.bar_data[reqId].append(bar)

    @iswrapper
    def historicalDataEnd(self, reqId: int, start: str, end: str) -> None:
    super().historicalDataEnd(reqId, start, end)
    self.pending_ends.remove(reqId)
    if len(self.pending_ends) == 0:
    print(f"All requests for {self.current} complete.")
    for rid, bars in self.bar_data.items():
    self.save_data(self.requests[rid], bars)
    self.current = datetime.strptime(start, "%Y%m%d %H:%M:%S")
    if self.current <= self.args.start_date:
    self.done = True
    else:
    for contract in self.contracts:
    self.historicalDataRequest(contract)

    @iswrapper
    def connectAck(self):
    logging.info("Connected")

    @iswrapper
    def nextValidId(self, order_id: int):
    super().nextValidId(order_id)

    self.next_valid_order_id = order_id
    logging.info(f"nextValidId: {order_id}")
    # we can start now
    self.start()

    def start(self):
    if self.started:
    return

    self.started = True
    for contract in self.contracts:
    self.reqHeadTimeStamp(
    self.next_request_id(contract), contract, self.args.data_type, 0, 1
    )

    @iswrapper
    def error(self, req_id: TickerId, error_code: int, error: str):
    super().error(req_id, error_code, error)
    print("Error. Id:", req_id, "Code:", error_code, "Msg:", error)


    def make_contract(symbol: str, sec_type: str, currency: str, exchange: str) -> Contract:
    contract = Contract()
    contract.symbol = symbol
    contract.secType = sec_type
    contract.currency = currency
    contract.exchange = exchange
    return contract


    class ValidationException(Exception):
    pass


    def _validate_in(value: str, name: str, valid: List[str]) -> None:
    if value not in valid:
    raise ValidationException(f"{value} not a valid {name} unit: {','.join(valid)}")


    def _validate(value: str, name: str, valid: List[str]) -> None:
    tokens = value.split()
    if len(tokens) != 2:
    raise ValidationException("{name} should be in the form <digit> <{name}>")
    _validate_in(tokens[1], name, valid)
    try:
    int(tokens[0])
    except ValueError as ve:
    raise ValidationException(f"{name} dimenion not a valid number: {ve}")


    SIZES = ["secs", "min", "mins", "hour", "hours", "day", "week", "month"]
    DURATIONS = ["S", "D", "W", "M", "Y"]


    def validate_duration(duration: str) -> None:
    _validate(duration, "duration", DURATIONS)


    def validate_size(size: str) -> None:
    _validate(size, "size", SIZES)


    def validate_data_type(data_type: str) -> None:
    _validate_in(
    data_type,
    "data_type",
    [
    "TRADES",
    "MIDPOINT",
    "BID",
    "ASK",
    "BID_ASK",
    "ADJUSTED_LAST",
    "HISTORICAL_VOLATILITY",
    "OPTION_IMPLIED_VOLATILITY",
    "REBATE_RATE",
    "FEE_RATE",
    "YIELD_BID",
    "YIELD_ASK",
    "YIELD_BID_ASK",
    "YIELD_LAST",
    ],
    )


    def main():

    now = datetime.now()

    class DateAction(argparse.Action):
    """Parses date strings."""

    def __call__(
    self,
    parser: argparse.ArgumentParser,
    namespace: argparse.Namespace,
    value: str,
    option_string: str = None,
    ):
    """Parse the date."""
    setattr(namespace, self.dest, parse(value))

    argp = argparse.ArgumentParser()
    argp.add_argument("symbol", nargs="+")
    argp.add_argument(
    "-d", "--debug", action="store_true", help="turn on debug logging"
    )
    argp.add_argument(
    "-p", "--port", type=int, default=7496, help="local port for TWS connection"
    )
    argp.add_argument("--size", type=str, default="1 min", help="bar size")
    argp.add_argument("--duration", type=str, default="1 D", help="bar duration")
    argp.add_argument(
    "-t", "--data-type", type=str, default="TRADES", help="bar data type"
    )
    argp.add_argument(
    "--base-directory",
    type=str,
    default="data",
    help="base directory to write bar files",
    )
    argp.add_argument(
    "--currency", type=str, default="USD", help="currency for symbols"
    )
    argp.add_argument(
    "--exchange", type=str, default="SMART", help="exchange for symbols"
    )
    argp.add_argument(
    "--security-type", type=str, default="STK", help="security type for symbols"
    )
    argp.add_argument(
    "--start-date",
    help="First day for bars",
    default=now - timedelta(days=2),
    action=DateAction,
    )
    argp.add_argument(
    "--end-date", help="Last day for bars", default=now, action=DateAction,
    )
    argp.add_argument(
    "--max-days", help="Set start date to earliest date", action="store_true",
    )
    args = argp.parse_args()

    if args.debug:
    logging.basicConfig(level=logging.DEBUG)
    else:
    logging.basicConfig(level=logging.INFO)

    try:
    validate_duration(args.duration)
    validate_size(args.size)
    args.data_type = args.data_type.upper()
    validate_data_type(args.data_type)
    except ValidationException as ve:
    print(ve)
    sys.exit(1)

    logging.debug(f"args={args}")
    contracts = []
    for s in args.symbol:
    contract = make_contract(s, args.security_type, args.currency, args.exchange)
    contracts.append(contract)
    os.makedirs(make_download_path(args, contract), exist_ok=True)
    app = DownloadApp(contracts, args)
    app.connect("127.0.0.1", args.port, clientId=0)

    app.run()


    if __name__ == "__main__":
    main()