Event driven trading system

This article is inspired by VNPY, which is available under the MIT License. And the official ByBit python SDK PyBit.

Algorithmic trading has inspired many people to implement their strategies. It frees the hands of the traders and materialize doesn’t suffer from the human emotion.

Here I will introduce an event-driven trading system: img.png

  • Websocket and Rest API As the graph above shows, we have two data sources which are websocket and rest api, we usually use websocket API for market data stream since it’s a long connection, and we use Rest API for sending order and it’s a quick one time connection, but the function from these two types of API are interchangable.

  • Shared queue We will have a queue that stores the data gained from the API, and this is a first in first out queue. We use a for loop to get the data from the queue.

  • Event handler We define different kind of event for the data, such as trade event, order event. And we will have a functional map that maps the data to their handler function.

  • Strategy The data that are processed by the handler function would be push to the strategy module where we would define our trading conditions.

Code

The code use the official ByBit python SDK PyBit. And the trading system inspried by VNPY

  • Websocket and Rest client
import threading
import time
from queue import Queue
from pybit.unified_trading import WebSocket


class WebSocketClient:
    def __init__(self, setting, queue: Queue, speed, testnet=False, demo=True):
        self.setting = setting
        self.symbols = setting['symbols']
        self.queue = queue
        self.orderbook_queue = Queue()
        self.speed = speed
        self.depth_count = setting['depth_count']

        self.ws_public = WebSocket(
            testnet=testnet,
            channel_type=setting['category'],
        )
        print("public websocket connected")
        self.ws_private = WebSocket(
            testnet=testnet,
            demo=demo,
            channel_type="private",
            api_key=setting['api_key'],
            api_secret=setting['api_secret'],
            trace_logging=False,
        )
        print("private websocket connected")
        self.started = False

    def _orderbook_handler(self, message):
        """Handle order book updates."""
        self.orderbook_queue.put(("orderbook", message))

    def _position_handler(self, message):
        """Handle position updates."""
        self.queue.put(("position", message))

    def _order_handler(self, message):
        self.queue.put(("order", message))

    def _wallet_handler(self, message):
        self.queue.put(("wallet", message))

    def _timer_event(self):
        while self.started:
            time.sleep(self.speed)
            self.queue.put(("timer", {"timer": self.speed}))

    def _trade_handler(self, message):
        self.queue.put(("trade", message))
        print(message)

    def start(self):
        """Start WebSocket connections."""
        self.started = True
        threading.Thread(target=self._timer_event, daemon=False).start()
        self._start_ws()

    def stop(self):
        self.started = False
        self.ws_public.exit()

    def _start_ws(self):
        for symbol in self.symbols:
            self.ws_public.orderbook_stream(self.depth_count, symbol, self._orderbook_handler)
            print("subscribe to symbol : {}".format(symbol))
        self.ws_private.position_stream(self._position_handler)
        self.ws_private.order_stream(self._order_handler)
        self.ws_private.wallet_stream(self._wallet_handler)
        # self.ws_public.trade_stream(self.symbols, self._trade_handler)
  • Trader object
import threading
import time
from queue import Queue
from pybit.unified_trading import WebSocket
from pybit.unified_trading import HTTP
from producer import WebSocketClient
from utility import Coin, Wallet, Position, Order


class Trader:
    def __init__(self, setting: dict, ws_private: WebSocket, queue: Queue,
                 orderbook_queue: Queue, demo=True, testnet=False):
        self.ws_private = ws_private
        self.setting = setting
        self.queue = queue
        self.orderbook_queue = orderbook_queue
        self.demo = demo
        self.testnet = testnet
        self.orderbook_thread = None
        self.symbols = setting['symbols']
        self.category = setting['category']
        self.active_orders = {}
        self.func_map = {
            "order": self.process_order,
            "orderbook": self.process_orderbook,
            "position": self.process_position,
            "wallet": self.process_wallet,
            "trade": self.process_trade,
            "timer": self.process_timer,
        }
        self.rest = HTTP(
            testnet=testnet,
            demo=demo,
            api_key=setting['api_key'],
            api_secret=setting["api_secret"]
        )
        print("trader private rest session created")

    def run(self):
        self.orderbook_thread = threading.Thread(target=self._run_orderbook).start()
        self._run()

    def _run(self):
        """Consume data from the queue and act on it."""
        while True:
            event_type, data = self.queue.get()
            func = self.func_map[event_type]
            func(data)

    def _run_orderbook(self):
        while True:
            _, data = self.orderbook_queue.get()
            self.process_orderbook(data)

    def process_orderbook(self, orderbook):
        self.on_depth(orderbook)

    def process_trade(self, trade):
        pass

    def process_timer(self, timer):
        self.on_timer(timer)

    def process_order(self, order):
        orderData = order['data'][0]
        orderData = Order(**orderData)
        self.on_order(orderData)

    def process_position(self, position):
        positionData = position['data'][0]
        positionData = Position(**positionData)
        self.on_position(positionData)

    def process_wallet(self, wallet):
        walletData = Wallet(**wallet)
        for coin in walletData.coins:
            self.on_account(coin)

    def on_account(self, account: Coin):
        pass

    def on_order(self, order: Order):
        pass

    def on_position(self, position: Position):
        pass

    def on_depth(self, depth):
        pass

    def on_timer(self, timer):
        res = self.rest.get_wallet_balance(accountType='UNIFIED')
        result = res['result']['list'][0]
        for coin in result['coin']:
            coin = Coin(**coin)
            self.on_account(coin)

    def send_order(self, direction, price=None,
                   qty=None,
                   symbol=None,
                   category=None,
                   is_maker=True,
                   orderLinkId=None,
                   orderType='Limit'):

        if orderType != 'Limit':
            is_maker = False

        if is_maker:
            timeInForce = 'PostOnly'
        else:
            timeInForce = 'GTC'

        if orderLinkId:
            self.rest.place_order(
                category=category,
                symbol=symbol,
                side=direction,
                price=price,
                qty=qty,
                timeInForce=timeInForce,
                orderLinkId=orderLinkId,
                orderType=orderType,

            )
            return orderLinkId
        else:
            result = self.rest.place_order(
                category=category,
                symbol=symbol,
                side=direction,
                price=price,
                qty=qty,
                timeInForce=timeInForce,
                orderType=orderType,

            )
            orderId = result['result']['orderId']
            return orderId

    def batch_cancel_orders(self, orderIds: list, symbol, category, isLinkId=True):
        request = []
        if isLinkId:
            for orderId in orderIds:
                request.append(
                    {
                        "symbol": symbol,
                        "orderLinkId": orderId
                    }
                )
            self.rest.cancel_batch_order(
                category=category,
                request=request
            )
        else:
            request = []
            for orderId in orderIds:
                request.append(
                    {
                        "symbol": symbol,
                        "orderId": orderId
                    }
                )
            self.rest.cancel_batch_order(
                category=category,
                request=request
            )

    def cancell_all(self):
        self.rest.cancel_all_orders(category=self.category)

    def query_account(self, ):
        print("query_account....")
        res = self.rest.get_wallet_balance(accountType='UNIFIED')
        wallet = res['result']['list'][0]
        walletData = Wallet(**wallet)
        for coin in walletData.coins:
            self.on_account(coin)


    def cancel_order(self, symbol, orderId):
        self.rest.cancel_order(orderId=orderId,
                               category=self.category,
                               symbol=symbol)
  • Utility
