From 025322219c35d132e43c0d53fe9e11b0637d5bfd Mon Sep 17 00:00:00 2001 From: Nicholai Date: Wed, 21 Jan 2026 09:32:12 -0700 Subject: [PATCH] feat: initial commit with code quality refactoring kalshi prediction market backtesting framework with: - trading pipeline (sources, filters, scorers, selectors) - position sizing with kelly criterion - multiple scoring strategies (momentum, mean reversion, etc) - random baseline for comparison refactoring includes: - extract shared resolve_closed_positions() function - reduce RandomBaseline::run() nesting with helper functions - move MarketCandidate Default impl to types.rs - add explanatory comments to complex logic --- .gitignore | 5 + Cargo.toml | 30 + README.md | 236 +++++++ data/.gitkeep | 0 data/fetch_state.json | 1 + scripts/fetch_kalshi_data.py | 254 ++++++++ scripts/train_ml_models.py | 280 +++++++++ src/backtest.rs | 445 +++++++++++++ src/data/loader.rs | 290 +++++++++ src/data/mod.rs | 3 + src/execution.rs | 325 ++++++++++ src/main.rs | 216 +++++++ src/metrics.rs | 280 +++++++++ src/pipeline/correlation_scorer.rs | 259 ++++++++ src/pipeline/filters.rs | 182 ++++++ src/pipeline/ml_scorer.rs | 272 ++++++++ src/pipeline/mod.rs | 230 +++++++ src/pipeline/scorers.rs | 978 +++++++++++++++++++++++++++++ src/pipeline/selector.rs | 64 ++ src/pipeline/sources.rs | 69 ++ src/types.rs | 343 ++++++++++ 21 files changed, 4762 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 README.md create mode 100644 data/.gitkeep create mode 100644 data/fetch_state.json create mode 100755 scripts/fetch_kalshi_data.py create mode 100644 scripts/train_ml_models.py create mode 100644 src/backtest.rs create mode 100644 src/data/loader.rs create mode 100644 src/data/mod.rs create mode 100644 src/execution.rs create mode 100644 src/main.rs create mode 100644 src/metrics.rs create mode 100644 src/pipeline/correlation_scorer.rs create mode 100644 src/pipeline/filters.rs create mode 100644 src/pipeline/ml_scorer.rs create mode 100644 src/pipeline/mod.rs create mode 100644 src/pipeline/scorers.rs create mode 100644 src/pipeline/selector.rs create mode 100644 src/pipeline/sources.rs create mode 100644 src/types.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..53a6493 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/target +/data/*.csv +/data/*.parquet +/results/*.json +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..5bd8ad8 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "kalshi-backtest" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1", features = ["full"] } +async-trait = "0.1" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +csv = "1.3" +chrono = { version = "0.4", features = ["serde"] } +clap = { version = "4", features = ["derive"] } +anyhow = "1" +thiserror = "1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +uuid = { version = "1", features = ["v4"] } +rust_decimal = { version = "1", features = ["serde"] } +rust_decimal_macros = "1" + +ort = { version = "2.0.0-rc.11", optional = true } +ndarray = { version = "0.16", optional = true } + +[features] +default = [] +ml = ["ort", "ndarray"] + +[dev-dependencies] +tempfile = "3" diff --git a/README.md b/README.md new file mode 100644 index 0000000..af80a44 --- /dev/null +++ b/README.md @@ -0,0 +1,236 @@ +kalshi-backtest +=== + +quant-level backtesting framework for kalshi prediction markets, using a candidate pipeline architecture. + + +features +--- + +- **multi-timeframe momentum** - detects divergence between short and long-term trends +- **bollinger bands mean reversion** - signals when price touches statistical extremes +- **order flow analysis** - tracks buying vs selling pressure via taker_side +- **kelly criterion position sizing** - dynamic sizing based on edge and win probability +- **exit signals** - take profit, stop loss, time stops, and score reversal triggers +- **category-aware weighting** - different strategies for politics, weather, sports, etc. +- **ensemble scoring** - combine multiple models with dynamic weighting +- **cross-market correlations** - lead-lag relationships between related markets +- **ML ensemble (optional)** - LSTM + MLP models via ONNX runtime + + +architecture +--- + +``` +Historical Data (CSV) + | + v ++------------------+ +| Backtest Loop | <- simulates time progression ++------------------+ + | + v ++------------------+ +| Candidate Pipeline | ++------------------+ + | | + v v + Sources Filters -> Scorers -> Selector + | + v ++------------------+ +| Trade Executor | <- kelly sizing, exit signals ++------------------+ + | + v ++------------------+ +| P&L Tracker | <- tracks positions, returns ++------------------+ + | + v + Performance Metrics +``` + + +data format +--- + +fetch data from kalshi API using the included script: + +```bash +python scripts/fetch_kalshi_data.py +``` + +or download from https://www.deltabase.tech/ + +**markets.csv**: +```csv +ticker,title,category,open_time,close_time,result,status,yes_bid,yes_ask,volume,open_interest +PRES-2024-DEM,Will Democrats win?,politics,2024-01-01 00:00:00,2024-11-06 00:00:00,no,finalized,45,47,10000,5000 +``` + +**trades.csv**: +```csv +timestamp,ticker,price,volume,taker_side +2024-01-05 12:00:00,PRES-2024-DEM,45,100,yes +2024-01-05 13:00:00,PRES-2024-DEM,46,50,no +``` + + +usage +--- + +```bash +# build +cargo build --release + +# run backtest with quant features +cargo run --release -- run \ + --data-dir data \ + --start 2024-01-01 \ + --end 2024-06-01 \ + --capital 10000 \ + --max-position 500 \ + --max-positions 10 \ + --kelly-fraction 0.25 \ + --max-position-pct 0.25 \ + --take-profit 0.20 \ + --stop-loss 0.15 \ + --max-hold-hours 72 \ + --compare-random + +# view results +cargo run --release -- summary --results-file results/backtest_result.json +``` + + +cli options +--- + +| option | default | description | +|--------|---------|-------------| +| --data-dir | data | directory with markets.csv and trades.csv | +| --start | required | backtest start date | +| --end | required | backtest end date | +| --capital | 10000 | initial capital | +| --max-position | 100 | max shares per position | +| --max-positions | 5 | max concurrent positions | +| --kelly-fraction | 0.25 | fraction of kelly criterion (0.1=conservative, 1.0=full) | +| --max-position-pct | 0.25 | max % of capital per position | +| --take-profit | 0.20 | take profit threshold (20% gain) | +| --stop-loss | 0.15 | stop loss threshold (15% loss) | +| --max-hold-hours | 72 | time stop in hours | +| --compare-random | false | compare vs random baseline | + + +scorers +--- + +**basic scorers**: +- `MomentumScorer` - price change over lookback period +- `MeanReversionScorer` - deviation from historical mean +- `VolumeScorer` - unusual volume detection +- `TimeDecayScorer` - prefer markets with more time to close + +**quant scorers**: +- `MultiTimeframeMomentumScorer` - analyzes 1h, 4h, 12h, 24h windows, detects divergence +- `BollingerMeanReversionScorer` - triggers at upper/lower band touches (2 std) +- `OrderFlowScorer` - buy/sell imbalance from taker_side +- `CategoryWeightedScorer` - different weights per category +- `EnsembleScorer` - combines models with dynamic weights +- `CorrelationScorer` - cross-market lead-lag signals + +**ml scorers** (requires `ml` feature): +- `MLEnsembleScorer` - LSTM + MLP via ONNX + + +position sizing +--- + +uses kelly criterion with safety multiplier: + +``` +kelly = (odds * win_prob - (1 - win_prob)) / odds +safe_kelly = kelly * kelly_fraction +position = min(bankroll * safe_kelly, max_position_pct * bankroll) +``` + + +exit signals +--- + +positions can exit via: +1. **resolution** - market resolves yes/no +2. **take profit** - pnl exceeds threshold +3. **stop loss** - pnl below threshold +4. **time stop** - held too long (capital rotation) +5. **score reversal** - strategy flips bearish + + +ml training (optional) +--- + +train ML models using pytorch, then export to ONNX: + +```bash +# install dependencies +pip install torch pandas numpy + +# train models +python scripts/train_ml_models.py \ + --data data/trades.csv \ + --markets data/markets.csv \ + --output models/ \ + --epochs 50 + +# enable ml feature +cargo build --release --features ml +``` + + +metrics +--- + +- total return ($ and %) +- sharpe ratio (annualized) +- max drawdown +- win rate +- average trade P&L +- average hold time +- trades per day +- return by category + + +extending +--- + +add custom scorers by implementing the `Scorer` trait: + +```rust +use async_trait::async_trait; + +pub struct MyScorer; + +#[async_trait] +impl Scorer for MyScorer { + fn name(&self) -> &'static str { + "MyScorer" + } + + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + // compute scores... + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + if let Some(score) = scored.scores.get("my_score") { + candidate.scores.insert("my_score".to_string(), *score); + } + } +} +``` + +then add to the pipeline in `backtest.rs`. diff --git a/data/.gitkeep b/data/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/fetch_state.json b/data/fetch_state.json new file mode 100644 index 0000000..713216a --- /dev/null +++ b/data/fetch_state.json @@ -0,0 +1 @@ +{"markets_cursor": "CgsI-rDDywYQkKOiMRI5S1hNVkVTUE9SVFNNVUxUSUdBTUVFWFRFTkRFRC1TMjAyNTBDMDMzMDBBRkYyLTkxNTVFNjFERTk3", "markets_count": 25000, "trades_cursor": null, "trades_count": 0, "markets_done": false, "trades_done": false} \ No newline at end of file diff --git a/scripts/fetch_kalshi_data.py b/scripts/fetch_kalshi_data.py new file mode 100755 index 0000000..ddde339 --- /dev/null +++ b/scripts/fetch_kalshi_data.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +""" +Fetch historical trade and market data from Kalshi's public API. +No authentication required for public endpoints. + +Features: +- Incremental saves (writes batches to disk) +- Resume capability (tracks cursor position) +- Retry logic with exponential backoff +""" + +import json +import csv +import time +import urllib.request +import urllib.error +from datetime import datetime +from pathlib import Path + +BASE_URL = "https://api.elections.kalshi.com/trade-api/v2" +STATE_FILE = "fetch_state.json" + +def fetch_json(url: str, max_retries: int = 5) -> dict: + """Fetch JSON from URL with retries and exponential backoff.""" + req = urllib.request.Request(url, headers={"Accept": "application/json"}) + + for attempt in range(max_retries): + try: + with urllib.request.urlopen(req, timeout=30) as resp: + return json.loads(resp.read().decode()) + except (urllib.error.HTTPError, urllib.error.URLError) as e: + wait = 2 ** attempt + print(f" attempt {attempt + 1}/{max_retries} failed: {e}") + if attempt < max_retries - 1: + print(f" retrying in {wait}s...") + time.sleep(wait) + else: + raise + except Exception as e: + wait = 2 ** attempt + print(f" unexpected error: {e}") + if attempt < max_retries - 1: + print(f" retrying in {wait}s...") + time.sleep(wait) + else: + raise + +def load_state(output_dir: Path) -> dict: + """Load saved state for resuming.""" + state_path = output_dir / STATE_FILE + if state_path.exists(): + with open(state_path) as f: + return json.load(f) + return {"markets_cursor": None, "markets_count": 0, + "trades_cursor": None, "trades_count": 0, + "markets_done": False, "trades_done": False} + +def save_state(output_dir: Path, state: dict): + """Save state for resuming.""" + state_path = output_dir / STATE_FILE + with open(state_path, "w") as f: + json.dump(state, f) + +def append_markets_csv(markets: list, output_path: Path, write_header: bool): + """Append markets to CSV.""" + mode = "w" if write_header else "a" + with open(output_path, mode, newline="") as f: + writer = csv.writer(f) + if write_header: + writer.writerow(["ticker", "title", "category", "open_time", + "close_time", "result", "status", "yes_bid", + "yes_ask", "volume", "open_interest"]) + + for m in markets: + result = "" + if m.get("result") == "yes": + result = "yes" + elif m.get("result") == "no": + result = "no" + elif m.get("status") == "finalized" and m.get("result"): + result = m.get("result") + + writer.writerow([ + m.get("ticker", ""), + m.get("title", ""), + m.get("category", ""), + m.get("open_time", ""), + m.get("close_time", m.get("expiration_time", "")), + result, + m.get("status", ""), + m.get("yes_bid", ""), + m.get("yes_ask", ""), + m.get("volume", ""), + m.get("open_interest", ""), + ]) + +def append_trades_csv(trades: list, output_path: Path, write_header: bool): + """Append trades to CSV.""" + mode = "w" if write_header else "a" + with open(output_path, mode, newline="") as f: + writer = csv.writer(f) + if write_header: + writer.writerow(["timestamp", "ticker", "price", "volume", "taker_side"]) + + for t in trades: + price = t.get("yes_price", t.get("price", 50)) + taker_side = t.get("taker_side", "") + if not taker_side: + taker_side = "yes" if t.get("is_taker_side_yes", True) else "no" + + writer.writerow([ + t.get("created_time", t.get("ts", "")), + t.get("ticker", t.get("market_ticker", "")), + price, + t.get("count", t.get("volume", 1)), + taker_side, + ]) + +def fetch_markets_incremental(output_dir: Path, state: dict) -> int: + """Fetch markets incrementally with state tracking.""" + output_path = output_dir / "markets.csv" + cursor = state["markets_cursor"] + total = state["markets_count"] + write_header = total == 0 + + print(f"Resuming from {total} markets...") + + while True: + url = f"{BASE_URL}/markets?limit=1000" + if cursor: + url += f"&cursor={cursor}" + + print(f"Fetching markets... ({total:,} so far)") + + try: + data = fetch_json(url) + except Exception as e: + print(f"Error fetching markets: {e}") + print(f"Progress saved. Run again to resume from {total:,} markets.") + return total + + batch = data.get("markets", []) + if batch: + append_markets_csv(batch, output_path, write_header) + write_header = False + total += len(batch) + + cursor = data.get("cursor") + state["markets_cursor"] = cursor + state["markets_count"] = total + save_state(output_dir, state) + + if not cursor: + state["markets_done"] = True + save_state(output_dir, state) + break + + time.sleep(0.3) + + return total + +def fetch_trades_incremental(output_dir: Path, state: dict, limit: int) -> int: + """Fetch trades incrementally with state tracking.""" + output_path = output_dir / "trades.csv" + cursor = state["trades_cursor"] + total = state["trades_count"] + write_header = total == 0 + + print(f"Resuming from {total} trades...") + + while total < limit: + url = f"{BASE_URL}/markets/trades?limit=1000" + if cursor: + url += f"&cursor={cursor}" + + print(f"Fetching trades... ({total:,}/{limit:,})") + + try: + data = fetch_json(url) + except Exception as e: + print(f"Error fetching trades: {e}") + print(f"Progress saved. Run again to resume from {total:,} trades.") + return total + + batch = data.get("trades", []) + if not batch: + break + + append_trades_csv(batch, output_path, write_header) + write_header = False + total += len(batch) + + cursor = data.get("cursor") + state["trades_cursor"] = cursor + state["trades_count"] = total + save_state(output_dir, state) + + if not cursor: + state["trades_done"] = True + save_state(output_dir, state) + break + + time.sleep(0.3) + + return total + +def main(): + output_dir = Path("/mnt/work/kalshi-data") + output_dir.mkdir(exist_ok=True) + + print("=" * 50) + print("Kalshi Data Fetcher (with resume)") + print("=" * 50) + + state = load_state(output_dir) + + # fetch markets + if not state["markets_done"]: + print("\n[1/2] Fetching markets...") + markets_count = fetch_markets_incremental(output_dir, state) + if state["markets_done"]: + print(f"Markets complete: {markets_count:,}") + else: + print(f"Markets paused at: {markets_count:,}") + return 1 + else: + print(f"\n[1/2] Markets already complete: {state['markets_count']:,}") + + # fetch trades + if not state["trades_done"]: + print("\n[2/2] Fetching trades...") + trades_count = fetch_trades_incremental(output_dir, state, limit=1000000) + if state["trades_done"]: + print(f"Trades complete: {trades_count:,}") + else: + print(f"Trades paused at: {trades_count:,}") + return 1 + else: + print(f"\n[2/2] Trades already complete: {state['trades_count']:,}") + + print("\n" + "=" * 50) + print("Done!") + print(f"Markets: {state['markets_count']:,}") + print(f"Trades: {state['trades_count']:,}") + print(f"Output: {output_dir}") + print("=" * 50) + + # clear state for next run + (output_dir / STATE_FILE).unlink(missing_ok=True) + + return 0 + +if __name__ == "__main__": + exit(main()) diff --git a/scripts/train_ml_models.py b/scripts/train_ml_models.py new file mode 100644 index 0000000..c83e198 --- /dev/null +++ b/scripts/train_ml_models.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +""" +Train ML models for the kalshi backtest framework. + +Models: +- LSTM: learns patterns from price history sequences +- MLP: learns optimal combination of hand-crafted features + +Usage: + python scripts/train_ml_models.py --data data/trades.csv --output models/ +""" + +import argparse +import json +import numpy as np +import pandas as pd +from pathlib import Path + +try: + import torch + import torch.nn as nn + import torch.optim as optim + from torch.utils.data import DataLoader, TensorDataset + HAS_TORCH = True +except ImportError: + HAS_TORCH = False + print("warning: pytorch not installed. run: pip install torch") + +def parse_args(): + parser = argparse.ArgumentParser(description="Train ML models for kalshi backtest") + parser.add_argument("--data", type=Path, default=Path("data/trades.csv")) + parser.add_argument("--markets", type=Path, default=Path("data/markets.csv")) + parser.add_argument("--output", type=Path, default=Path("models")) + parser.add_argument("--epochs", type=int, default=50) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--seq-len", type=int, default=24) + parser.add_argument("--train-split", type=float, default=0.8) + return parser.parse_args() + +class LSTMPredictor(nn.Module): + def __init__(self, input_size=1, hidden_size=128, num_layers=2, dropout=0.2): + super().__init__() + self.lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + dropout=dropout, + batch_first=True, + ) + self.fc = nn.Linear(hidden_size, 1) + self.tanh = nn.Tanh() + + def forward(self, x): + lstm_out, _ = self.lstm(x) + last_output = lstm_out[:, -1, :] + out = self.fc(last_output) + return self.tanh(out) + +class MLPPredictor(nn.Module): + def __init__(self, input_size=7, hidden_sizes=[64, 32]): + super().__init__() + layers = [] + prev_size = input_size + for h in hidden_sizes: + layers.append(nn.Linear(prev_size, h)) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(0.2)) + prev_size = h + layers.append(nn.Linear(prev_size, 1)) + layers.append(nn.Tanh()) + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + +def load_data(trades_path: Path, markets_path: Path, seq_len: int): + print(f"loading trades from {trades_path}...") + trades = pd.read_csv(trades_path) + trades["timestamp"] = pd.to_datetime(trades["timestamp"]) + trades = trades.sort_values(["ticker", "timestamp"]) + + print(f"loading markets from {markets_path}...") + markets = pd.read_csv(markets_path) + markets["close_time"] = pd.to_datetime(markets["close_time"]) + + result_map = dict(zip(markets["ticker"], markets["result"])) + + sequences = [] + features = [] + labels = [] + + for ticker, group in trades.groupby("ticker"): + result = result_map.get(ticker) + if result not in ["yes", "no"]: + continue + + label = 1.0 if result == "yes" else -1.0 + + prices = group["price"].values / 100.0 + volumes = group["volume"].values + taker_sides = (group["taker_side"] == "yes").astype(float).values + + if len(prices) < seq_len: + continue + + for i in range(seq_len, len(prices)): + seq = prices[i - seq_len : i] + log_returns = np.diff(np.log(np.clip(seq, 1e-6, 1.0))) + + if len(log_returns) == seq_len - 1: + log_returns = np.pad(log_returns, (1, 0), mode="constant") + + sequences.append(log_returns) + + curr_price = prices[i - 1] + momentum = prices[i - 1] - prices[i - seq_len] if len(prices) > seq_len else 0 + mean_price = np.mean(prices[i - seq_len : i]) + mean_reversion = mean_price - curr_price + vol_sum = np.sum(volumes[i - seq_len : i]) + buy_vol = np.sum(volumes[i - seq_len : i] * taker_sides[i - seq_len : i]) + sell_vol = vol_sum - buy_vol + order_flow = (buy_vol - sell_vol) / max(vol_sum, 1) + + feat = [ + momentum, + mean_reversion, + np.log1p(vol_sum), + order_flow, + curr_price, + np.std(log_returns) if len(log_returns) > 1 else 0, + len(group) / 1000.0, + ] + features.append(feat) + labels.append(label) + + print(f"created {len(sequences)} training samples") + return np.array(sequences), np.array(features), np.array(labels) + +def train_lstm(sequences, labels, args): + print("\n" + "=" * 50) + print("Training LSTM") + print("=" * 50) + + n = len(sequences) + split = int(n * args.train_split) + + X_train = torch.tensor(sequences[:split], dtype=torch.float32).unsqueeze(-1) + y_train = torch.tensor(labels[:split], dtype=torch.float32).unsqueeze(-1) + X_test = torch.tensor(sequences[split:], dtype=torch.float32).unsqueeze(-1) + y_test = torch.tensor(labels[split:], dtype=torch.float32).unsqueeze(-1) + + train_dataset = TensorDataset(X_train, y_train) + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) + + model = LSTMPredictor(input_size=1, hidden_size=128, num_layers=2) + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + + for epoch in range(args.epochs): + model.train() + total_loss = 0 + for X_batch, y_batch in train_loader: + optimizer.zero_grad() + output = model(X_batch) + loss = criterion(output, y_batch) + loss.backward() + optimizer.step() + total_loss += loss.item() + + if (epoch + 1) % 10 == 0: + model.set_mode_to_inference() + with torch.no_grad(): + train_pred = model(X_train) + test_pred = model(X_test) + train_acc = ((train_pred > 0) == (y_train > 0)).float().mean() + test_acc = ((test_pred > 0) == (y_test > 0)).float().mean() + print(f"epoch {epoch + 1}/{args.epochs}: loss={total_loss/len(train_loader):.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}") + + return model + +def train_mlp(features, labels, args): + print("\n" + "=" * 50) + print("Training MLP") + print("=" * 50) + + features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8) + + n = len(features) + split = int(n * args.train_split) + + X_train = torch.tensor(features[:split], dtype=torch.float32) + y_train = torch.tensor(labels[:split], dtype=torch.float32).unsqueeze(-1) + X_test = torch.tensor(features[split:], dtype=torch.float32) + y_test = torch.tensor(labels[split:], dtype=torch.float32).unsqueeze(-1) + + train_dataset = TensorDataset(X_train, y_train) + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) + + model = MLPPredictor(input_size=features.shape[1]) + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=0.001) + + for epoch in range(args.epochs): + model.train() + total_loss = 0 + for X_batch, y_batch in train_loader: + optimizer.zero_grad() + output = model(X_batch) + loss = criterion(output, y_batch) + loss.backward() + optimizer.step() + total_loss += loss.item() + + if (epoch + 1) % 10 == 0: + model.set_mode_to_inference() + with torch.no_grad(): + train_pred = model(X_train) + test_pred = model(X_test) + train_acc = ((train_pred > 0) == (y_train > 0)).float().mean() + test_acc = ((test_pred > 0) == (y_test > 0)).float().mean() + print(f"epoch {epoch + 1}/{args.epochs}: loss={total_loss/len(train_loader):.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}") + + return model + +def export_onnx(model, output_path: Path, input_shape, input_name="input", output_name="output"): + model.set_mode_to_inference() + dummy_input = torch.randn(*input_shape) + + torch.onnx.export( + model, + dummy_input, + output_path, + input_names=[input_name], + output_names=[output_name], + dynamic_axes={ + input_name: {0: "batch_size"}, + output_name: {0: "batch_size"}, + }, + opset_version=14, + ) + print(f"exported to {output_path}") + +def main(): + args = parse_args() + + if not HAS_TORCH: + print("error: pytorch required for training. install with: pip install torch") + return 1 + + if not args.data.exists(): + print(f"error: data file not found: {args.data}") + return 1 + + if not args.markets.exists(): + print(f"error: markets file not found: {args.markets}") + return 1 + + args.output.mkdir(parents=True, exist_ok=True) + + sequences, features, labels = load_data(args.data, args.markets, args.seq_len) + + if len(sequences) < 100: + print(f"error: not enough training data ({len(sequences)} samples)") + return 1 + + lstm_model = train_lstm(sequences, labels, args) + export_onnx(lstm_model, args.output / "lstm.onnx", (1, args.seq_len, 1)) + + mlp_model = train_mlp(features, labels, args) + export_onnx(mlp_model, args.output / "mlp.onnx", (1, features.shape[1])) + + print("\n" + "=" * 50) + print("Training complete!") + print(f"Models saved to: {args.output}") + print("=" * 50) + + return 0 + +if __name__ == "__main__": + exit(main()) diff --git a/src/backtest.rs b/src/backtest.rs new file mode 100644 index 0000000..800f520 --- /dev/null +++ b/src/backtest.rs @@ -0,0 +1,445 @@ +use crate::data::HistoricalData; +use crate::execution::{Executor, PositionSizingConfig}; +use crate::metrics::{BacktestResult, MetricsCollector}; +use crate::pipeline::{ + AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer, Filter, + HistoricalMarketSource, LiquidityFilter, MeanReversionScorer, MomentumScorer, + MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer, Selector, Source, TimeDecayScorer, + TimeToCloseFilter, TopKSelector, TradingPipeline, VolumeScorer, +}; +use crate::types::{ + BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType, + TradingContext, +}; +use chrono::{DateTime, Utc}; +use rust_decimal::Decimal; +use rust_decimal::prelude::ToPrimitive; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use tracing::info; + +/// resolves any positions in markets that have closed +/// returns list of (ticker, result, pnl) for logging purposes +fn resolve_closed_positions( + portfolio: &mut Portfolio, + data: &HistoricalData, + resolved: &mut HashSet, + at: DateTime, + history: &mut Vec, + metrics: &mut MetricsCollector, +) -> Vec<(String, MarketResult, Option)> { + let tickers: Vec = portfolio.positions.keys().cloned().collect(); + let mut resolutions = Vec::new(); + + for ticker in tickers { + if resolved.contains(&ticker) { + continue; + } + + let Some(result) = data.get_resolution_at(&ticker, at) else { + continue; + }; + + resolved.insert(ticker.clone()); + let Some(pos) = portfolio.positions.get(&ticker).cloned() else { + continue; + }; + + let pnl = portfolio.resolve_position(&ticker, result); + + let exit_price = match result { + MarketResult::Yes => match pos.side { + Side::Yes => Decimal::ONE, + Side::No => Decimal::ZERO, + }, + MarketResult::No => match pos.side { + Side::Yes => Decimal::ZERO, + Side::No => Decimal::ONE, + }, + MarketResult::Cancelled => pos.avg_entry_price, + }; + + let category = data + .markets + .get(&ticker) + .map(|m| m.category.clone()) + .unwrap_or_default(); + + let trade = Trade { + ticker: ticker.clone(), + side: pos.side, + quantity: pos.quantity, + price: exit_price, + timestamp: at, + trade_type: TradeType::Resolution, + }; + + history.push(trade.clone()); + metrics.record_trade(&trade, &category); + resolutions.push((ticker, result, pnl)); + } + + resolutions +} + +pub struct Backtester { + config: BacktestConfig, + data: Arc, + pipeline: TradingPipeline, + executor: Executor, +} + +impl Backtester { + pub fn new(config: BacktestConfig, data: Arc) -> Self { + let pipeline = Self::build_default_pipeline(data.clone(), &config); + let executor = Executor::new(data.clone(), 10, config.max_position_size); + + Self { + config, + data, + pipeline, + executor, + } + } + + pub fn with_configs( + config: BacktestConfig, + data: Arc, + sizing_config: PositionSizingConfig, + exit_config: ExitConfig, + ) -> Self { + let pipeline = Self::build_default_pipeline(data.clone(), &config); + let executor = Executor::new(data.clone(), 10, config.max_position_size) + .with_sizing_config(sizing_config) + .with_exit_config(exit_config); + + Self { + config, + data, + pipeline, + executor, + } + } + + pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self { + self.pipeline = pipeline; + self + } + + fn build_default_pipeline(data: Arc, config: &BacktestConfig) -> TradingPipeline { + let sources: Vec> = vec![ + Box::new(HistoricalMarketSource::new(data, 24)), + ]; + + let filters: Vec> = vec![ + Box::new(LiquidityFilter::new(100)), + Box::new(TimeToCloseFilter::new(2, Some(720))), + Box::new(AlreadyPositionedFilter::new(config.max_position_size)), + ]; + + let scorers: Vec> = vec![ + Box::new(MomentumScorer::new(6)), + Box::new(MultiTimeframeMomentumScorer::default_windows()), + Box::new(MeanReversionScorer::new(24)), + Box::new(BollingerMeanReversionScorer::default_config()), + Box::new(VolumeScorer::new(6)), + Box::new(OrderFlowScorer::new()), + Box::new(TimeDecayScorer::new()), + Box::new(CategoryWeightedScorer::with_defaults()), + ]; + + let selector: Box = Box::new(TopKSelector::new(config.max_positions)); + + TradingPipeline::new(sources, filters, scorers, selector, config.max_positions) + } + + pub async fn run(&self) -> BacktestResult { + let mut context = TradingContext::new(self.config.initial_capital, self.config.start_time); + let mut metrics = MetricsCollector::new(self.config.initial_capital); + let mut resolved_markets: HashSet = HashSet::new(); + + let mut current_time = self.config.start_time; + + info!( + start = %self.config.start_time, + end = %self.config.end_time, + interval_hours = self.config.interval.num_hours(), + "starting backtest" + ); + + while current_time < self.config.end_time { + context.timestamp = current_time; + context.request_id = uuid::Uuid::new_v4().to_string(); + + let resolutions = resolve_closed_positions( + &mut context.portfolio, + &self.data, + &mut resolved_markets, + current_time, + &mut context.trading_history, + &mut metrics, + ); + for (ticker, result, pnl) in resolutions { + info!(ticker = %ticker, result = ?result, pnl = ?pnl, "market resolved"); + } + + let result = self.pipeline.execute(context.clone()).await; + + let candidate_scores: std::collections::HashMap = result + .selected_candidates + .iter() + .map(|c| (c.ticker.clone(), c.final_score)) + .collect(); + + let exit_signals = self.executor.generate_exit_signals(&context, &candidate_scores); + for exit in exit_signals { + if let Some(position) = context.portfolio.positions.get(&exit.ticker).cloned() { + let pnl = context.portfolio.close_position(&exit.ticker, exit.current_price); + + info!( + ticker = %exit.ticker, + reason = ?exit.reason, + pnl = ?pnl, + "exit triggered" + ); + + let category = self + .data + .markets + .get(&exit.ticker) + .map(|m| m.category.clone()) + .unwrap_or_default(); + + let exit_price = match position.side { + crate::types::Side::Yes => exit.current_price, + crate::types::Side::No => Decimal::ONE - exit.current_price, + }; + + let trade = Trade { + ticker: exit.ticker.clone(), + side: position.side, + quantity: position.quantity, + price: exit_price, + timestamp: current_time, + trade_type: TradeType::Close, + }; + + context.trading_history.push(trade.clone()); + metrics.record_trade(&trade, &category); + } + } + + let signals = self.executor.generate_signals(&result.selected_candidates, &context); + + for signal in signals { + if let Some(fill) = self.executor.execute_signal(&signal, &context) { + info!( + ticker = %fill.ticker, + side = ?fill.side, + quantity = fill.quantity, + price = %fill.price, + "executed trade" + ); + + context.portfolio.apply_fill(&fill); + + let category = self + .data + .markets + .get(&fill.ticker) + .map(|m| m.category.clone()) + .unwrap_or_default(); + + let trade = Trade { + ticker: fill.ticker.clone(), + side: fill.side, + quantity: fill.quantity, + price: fill.price, + timestamp: fill.timestamp, + trade_type: TradeType::Open, + }; + + context.trading_history.push(trade.clone()); + metrics.record_trade(&trade, &category); + } + } + + let market_prices = self.get_current_prices(current_time); + metrics.record(current_time, &context.portfolio, &market_prices); + + current_time = current_time + self.config.interval; + } + + let resolutions = resolve_closed_positions( + &mut context.portfolio, + &self.data, + &mut resolved_markets, + self.config.end_time, + &mut context.trading_history, + &mut metrics, + ); + for (ticker, result, pnl) in resolutions { + info!(ticker = %ticker, result = ?result, pnl = ?pnl, "market resolved"); + } + + info!( + trades = context.trading_history.len(), + positions = context.portfolio.positions.len(), + cash = %context.portfolio.cash, + "backtest complete" + ); + + metrics.finalize() + } + + fn get_current_prices(&self, at: DateTime) -> HashMap { + self.data + .markets + .keys() + .filter_map(|ticker| { + self.data + .get_current_price(ticker, at) + .map(|p| (ticker.clone(), p)) + }) + .collect() + } +} + +pub struct RandomBaseline { + config: BacktestConfig, + data: Arc, +} + +impl RandomBaseline { + pub fn new(config: BacktestConfig, data: Arc) -> Self { + Self { config, data } + } + + pub async fn run(&self) -> BacktestResult { + let mut context = TradingContext::new(self.config.initial_capital, self.config.start_time); + let mut metrics = MetricsCollector::new(self.config.initial_capital); + let mut resolved_markets: HashSet = HashSet::new(); + let mut rng_state: u64 = 42; + + let mut current_time = self.config.start_time; + + while current_time < self.config.end_time { + context.timestamp = current_time; + + resolve_closed_positions( + &mut context.portfolio, + &self.data, + &mut resolved_markets, + current_time, + &mut context.trading_history, + &mut metrics, + ); + + if let Some(fill) = self.try_random_trade(&context, current_time, &mut rng_state) { + let category = self + .data + .markets + .get(&fill.ticker) + .map(|m| m.category.clone()) + .unwrap_or_default(); + + context.portfolio.apply_fill(&fill); + + let trade = Trade { + ticker: fill.ticker.clone(), + side: fill.side, + quantity: fill.quantity, + price: fill.price, + timestamp: current_time, + trade_type: TradeType::Open, + }; + + context.trading_history.push(trade.clone()); + metrics.record_trade(&trade, &category); + } + + let market_prices = self.get_current_prices(current_time); + metrics.record(current_time, &context.portfolio, &market_prices); + + current_time = current_time + self.config.interval; + } + + resolve_closed_positions( + &mut context.portfolio, + &self.data, + &mut resolved_markets, + self.config.end_time, + &mut context.trading_history, + &mut metrics, + ); + + metrics.finalize() + } + + fn try_random_trade( + &self, + context: &TradingContext, + at: DateTime, + rng_state: &mut u64, + ) -> Option { + if context.portfolio.positions.len() >= self.config.max_positions { + return None; + } + + let active_markets = self.data.get_active_markets(at); + let unpositioned: Vec<_> = active_markets + .iter() + .filter(|m| !context.portfolio.has_position(&m.ticker)) + .collect(); + + if unpositioned.is_empty() { + return None; + } + + *rng_state = lcg_next(*rng_state); + let idx = (*rng_state as usize) % unpositioned.len(); + let market = unpositioned[idx]; + + let price = self.data.get_current_price(&market.ticker, at)?; + let side = if *rng_state % 2 == 0 { Side::Yes } else { Side::No }; + + let effective_price = match side { + Side::Yes => price, + Side::No => Decimal::ONE - price, + }; + + let quantity = self + .config + .max_position_size + .min((context.portfolio.cash / effective_price).to_u64().unwrap_or(0)); + + if quantity == 0 { + return None; + } + + Some(Fill { + ticker: market.ticker.clone(), + side, + quantity, + price: effective_price, + timestamp: at, + }) + } + + fn get_current_prices(&self, at: DateTime) -> HashMap { + self.data + .markets + .keys() + .filter_map(|ticker| { + self.data + .get_current_price(ticker, at) + .map(|p| (ticker.clone(), p)) + }) + .collect() + } +} + +/// linear congruential generator for deterministic random baseline +fn lcg_next(state: u64) -> u64 { + state.wrapping_mul(1103515245).wrapping_add(12345) +} diff --git a/src/data/loader.rs b/src/data/loader.rs new file mode 100644 index 0000000..4b3114a --- /dev/null +++ b/src/data/loader.rs @@ -0,0 +1,290 @@ +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use csv::ReaderBuilder; +use rust_decimal::Decimal; +use serde::Deserialize; +use std::collections::HashMap; +use std::path::Path; + +use crate::types::{MarketData, MarketResult, PricePoint, Side, TradeData}; + +#[derive(Debug, Deserialize)] +struct CsvMarket { + ticker: String, + title: String, + category: String, + #[serde(with = "flexible_datetime")] + open_time: DateTime, + #[serde(with = "flexible_datetime")] + close_time: DateTime, + result: Option, +} + +#[derive(Debug, Deserialize)] +struct CsvTrade { + #[serde(with = "flexible_datetime")] + timestamp: DateTime, + ticker: String, + price: f64, + volume: u64, + taker_side: String, +} + +mod flexible_datetime { + use chrono::{DateTime, NaiveDateTime, Utc}; + use serde::{self, Deserialize, Deserializer}; + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + + if let Ok(dt) = DateTime::parse_from_rfc3339(&s) { + return Ok(dt.with_timezone(&Utc)); + } + + if let Ok(dt) = NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S") { + return Ok(dt.and_utc()); + } + + if let Ok(ts) = s.parse::() { + return DateTime::from_timestamp(ts, 0) + .ok_or_else(|| serde::de::Error::custom("invalid timestamp")); + } + + Err(serde::de::Error::custom(format!( + "could not parse datetime: {}", + s + ))) + } +} + +pub struct HistoricalData { + pub markets: HashMap, + pub trades: Vec, + trade_index: HashMap>, +} + +impl HistoricalData { + pub fn load(data_dir: &Path) -> Result { + let markets_path = data_dir.join("markets.csv"); + let trades_path = data_dir.join("trades.csv"); + + let markets = load_markets(&markets_path) + .with_context(|| format!("loading markets from {:?}", markets_path))?; + + let trades = + load_trades(&trades_path).with_context(|| format!("loading trades from {:?}", trades_path))?; + + let mut trade_index: HashMap> = HashMap::new(); + for (i, trade) in trades.iter().enumerate() { + trade_index.entry(trade.ticker.clone()).or_default().push(i); + } + + Ok(Self { + markets, + trades, + trade_index, + }) + } + + pub fn get_active_markets(&self, at: DateTime) -> Vec<&MarketData> { + self.markets + .values() + .filter(|m| at >= m.open_time && at < m.close_time) + .collect() + } + + pub fn get_trades_for_market(&self, ticker: &str, from: DateTime, to: DateTime) -> Vec<&TradeData> { + self.trade_index + .get(ticker) + .map(|indices| { + indices + .iter() + .filter_map(|&i| { + let trade = &self.trades[i]; + if trade.timestamp >= from && trade.timestamp < to { + Some(trade) + } else { + None + } + }) + .collect() + }) + .unwrap_or_default() + } + + pub fn get_current_price(&self, ticker: &str, at: DateTime) -> Option { + self.trade_index.get(ticker).and_then(|indices| { + indices + .iter() + .filter_map(|&i| { + let trade = &self.trades[i]; + if trade.timestamp <= at { + Some(trade) + } else { + None + } + }) + .last() + .map(|t| t.price) + }) + } + + pub fn get_price_history( + &self, + ticker: &str, + from: DateTime, + to: DateTime, + ) -> Vec { + self.get_trades_for_market(ticker, from, to) + .into_iter() + .map(|t| PricePoint { + timestamp: t.timestamp, + yes_price: t.price, + volume: t.volume, + }) + .collect() + } + + pub fn get_volume_24h(&self, ticker: &str, at: DateTime) -> u64 { + let from = at - chrono::Duration::hours(24); + self.get_trades_for_market(ticker, from, at) + .iter() + .map(|t| t.volume) + .sum() + } + + pub fn get_order_flow_24h(&self, ticker: &str, at: DateTime) -> (u64, u64) { + let from = at - chrono::Duration::hours(24); + let trades = self.get_trades_for_market(ticker, from, at); + let buy_vol: u64 = trades.iter().filter(|t| t.taker_side == Side::Yes).map(|t| t.volume).sum(); + let sell_vol: u64 = trades.iter().filter(|t| t.taker_side == Side::No).map(|t| t.volume).sum(); + (buy_vol, sell_vol) + } + + pub fn get_resolutions(&self, at: DateTime) -> Vec<(&MarketData, MarketResult)> { + self.markets + .values() + .filter_map(|m| { + if m.close_time <= at { + m.result.map(|r| (m, r)) + } else { + None + } + }) + .collect() + } + + pub fn get_resolution_at(&self, ticker: &str, at: DateTime) -> Option { + self.markets.get(ticker).and_then(|m| { + if m.close_time <= at { + m.result + } else { + None + } + }) + } +} + +fn load_markets(path: &Path) -> Result> { + let mut reader = ReaderBuilder::new() + .has_headers(true) + .flexible(true) + .from_path(path)?; + + let mut markets = HashMap::new(); + + for result in reader.deserialize() { + let record: CsvMarket = result?; + let result = record.result.as_ref().and_then(|r| match r.to_lowercase().as_str() { + "yes" => Some(MarketResult::Yes), + "no" => Some(MarketResult::No), + "cancelled" | "canceled" => Some(MarketResult::Cancelled), + "" => None, + _ => None, + }); + + markets.insert( + record.ticker.clone(), + MarketData { + ticker: record.ticker, + title: record.title, + category: record.category, + open_time: record.open_time, + close_time: record.close_time, + result, + }, + ); + } + + Ok(markets) +} + +fn load_trades(path: &Path) -> Result> { + let mut reader = ReaderBuilder::new() + .has_headers(true) + .flexible(true) + .from_path(path)?; + + let mut trades = Vec::new(); + + for result in reader.deserialize() { + let record: CsvTrade = result?; + let side = match record.taker_side.to_lowercase().as_str() { + "yes" | "buy" => Side::Yes, + "no" | "sell" => Side::No, + _ => continue, + }; + + trades.push(TradeData { + timestamp: record.timestamp, + ticker: record.ticker, + price: Decimal::try_from(record.price / 100.0).unwrap_or(Decimal::ZERO), + volume: record.volume, + taker_side: side, + }); + } + + trades.sort_by_key(|t| t.timestamp); + + Ok(trades) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::TempDir; + + fn create_test_data() -> TempDir { + let dir = TempDir::new().unwrap(); + + let markets_csv = r#"ticker,title,category,open_time,close_time,result +TEST-MKT-1,Test Market 1,politics,2024-01-01 00:00:00,2024-01-15 00:00:00,yes +TEST-MKT-2,Test Market 2,economics,2024-01-01 00:00:00,2024-01-20 00:00:00,no +"#; + let mut f = std::fs::File::create(dir.path().join("markets.csv")).unwrap(); + f.write_all(markets_csv.as_bytes()).unwrap(); + + let trades_csv = r#"timestamp,ticker,price,volume,taker_side +2024-01-05 12:00:00,TEST-MKT-1,55,100,yes +2024-01-05 13:00:00,TEST-MKT-1,57,50,yes +2024-01-06 10:00:00,TEST-MKT-2,45,200,no +"#; + let mut f = std::fs::File::create(dir.path().join("trades.csv")).unwrap(); + f.write_all(trades_csv.as_bytes()).unwrap(); + + dir + } + + #[test] + fn test_load_historical_data() { + let dir = create_test_data(); + let data = HistoricalData::load(dir.path()).unwrap(); + + assert_eq!(data.markets.len(), 2); + assert_eq!(data.trades.len(), 3); + } +} diff --git a/src/data/mod.rs b/src/data/mod.rs new file mode 100644 index 0000000..57a9d73 --- /dev/null +++ b/src/data/mod.rs @@ -0,0 +1,3 @@ +mod loader; + +pub use loader::HistoricalData; diff --git a/src/execution.rs b/src/execution.rs new file mode 100644 index 0000000..75d4e30 --- /dev/null +++ b/src/execution.rs @@ -0,0 +1,325 @@ +use crate::data::HistoricalData; +use crate::types::{ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side, Signal, TradingContext}; +use rust_decimal::Decimal; +use rust_decimal::prelude::ToPrimitive; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct PositionSizingConfig { + pub kelly_fraction: f64, + pub max_position_pct: f64, + pub min_position_size: u64, + pub max_position_size: u64, +} + +impl Default for PositionSizingConfig { + fn default() -> Self { + Self { + kelly_fraction: 0.25, + max_position_pct: 0.25, + min_position_size: 10, + max_position_size: 1000, + } + } +} + +impl PositionSizingConfig { + pub fn conservative() -> Self { + Self { + kelly_fraction: 0.1, + max_position_pct: 0.1, + min_position_size: 10, + max_position_size: 500, + } + } + + pub fn aggressive() -> Self { + Self { + kelly_fraction: 0.5, + max_position_pct: 0.4, + min_position_size: 10, + max_position_size: 2000, + } + } +} + +/// maps scoring edge [-inf, +inf] to win probability [0, 1] +/// tanh squashes extreme values smoothly; +1)/2 shifts from [-1,1] to [0,1] +fn edge_to_win_probability(edge: f64) -> f64 { + (1.0 + edge.tanh()) / 2.0 +} + +fn kelly_size( + edge: f64, + price: f64, + bankroll: f64, + config: &PositionSizingConfig, +) -> u64 { + if edge.abs() < 0.01 || price <= 0.0 || price >= 1.0 { + return 0; + } + + let win_prob = edge_to_win_probability(edge); + let odds = (1.0 - price) / price; + + if odds <= 0.0 { + return 0; + } + + let kelly = (odds * win_prob - (1.0 - win_prob)) / odds; + let safe_kelly = (kelly * config.kelly_fraction).max(0.0); + let position_value = bankroll * safe_kelly.min(config.max_position_pct); + let shares = (position_value / price).floor() as u64; + + shares.max(config.min_position_size).min(config.max_position_size) +} + +pub struct Executor { + data: Arc, + slippage_bps: u32, + max_position_size: u64, + sizing_config: PositionSizingConfig, + exit_config: ExitConfig, +} + +impl Executor { + pub fn new(data: Arc, slippage_bps: u32, max_position_size: u64) -> Self { + Self { + data, + slippage_bps, + max_position_size, + sizing_config: PositionSizingConfig::default(), + exit_config: ExitConfig::default(), + } + } + + pub fn with_sizing_config(mut self, config: PositionSizingConfig) -> Self { + self.sizing_config = config; + self + } + + pub fn with_exit_config(mut self, config: ExitConfig) -> Self { + self.exit_config = config; + self + } + + pub fn generate_exit_signals( + &self, + context: &TradingContext, + candidate_scores: &std::collections::HashMap, + ) -> Vec { + let mut exits = Vec::new(); + + for (ticker, position) in &context.portfolio.positions { + let current_price = match self.data.get_current_price(ticker, context.timestamp) { + Some(p) => p, + None => continue, + }; + + let effective_price = match position.side { + Side::Yes => current_price, + Side::No => Decimal::ONE - current_price, + }; + + let entry_price_f64 = position.avg_entry_price.to_f64().unwrap_or(0.5); + let current_price_f64 = effective_price.to_f64().unwrap_or(0.5); + + if entry_price_f64 <= 0.0 { + continue; + } + + let pnl_pct = (current_price_f64 - entry_price_f64) / entry_price_f64; + + if pnl_pct >= self.exit_config.take_profit_pct { + exits.push(ExitSignal { + ticker: ticker.clone(), + reason: ExitReason::TakeProfit { pnl_pct }, + current_price, + }); + continue; + } + + if pnl_pct <= -self.exit_config.stop_loss_pct { + exits.push(ExitSignal { + ticker: ticker.clone(), + reason: ExitReason::StopLoss { pnl_pct }, + current_price, + }); + continue; + } + + let hours_held = (context.timestamp - position.entry_time).num_hours(); + if hours_held >= self.exit_config.max_hold_hours { + exits.push(ExitSignal { + ticker: ticker.clone(), + reason: ExitReason::TimeStop { hours_held }, + current_price, + }); + continue; + } + + if let Some(&new_score) = candidate_scores.get(ticker) { + if new_score < self.exit_config.score_reversal_threshold { + exits.push(ExitSignal { + ticker: ticker.clone(), + reason: ExitReason::ScoreReversal { new_score }, + current_price, + }); + } + } + } + + exits + } + + pub fn generate_signals( + &self, + candidates: &[MarketCandidate], + context: &TradingContext, + ) -> Vec { + candidates + .iter() + .filter_map(|c| self.candidate_to_signal(c, context)) + .collect() + } + + fn candidate_to_signal( + &self, + candidate: &MarketCandidate, + context: &TradingContext, + ) -> Option { + let current_position = context.portfolio.get_position(&candidate.ticker); + let current_qty = current_position.map(|p| p.quantity).unwrap_or(0); + + if current_qty >= self.max_position_size { + return None; + } + + let yes_price = candidate.current_yes_price.to_f64().unwrap_or(0.5); + + // positive score = bullish signal, so buy the cheaper side (better risk/reward) + // negative score = bearish signal, so buy against the expensive side + let side = if candidate.final_score > 0.0 { + if yes_price < 0.5 { Side::Yes } else { Side::No } + } else if candidate.final_score < 0.0 { + if yes_price > 0.5 { Side::No } else { Side::Yes } + } else { + return None; + }; + + let price = match side { + Side::Yes => candidate.current_yes_price, + Side::No => candidate.current_no_price, + }; + + let available_cash = context.portfolio.cash.to_f64().unwrap_or(0.0); + let price_f64 = price.to_f64().unwrap_or(0.5); + + if price_f64 <= 0.0 { + return None; + } + + let kelly_qty = kelly_size( + candidate.final_score, + price_f64, + available_cash, + &self.sizing_config, + ); + + let max_affordable = (available_cash / price_f64) as u64; + let quantity = kelly_qty + .min(max_affordable) + .min(self.max_position_size - current_qty); + + if quantity < self.sizing_config.min_position_size { + return None; + } + + Some(Signal { + ticker: candidate.ticker.clone(), + side, + quantity, + limit_price: Some(price), + reason: format!( + "score={:.3}, side={:?}, price={:.2}", + candidate.final_score, side, price_f64 + ), + }) + } + + pub fn execute_signal( + &self, + signal: &Signal, + context: &TradingContext, + ) -> Option { + let market_price = self.data.get_current_price(&signal.ticker, context.timestamp)?; + + let effective_price = match signal.side { + Side::Yes => market_price, + Side::No => Decimal::ONE - market_price, + }; + + let slippage = Decimal::from(self.slippage_bps) / Decimal::from(10000); + let fill_price = effective_price * (Decimal::ONE + slippage); + + if let Some(limit) = signal.limit_price { + if fill_price > limit * (Decimal::ONE + slippage * Decimal::from(2)) { + return None; + } + } + + let cost = fill_price * Decimal::from(signal.quantity); + if cost > context.portfolio.cash { + let affordable = (context.portfolio.cash / fill_price) + .to_u64() + .unwrap_or(0); + if affordable == 0 { + return None; + } + + return Some(Fill { + ticker: signal.ticker.clone(), + side: signal.side, + quantity: affordable, + price: fill_price, + timestamp: context.timestamp, + }); + } + + Some(Fill { + ticker: signal.ticker.clone(), + side: signal.side, + quantity: signal.quantity, + price: fill_price, + timestamp: context.timestamp, + }) + } +} + +pub fn simple_signal_generator( + candidates: &[MarketCandidate], + context: &TradingContext, + position_size: u64, +) -> Vec { + candidates + .iter() + .filter(|c| c.final_score > 0.0) + .filter(|c| !context.portfolio.has_position(&c.ticker)) + .map(|c| { + let yes_price = c.current_yes_price.to_f64().unwrap_or(0.5); + let (side, price) = if yes_price < 0.5 { + (Side::Yes, c.current_yes_price) + } else { + (Side::No, c.current_no_price) + }; + + Signal { + ticker: c.ticker.clone(), + side, + quantity: position_size, + limit_price: Some(price), + reason: format!("simple: score={:.3}", c.final_score), + } + }) + .collect() +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..f35108d --- /dev/null +++ b/src/main.rs @@ -0,0 +1,216 @@ +mod backtest; +mod data; +mod execution; +mod metrics; +mod pipeline; +mod types; + +use anyhow::{Context, Result}; +use backtest::{Backtester, RandomBaseline}; +use chrono::{DateTime, NaiveDate, TimeZone, Utc}; +use clap::{Parser, Subcommand}; +use data::HistoricalData; +use execution::{Executor, PositionSizingConfig}; +use rust_decimal::Decimal; +use std::path::PathBuf; +use std::sync::Arc; +use tracing::info; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use types::{BacktestConfig, ExitConfig}; + +#[derive(Parser)] +#[command(name = "kalshi-backtest")] +#[command(about = "backtesting framework for kalshi prediction markets")] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + Run { + #[arg(short, long, default_value = "data")] + data_dir: PathBuf, + + #[arg(long)] + start: String, + + #[arg(long)] + end: String, + + #[arg(long, default_value = "10000")] + capital: f64, + + #[arg(long, default_value = "100")] + max_position: u64, + + #[arg(long, default_value = "5")] + max_positions: usize, + + #[arg(long, default_value = "1")] + interval_hours: i64, + + #[arg(long, default_value = "results")] + output_dir: PathBuf, + + #[arg(long)] + compare_random: bool, + + #[arg(long, default_value = "0.25")] + kelly_fraction: f64, + + #[arg(long, default_value = "0.25")] + max_position_pct: f64, + + #[arg(long, default_value = "0.20")] + take_profit: f64, + + #[arg(long, default_value = "0.15")] + stop_loss: f64, + + #[arg(long, default_value = "72")] + max_hold_hours: i64, + }, + + Summary { + #[arg(short, long)] + results_file: PathBuf, + }, +} + +fn parse_date(s: &str) -> Result> { + if let Ok(dt) = DateTime::parse_from_rfc3339(s) { + return Ok(dt.with_timezone(&Utc)); + } + + if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") { + return Ok(Utc.from_utc_datetime(&date.and_hms_opt(0, 0, 0).unwrap())); + } + + Err(anyhow::anyhow!("could not parse date: {}", s)) +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "kalshi_backtest=info".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let cli = Cli::parse(); + + match cli.command { + Commands::Run { + data_dir, + start, + end, + capital, + max_position, + max_positions, + interval_hours, + output_dir, + compare_random, + kelly_fraction, + max_position_pct, + take_profit, + stop_loss, + max_hold_hours, + } => { + let start_time = parse_date(&start).context("parsing start date")?; + let end_time = parse_date(&end).context("parsing end date")?; + + info!( + data_dir = %data_dir.display(), + start = %start_time, + end = %end_time, + capital = capital, + "loading data" + ); + + let data = Arc::new( + HistoricalData::load(&data_dir).context("loading historical data")?, + ); + + info!( + markets = data.markets.len(), + trades = data.trades.len(), + "data loaded" + ); + + let config = BacktestConfig { + start_time, + end_time, + interval: chrono::Duration::hours(interval_hours), + initial_capital: Decimal::try_from(capital).unwrap(), + max_position_size: max_position, + max_positions, + }; + + let sizing_config = PositionSizingConfig { + kelly_fraction, + max_position_pct, + min_position_size: 10, + max_position_size: max_position, + }; + + let exit_config = ExitConfig { + take_profit_pct: take_profit, + stop_loss_pct: stop_loss, + max_hold_hours, + score_reversal_threshold: -0.3, + }; + + let backtester = Backtester::with_configs(config.clone(), data.clone(), sizing_config, exit_config); + let result = backtester.run().await; + + println!("{}", result.summary()); + + std::fs::create_dir_all(&output_dir)?; + let result_path = output_dir.join("backtest_result.json"); + let json = serde_json::to_string_pretty(&result)?; + std::fs::write(&result_path, json)?; + info!(path = %result_path.display(), "results saved"); + + if compare_random { + println!("\n--- random baseline ---\n"); + let baseline = RandomBaseline::new(config, data); + let baseline_result = baseline.run().await; + println!("{}", baseline_result.summary()); + + let baseline_path = output_dir.join("baseline_result.json"); + let json = serde_json::to_string_pretty(&baseline_result)?; + std::fs::write(&baseline_path, json)?; + + println!("\n--- comparison ---\n"); + println!( + "strategy return: {:.2}% vs baseline: {:.2}%", + result.total_return_pct, baseline_result.total_return_pct + ); + println!( + "strategy sharpe: {:.3} vs baseline: {:.3}", + result.sharpe_ratio, baseline_result.sharpe_ratio + ); + println!( + "strategy win rate: {:.1}% vs baseline: {:.1}%", + result.win_rate, baseline_result.win_rate + ); + } + + Ok(()) + } + + Commands::Summary { results_file } => { + let content = std::fs::read_to_string(&results_file) + .context("reading results file")?; + let result: metrics::BacktestResult = + serde_json::from_str(&content).context("parsing results")?; + + println!("{}", result.summary()); + + Ok(()) + } + } +} diff --git a/src/metrics.rs b/src/metrics.rs new file mode 100644 index 0000000..60b6c9a --- /dev/null +++ b/src/metrics.rs @@ -0,0 +1,280 @@ +use crate::types::{Portfolio, Trade, TradeType}; +use chrono::{DateTime, Utc}; +use rust_decimal::Decimal; +use rust_decimal::prelude::ToPrimitive; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BacktestResult { + pub total_return: f64, + pub total_return_pct: f64, + pub sharpe_ratio: f64, + pub max_drawdown: f64, + pub max_drawdown_pct: f64, + pub win_rate: f64, + pub total_trades: usize, + pub winning_trades: usize, + pub losing_trades: usize, + pub avg_trade_pnl: f64, + pub avg_hold_time_hours: f64, + pub trades_per_day: f64, + pub return_by_category: HashMap, + pub equity_curve: Vec, + pub trade_log: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EquityPoint { + pub timestamp: DateTime, + pub equity: f64, + pub cash: f64, + pub positions_value: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TradeRecord { + pub ticker: String, + pub entry_time: DateTime, + pub exit_time: Option>, + pub side: String, + pub quantity: u64, + pub entry_price: f64, + pub exit_price: Option, + pub pnl: Option, + pub category: String, +} + +pub struct MetricsCollector { + initial_capital: Decimal, + equity_curve: Vec, + trade_records: HashMap, + closed_trades: Vec, + daily_returns: Vec, + last_equity: f64, + peak_equity: f64, + max_drawdown: f64, +} + +impl MetricsCollector { + pub fn new(initial_capital: Decimal) -> Self { + let capital = initial_capital.to_f64().unwrap_or(10000.0); + Self { + initial_capital, + equity_curve: Vec::new(), + trade_records: HashMap::new(), + closed_trades: Vec::new(), + daily_returns: Vec::new(), + last_equity: capital, + peak_equity: capital, + max_drawdown: 0.0, + } + } + + pub fn record( + &mut self, + timestamp: DateTime, + portfolio: &Portfolio, + market_prices: &HashMap, + ) { + let positions_value = portfolio + .positions + .values() + .map(|p| { + let price = market_prices + .get(&p.ticker) + .copied() + .unwrap_or(p.avg_entry_price); + (price * Decimal::from(p.quantity)).to_f64().unwrap_or(0.0) + }) + .sum(); + + let cash = portfolio.cash.to_f64().unwrap_or(0.0); + let equity = cash + positions_value; + + if equity > self.peak_equity { + self.peak_equity = equity; + } + + let drawdown = (self.peak_equity - equity) / self.peak_equity; + if drawdown > self.max_drawdown { + self.max_drawdown = drawdown; + } + + if self.last_equity > 0.0 { + let daily_return = (equity - self.last_equity) / self.last_equity; + self.daily_returns.push(daily_return); + } + self.last_equity = equity; + + self.equity_curve.push(EquityPoint { + timestamp, + equity, + cash, + positions_value, + }); + } + + pub fn record_trade(&mut self, trade: &Trade, category: &str) { + match trade.trade_type { + TradeType::Open => { + let record = TradeRecord { + ticker: trade.ticker.clone(), + entry_time: trade.timestamp, + exit_time: None, + side: format!("{:?}", trade.side), + quantity: trade.quantity, + entry_price: trade.price.to_f64().unwrap_or(0.0), + exit_price: None, + pnl: None, + category: category.to_string(), + }; + self.trade_records.insert(trade.ticker.clone(), record); + } + TradeType::Close | TradeType::Resolution => { + if let Some(mut record) = self.trade_records.remove(&trade.ticker) { + let exit_price = trade.price.to_f64().unwrap_or(0.0); + let entry_cost = record.entry_price * record.quantity as f64; + let exit_value = exit_price * record.quantity as f64; + let pnl = exit_value - entry_cost; + + record.exit_time = Some(trade.timestamp); + record.exit_price = Some(exit_price); + record.pnl = Some(pnl); + + self.closed_trades.push(record); + } + } + } + } + + pub fn finalize(self) -> BacktestResult { + let initial = self.initial_capital.to_f64().unwrap_or(10000.0); + let final_equity = self.equity_curve.last().map(|e| e.equity).unwrap_or(initial); + let total_return = final_equity - initial; + let total_return_pct = total_return / initial * 100.0; + + let sharpe_ratio = if self.daily_returns.len() > 1 { + let mean: f64 = self.daily_returns.iter().sum::() / self.daily_returns.len() as f64; + let variance: f64 = self + .daily_returns + .iter() + .map(|r| (r - mean).powi(2)) + .sum::() + / (self.daily_returns.len() - 1) as f64; + let std_dev = variance.sqrt(); + if std_dev > 0.0 { + (mean / std_dev) * (252.0_f64).sqrt() + } else { + 0.0 + } + } else { + 0.0 + }; + + let winning_trades = self.closed_trades.iter().filter(|t| t.pnl.unwrap_or(0.0) > 0.0).count(); + let losing_trades = self.closed_trades.iter().filter(|t| t.pnl.unwrap_or(0.0) < 0.0).count(); + let total_trades = self.closed_trades.len(); + + let win_rate = if total_trades > 0 { + winning_trades as f64 / total_trades as f64 * 100.0 + } else { + 0.0 + }; + + let avg_trade_pnl = if total_trades > 0 { + self.closed_trades.iter().filter_map(|t| t.pnl).sum::() / total_trades as f64 + } else { + 0.0 + }; + + let avg_hold_time_hours = if total_trades > 0 { + self.closed_trades + .iter() + .filter_map(|t| { + t.exit_time.map(|exit| (exit - t.entry_time).num_hours() as f64) + }) + .sum::() + / total_trades as f64 + } else { + 0.0 + }; + + let duration_days = if self.equity_curve.len() >= 2 { + let start = self.equity_curve.first().unwrap().timestamp; + let end = self.equity_curve.last().unwrap().timestamp; + (end - start).num_days().max(1) as f64 + } else { + 1.0 + }; + + let trades_per_day = total_trades as f64 / duration_days; + + let mut return_by_category: HashMap = HashMap::new(); + for trade in &self.closed_trades { + *return_by_category.entry(trade.category.clone()).or_insert(0.0) += + trade.pnl.unwrap_or(0.0); + } + + BacktestResult { + total_return, + total_return_pct, + sharpe_ratio, + max_drawdown: self.max_drawdown * 100.0, + max_drawdown_pct: self.max_drawdown * 100.0, + win_rate, + total_trades, + winning_trades, + losing_trades, + avg_trade_pnl, + avg_hold_time_hours, + trades_per_day, + return_by_category, + equity_curve: self.equity_curve, + trade_log: self.closed_trades, + } + } +} + +impl BacktestResult { + pub fn summary(&self) -> String { + format!( + r#" +backtest results +================ + +performance +----------- +total return: ${:.2} ({:.2}%) +sharpe ratio: {:.3} +max drawdown: {:.2}% + +trades +------ +total trades: {} +win rate: {:.1}% +avg trade pnl: ${:.2} +avg hold time: {:.1} hours +trades per day: {:.2} + +by category +----------- +{} +"#, + self.total_return, + self.total_return_pct, + self.sharpe_ratio, + self.max_drawdown_pct, + self.total_trades, + self.win_rate, + self.avg_trade_pnl, + self.avg_hold_time_hours, + self.trades_per_day, + self.return_by_category + .iter() + .map(|(k, v)| format!(" {}: ${:.2}", k, v)) + .collect::>() + .join("\n") + ) + } +} diff --git a/src/pipeline/correlation_scorer.rs b/src/pipeline/correlation_scorer.rs new file mode 100644 index 0000000..4d9aea9 --- /dev/null +++ b/src/pipeline/correlation_scorer.rs @@ -0,0 +1,259 @@ +//! Cross-market correlation scorer +//! +//! Uses lead-lag relationships between related markets to generate signals. +//! When a related market moves, we expect similar movements in correlated markets. + +use crate::pipeline::Scorer; +use crate::types::{MarketCandidate, TradingContext}; +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +/// correlation entry between two markets +#[derive(Debug, Clone)] +pub struct CorrelationEntry { + pub ticker_a: String, + pub ticker_b: String, + pub correlation: f64, + pub lag_hours: i64, +} + +/// cross-market correlation scorer +/// uses precomputed correlations to generate signals based on related market movements +pub struct CorrelationScorer { + correlations: Arc>>>, + lookback_hours: i64, + min_correlation: f64, +} + +impl CorrelationScorer { + pub fn new(lookback_hours: i64, min_correlation: f64) -> Self { + Self { + correlations: Arc::new(RwLock::new(HashMap::new())), + lookback_hours, + min_correlation, + } + } + + pub fn default_config() -> Self { + Self::new(24, 0.5) + } + + pub fn load_correlations(&self, correlations: Vec) { + let mut map = self.correlations.write().unwrap(); + map.clear(); + + for entry in correlations { + map.entry(entry.ticker_a.clone()) + .or_insert_with(Vec::new) + .push(entry.clone()); + map.entry(entry.ticker_b.clone()) + .or_insert_with(Vec::new) + .push(CorrelationEntry { + ticker_a: entry.ticker_b.clone(), + ticker_b: entry.ticker_a.clone(), + correlation: entry.correlation, + lag_hours: -entry.lag_hours, + }); + } + } + + pub fn add_category_correlations(&self, categories: &[&str]) { + let mut entries = Vec::new(); + + for (i, cat_a) in categories.iter().enumerate() { + for cat_b in categories.iter().skip(i + 1) { + entries.push(CorrelationEntry { + ticker_a: cat_a.to_string(), + ticker_b: cat_b.to_string(), + correlation: 0.3, + lag_hours: 0, + }); + } + } + + self.load_correlations(entries); + } + + fn get_related_signals( + &self, + ticker: &str, + all_candidates: &[MarketCandidate], + ) -> f64 { + let correlations = self.correlations.read().unwrap(); + let Some(related) = correlations.get(ticker) else { + return 0.0; + }; + + let candidate_map: HashMap<&str, &MarketCandidate> = all_candidates + .iter() + .map(|c| (c.ticker.as_str(), c)) + .collect(); + + let mut weighted_signal = 0.0; + let mut total_weight = 0.0; + + for entry in related { + if entry.correlation.abs() < self.min_correlation { + continue; + } + + if let Some(related_candidate) = candidate_map.get(entry.ticker_b.as_str()) { + let related_momentum = related_candidate + .scores + .get("momentum") + .copied() + .unwrap_or(0.0); + + let signal = related_momentum * entry.correlation; + weighted_signal += signal * entry.correlation.abs(); + total_weight += entry.correlation.abs(); + } + } + + if total_weight > 0.0 { + weighted_signal / total_weight + } else { + 0.0 + } + } + + fn calculate_category_correlation( + candidate: &MarketCandidate, + all_candidates: &[MarketCandidate], + ) -> f64 { + let same_category: Vec<&MarketCandidate> = all_candidates + .iter() + .filter(|c| c.category == candidate.category && c.ticker != candidate.ticker) + .collect(); + + if same_category.is_empty() { + return 0.0; + } + + let avg_momentum: f64 = same_category + .iter() + .filter_map(|c| c.scores.get("momentum").copied()) + .sum::() + / same_category.len() as f64; + + avg_momentum * 0.3 + } +} + +#[async_trait] +impl Scorer for CorrelationScorer { + fn name(&self) -> &'static str { + "CorrelationScorer" + } + + async fn score( + &self, + _context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let related_signal = self.get_related_signals(&c.ticker, candidates); + let category_signal = Self::calculate_category_correlation(c, candidates); + let combined = related_signal * 0.7 + category_signal * 0.3; + + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("cross_market".to_string(), combined); + scored.scores.insert("related_signal".to_string(), related_signal); + scored.scores.insert("category_signal".to_string(), category_signal); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + for key in ["cross_market", "related_signal", "category_signal"] { + if let Some(score) = scored.scores.get(key) { + candidate.scores.insert(key.to_string(), *score); + } + } + } +} + +/// granger causality calculator (simplified version) +/// full implementation would use statistical tests +pub fn calculate_granger_causality( + prices_a: &[f64], + prices_b: &[f64], + max_lag: i64, +) -> Option<(f64, i64)> { + if prices_a.len() < 20 || prices_b.len() < 20 { + return None; + } + + let min_len = prices_a.len().min(prices_b.len()); + let prices_a = &prices_a[..min_len]; + let prices_b = &prices_b[..min_len]; + + let returns_a: Vec = prices_a + .windows(2) + .map(|w| if w[0] > 0.0 { (w[1] / w[0]).ln() } else { 0.0 }) + .collect(); + let returns_b: Vec = prices_b + .windows(2) + .map(|w| if w[0] > 0.0 { (w[1] / w[0]).ln() } else { 0.0 }) + .collect(); + + let mut best_corr: f64 = 0.0; + let mut best_lag: i64 = 0; + + for lag in -max_lag..=max_lag { + let (a_slice, b_slice) = if lag >= 0 { + let l = lag as usize; + if l >= returns_a.len() || l >= returns_b.len() { + continue; + } + (&returns_a[l..], &returns_b[..returns_b.len() - l]) + } else { + let l = (-lag) as usize; + if l >= returns_a.len() || l >= returns_b.len() { + continue; + } + (&returns_a[..returns_a.len() - l], &returns_b[l..]) + }; + + if a_slice.len() < 10 || b_slice.len() < 10 || a_slice.len() != b_slice.len() { + continue; + } + + let n = a_slice.len() as f64; + let mean_a: f64 = a_slice.iter().sum::() / n; + let mean_b: f64 = b_slice.iter().sum::() / n; + + let cov: f64 = a_slice + .iter() + .zip(b_slice.iter()) + .map(|(a, b)| (a - mean_a) * (b - mean_b)) + .sum::() + / n; + + let std_a: f64 = (a_slice.iter().map(|a| (a - mean_a).powi(2)).sum::() / n).sqrt(); + let std_b: f64 = (b_slice.iter().map(|b| (b - mean_b).powi(2)).sum::() / n).sqrt(); + + if std_a > 0.0 && std_b > 0.0 { + let corr = cov / (std_a * std_b); + if corr.abs() > best_corr.abs() { + best_corr = corr; + best_lag = lag; + } + } + } + + if best_corr.abs() > 0.1 { + Some((best_corr, best_lag)) + } else { + None + } +} diff --git a/src/pipeline/filters.rs b/src/pipeline/filters.rs new file mode 100644 index 0000000..4781173 --- /dev/null +++ b/src/pipeline/filters.rs @@ -0,0 +1,182 @@ +use crate::pipeline::{Filter, FilterResult}; +use crate::types::{MarketCandidate, TradingContext}; +use async_trait::async_trait; +use chrono::Duration; +use std::collections::HashSet; + +pub struct LiquidityFilter { + min_volume_24h: u64, +} + +impl LiquidityFilter { + pub fn new(min_volume_24h: u64) -> Self { + Self { min_volume_24h } + } +} + +#[async_trait] +impl Filter for LiquidityFilter { + fn name(&self) -> &'static str { + "LiquidityFilter" + } + + async fn filter( + &self, + _context: &TradingContext, + candidates: Vec, + ) -> Result { + let (kept, removed): (Vec<_>, Vec<_>) = candidates + .into_iter() + .partition(|c| c.volume_24h >= self.min_volume_24h); + + Ok(FilterResult { kept, removed }) + } +} + +pub struct TimeToCloseFilter { + min_hours: i64, + max_hours: Option, +} + +impl TimeToCloseFilter { + pub fn new(min_hours: i64, max_hours: Option) -> Self { + Self { min_hours, max_hours } + } +} + +#[async_trait] +impl Filter for TimeToCloseFilter { + fn name(&self) -> &'static str { + "TimeToCloseFilter" + } + + async fn filter( + &self, + context: &TradingContext, + candidates: Vec, + ) -> Result { + let min_duration = Duration::hours(self.min_hours); + let max_duration = self.max_hours.map(Duration::hours); + + let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| { + let ttc = c.time_to_close(context.timestamp); + let above_min = ttc >= min_duration; + let below_max = max_duration.map(|max| ttc <= max).unwrap_or(true); + above_min && below_max + }); + + Ok(FilterResult { kept, removed }) + } +} + +pub struct AlreadyPositionedFilter { + max_position_per_market: u64, +} + +impl AlreadyPositionedFilter { + pub fn new(max_position_per_market: u64) -> Self { + Self { + max_position_per_market, + } + } +} + +#[async_trait] +impl Filter for AlreadyPositionedFilter { + fn name(&self) -> &'static str { + "AlreadyPositionedFilter" + } + + async fn filter( + &self, + context: &TradingContext, + candidates: Vec, + ) -> Result { + let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| { + context + .portfolio + .get_position(&c.ticker) + .map(|p| p.quantity < self.max_position_per_market) + .unwrap_or(true) + }); + + Ok(FilterResult { kept, removed }) + } +} + +pub struct CategoryFilter { + whitelist: Option>, + blacklist: HashSet, +} + +impl CategoryFilter { + pub fn whitelist(categories: Vec) -> Self { + Self { + whitelist: Some(categories.into_iter().collect()), + blacklist: HashSet::new(), + } + } + + pub fn blacklist(categories: Vec) -> Self { + Self { + whitelist: None, + blacklist: categories.into_iter().collect(), + } + } +} + +#[async_trait] +impl Filter for CategoryFilter { + fn name(&self) -> &'static str { + "CategoryFilter" + } + + async fn filter( + &self, + _context: &TradingContext, + candidates: Vec, + ) -> Result { + let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| { + let in_whitelist = self + .whitelist + .as_ref() + .map(|w| w.contains(&c.category)) + .unwrap_or(true); + let not_blacklisted = !self.blacklist.contains(&c.category); + in_whitelist && not_blacklisted + }); + + Ok(FilterResult { kept, removed }) + } +} + +pub struct PriceRangeFilter { + min_price: f64, + max_price: f64, +} + +impl PriceRangeFilter { + pub fn new(min_price: f64, max_price: f64) -> Self { + Self { min_price, max_price } + } +} + +#[async_trait] +impl Filter for PriceRangeFilter { + fn name(&self) -> &'static str { + "PriceRangeFilter" + } + + async fn filter( + &self, + _context: &TradingContext, + candidates: Vec, + ) -> Result { + let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| { + let price = c.current_yes_price.to_string().parse::().unwrap_or(0.5); + price >= self.min_price && price <= self.max_price + }); + + Ok(FilterResult { kept, removed }) + } +} diff --git a/src/pipeline/ml_scorer.rs b/src/pipeline/ml_scorer.rs new file mode 100644 index 0000000..3b86410 --- /dev/null +++ b/src/pipeline/ml_scorer.rs @@ -0,0 +1,272 @@ +//! ML-based scorer using ONNX models +//! +//! This module provides ML-based scoring using pre-trained ONNX models. +//! Models are trained separately using the Python scripts in scripts/ +//! +//! Enable with the `ml` feature: +//! ```toml +//! kalshi-backtest = { features = ["ml"] } +//! ``` + +use crate::pipeline::Scorer; +use crate::types::{MarketCandidate, TradingContext}; +use async_trait::async_trait; +use std::path::Path; + +#[cfg(feature = "ml")] +use { + ndarray::{Array1, Array2}, + ort::{session::Session, value::Value}, + std::sync::Arc, +}; + +/// ML ensemble scorer that combines multiple ONNX models +/// +/// Models: +/// - LSTM: sequence model for price history +/// - MLP: feedforward on hand-crafted features +#[cfg(feature = "ml")] +pub struct MLEnsembleScorer { + lstm_session: Option>, + mlp_session: Option>, + ensemble_weights: Vec, +} + +#[cfg(feature = "ml")] +impl MLEnsembleScorer { + pub fn new() -> Self { + Self { + lstm_session: None, + mlp_session: None, + ensemble_weights: vec![0.5, 0.5], + } + } + + pub fn load_models(model_dir: &Path) -> Result { + let lstm_path = model_dir.join("lstm.onnx"); + let mlp_path = model_dir.join("mlp.onnx"); + + let lstm_session = if lstm_path.exists() { + Some(Arc::new( + Session::builder() + .map_err(|e| format!("failed to create session builder: {}", e))? + .commit_from_file(&lstm_path) + .map_err(|e| format!("failed to load LSTM model: {}", e))?, + )) + } else { + None + }; + + let mlp_session = if mlp_path.exists() { + Some(Arc::new( + Session::builder() + .map_err(|e| format!("failed to create session builder: {}", e))? + .commit_from_file(&mlp_path) + .map_err(|e| format!("failed to load MLP model: {}", e))?, + )) + } else { + None + }; + + Ok(Self { + lstm_session, + mlp_session, + ensemble_weights: vec![0.5, 0.5], + }) + } + + pub fn with_weights(mut self, weights: Vec) -> Self { + self.ensemble_weights = weights; + self + } + + fn extract_features(candidate: &MarketCandidate) -> Vec { + vec![ + candidate.scores.get("momentum").copied().unwrap_or(0.0), + candidate.scores.get("mean_reversion").copied().unwrap_or(0.0), + candidate.scores.get("volume").copied().unwrap_or(0.0), + candidate.scores.get("time_decay").copied().unwrap_or(0.0), + candidate.scores.get("order_flow").copied().unwrap_or(0.0), + candidate.scores.get("bollinger_reversion").copied().unwrap_or(0.0), + candidate.scores.get("mtf_momentum").copied().unwrap_or(0.0), + ] + } + + fn extract_price_sequence(candidate: &MarketCandidate, max_len: usize) -> Vec { + use rust_decimal::prelude::ToPrimitive; + + let prices: Vec = candidate + .price_history + .iter() + .rev() + .take(max_len) + .filter_map(|p| p.yes_price.to_f64()) + .collect(); + + let mut sequence = vec![0.0; max_len]; + for (i, &price) in prices.iter().enumerate() { + if i < max_len { + sequence[max_len - 1 - i] = price; + } + } + + if prices.len() >= 2 { + let mut log_returns = Vec::with_capacity(max_len); + for i in 1..sequence.len() { + if sequence[i - 1] > 0.0 && sequence[i] > 0.0 { + log_returns.push((sequence[i] / sequence[i - 1]).ln()); + } else { + log_returns.push(0.0); + } + } + log_returns.insert(0, 0.0); + log_returns + } else { + sequence + } + } + + fn predict_lstm(&self, sequence: &[f64]) -> f64 { + let Some(session) = &self.lstm_session else { + return 0.0; + }; + + let input = Array2::from_shape_vec((1, sequence.len()), sequence.to_vec()) + .expect("invalid shape"); + + match session.run(ort::inputs!["input" => input.view()]) { + Ok(outputs) => { + if let Some(output) = outputs.get("output") { + if let Ok(tensor) = output.try_extract_tensor::() { + return tensor.view().iter().next().copied().unwrap_or(0.0) as f64; + } + } + 0.0 + } + Err(_) => 0.0, + } + } + + fn predict_mlp(&self, features: &[f64]) -> f64 { + let Some(session) = &self.mlp_session else { + return 0.0; + }; + + let input = Array1::from_vec(features.to_vec()); + let input_2d = input.insert_axis(ndarray::Axis(0)); + + match session.run(ort::inputs!["input" => input_2d.view()]) { + Ok(outputs) => { + if let Some(output) = outputs.get("output") { + if let Ok(tensor) = output.try_extract_tensor::() { + return tensor.view().iter().next().copied().unwrap_or(0.0) as f64; + } + } + 0.0 + } + Err(_) => 0.0, + } + } + + fn predict(&self, candidate: &MarketCandidate) -> f64 { + let sequence = Self::extract_price_sequence(candidate, 24); + let features = Self::extract_features(candidate); + + let lstm_pred = self.predict_lstm(&sequence); + let mlp_pred = self.predict_mlp(&features); + + let mut ensemble = 0.0; + let mut total_weight = 0.0; + + if self.lstm_session.is_some() && self.ensemble_weights.len() > 0 { + ensemble += lstm_pred * self.ensemble_weights[0]; + total_weight += self.ensemble_weights[0]; + } + + if self.mlp_session.is_some() && self.ensemble_weights.len() > 1 { + ensemble += mlp_pred * self.ensemble_weights[1]; + total_weight += self.ensemble_weights[1]; + } + + if total_weight > 0.0 { + ensemble / total_weight + } else { + 0.0 + } + } +} + +#[cfg(feature = "ml")] +#[async_trait] +impl Scorer for MLEnsembleScorer { + fn name(&self) -> &'static str { + "MLEnsembleScorer" + } + + async fn score( + &self, + _context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let ml_score = self.predict(c); + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("ml_ensemble".to_string(), ml_score); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + if let Some(score) = scored.scores.get("ml_ensemble") { + candidate.scores.insert("ml_ensemble".to_string(), *score); + } + } +} + +/// stub scorer when ML feature is disabled +#[cfg(not(feature = "ml"))] +pub struct MLEnsembleScorer; + +#[cfg(not(feature = "ml"))] +impl MLEnsembleScorer { + pub fn new() -> Self { + Self + } + + pub fn load_models(_model_dir: &Path) -> Result { + Ok(Self) + } + + pub fn with_weights(self, _weights: Vec) -> Self { + self + } +} + +#[cfg(not(feature = "ml"))] +#[async_trait] +impl Scorer for MLEnsembleScorer { + fn name(&self) -> &'static str { + "MLEnsembleScorer" + } + + async fn score( + &self, + _context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + Ok(candidates.iter().map(|c| MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }).collect()) + } + + fn update(&self, _candidate: &mut MarketCandidate, _scored: MarketCandidate) {} +} diff --git a/src/pipeline/mod.rs b/src/pipeline/mod.rs new file mode 100644 index 0000000..1e5cbc3 --- /dev/null +++ b/src/pipeline/mod.rs @@ -0,0 +1,230 @@ +mod correlation_scorer; +mod filters; +mod ml_scorer; +mod scorers; +mod selector; +mod sources; + +pub use correlation_scorer::*; +pub use filters::*; +pub use ml_scorer::*; +pub use scorers::*; +pub use selector::*; +pub use sources::*; + +use crate::types::{MarketCandidate, TradingContext}; +use async_trait::async_trait; +use std::sync::Arc; +use tracing::{error, info}; + +pub struct PipelineResult { + pub retrieved_candidates: Vec, + pub filtered_candidates: Vec, + pub selected_candidates: Vec, + pub context: Arc, +} + +pub struct FilterResult { + pub kept: Vec, + pub removed: Vec, +} + +#[async_trait] +pub trait Source: Send + Sync { + fn name(&self) -> &'static str; + fn enable(&self, _context: &TradingContext) -> bool { + true + } + async fn get_candidates(&self, context: &TradingContext) -> Result, String>; +} + +#[async_trait] +pub trait Filter: Send + Sync { + fn name(&self) -> &'static str; + fn enable(&self, _context: &TradingContext) -> bool { + true + } + async fn filter( + &self, + context: &TradingContext, + candidates: Vec, + ) -> Result; +} + +#[async_trait] +pub trait Scorer: Send + Sync { + fn name(&self) -> &'static str; + fn enable(&self, _context: &TradingContext) -> bool { + true + } + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String>; + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate); + + fn update_all(&self, candidates: &mut [MarketCandidate], scored: Vec) { + for (c, s) in candidates.iter_mut().zip(scored) { + self.update(c, s); + } + } +} + +pub trait Selector: Send + Sync { + fn name(&self) -> &'static str; + fn enable(&self, _context: &TradingContext) -> bool { + true + } + fn select(&self, context: &TradingContext, candidates: Vec) -> Vec; +} + +pub struct TradingPipeline { + sources: Vec>, + filters: Vec>, + scorers: Vec>, + selector: Box, + result_size: usize, +} + +impl TradingPipeline { + pub fn new( + sources: Vec>, + filters: Vec>, + scorers: Vec>, + selector: Box, + result_size: usize, + ) -> Self { + Self { + sources, + filters, + scorers, + selector, + result_size, + } + } + + pub async fn execute(&self, context: TradingContext) -> PipelineResult { + let request_id = context.request_id().to_string(); + + let candidates = self.fetch_candidates(&context).await; + info!( + request_id = %request_id, + candidates = candidates.len(), + "fetched candidates" + ); + + let (kept, filtered) = self.filter(&context, candidates.clone()).await; + info!( + request_id = %request_id, + kept = kept.len(), + filtered = filtered.len(), + "filtered candidates" + ); + + let scored = self.score(&context, kept).await; + + let mut selected = self.select(&context, scored); + selected.truncate(self.result_size); + + info!( + request_id = %request_id, + selected = selected.len(), + "selected candidates" + ); + + PipelineResult { + retrieved_candidates: candidates, + filtered_candidates: filtered, + selected_candidates: selected, + context: Arc::new(context), + } + } + + async fn fetch_candidates(&self, context: &TradingContext) -> Vec { + let mut all_candidates = Vec::new(); + + for source in self.sources.iter().filter(|s| s.enable(context)) { + match source.get_candidates(context).await { + Ok(mut candidates) => { + info!( + source = source.name(), + count = candidates.len(), + "source returned candidates" + ); + all_candidates.append(&mut candidates); + } + Err(e) => { + error!(source = source.name(), error = %e, "source failed"); + } + } + } + + all_candidates + } + + async fn filter( + &self, + context: &TradingContext, + mut candidates: Vec, + ) -> (Vec, Vec) { + let mut all_removed = Vec::new(); + + for filter in self.filters.iter().filter(|f| f.enable(context)) { + let backup = candidates.clone(); + match filter.filter(context, candidates).await { + Ok(result) => { + info!( + filter = filter.name(), + kept = result.kept.len(), + removed = result.removed.len(), + "filter applied" + ); + candidates = result.kept; + all_removed.extend(result.removed); + } + Err(e) => { + error!(filter = filter.name(), error = %e, "filter failed"); + candidates = backup; + } + } + } + + (candidates, all_removed) + } + + async fn score(&self, context: &TradingContext, mut candidates: Vec) -> Vec { + let expected_len = candidates.len(); + + for scorer in self.scorers.iter().filter(|s| s.enable(context)) { + match scorer.score(context, &candidates).await { + Ok(scored) => { + if scored.len() == expected_len { + scorer.update_all(&mut candidates, scored); + } else { + error!( + scorer = scorer.name(), + expected = expected_len, + got = scored.len(), + "scorer returned wrong number of candidates" + ); + } + } + Err(e) => { + error!(scorer = scorer.name(), error = %e, "scorer failed"); + } + } + } + + candidates + } + + fn select(&self, context: &TradingContext, candidates: Vec) -> Vec { + if self.selector.enable(context) { + self.selector.select(context, candidates) + } else { + candidates + } + } +} diff --git a/src/pipeline/scorers.rs b/src/pipeline/scorers.rs new file mode 100644 index 0000000..97df496 --- /dev/null +++ b/src/pipeline/scorers.rs @@ -0,0 +1,978 @@ +use crate::pipeline::Scorer; +use crate::types::{MarketCandidate, TradingContext}; +use async_trait::async_trait; +use rust_decimal::prelude::ToPrimitive; +use std::sync::{Arc, Mutex}; + +/// rolling statistics for z-score normalization +/// tracks mean and std deviation over a sliding window +#[derive(Debug, Clone)] +pub struct RollingStats { + values: Vec, + max_size: usize, + sum: f64, + sum_sq: f64, +} + +impl RollingStats { + pub fn new(max_size: usize) -> Self { + Self { + values: Vec::with_capacity(max_size), + max_size, + sum: 0.0, + sum_sq: 0.0, + } + } + + pub fn push(&mut self, value: f64) { + if !value.is_finite() { + return; + } + + if self.values.len() >= self.max_size { + let old = self.values.remove(0); + self.sum -= old; + self.sum_sq -= old * old; + } + + self.values.push(value); + self.sum += value; + self.sum_sq += value * value; + } + + pub fn push_batch(&mut self, values: &[f64]) { + for &v in values { + self.push(v); + } + } + + pub fn mean(&self) -> f64 { + if self.values.is_empty() { + 0.0 + } else { + self.sum / self.values.len() as f64 + } + } + + pub fn std(&self) -> f64 { + let n = self.values.len(); + if n < 2 { + return 1.0; + } + + let mean = self.mean(); + let variance = (self.sum_sq / n as f64) - (mean * mean); + variance.max(0.0).sqrt() + } + + pub fn normalize(&self, value: f64) -> f64 { + let std = self.std().max(0.001); + (value - self.mean()) / std + } + + pub fn len(&self) -> usize { + self.values.len() + } + + pub fn is_ready(&self) -> bool { + self.values.len() >= self.max_size / 4 + } +} + +/// wrapper that normalizes any scorer's output to z-scores +pub struct NormalizedScorer { + inner: S, + score_key: String, + stats: Arc>, +} + +impl NormalizedScorer { + pub fn new(inner: S, score_key: &str, history_size: usize) -> Self { + Self { + inner, + score_key: score_key.to_string(), + stats: Arc::new(Mutex::new(RollingStats::new(history_size))), + } + } +} + +#[async_trait] +impl Scorer for NormalizedScorer { + fn name(&self) -> &'static str { + self.inner.name() + } + + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let raw_scored = self.inner.score(context, candidates).await?; + + let raw_scores: Vec = raw_scored + .iter() + .filter_map(|c| c.scores.get(&self.score_key).copied()) + .collect(); + + { + let mut stats = self.stats.lock().unwrap(); + stats.push_batch(&raw_scores); + } + + let stats = self.stats.lock().unwrap(); + let normalized = raw_scored + .into_iter() + .map(|mut c| { + if let Some(&raw) = c.scores.get(&self.score_key) { + let z = if stats.is_ready() { + stats.normalize(raw) + } else { + raw + }; + c.scores.insert(self.score_key.clone(), z); + } + c + }) + .collect(); + + Ok(normalized) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + self.inner.update(candidate, scored); + } +} + +pub struct MomentumScorer { + lookback_hours: i64, +} + +impl MomentumScorer { + pub fn new(lookback_hours: i64) -> Self { + Self { lookback_hours } + } + + fn calculate_momentum(candidate: &MarketCandidate, now: chrono::DateTime, lookback_hours: i64) -> f64 { + let lookback_start = now - chrono::Duration::hours(lookback_hours); + let relevant_history: Vec<_> = candidate + .price_history + .iter() + .filter(|p| p.timestamp >= lookback_start) + .collect(); + + if relevant_history.len() < 2 { + return 0.0; + } + + let first = relevant_history.first().unwrap().yes_price.to_f64().unwrap_or(0.5); + let last = relevant_history.last().unwrap().yes_price.to_f64().unwrap_or(0.5); + + last - first + } +} + +#[async_trait] +impl Scorer for MomentumScorer { + fn name(&self) -> &'static str { + "MomentumScorer" + } + + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let momentum = Self::calculate_momentum(c, context.timestamp, self.lookback_hours); + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("momentum".to_string(), momentum); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + if let Some(score) = scored.scores.get("momentum") { + candidate.scores.insert("momentum".to_string(), *score); + } + } +} + +/// multi-timeframe momentum scorer +/// looks at multiple windows and detects divergence between short and long term +pub struct MultiTimeframeMomentumScorer { + windows: Vec, +} + +impl MultiTimeframeMomentumScorer { + pub fn new(windows: Vec) -> Self { + Self { windows } + } + + pub fn default_windows() -> Self { + Self::new(vec![1, 4, 12, 24]) + } + + fn calculate_momentum_for_window( + candidate: &MarketCandidate, + now: chrono::DateTime, + hours: i64, + ) -> f64 { + let lookback_start = now - chrono::Duration::hours(hours); + let relevant_history: Vec<_> = candidate + .price_history + .iter() + .filter(|p| p.timestamp >= lookback_start) + .collect(); + + if relevant_history.len() < 2 { + return 0.0; + } + + let first = relevant_history.first().unwrap().yes_price.to_f64().unwrap_or(0.5); + let last = relevant_history.last().unwrap().yes_price.to_f64().unwrap_or(0.5); + + last - first + } + + fn calculate_score( + candidate: &MarketCandidate, + now: chrono::DateTime, + windows: &[i64], + ) -> (f64, f64, f64) { + let momentums: Vec = windows + .iter() + .map(|&w| Self::calculate_momentum_for_window(candidate, now, w)) + .collect(); + + if momentums.is_empty() { + return (0.0, 0.0, 1.0); + } + + let avg_momentum = momentums.iter().sum::() / momentums.len() as f64; + + let signs: Vec = momentums.iter().map(|&m| if m > 0.0 { 1 } else if m < 0.0 { -1 } else { 0 }).collect(); + let all_same_sign = signs.iter().all(|&s| s == signs[0]) && signs[0] != 0; + let alignment = if all_same_sign { 1.0 } else { 0.5 }; + + let short_avg = if momentums.len() >= 2 { + momentums[..momentums.len() / 2].iter().sum::() / (momentums.len() / 2) as f64 + } else { + momentums[0] + }; + let long_avg = if momentums.len() >= 2 { + momentums[momentums.len() / 2..].iter().sum::() / (momentums.len() - momentums.len() / 2) as f64 + } else { + momentums[0] + }; + + let divergence = if (short_avg > 0.0 && long_avg < 0.0) || (short_avg < 0.0 && long_avg > 0.0) { + (short_avg - long_avg).abs() + } else { + 0.0 + }; + + (avg_momentum * alignment, divergence, alignment) + } +} + +#[async_trait] +impl Scorer for MultiTimeframeMomentumScorer { + fn name(&self) -> &'static str { + "MultiTimeframeMomentumScorer" + } + + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let (momentum, divergence, alignment) = Self::calculate_score(c, context.timestamp, &self.windows); + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("mtf_momentum".to_string(), momentum); + scored.scores.insert("mtf_divergence".to_string(), divergence); + scored.scores.insert("mtf_alignment".to_string(), alignment); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + for key in ["mtf_momentum", "mtf_divergence", "mtf_alignment"] { + if let Some(score) = scored.scores.get(key) { + candidate.scores.insert(key.to_string(), *score); + } + } + } +} + +pub struct MeanReversionScorer { + lookback_hours: i64, +} + +impl MeanReversionScorer { + pub fn new(lookback_hours: i64) -> Self { + Self { lookback_hours } + } + + fn calculate_deviation(candidate: &MarketCandidate, now: chrono::DateTime, lookback_hours: i64) -> f64 { + let lookback_start = now - chrono::Duration::hours(lookback_hours); + let prices: Vec = candidate + .price_history + .iter() + .filter(|p| p.timestamp >= lookback_start) + .filter_map(|p| p.yes_price.to_f64()) + .collect(); + + if prices.is_empty() { + return 0.0; + } + + let mean: f64 = prices.iter().sum::() / prices.len() as f64; + let current = candidate.current_yes_price.to_f64().unwrap_or(0.5); + let deviation = current - mean; + + -deviation + } +} + +#[async_trait] +impl Scorer for MeanReversionScorer { + fn name(&self) -> &'static str { + "MeanReversionScorer" + } + + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let reversion = Self::calculate_deviation(c, context.timestamp, self.lookback_hours); + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("mean_reversion".to_string(), reversion); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + if let Some(score) = scored.scores.get("mean_reversion") { + candidate.scores.insert("mean_reversion".to_string(), *score); + } + } +} + +/// bollinger bands mean reversion scorer +/// triggers when price touches statistical extremes (upper/lower bands) +pub struct BollingerMeanReversionScorer { + lookback_hours: i64, + num_std: f64, +} + +impl BollingerMeanReversionScorer { + pub fn new(lookback_hours: i64, num_std: f64) -> Self { + Self { lookback_hours, num_std } + } + + pub fn default_config() -> Self { + Self::new(24, 2.0) + } + + fn calculate_bands( + candidate: &MarketCandidate, + now: chrono::DateTime, + lookback_hours: i64, + ) -> Option<(f64, f64, f64)> { + let lookback_start = now - chrono::Duration::hours(lookback_hours); + let prices: Vec = candidate + .price_history + .iter() + .filter(|p| p.timestamp >= lookback_start) + .filter_map(|p| p.yes_price.to_f64()) + .collect(); + + if prices.len() < 5 { + return None; + } + + let mean: f64 = prices.iter().sum::() / prices.len() as f64; + let variance: f64 = prices.iter().map(|p| (p - mean).powi(2)).sum::() / prices.len() as f64; + let std = variance.sqrt(); + + Some((mean, std, *prices.last().unwrap_or(&mean))) + } + + fn calculate_score( + candidate: &MarketCandidate, + now: chrono::DateTime, + lookback_hours: i64, + num_std: f64, + ) -> (f64, f64) { + let (mean, std, current) = match Self::calculate_bands(candidate, now, lookback_hours) { + Some(v) => v, + None => return (0.0, 0.0), + }; + + let upper_band = mean + num_std * std; + let lower_band = mean - num_std * std; + let band_width = upper_band - lower_band; + + if band_width < 0.001 { + return (0.0, 0.0); + } + + let position = (current - lower_band) / band_width; + + let score = if current >= upper_band { + -(current - upper_band) / std.max(0.001) + } else if current <= lower_band { + (lower_band - current) / std.max(0.001) + } else if current > mean { + -(position - 0.5) * 0.5 + } else { + (0.5 - position) * 0.5 + }; + + let band_position = (position * 2.0 - 1.0).clamp(-1.0, 1.0); + + (score.clamp(-2.0, 2.0), band_position) + } +} + +#[async_trait] +impl Scorer for BollingerMeanReversionScorer { + fn name(&self) -> &'static str { + "BollingerMeanReversionScorer" + } + + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let (score, band_pos) = Self::calculate_score(c, context.timestamp, self.lookback_hours, self.num_std); + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("bollinger_reversion".to_string(), score); + scored.scores.insert("bollinger_position".to_string(), band_pos); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + for key in ["bollinger_reversion", "bollinger_position"] { + if let Some(score) = scored.scores.get(key) { + candidate.scores.insert(key.to_string(), *score); + } + } + } +} + +pub struct VolumeScorer { + lookback_hours: i64, +} + +impl VolumeScorer { + pub fn new(lookback_hours: i64) -> Self { + Self { lookback_hours } + } + + fn calculate_volume_score(candidate: &MarketCandidate, now: chrono::DateTime, lookback_hours: i64) -> f64 { + let lookback_start = now - chrono::Duration::hours(lookback_hours); + let recent_volume: u64 = candidate + .price_history + .iter() + .filter(|p| p.timestamp >= lookback_start) + .map(|p| p.volume) + .sum(); + + if candidate.total_volume == 0 { + return 0.0; + } + + let avg_hourly_volume = candidate.total_volume as f64 + / ((now - candidate.open_time).num_hours().max(1) as f64); + let recent_hourly_volume = recent_volume as f64 / lookback_hours.max(1) as f64; + + if avg_hourly_volume > 0.0 { + (recent_hourly_volume / avg_hourly_volume).ln().max(-2.0).min(2.0) + } else { + 0.0 + } + } +} + +#[async_trait] +impl Scorer for VolumeScorer { + fn name(&self) -> &'static str { + "VolumeScorer" + } + + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let volume = Self::calculate_volume_score(c, context.timestamp, self.lookback_hours); + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("volume".to_string(), volume); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + if let Some(score) = scored.scores.get("volume") { + candidate.scores.insert("volume".to_string(), *score); + } + } +} + +pub struct TimeDecayScorer; + +impl TimeDecayScorer { + pub fn new() -> Self { + Self + } + + fn calculate_time_decay(candidate: &MarketCandidate, now: chrono::DateTime) -> f64 { + let ttc = candidate.time_to_close(now); + let hours_remaining = ttc.num_hours() as f64; + + if hours_remaining <= 0.0 { + return -1.0; + } + + let decay = 1.0 - (1.0 / (hours_remaining / 24.0 + 1.0)); + decay.min(1.0).max(0.0) + } +} + +impl Default for TimeDecayScorer { + fn default() -> Self { + Self::new() + } +} + +/// order flow imbalance scorer +/// measures buying vs selling pressure using taker_side from trades +pub struct OrderFlowScorer; + +impl OrderFlowScorer { + pub fn new() -> Self { + Self + } + + fn calculate_imbalance(candidate: &MarketCandidate) -> f64 { + let buy_vol = candidate.buy_volume_24h as f64; + let sell_vol = candidate.sell_volume_24h as f64; + let total = buy_vol + sell_vol; + + if total == 0.0 { + return 0.0; + } + + (buy_vol - sell_vol) / total + } +} + +impl Default for OrderFlowScorer { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Scorer for OrderFlowScorer { + fn name(&self) -> &'static str { + "OrderFlowScorer" + } + + async fn score( + &self, + _context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let imbalance = Self::calculate_imbalance(c); + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("order_flow".to_string(), imbalance); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + if let Some(score) = scored.scores.get("order_flow") { + candidate.scores.insert("order_flow".to_string(), *score); + } + } +} + +#[async_trait] +impl Scorer for TimeDecayScorer { + fn name(&self) -> &'static str { + "TimeDecayScorer" + } + + async fn score( + &self, + context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let time_decay = Self::calculate_time_decay(c, context.timestamp); + let mut scored = MarketCandidate { + scores: c.scores.clone(), + ..Default::default() + }; + scored.scores.insert("time_decay".to_string(), time_decay); + scored + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + if let Some(score) = scored.scores.get("time_decay") { + candidate.scores.insert("time_decay".to_string(), *score); + } + } +} + +pub struct WeightedScorer { + weights: Vec<(String, f64)>, +} + +impl WeightedScorer { + pub fn new(weights: Vec<(String, f64)>) -> Self { + Self { weights } + } + + pub fn default_weights() -> Self { + Self::new(vec![ + ("momentum".to_string(), 0.4), + ("mean_reversion".to_string(), 0.3), + ("volume".to_string(), 0.2), + ("time_decay".to_string(), 0.1), + ]) + } + + fn compute_weighted_score(&self, candidate: &MarketCandidate) -> f64 { + self.weights + .iter() + .map(|(name, weight)| { + candidate.scores.get(name).copied().unwrap_or(0.0) * weight + }) + .sum() + } +} + +#[async_trait] +impl Scorer for WeightedScorer { + fn name(&self) -> &'static str { + "WeightedScorer" + } + + async fn score( + &self, + _context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let weighted_score = self.compute_weighted_score(c); + MarketCandidate { + final_score: weighted_score, + ..Default::default() + } + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + candidate.final_score = scored.final_score; + } +} + +#[derive(Clone)] +pub struct ScorerWeights { + pub momentum: f64, + pub mean_reversion: f64, + pub volume: f64, + pub time_decay: f64, + pub order_flow: f64, + pub bollinger: f64, + pub mtf_momentum: f64, +} + +impl Default for ScorerWeights { + fn default() -> Self { + Self { + momentum: 0.2, + mean_reversion: 0.2, + volume: 0.15, + time_decay: 0.1, + order_flow: 0.15, + bollinger: 0.1, + mtf_momentum: 0.1, + } + } +} + +impl ScorerWeights { + pub fn politics() -> Self { + Self { + momentum: 0.35, + mean_reversion: 0.1, + volume: 0.1, + time_decay: 0.1, + order_flow: 0.15, + bollinger: 0.05, + mtf_momentum: 0.15, + } + } + + pub fn weather() -> Self { + Self { + momentum: 0.1, + mean_reversion: 0.35, + volume: 0.1, + time_decay: 0.15, + order_flow: 0.1, + bollinger: 0.15, + mtf_momentum: 0.05, + } + } + + pub fn sports() -> Self { + Self { + momentum: 0.2, + mean_reversion: 0.1, + volume: 0.15, + time_decay: 0.1, + order_flow: 0.3, + bollinger: 0.05, + mtf_momentum: 0.1, + } + } + + pub fn economics() -> Self { + Self { + momentum: 0.25, + mean_reversion: 0.2, + volume: 0.15, + time_decay: 0.1, + order_flow: 0.15, + bollinger: 0.1, + mtf_momentum: 0.05, + } + } + + fn compute_score(&self, candidate: &MarketCandidate) -> f64 { + let get_score = |key: &str| candidate.scores.get(key).copied().unwrap_or(0.0); + + self.momentum * get_score("momentum") + + self.mean_reversion * get_score("mean_reversion") + + self.volume * get_score("volume") + + self.time_decay * get_score("time_decay") + + self.order_flow * get_score("order_flow") + + self.bollinger * get_score("bollinger_reversion") + + self.mtf_momentum * get_score("mtf_momentum") + } +} + +/// category-aware weighted scorer +/// applies different weights based on market category +pub struct CategoryWeightedScorer { + category_weights: std::collections::HashMap, + default_weights: ScorerWeights, +} + +impl CategoryWeightedScorer { + pub fn new( + category_weights: std::collections::HashMap, + default_weights: ScorerWeights, + ) -> Self { + Self { category_weights, default_weights } + } + + pub fn with_defaults() -> Self { + let mut category_weights = std::collections::HashMap::new(); + category_weights.insert("politics".to_string(), ScorerWeights::politics()); + category_weights.insert("weather".to_string(), ScorerWeights::weather()); + category_weights.insert("sports".to_string(), ScorerWeights::sports()); + category_weights.insert("economics".to_string(), ScorerWeights::economics()); + category_weights.insert("financial".to_string(), ScorerWeights::economics()); + + Self { + category_weights, + default_weights: ScorerWeights::default(), + } + } + + fn get_weights(&self, category: &str) -> &ScorerWeights { + let lower = category.to_lowercase(); + self.category_weights.get(&lower).unwrap_or(&self.default_weights) + } +} + +#[async_trait] +impl Scorer for CategoryWeightedScorer { + fn name(&self) -> &'static str { + "CategoryWeightedScorer" + } + + async fn score( + &self, + _context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let weights = self.get_weights(&c.category); + let weighted_score = weights.compute_score(c); + MarketCandidate { + final_score: weighted_score, + ..Default::default() + } + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + candidate.final_score = scored.final_score; + } +} + +/// ensemble scorer that combines multiple models with dynamic weighting +/// weights can be updated based on recent accuracy +pub struct EnsembleScorer { + model_weights: std::sync::Arc>>, + model_keys: Vec, +} + +impl EnsembleScorer { + pub fn new(model_keys: Vec, initial_weights: Vec) -> Self { + assert_eq!(model_keys.len(), initial_weights.len()); + Self { + model_weights: std::sync::Arc::new(std::sync::Mutex::new(initial_weights)), + model_keys, + } + } + + pub fn default_ensemble() -> Self { + Self::new( + vec![ + "momentum".to_string(), + "mean_reversion".to_string(), + "bollinger_reversion".to_string(), + "order_flow".to_string(), + "mtf_momentum".to_string(), + ], + vec![0.25, 0.2, 0.2, 0.2, 0.15], + ) + } + + pub fn update_weights(&self, new_weights: Vec) { + let mut weights = self.model_weights.lock().unwrap(); + *weights = new_weights; + } + + fn compute_score(&self, candidate: &MarketCandidate) -> f64 { + let weights = self.model_weights.lock().unwrap(); + self.model_keys + .iter() + .zip(weights.iter()) + .map(|(key, weight)| { + candidate.scores.get(key).copied().unwrap_or(0.0) * weight + }) + .sum() + } +} + +#[async_trait] +impl Scorer for EnsembleScorer { + fn name(&self) -> &'static str { + "EnsembleScorer" + } + + async fn score( + &self, + _context: &TradingContext, + candidates: &[MarketCandidate], + ) -> Result, String> { + let scored = candidates + .iter() + .map(|c| { + let ensemble_score = self.compute_score(c); + MarketCandidate { + final_score: ensemble_score, + ..Default::default() + } + }) + .collect(); + + Ok(scored) + } + + fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) { + candidate.final_score = scored.final_score; + } +} + diff --git a/src/pipeline/selector.rs b/src/pipeline/selector.rs new file mode 100644 index 0000000..7bc7744 --- /dev/null +++ b/src/pipeline/selector.rs @@ -0,0 +1,64 @@ +use crate::pipeline::Selector; +use crate::types::{MarketCandidate, TradingContext}; + +pub struct TopKSelector { + k: usize, +} + +impl TopKSelector { + pub fn new(k: usize) -> Self { + Self { k } + } +} + +impl Selector for TopKSelector { + fn name(&self) -> &'static str { + "TopKSelector" + } + + fn select(&self, _context: &TradingContext, mut candidates: Vec) -> Vec { + candidates.sort_by(|a, b| { + b.final_score + .partial_cmp(&a.final_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + candidates.truncate(self.k); + candidates + } +} + +pub struct ThresholdSelector { + min_score: f64, + max_candidates: Option, +} + +impl ThresholdSelector { + pub fn new(min_score: f64, max_candidates: Option) -> Self { + Self { + min_score, + max_candidates, + } + } +} + +impl Selector for ThresholdSelector { + fn name(&self) -> &'static str { + "ThresholdSelector" + } + + fn select(&self, _context: &TradingContext, mut candidates: Vec) -> Vec { + candidates.retain(|c| c.final_score >= self.min_score); + candidates.sort_by(|a, b| { + b.final_score + .partial_cmp(&a.final_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + if let Some(max) = self.max_candidates { + candidates.truncate(max); + } + + candidates + } +} diff --git a/src/pipeline/sources.rs b/src/pipeline/sources.rs new file mode 100644 index 0000000..ea34056 --- /dev/null +++ b/src/pipeline/sources.rs @@ -0,0 +1,69 @@ +use crate::data::HistoricalData; +use crate::pipeline::Source; +use crate::types::{MarketCandidate, TradingContext}; +use async_trait::async_trait; +use rust_decimal::Decimal; +use std::collections::HashMap; +use std::sync::Arc; + +pub struct HistoricalMarketSource { + data: Arc, + lookback_hours: i64, +} + +impl HistoricalMarketSource { + pub fn new(data: Arc, lookback_hours: i64) -> Self { + Self { data, lookback_hours } + } +} + +#[async_trait] +impl Source for HistoricalMarketSource { + fn name(&self) -> &'static str { + "HistoricalMarketSource" + } + + async fn get_candidates(&self, context: &TradingContext) -> Result, String> { + let now = context.timestamp; + let active_markets = self.data.get_active_markets(now); + + let candidates: Vec = active_markets + .into_iter() + .filter_map(|market| { + let current_price = self.data.get_current_price(&market.ticker, now)?; + let lookback_start = now - chrono::Duration::hours(self.lookback_hours); + let price_history = self.data.get_price_history(&market.ticker, lookback_start, now); + let volume_24h = self.data.get_volume_24h(&market.ticker, now); + + let total_volume: u64 = self + .data + .get_trades_for_market(&market.ticker, market.open_time, now) + .iter() + .map(|t| t.volume) + .sum(); + + let (buy_volume_24h, sell_volume_24h) = self.data.get_order_flow_24h(&market.ticker, now); + + Some(MarketCandidate { + ticker: market.ticker.clone(), + title: market.title.clone(), + category: market.category.clone(), + current_yes_price: current_price, + current_no_price: Decimal::ONE - current_price, + volume_24h, + total_volume, + buy_volume_24h, + sell_volume_24h, + open_time: market.open_time, + close_time: market.close_time, + result: market.result, + price_history, + scores: HashMap::new(), + final_score: 0.0, + }) + }) + .collect(); + + Ok(candidates) + } +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..0cc9cbc --- /dev/null +++ b/src/types.rs @@ -0,0 +1,343 @@ +use chrono::{DateTime, Utc}; +use rust_decimal::Decimal; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum Side { + Yes, + No, +} + +impl Side { + pub fn opposite(&self) -> Self { + match self { + Side::Yes => Side::No, + Side::No => Side::Yes, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum MarketResult { + Yes, + No, + Cancelled, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MarketCandidate { + pub ticker: String, + pub title: String, + pub category: String, + pub current_yes_price: Decimal, + pub current_no_price: Decimal, + pub volume_24h: u64, + pub total_volume: u64, + pub buy_volume_24h: u64, + pub sell_volume_24h: u64, + pub open_time: DateTime, + pub close_time: DateTime, + pub result: Option, + pub price_history: Vec, + + pub scores: HashMap, + pub final_score: f64, +} + +impl MarketCandidate { + pub fn time_to_close(&self, now: DateTime) -> chrono::Duration { + self.close_time - now + } + + pub fn is_open(&self, now: DateTime) -> bool { + now >= self.open_time && now < self.close_time + } +} + +impl Default for MarketCandidate { + fn default() -> Self { + Self { + ticker: String::new(), + title: String::new(), + category: String::new(), + current_yes_price: Decimal::ZERO, + current_no_price: Decimal::ZERO, + volume_24h: 0, + total_volume: 0, + buy_volume_24h: 0, + sell_volume_24h: 0, + open_time: Utc::now(), + close_time: Utc::now(), + result: None, + price_history: Vec::new(), + scores: HashMap::new(), + final_score: 0.0, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PricePoint { + pub timestamp: DateTime, + pub yes_price: Decimal, + pub volume: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TradingContext { + pub request_id: String, + pub timestamp: DateTime, + pub portfolio: Portfolio, + pub trading_history: Vec, +} + +impl TradingContext { + pub fn new(initial_capital: Decimal, start_time: DateTime) -> Self { + Self { + request_id: uuid::Uuid::new_v4().to_string(), + timestamp: start_time, + portfolio: Portfolio::new(initial_capital), + trading_history: Vec::new(), + } + } + + pub fn request_id(&self) -> &str { + &self.request_id + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Portfolio { + pub positions: HashMap, + pub cash: Decimal, + pub initial_capital: Decimal, +} + +impl Portfolio { + pub fn new(initial_capital: Decimal) -> Self { + Self { + positions: HashMap::new(), + cash: initial_capital, + initial_capital, + } + } + + pub fn total_value(&self, market_prices: &HashMap) -> Decimal { + let position_value: Decimal = self + .positions + .values() + .map(|p| { + let price = market_prices.get(&p.ticker).copied().unwrap_or(p.avg_entry_price); + let effective_price = match p.side { + Side::Yes => price, + Side::No => Decimal::ONE - price, + }; + effective_price * Decimal::from(p.quantity) + }) + .sum(); + self.cash + position_value + } + + pub fn has_position(&self, ticker: &str) -> bool { + self.positions.contains_key(ticker) + } + + pub fn get_position(&self, ticker: &str) -> Option<&Position> { + self.positions.get(ticker) + } + + pub fn apply_fill(&mut self, fill: &Fill) { + let cost = fill.price * Decimal::from(fill.quantity); + + match fill.side { + Side::Yes | Side::No => { + self.cash -= cost; + let position = self.positions.entry(fill.ticker.clone()).or_insert_with(|| { + Position { + ticker: fill.ticker.clone(), + side: fill.side, + quantity: 0, + avg_entry_price: Decimal::ZERO, + entry_time: fill.timestamp, + } + }); + + let total_cost = + position.avg_entry_price * Decimal::from(position.quantity) + cost; + position.quantity += fill.quantity; + if position.quantity > 0 { + position.avg_entry_price = total_cost / Decimal::from(position.quantity); + } + } + } + } + + pub fn resolve_position(&mut self, ticker: &str, result: MarketResult) -> Option { + let position = self.positions.remove(ticker)?; + + let payout = match (result, position.side) { + (MarketResult::Yes, Side::Yes) | (MarketResult::No, Side::No) => { + Decimal::from(position.quantity) + } + (MarketResult::Cancelled, _) => { + position.avg_entry_price * Decimal::from(position.quantity) + } + _ => Decimal::ZERO, + }; + + self.cash += payout; + + let cost = position.avg_entry_price * Decimal::from(position.quantity); + Some(payout - cost) + } + + pub fn close_position(&mut self, ticker: &str, exit_price: Decimal) -> Option { + let position = self.positions.remove(ticker)?; + + let effective_exit_price = match position.side { + Side::Yes => exit_price, + Side::No => Decimal::ONE - exit_price, + }; + + let exit_value = effective_exit_price * Decimal::from(position.quantity); + self.cash += exit_value; + + let cost = position.avg_entry_price * Decimal::from(position.quantity); + Some(exit_value - cost) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Position { + pub ticker: String, + pub side: Side, + pub quantity: u64, + pub avg_entry_price: Decimal, + pub entry_time: DateTime, +} + +impl Position { + pub fn cost_basis(&self) -> Decimal { + self.avg_entry_price * Decimal::from(self.quantity) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Trade { + pub ticker: String, + pub side: Side, + pub quantity: u64, + pub price: Decimal, + pub timestamp: DateTime, + pub trade_type: TradeType, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum TradeType { + Open, + Close, + Resolution, +} + +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +pub enum ExitReason { + Resolution(MarketResult), + TakeProfit { pnl_pct: f64 }, + StopLoss { pnl_pct: f64 }, + TimeStop { hours_held: i64 }, + ScoreReversal { new_score: f64 }, +} + +#[derive(Debug, Clone)] +pub struct ExitSignal { + pub ticker: String, + pub reason: ExitReason, + pub current_price: Decimal, +} + +#[derive(Debug, Clone)] +pub struct ExitConfig { + pub take_profit_pct: f64, + pub stop_loss_pct: f64, + pub max_hold_hours: i64, + pub score_reversal_threshold: f64, +} + +impl Default for ExitConfig { + fn default() -> Self { + Self { + take_profit_pct: 0.20, + stop_loss_pct: 0.15, + max_hold_hours: 72, + score_reversal_threshold: -0.3, + } + } +} + +impl ExitConfig { + pub fn conservative() -> Self { + Self { + take_profit_pct: 0.15, + stop_loss_pct: 0.10, + max_hold_hours: 48, + score_reversal_threshold: -0.2, + } + } + + pub fn aggressive() -> Self { + Self { + take_profit_pct: 0.30, + stop_loss_pct: 0.20, + max_hold_hours: 120, + score_reversal_threshold: -0.5, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Fill { + pub ticker: String, + pub side: Side, + pub quantity: u64, + pub price: Decimal, + pub timestamp: DateTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Signal { + pub ticker: String, + pub side: Side, + pub quantity: u64, + pub limit_price: Option, + pub reason: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MarketData { + pub ticker: String, + pub title: String, + pub category: String, + pub open_time: DateTime, + pub close_time: DateTime, + pub result: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TradeData { + pub timestamp: DateTime, + pub ticker: String, + pub price: Decimal, + pub volume: u64, + pub taker_side: Side, +} + +#[derive(Debug, Clone)] +pub struct BacktestConfig { + pub start_time: DateTime, + pub end_time: DateTime, + pub interval: chrono::Duration, + pub initial_capital: Decimal, + pub max_position_size: u64, + pub max_positions: usize, +}