from dataclasses import dataclass, field
from typing import List, Dict, Any

from dataclasses import dataclass, fields
from typing import Any, Dict



@dataclass
class Coin:
    coin: str
    equity: float
    usdValue: float
    walletBalance: float
    availableToWithdraw: float
    totalOrderIM: float
    totalPositionIM: float
    totalPositionMM: float
    unrealisedPnl: float
    cumRealisedPnl: float
    bonus: float
    collateralSwitch: bool
    marginCollateral: bool
    locked: float
    _extra_attributes: Dict[str, Any] = field(default_factory=dict, init=False)

    def __init__(self, **data):
        # 获取所有定义的字段
        defined_fields = {f.name for f in fields(self.__class__)}

        # 将已定义字段的值赋予实例
        for key in defined_fields:
            setattr(self, key, data.pop(key, None))

        # 将未定义字段存入 _extra_attributes 并动态添加为属性
        self._extra_attributes = data
        for key, value in data.items():
            setattr(self, key, value)



@dataclass
class Wallet:
    accountIMRate: float
    accountMMRate: float
    totalEquity: float
    totalWalletBalance: float
    totalMarginBalance: float
    totalAvailableBalance: float
    totalPerpUPL: float
    totalInitialMargin: float
    totalMaintenanceMargin: float
    accountLTV: float
    accountType: str
    coins: List[Coin] = field(default_factory=list)
    _extra_attributes: Dict[str, Any] = field(default_factory=dict, init=False)

    def __init__(self, **data):
        # 获取所有已定义字段
        defined_fields = {f.name for f in fields(self.__class__)}
        for field_name in defined_fields:
            if field_name == "coins":
                coins_data = data.pop("coin", [])  # 注意这里是 'coin'
                self.coins = [Coin(**coin) for coin in coins_data]
            else:
                setattr(self, field_name, data.pop(field_name, None))

        # 处理未定义字段
        self._extra_attributes = data
        for key, value in data.items():
            setattr(self, key, value)




@dataclass
class Order:
    category: str
    symbol: str
    orderId: str
    orderLinkId: str
    side: str
    orderStatus: str
    price: float
    qty: float
    avgPrice: float
    leavesQty: float
    cumExecQty: float
    cumExecValue: float
    cumExecFee: float
    orderType: str
    timeInForce: str
    createdTime: str
    updatedTime: str
    reduceOnly: bool
    closeOnTrigger: bool
    _extra_attributes: Dict[str, Any] = field(default_factory=dict, init=False)

    def __init__(self, **data):
        # 获取所有已定义字段
        defined_fields = {f.name for f in fields(self.__class__)}
        for field_name in defined_fields:
            setattr(self, field_name, data.pop(field_name, None))

        # 处理未定义字段
        self._extra_attributes = data
        for key, value in data.items():
            setattr(self, key, value)



@dataclass
class Position:
    positionIdx: int
    tradeMode: int
    riskId: int
    riskLimitValue: float
    symbol: str
    side: str
    size: float
    entryPrice: float
    leverage: float
    positionValue: float
    positionBalance: float
    markPrice: float
    positionIM: float
    positionMM: float
    unrealisedPnl: float
    cumRealisedPnl: float
    curRealisedPnl: float
    createdTime: int
    updatedTime: int
    positionStatus: str
    isReduceOnly: bool
    _extra_attributes: Dict[str, Any] = field(default_factory=dict, init=False)

    def __post_init__(self):
        for key, value in self._extra_attributes.items():
            setattr(self, key, value)

    @classmethod
    def from_dict(cls, data: Dict[str, Any]):
        defined_fields = {f.name for f in cls.__dataclass_fields__.values()}
        known_data = {k: v for k, v in data.items() if k in defined_fields}
        extra_data = {k: v for k, v in data.items() if k not in defined_fields}
        instance = cls(**known_data)
        instance._extra_attributes = extra_data
        instance.__post_init__()
        return instance