feat: initial commit with code quality refactoring
kalshi prediction market backtesting framework with: - trading pipeline (sources, filters, scorers, selectors) - position sizing with kelly criterion - multiple scoring strategies (momentum, mean reversion, etc) - random baseline for comparison refactoring includes: - extract shared resolve_closed_positions() function - reduce RandomBaseline::run() nesting with helper functions - move MarketCandidate Default impl to types.rs - add explanatory comments to complex logic
This commit is contained in:
commit
025322219c
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
/target
|
||||
/data/*.csv
|
||||
/data/*.parquet
|
||||
/results/*.json
|
||||
Cargo.lock
|
||||
30
Cargo.toml
Normal file
30
Cargo.toml
Normal file
@ -0,0 +1,30 @@
|
||||
[package]
|
||||
name = "kalshi-backtest"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
async-trait = "0.1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
csv = "1.3"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
anyhow = "1"
|
||||
thiserror = "1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
rust_decimal = { version = "1", features = ["serde"] }
|
||||
rust_decimal_macros = "1"
|
||||
|
||||
ort = { version = "2.0.0-rc.11", optional = true }
|
||||
ndarray = { version = "0.16", optional = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
ml = ["ort", "ndarray"]
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
236
README.md
Normal file
236
README.md
Normal file
@ -0,0 +1,236 @@
|
||||
kalshi-backtest
|
||||
===
|
||||
|
||||
quant-level backtesting framework for kalshi prediction markets, using a candidate pipeline architecture.
|
||||
|
||||
|
||||
features
|
||||
---
|
||||
|
||||
- **multi-timeframe momentum** - detects divergence between short and long-term trends
|
||||
- **bollinger bands mean reversion** - signals when price touches statistical extremes
|
||||
- **order flow analysis** - tracks buying vs selling pressure via taker_side
|
||||
- **kelly criterion position sizing** - dynamic sizing based on edge and win probability
|
||||
- **exit signals** - take profit, stop loss, time stops, and score reversal triggers
|
||||
- **category-aware weighting** - different strategies for politics, weather, sports, etc.
|
||||
- **ensemble scoring** - combine multiple models with dynamic weighting
|
||||
- **cross-market correlations** - lead-lag relationships between related markets
|
||||
- **ML ensemble (optional)** - LSTM + MLP models via ONNX runtime
|
||||
|
||||
|
||||
architecture
|
||||
---
|
||||
|
||||
```
|
||||
Historical Data (CSV)
|
||||
|
|
||||
v
|
||||
+------------------+
|
||||
| Backtest Loop | <- simulates time progression
|
||||
+------------------+
|
||||
|
|
||||
v
|
||||
+------------------+
|
||||
| Candidate Pipeline |
|
||||
+------------------+
|
||||
| |
|
||||
v v
|
||||
Sources Filters -> Scorers -> Selector
|
||||
|
|
||||
v
|
||||
+------------------+
|
||||
| Trade Executor | <- kelly sizing, exit signals
|
||||
+------------------+
|
||||
|
|
||||
v
|
||||
+------------------+
|
||||
| P&L Tracker | <- tracks positions, returns
|
||||
+------------------+
|
||||
|
|
||||
v
|
||||
Performance Metrics
|
||||
```
|
||||
|
||||
|
||||
data format
|
||||
---
|
||||
|
||||
fetch data from kalshi API using the included script:
|
||||
|
||||
```bash
|
||||
python scripts/fetch_kalshi_data.py
|
||||
```
|
||||
|
||||
or download from https://www.deltabase.tech/
|
||||
|
||||
**markets.csv**:
|
||||
```csv
|
||||
ticker,title,category,open_time,close_time,result,status,yes_bid,yes_ask,volume,open_interest
|
||||
PRES-2024-DEM,Will Democrats win?,politics,2024-01-01 00:00:00,2024-11-06 00:00:00,no,finalized,45,47,10000,5000
|
||||
```
|
||||
|
||||
**trades.csv**:
|
||||
```csv
|
||||
timestamp,ticker,price,volume,taker_side
|
||||
2024-01-05 12:00:00,PRES-2024-DEM,45,100,yes
|
||||
2024-01-05 13:00:00,PRES-2024-DEM,46,50,no
|
||||
```
|
||||
|
||||
|
||||
usage
|
||||
---
|
||||
|
||||
```bash
|
||||
# build
|
||||
cargo build --release
|
||||
|
||||
# run backtest with quant features
|
||||
cargo run --release -- run \
|
||||
--data-dir data \
|
||||
--start 2024-01-01 \
|
||||
--end 2024-06-01 \
|
||||
--capital 10000 \
|
||||
--max-position 500 \
|
||||
--max-positions 10 \
|
||||
--kelly-fraction 0.25 \
|
||||
--max-position-pct 0.25 \
|
||||
--take-profit 0.20 \
|
||||
--stop-loss 0.15 \
|
||||
--max-hold-hours 72 \
|
||||
--compare-random
|
||||
|
||||
# view results
|
||||
cargo run --release -- summary --results-file results/backtest_result.json
|
||||
```
|
||||
|
||||
|
||||
cli options
|
||||
---
|
||||
|
||||
| option | default | description |
|
||||
|--------|---------|-------------|
|
||||
| --data-dir | data | directory with markets.csv and trades.csv |
|
||||
| --start | required | backtest start date |
|
||||
| --end | required | backtest end date |
|
||||
| --capital | 10000 | initial capital |
|
||||
| --max-position | 100 | max shares per position |
|
||||
| --max-positions | 5 | max concurrent positions |
|
||||
| --kelly-fraction | 0.25 | fraction of kelly criterion (0.1=conservative, 1.0=full) |
|
||||
| --max-position-pct | 0.25 | max % of capital per position |
|
||||
| --take-profit | 0.20 | take profit threshold (20% gain) |
|
||||
| --stop-loss | 0.15 | stop loss threshold (15% loss) |
|
||||
| --max-hold-hours | 72 | time stop in hours |
|
||||
| --compare-random | false | compare vs random baseline |
|
||||
|
||||
|
||||
scorers
|
||||
---
|
||||
|
||||
**basic scorers**:
|
||||
- `MomentumScorer` - price change over lookback period
|
||||
- `MeanReversionScorer` - deviation from historical mean
|
||||
- `VolumeScorer` - unusual volume detection
|
||||
- `TimeDecayScorer` - prefer markets with more time to close
|
||||
|
||||
**quant scorers**:
|
||||
- `MultiTimeframeMomentumScorer` - analyzes 1h, 4h, 12h, 24h windows, detects divergence
|
||||
- `BollingerMeanReversionScorer` - triggers at upper/lower band touches (2 std)
|
||||
- `OrderFlowScorer` - buy/sell imbalance from taker_side
|
||||
- `CategoryWeightedScorer` - different weights per category
|
||||
- `EnsembleScorer` - combines models with dynamic weights
|
||||
- `CorrelationScorer` - cross-market lead-lag signals
|
||||
|
||||
**ml scorers** (requires `ml` feature):
|
||||
- `MLEnsembleScorer` - LSTM + MLP via ONNX
|
||||
|
||||
|
||||
position sizing
|
||||
---
|
||||
|
||||
uses kelly criterion with safety multiplier:
|
||||
|
||||
```
|
||||
kelly = (odds * win_prob - (1 - win_prob)) / odds
|
||||
safe_kelly = kelly * kelly_fraction
|
||||
position = min(bankroll * safe_kelly, max_position_pct * bankroll)
|
||||
```
|
||||
|
||||
|
||||
exit signals
|
||||
---
|
||||
|
||||
positions can exit via:
|
||||
1. **resolution** - market resolves yes/no
|
||||
2. **take profit** - pnl exceeds threshold
|
||||
3. **stop loss** - pnl below threshold
|
||||
4. **time stop** - held too long (capital rotation)
|
||||
5. **score reversal** - strategy flips bearish
|
||||
|
||||
|
||||
ml training (optional)
|
||||
---
|
||||
|
||||
train ML models using pytorch, then export to ONNX:
|
||||
|
||||
```bash
|
||||
# install dependencies
|
||||
pip install torch pandas numpy
|
||||
|
||||
# train models
|
||||
python scripts/train_ml_models.py \
|
||||
--data data/trades.csv \
|
||||
--markets data/markets.csv \
|
||||
--output models/ \
|
||||
--epochs 50
|
||||
|
||||
# enable ml feature
|
||||
cargo build --release --features ml
|
||||
```
|
||||
|
||||
|
||||
metrics
|
||||
---
|
||||
|
||||
- total return ($ and %)
|
||||
- sharpe ratio (annualized)
|
||||
- max drawdown
|
||||
- win rate
|
||||
- average trade P&L
|
||||
- average hold time
|
||||
- trades per day
|
||||
- return by category
|
||||
|
||||
|
||||
extending
|
||||
---
|
||||
|
||||
add custom scorers by implementing the `Scorer` trait:
|
||||
|
||||
```rust
|
||||
use async_trait::async_trait;
|
||||
|
||||
pub struct MyScorer;
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for MyScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"MyScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
// compute scores...
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
if let Some(score) = scored.scores.get("my_score") {
|
||||
candidate.scores.insert("my_score".to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
then add to the pipeline in `backtest.rs`.
|
||||
0
data/.gitkeep
Normal file
0
data/.gitkeep
Normal file
1
data/fetch_state.json
Normal file
1
data/fetch_state.json
Normal file
@ -0,0 +1 @@
|
||||
{"markets_cursor": "CgsI-rDDywYQkKOiMRI5S1hNVkVTUE9SVFNNVUxUSUdBTUVFWFRFTkRFRC1TMjAyNTBDMDMzMDBBRkYyLTkxNTVFNjFERTk3", "markets_count": 25000, "trades_cursor": null, "trades_count": 0, "markets_done": false, "trades_done": false}
|
||||
254
scripts/fetch_kalshi_data.py
Executable file
254
scripts/fetch_kalshi_data.py
Executable file
@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fetch historical trade and market data from Kalshi's public API.
|
||||
No authentication required for public endpoints.
|
||||
|
||||
Features:
|
||||
- Incremental saves (writes batches to disk)
|
||||
- Resume capability (tracks cursor position)
|
||||
- Retry logic with exponential backoff
|
||||
"""
|
||||
|
||||
import json
|
||||
import csv
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
BASE_URL = "https://api.elections.kalshi.com/trade-api/v2"
|
||||
STATE_FILE = "fetch_state.json"
|
||||
|
||||
def fetch_json(url: str, max_retries: int = 5) -> dict:
|
||||
"""Fetch JSON from URL with retries and exponential backoff."""
|
||||
req = urllib.request.Request(url, headers={"Accept": "application/json"})
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
except (urllib.error.HTTPError, urllib.error.URLError) as e:
|
||||
wait = 2 ** attempt
|
||||
print(f" attempt {attempt + 1}/{max_retries} failed: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
print(f" retrying in {wait}s...")
|
||||
time.sleep(wait)
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
wait = 2 ** attempt
|
||||
print(f" unexpected error: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
print(f" retrying in {wait}s...")
|
||||
time.sleep(wait)
|
||||
else:
|
||||
raise
|
||||
|
||||
def load_state(output_dir: Path) -> dict:
|
||||
"""Load saved state for resuming."""
|
||||
state_path = output_dir / STATE_FILE
|
||||
if state_path.exists():
|
||||
with open(state_path) as f:
|
||||
return json.load(f)
|
||||
return {"markets_cursor": None, "markets_count": 0,
|
||||
"trades_cursor": None, "trades_count": 0,
|
||||
"markets_done": False, "trades_done": False}
|
||||
|
||||
def save_state(output_dir: Path, state: dict):
|
||||
"""Save state for resuming."""
|
||||
state_path = output_dir / STATE_FILE
|
||||
with open(state_path, "w") as f:
|
||||
json.dump(state, f)
|
||||
|
||||
def append_markets_csv(markets: list, output_path: Path, write_header: bool):
|
||||
"""Append markets to CSV."""
|
||||
mode = "w" if write_header else "a"
|
||||
with open(output_path, mode, newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
if write_header:
|
||||
writer.writerow(["ticker", "title", "category", "open_time",
|
||||
"close_time", "result", "status", "yes_bid",
|
||||
"yes_ask", "volume", "open_interest"])
|
||||
|
||||
for m in markets:
|
||||
result = ""
|
||||
if m.get("result") == "yes":
|
||||
result = "yes"
|
||||
elif m.get("result") == "no":
|
||||
result = "no"
|
||||
elif m.get("status") == "finalized" and m.get("result"):
|
||||
result = m.get("result")
|
||||
|
||||
writer.writerow([
|
||||
m.get("ticker", ""),
|
||||
m.get("title", ""),
|
||||
m.get("category", ""),
|
||||
m.get("open_time", ""),
|
||||
m.get("close_time", m.get("expiration_time", "")),
|
||||
result,
|
||||
m.get("status", ""),
|
||||
m.get("yes_bid", ""),
|
||||
m.get("yes_ask", ""),
|
||||
m.get("volume", ""),
|
||||
m.get("open_interest", ""),
|
||||
])
|
||||
|
||||
def append_trades_csv(trades: list, output_path: Path, write_header: bool):
|
||||
"""Append trades to CSV."""
|
||||
mode = "w" if write_header else "a"
|
||||
with open(output_path, mode, newline="") as f:
|
||||
writer = csv.writer(f)
|
||||
if write_header:
|
||||
writer.writerow(["timestamp", "ticker", "price", "volume", "taker_side"])
|
||||
|
||||
for t in trades:
|
||||
price = t.get("yes_price", t.get("price", 50))
|
||||
taker_side = t.get("taker_side", "")
|
||||
if not taker_side:
|
||||
taker_side = "yes" if t.get("is_taker_side_yes", True) else "no"
|
||||
|
||||
writer.writerow([
|
||||
t.get("created_time", t.get("ts", "")),
|
||||
t.get("ticker", t.get("market_ticker", "")),
|
||||
price,
|
||||
t.get("count", t.get("volume", 1)),
|
||||
taker_side,
|
||||
])
|
||||
|
||||
def fetch_markets_incremental(output_dir: Path, state: dict) -> int:
|
||||
"""Fetch markets incrementally with state tracking."""
|
||||
output_path = output_dir / "markets.csv"
|
||||
cursor = state["markets_cursor"]
|
||||
total = state["markets_count"]
|
||||
write_header = total == 0
|
||||
|
||||
print(f"Resuming from {total} markets...")
|
||||
|
||||
while True:
|
||||
url = f"{BASE_URL}/markets?limit=1000"
|
||||
if cursor:
|
||||
url += f"&cursor={cursor}"
|
||||
|
||||
print(f"Fetching markets... ({total:,} so far)")
|
||||
|
||||
try:
|
||||
data = fetch_json(url)
|
||||
except Exception as e:
|
||||
print(f"Error fetching markets: {e}")
|
||||
print(f"Progress saved. Run again to resume from {total:,} markets.")
|
||||
return total
|
||||
|
||||
batch = data.get("markets", [])
|
||||
if batch:
|
||||
append_markets_csv(batch, output_path, write_header)
|
||||
write_header = False
|
||||
total += len(batch)
|
||||
|
||||
cursor = data.get("cursor")
|
||||
state["markets_cursor"] = cursor
|
||||
state["markets_count"] = total
|
||||
save_state(output_dir, state)
|
||||
|
||||
if not cursor:
|
||||
state["markets_done"] = True
|
||||
save_state(output_dir, state)
|
||||
break
|
||||
|
||||
time.sleep(0.3)
|
||||
|
||||
return total
|
||||
|
||||
def fetch_trades_incremental(output_dir: Path, state: dict, limit: int) -> int:
|
||||
"""Fetch trades incrementally with state tracking."""
|
||||
output_path = output_dir / "trades.csv"
|
||||
cursor = state["trades_cursor"]
|
||||
total = state["trades_count"]
|
||||
write_header = total == 0
|
||||
|
||||
print(f"Resuming from {total} trades...")
|
||||
|
||||
while total < limit:
|
||||
url = f"{BASE_URL}/markets/trades?limit=1000"
|
||||
if cursor:
|
||||
url += f"&cursor={cursor}"
|
||||
|
||||
print(f"Fetching trades... ({total:,}/{limit:,})")
|
||||
|
||||
try:
|
||||
data = fetch_json(url)
|
||||
except Exception as e:
|
||||
print(f"Error fetching trades: {e}")
|
||||
print(f"Progress saved. Run again to resume from {total:,} trades.")
|
||||
return total
|
||||
|
||||
batch = data.get("trades", [])
|
||||
if not batch:
|
||||
break
|
||||
|
||||
append_trades_csv(batch, output_path, write_header)
|
||||
write_header = False
|
||||
total += len(batch)
|
||||
|
||||
cursor = data.get("cursor")
|
||||
state["trades_cursor"] = cursor
|
||||
state["trades_count"] = total
|
||||
save_state(output_dir, state)
|
||||
|
||||
if not cursor:
|
||||
state["trades_done"] = True
|
||||
save_state(output_dir, state)
|
||||
break
|
||||
|
||||
time.sleep(0.3)
|
||||
|
||||
return total
|
||||
|
||||
def main():
|
||||
output_dir = Path("/mnt/work/kalshi-data")
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
print("=" * 50)
|
||||
print("Kalshi Data Fetcher (with resume)")
|
||||
print("=" * 50)
|
||||
|
||||
state = load_state(output_dir)
|
||||
|
||||
# fetch markets
|
||||
if not state["markets_done"]:
|
||||
print("\n[1/2] Fetching markets...")
|
||||
markets_count = fetch_markets_incremental(output_dir, state)
|
||||
if state["markets_done"]:
|
||||
print(f"Markets complete: {markets_count:,}")
|
||||
else:
|
||||
print(f"Markets paused at: {markets_count:,}")
|
||||
return 1
|
||||
else:
|
||||
print(f"\n[1/2] Markets already complete: {state['markets_count']:,}")
|
||||
|
||||
# fetch trades
|
||||
if not state["trades_done"]:
|
||||
print("\n[2/2] Fetching trades...")
|
||||
trades_count = fetch_trades_incremental(output_dir, state, limit=1000000)
|
||||
if state["trades_done"]:
|
||||
print(f"Trades complete: {trades_count:,}")
|
||||
else:
|
||||
print(f"Trades paused at: {trades_count:,}")
|
||||
return 1
|
||||
else:
|
||||
print(f"\n[2/2] Trades already complete: {state['trades_count']:,}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Done!")
|
||||
print(f"Markets: {state['markets_count']:,}")
|
||||
print(f"Trades: {state['trades_count']:,}")
|
||||
print(f"Output: {output_dir}")
|
||||
print("=" * 50)
|
||||
|
||||
# clear state for next run
|
||||
(output_dir / STATE_FILE).unlink(missing_ok=True)
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
280
scripts/train_ml_models.py
Normal file
280
scripts/train_ml_models.py
Normal file
@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Train ML models for the kalshi backtest framework.
|
||||
|
||||
Models:
|
||||
- LSTM: learns patterns from price history sequences
|
||||
- MLP: learns optimal combination of hand-crafted features
|
||||
|
||||
Usage:
|
||||
python scripts/train_ml_models.py --data data/trades.csv --output models/
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
HAS_TORCH = True
|
||||
except ImportError:
|
||||
HAS_TORCH = False
|
||||
print("warning: pytorch not installed. run: pip install torch")
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train ML models for kalshi backtest")
|
||||
parser.add_argument("--data", type=Path, default=Path("data/trades.csv"))
|
||||
parser.add_argument("--markets", type=Path, default=Path("data/markets.csv"))
|
||||
parser.add_argument("--output", type=Path, default=Path("models"))
|
||||
parser.add_argument("--epochs", type=int, default=50)
|
||||
parser.add_argument("--batch-size", type=int, default=64)
|
||||
parser.add_argument("--seq-len", type=int, default=24)
|
||||
parser.add_argument("--train-split", type=float, default=0.8)
|
||||
return parser.parse_args()
|
||||
|
||||
class LSTMPredictor(nn.Module):
|
||||
def __init__(self, input_size=1, hidden_size=128, num_layers=2, dropout=0.2):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
self.fc = nn.Linear(hidden_size, 1)
|
||||
self.tanh = nn.Tanh()
|
||||
|
||||
def forward(self, x):
|
||||
lstm_out, _ = self.lstm(x)
|
||||
last_output = lstm_out[:, -1, :]
|
||||
out = self.fc(last_output)
|
||||
return self.tanh(out)
|
||||
|
||||
class MLPPredictor(nn.Module):
|
||||
def __init__(self, input_size=7, hidden_sizes=[64, 32]):
|
||||
super().__init__()
|
||||
layers = []
|
||||
prev_size = input_size
|
||||
for h in hidden_sizes:
|
||||
layers.append(nn.Linear(prev_size, h))
|
||||
layers.append(nn.ReLU())
|
||||
layers.append(nn.Dropout(0.2))
|
||||
prev_size = h
|
||||
layers.append(nn.Linear(prev_size, 1))
|
||||
layers.append(nn.Tanh())
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def load_data(trades_path: Path, markets_path: Path, seq_len: int):
|
||||
print(f"loading trades from {trades_path}...")
|
||||
trades = pd.read_csv(trades_path)
|
||||
trades["timestamp"] = pd.to_datetime(trades["timestamp"])
|
||||
trades = trades.sort_values(["ticker", "timestamp"])
|
||||
|
||||
print(f"loading markets from {markets_path}...")
|
||||
markets = pd.read_csv(markets_path)
|
||||
markets["close_time"] = pd.to_datetime(markets["close_time"])
|
||||
|
||||
result_map = dict(zip(markets["ticker"], markets["result"]))
|
||||
|
||||
sequences = []
|
||||
features = []
|
||||
labels = []
|
||||
|
||||
for ticker, group in trades.groupby("ticker"):
|
||||
result = result_map.get(ticker)
|
||||
if result not in ["yes", "no"]:
|
||||
continue
|
||||
|
||||
label = 1.0 if result == "yes" else -1.0
|
||||
|
||||
prices = group["price"].values / 100.0
|
||||
volumes = group["volume"].values
|
||||
taker_sides = (group["taker_side"] == "yes").astype(float).values
|
||||
|
||||
if len(prices) < seq_len:
|
||||
continue
|
||||
|
||||
for i in range(seq_len, len(prices)):
|
||||
seq = prices[i - seq_len : i]
|
||||
log_returns = np.diff(np.log(np.clip(seq, 1e-6, 1.0)))
|
||||
|
||||
if len(log_returns) == seq_len - 1:
|
||||
log_returns = np.pad(log_returns, (1, 0), mode="constant")
|
||||
|
||||
sequences.append(log_returns)
|
||||
|
||||
curr_price = prices[i - 1]
|
||||
momentum = prices[i - 1] - prices[i - seq_len] if len(prices) > seq_len else 0
|
||||
mean_price = np.mean(prices[i - seq_len : i])
|
||||
mean_reversion = mean_price - curr_price
|
||||
vol_sum = np.sum(volumes[i - seq_len : i])
|
||||
buy_vol = np.sum(volumes[i - seq_len : i] * taker_sides[i - seq_len : i])
|
||||
sell_vol = vol_sum - buy_vol
|
||||
order_flow = (buy_vol - sell_vol) / max(vol_sum, 1)
|
||||
|
||||
feat = [
|
||||
momentum,
|
||||
mean_reversion,
|
||||
np.log1p(vol_sum),
|
||||
order_flow,
|
||||
curr_price,
|
||||
np.std(log_returns) if len(log_returns) > 1 else 0,
|
||||
len(group) / 1000.0,
|
||||
]
|
||||
features.append(feat)
|
||||
labels.append(label)
|
||||
|
||||
print(f"created {len(sequences)} training samples")
|
||||
return np.array(sequences), np.array(features), np.array(labels)
|
||||
|
||||
def train_lstm(sequences, labels, args):
|
||||
print("\n" + "=" * 50)
|
||||
print("Training LSTM")
|
||||
print("=" * 50)
|
||||
|
||||
n = len(sequences)
|
||||
split = int(n * args.train_split)
|
||||
|
||||
X_train = torch.tensor(sequences[:split], dtype=torch.float32).unsqueeze(-1)
|
||||
y_train = torch.tensor(labels[:split], dtype=torch.float32).unsqueeze(-1)
|
||||
X_test = torch.tensor(sequences[split:], dtype=torch.float32).unsqueeze(-1)
|
||||
y_test = torch.tensor(labels[split:], dtype=torch.float32).unsqueeze(-1)
|
||||
|
||||
train_dataset = TensorDataset(X_train, y_train)
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
model = LSTMPredictor(input_size=1, hidden_size=128, num_layers=2)
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
for X_batch, y_batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
output = model(X_batch)
|
||||
loss = criterion(output, y_batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
model.set_mode_to_inference()
|
||||
with torch.no_grad():
|
||||
train_pred = model(X_train)
|
||||
test_pred = model(X_test)
|
||||
train_acc = ((train_pred > 0) == (y_train > 0)).float().mean()
|
||||
test_acc = ((test_pred > 0) == (y_test > 0)).float().mean()
|
||||
print(f"epoch {epoch + 1}/{args.epochs}: loss={total_loss/len(train_loader):.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
|
||||
|
||||
return model
|
||||
|
||||
def train_mlp(features, labels, args):
|
||||
print("\n" + "=" * 50)
|
||||
print("Training MLP")
|
||||
print("=" * 50)
|
||||
|
||||
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
|
||||
|
||||
n = len(features)
|
||||
split = int(n * args.train_split)
|
||||
|
||||
X_train = torch.tensor(features[:split], dtype=torch.float32)
|
||||
y_train = torch.tensor(labels[:split], dtype=torch.float32).unsqueeze(-1)
|
||||
X_test = torch.tensor(features[split:], dtype=torch.float32)
|
||||
y_test = torch.tensor(labels[split:], dtype=torch.float32).unsqueeze(-1)
|
||||
|
||||
train_dataset = TensorDataset(X_train, y_train)
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
model = MLPPredictor(input_size=features.shape[1])
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
for X_batch, y_batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
output = model(X_batch)
|
||||
loss = criterion(output, y_batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
model.set_mode_to_inference()
|
||||
with torch.no_grad():
|
||||
train_pred = model(X_train)
|
||||
test_pred = model(X_test)
|
||||
train_acc = ((train_pred > 0) == (y_train > 0)).float().mean()
|
||||
test_acc = ((test_pred > 0) == (y_test > 0)).float().mean()
|
||||
print(f"epoch {epoch + 1}/{args.epochs}: loss={total_loss/len(train_loader):.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
|
||||
|
||||
return model
|
||||
|
||||
def export_onnx(model, output_path: Path, input_shape, input_name="input", output_name="output"):
|
||||
model.set_mode_to_inference()
|
||||
dummy_input = torch.randn(*input_shape)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
output_path,
|
||||
input_names=[input_name],
|
||||
output_names=[output_name],
|
||||
dynamic_axes={
|
||||
input_name: {0: "batch_size"},
|
||||
output_name: {0: "batch_size"},
|
||||
},
|
||||
opset_version=14,
|
||||
)
|
||||
print(f"exported to {output_path}")
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if not HAS_TORCH:
|
||||
print("error: pytorch required for training. install with: pip install torch")
|
||||
return 1
|
||||
|
||||
if not args.data.exists():
|
||||
print(f"error: data file not found: {args.data}")
|
||||
return 1
|
||||
|
||||
if not args.markets.exists():
|
||||
print(f"error: markets file not found: {args.markets}")
|
||||
return 1
|
||||
|
||||
args.output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sequences, features, labels = load_data(args.data, args.markets, args.seq_len)
|
||||
|
||||
if len(sequences) < 100:
|
||||
print(f"error: not enough training data ({len(sequences)} samples)")
|
||||
return 1
|
||||
|
||||
lstm_model = train_lstm(sequences, labels, args)
|
||||
export_onnx(lstm_model, args.output / "lstm.onnx", (1, args.seq_len, 1))
|
||||
|
||||
mlp_model = train_mlp(features, labels, args)
|
||||
export_onnx(mlp_model, args.output / "mlp.onnx", (1, features.shape[1]))
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Training complete!")
|
||||
print(f"Models saved to: {args.output}")
|
||||
print("=" * 50)
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
445
src/backtest.rs
Normal file
445
src/backtest.rs
Normal file
@ -0,0 +1,445 @@
|
||||
use crate::data::HistoricalData;
|
||||
use crate::execution::{Executor, PositionSizingConfig};
|
||||
use crate::metrics::{BacktestResult, MetricsCollector};
|
||||
use crate::pipeline::{
|
||||
AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer, Filter,
|
||||
HistoricalMarketSource, LiquidityFilter, MeanReversionScorer, MomentumScorer,
|
||||
MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer, Selector, Source, TimeDecayScorer,
|
||||
TimeToCloseFilter, TopKSelector, TradingPipeline, VolumeScorer,
|
||||
};
|
||||
use crate::types::{
|
||||
BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType,
|
||||
TradingContext,
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use rust_decimal::Decimal;
|
||||
use rust_decimal::prelude::ToPrimitive;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
/// resolves any positions in markets that have closed
|
||||
/// returns list of (ticker, result, pnl) for logging purposes
|
||||
fn resolve_closed_positions(
|
||||
portfolio: &mut Portfolio,
|
||||
data: &HistoricalData,
|
||||
resolved: &mut HashSet<String>,
|
||||
at: DateTime<Utc>,
|
||||
history: &mut Vec<Trade>,
|
||||
metrics: &mut MetricsCollector,
|
||||
) -> Vec<(String, MarketResult, Option<Decimal>)> {
|
||||
let tickers: Vec<String> = portfolio.positions.keys().cloned().collect();
|
||||
let mut resolutions = Vec::new();
|
||||
|
||||
for ticker in tickers {
|
||||
if resolved.contains(&ticker) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(result) = data.get_resolution_at(&ticker, at) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
resolved.insert(ticker.clone());
|
||||
let Some(pos) = portfolio.positions.get(&ticker).cloned() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let pnl = portfolio.resolve_position(&ticker, result);
|
||||
|
||||
let exit_price = match result {
|
||||
MarketResult::Yes => match pos.side {
|
||||
Side::Yes => Decimal::ONE,
|
||||
Side::No => Decimal::ZERO,
|
||||
},
|
||||
MarketResult::No => match pos.side {
|
||||
Side::Yes => Decimal::ZERO,
|
||||
Side::No => Decimal::ONE,
|
||||
},
|
||||
MarketResult::Cancelled => pos.avg_entry_price,
|
||||
};
|
||||
|
||||
let category = data
|
||||
.markets
|
||||
.get(&ticker)
|
||||
.map(|m| m.category.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let trade = Trade {
|
||||
ticker: ticker.clone(),
|
||||
side: pos.side,
|
||||
quantity: pos.quantity,
|
||||
price: exit_price,
|
||||
timestamp: at,
|
||||
trade_type: TradeType::Resolution,
|
||||
};
|
||||
|
||||
history.push(trade.clone());
|
||||
metrics.record_trade(&trade, &category);
|
||||
resolutions.push((ticker, result, pnl));
|
||||
}
|
||||
|
||||
resolutions
|
||||
}
|
||||
|
||||
pub struct Backtester {
|
||||
config: BacktestConfig,
|
||||
data: Arc<HistoricalData>,
|
||||
pipeline: TradingPipeline,
|
||||
executor: Executor,
|
||||
}
|
||||
|
||||
impl Backtester {
|
||||
pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self {
|
||||
let pipeline = Self::build_default_pipeline(data.clone(), &config);
|
||||
let executor = Executor::new(data.clone(), 10, config.max_position_size);
|
||||
|
||||
Self {
|
||||
config,
|
||||
data,
|
||||
pipeline,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_configs(
|
||||
config: BacktestConfig,
|
||||
data: Arc<HistoricalData>,
|
||||
sizing_config: PositionSizingConfig,
|
||||
exit_config: ExitConfig,
|
||||
) -> Self {
|
||||
let pipeline = Self::build_default_pipeline(data.clone(), &config);
|
||||
let executor = Executor::new(data.clone(), 10, config.max_position_size)
|
||||
.with_sizing_config(sizing_config)
|
||||
.with_exit_config(exit_config);
|
||||
|
||||
Self {
|
||||
config,
|
||||
data,
|
||||
pipeline,
|
||||
executor,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self {
|
||||
self.pipeline = pipeline;
|
||||
self
|
||||
}
|
||||
|
||||
fn build_default_pipeline(data: Arc<HistoricalData>, config: &BacktestConfig) -> TradingPipeline {
|
||||
let sources: Vec<Box<dyn Source>> = vec![
|
||||
Box::new(HistoricalMarketSource::new(data, 24)),
|
||||
];
|
||||
|
||||
let filters: Vec<Box<dyn Filter>> = vec![
|
||||
Box::new(LiquidityFilter::new(100)),
|
||||
Box::new(TimeToCloseFilter::new(2, Some(720))),
|
||||
Box::new(AlreadyPositionedFilter::new(config.max_position_size)),
|
||||
];
|
||||
|
||||
let scorers: Vec<Box<dyn Scorer>> = vec![
|
||||
Box::new(MomentumScorer::new(6)),
|
||||
Box::new(MultiTimeframeMomentumScorer::default_windows()),
|
||||
Box::new(MeanReversionScorer::new(24)),
|
||||
Box::new(BollingerMeanReversionScorer::default_config()),
|
||||
Box::new(VolumeScorer::new(6)),
|
||||
Box::new(OrderFlowScorer::new()),
|
||||
Box::new(TimeDecayScorer::new()),
|
||||
Box::new(CategoryWeightedScorer::with_defaults()),
|
||||
];
|
||||
|
||||
let selector: Box<dyn Selector> = Box::new(TopKSelector::new(config.max_positions));
|
||||
|
||||
TradingPipeline::new(sources, filters, scorers, selector, config.max_positions)
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> BacktestResult {
|
||||
let mut context = TradingContext::new(self.config.initial_capital, self.config.start_time);
|
||||
let mut metrics = MetricsCollector::new(self.config.initial_capital);
|
||||
let mut resolved_markets: HashSet<String> = HashSet::new();
|
||||
|
||||
let mut current_time = self.config.start_time;
|
||||
|
||||
info!(
|
||||
start = %self.config.start_time,
|
||||
end = %self.config.end_time,
|
||||
interval_hours = self.config.interval.num_hours(),
|
||||
"starting backtest"
|
||||
);
|
||||
|
||||
while current_time < self.config.end_time {
|
||||
context.timestamp = current_time;
|
||||
context.request_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let resolutions = resolve_closed_positions(
|
||||
&mut context.portfolio,
|
||||
&self.data,
|
||||
&mut resolved_markets,
|
||||
current_time,
|
||||
&mut context.trading_history,
|
||||
&mut metrics,
|
||||
);
|
||||
for (ticker, result, pnl) in resolutions {
|
||||
info!(ticker = %ticker, result = ?result, pnl = ?pnl, "market resolved");
|
||||
}
|
||||
|
||||
let result = self.pipeline.execute(context.clone()).await;
|
||||
|
||||
let candidate_scores: std::collections::HashMap<String, f64> = result
|
||||
.selected_candidates
|
||||
.iter()
|
||||
.map(|c| (c.ticker.clone(), c.final_score))
|
||||
.collect();
|
||||
|
||||
let exit_signals = self.executor.generate_exit_signals(&context, &candidate_scores);
|
||||
for exit in exit_signals {
|
||||
if let Some(position) = context.portfolio.positions.get(&exit.ticker).cloned() {
|
||||
let pnl = context.portfolio.close_position(&exit.ticker, exit.current_price);
|
||||
|
||||
info!(
|
||||
ticker = %exit.ticker,
|
||||
reason = ?exit.reason,
|
||||
pnl = ?pnl,
|
||||
"exit triggered"
|
||||
);
|
||||
|
||||
let category = self
|
||||
.data
|
||||
.markets
|
||||
.get(&exit.ticker)
|
||||
.map(|m| m.category.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let exit_price = match position.side {
|
||||
crate::types::Side::Yes => exit.current_price,
|
||||
crate::types::Side::No => Decimal::ONE - exit.current_price,
|
||||
};
|
||||
|
||||
let trade = Trade {
|
||||
ticker: exit.ticker.clone(),
|
||||
side: position.side,
|
||||
quantity: position.quantity,
|
||||
price: exit_price,
|
||||
timestamp: current_time,
|
||||
trade_type: TradeType::Close,
|
||||
};
|
||||
|
||||
context.trading_history.push(trade.clone());
|
||||
metrics.record_trade(&trade, &category);
|
||||
}
|
||||
}
|
||||
|
||||
let signals = self.executor.generate_signals(&result.selected_candidates, &context);
|
||||
|
||||
for signal in signals {
|
||||
if let Some(fill) = self.executor.execute_signal(&signal, &context) {
|
||||
info!(
|
||||
ticker = %fill.ticker,
|
||||
side = ?fill.side,
|
||||
quantity = fill.quantity,
|
||||
price = %fill.price,
|
||||
"executed trade"
|
||||
);
|
||||
|
||||
context.portfolio.apply_fill(&fill);
|
||||
|
||||
let category = self
|
||||
.data
|
||||
.markets
|
||||
.get(&fill.ticker)
|
||||
.map(|m| m.category.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
let trade = Trade {
|
||||
ticker: fill.ticker.clone(),
|
||||
side: fill.side,
|
||||
quantity: fill.quantity,
|
||||
price: fill.price,
|
||||
timestamp: fill.timestamp,
|
||||
trade_type: TradeType::Open,
|
||||
};
|
||||
|
||||
context.trading_history.push(trade.clone());
|
||||
metrics.record_trade(&trade, &category);
|
||||
}
|
||||
}
|
||||
|
||||
let market_prices = self.get_current_prices(current_time);
|
||||
metrics.record(current_time, &context.portfolio, &market_prices);
|
||||
|
||||
current_time = current_time + self.config.interval;
|
||||
}
|
||||
|
||||
let resolutions = resolve_closed_positions(
|
||||
&mut context.portfolio,
|
||||
&self.data,
|
||||
&mut resolved_markets,
|
||||
self.config.end_time,
|
||||
&mut context.trading_history,
|
||||
&mut metrics,
|
||||
);
|
||||
for (ticker, result, pnl) in resolutions {
|
||||
info!(ticker = %ticker, result = ?result, pnl = ?pnl, "market resolved");
|
||||
}
|
||||
|
||||
info!(
|
||||
trades = context.trading_history.len(),
|
||||
positions = context.portfolio.positions.len(),
|
||||
cash = %context.portfolio.cash,
|
||||
"backtest complete"
|
||||
);
|
||||
|
||||
metrics.finalize()
|
||||
}
|
||||
|
||||
fn get_current_prices(&self, at: DateTime<Utc>) -> HashMap<String, Decimal> {
|
||||
self.data
|
||||
.markets
|
||||
.keys()
|
||||
.filter_map(|ticker| {
|
||||
self.data
|
||||
.get_current_price(ticker, at)
|
||||
.map(|p| (ticker.clone(), p))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RandomBaseline {
|
||||
config: BacktestConfig,
|
||||
data: Arc<HistoricalData>,
|
||||
}
|
||||
|
||||
impl RandomBaseline {
|
||||
pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self {
|
||||
Self { config, data }
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> BacktestResult {
|
||||
let mut context = TradingContext::new(self.config.initial_capital, self.config.start_time);
|
||||
let mut metrics = MetricsCollector::new(self.config.initial_capital);
|
||||
let mut resolved_markets: HashSet<String> = HashSet::new();
|
||||
let mut rng_state: u64 = 42;
|
||||
|
||||
let mut current_time = self.config.start_time;
|
||||
|
||||
while current_time < self.config.end_time {
|
||||
context.timestamp = current_time;
|
||||
|
||||
resolve_closed_positions(
|
||||
&mut context.portfolio,
|
||||
&self.data,
|
||||
&mut resolved_markets,
|
||||
current_time,
|
||||
&mut context.trading_history,
|
||||
&mut metrics,
|
||||
);
|
||||
|
||||
if let Some(fill) = self.try_random_trade(&context, current_time, &mut rng_state) {
|
||||
let category = self
|
||||
.data
|
||||
.markets
|
||||
.get(&fill.ticker)
|
||||
.map(|m| m.category.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
context.portfolio.apply_fill(&fill);
|
||||
|
||||
let trade = Trade {
|
||||
ticker: fill.ticker.clone(),
|
||||
side: fill.side,
|
||||
quantity: fill.quantity,
|
||||
price: fill.price,
|
||||
timestamp: current_time,
|
||||
trade_type: TradeType::Open,
|
||||
};
|
||||
|
||||
context.trading_history.push(trade.clone());
|
||||
metrics.record_trade(&trade, &category);
|
||||
}
|
||||
|
||||
let market_prices = self.get_current_prices(current_time);
|
||||
metrics.record(current_time, &context.portfolio, &market_prices);
|
||||
|
||||
current_time = current_time + self.config.interval;
|
||||
}
|
||||
|
||||
resolve_closed_positions(
|
||||
&mut context.portfolio,
|
||||
&self.data,
|
||||
&mut resolved_markets,
|
||||
self.config.end_time,
|
||||
&mut context.trading_history,
|
||||
&mut metrics,
|
||||
);
|
||||
|
||||
metrics.finalize()
|
||||
}
|
||||
|
||||
fn try_random_trade(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
at: DateTime<Utc>,
|
||||
rng_state: &mut u64,
|
||||
) -> Option<Fill> {
|
||||
if context.portfolio.positions.len() >= self.config.max_positions {
|
||||
return None;
|
||||
}
|
||||
|
||||
let active_markets = self.data.get_active_markets(at);
|
||||
let unpositioned: Vec<_> = active_markets
|
||||
.iter()
|
||||
.filter(|m| !context.portfolio.has_position(&m.ticker))
|
||||
.collect();
|
||||
|
||||
if unpositioned.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
*rng_state = lcg_next(*rng_state);
|
||||
let idx = (*rng_state as usize) % unpositioned.len();
|
||||
let market = unpositioned[idx];
|
||||
|
||||
let price = self.data.get_current_price(&market.ticker, at)?;
|
||||
let side = if *rng_state % 2 == 0 { Side::Yes } else { Side::No };
|
||||
|
||||
let effective_price = match side {
|
||||
Side::Yes => price,
|
||||
Side::No => Decimal::ONE - price,
|
||||
};
|
||||
|
||||
let quantity = self
|
||||
.config
|
||||
.max_position_size
|
||||
.min((context.portfolio.cash / effective_price).to_u64().unwrap_or(0));
|
||||
|
||||
if quantity == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Fill {
|
||||
ticker: market.ticker.clone(),
|
||||
side,
|
||||
quantity,
|
||||
price: effective_price,
|
||||
timestamp: at,
|
||||
})
|
||||
}
|
||||
|
||||
fn get_current_prices(&self, at: DateTime<Utc>) -> HashMap<String, Decimal> {
|
||||
self.data
|
||||
.markets
|
||||
.keys()
|
||||
.filter_map(|ticker| {
|
||||
self.data
|
||||
.get_current_price(ticker, at)
|
||||
.map(|p| (ticker.clone(), p))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// linear congruential generator for deterministic random baseline
|
||||
fn lcg_next(state: u64) -> u64 {
|
||||
state.wrapping_mul(1103515245).wrapping_add(12345)
|
||||
}
|
||||
290
src/data/loader.rs
Normal file
290
src/data/loader.rs
Normal file
@ -0,0 +1,290 @@
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use csv::ReaderBuilder;
|
||||
use rust_decimal::Decimal;
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::types::{MarketData, MarketResult, PricePoint, Side, TradeData};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CsvMarket {
|
||||
ticker: String,
|
||||
title: String,
|
||||
category: String,
|
||||
#[serde(with = "flexible_datetime")]
|
||||
open_time: DateTime<Utc>,
|
||||
#[serde(with = "flexible_datetime")]
|
||||
close_time: DateTime<Utc>,
|
||||
result: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CsvTrade {
|
||||
#[serde(with = "flexible_datetime")]
|
||||
timestamp: DateTime<Utc>,
|
||||
ticker: String,
|
||||
price: f64,
|
||||
volume: u64,
|
||||
taker_side: String,
|
||||
}
|
||||
|
||||
mod flexible_datetime {
|
||||
use chrono::{DateTime, NaiveDateTime, Utc};
|
||||
use serde::{self, Deserialize, Deserializer};
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
|
||||
if let Ok(dt) = DateTime::parse_from_rfc3339(&s) {
|
||||
return Ok(dt.with_timezone(&Utc));
|
||||
}
|
||||
|
||||
if let Ok(dt) = NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S") {
|
||||
return Ok(dt.and_utc());
|
||||
}
|
||||
|
||||
if let Ok(ts) = s.parse::<i64>() {
|
||||
return DateTime::from_timestamp(ts, 0)
|
||||
.ok_or_else(|| serde::de::Error::custom("invalid timestamp"));
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom(format!(
|
||||
"could not parse datetime: {}",
|
||||
s
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HistoricalData {
|
||||
pub markets: HashMap<String, MarketData>,
|
||||
pub trades: Vec<TradeData>,
|
||||
trade_index: HashMap<String, Vec<usize>>,
|
||||
}
|
||||
|
||||
impl HistoricalData {
|
||||
pub fn load(data_dir: &Path) -> Result<Self> {
|
||||
let markets_path = data_dir.join("markets.csv");
|
||||
let trades_path = data_dir.join("trades.csv");
|
||||
|
||||
let markets = load_markets(&markets_path)
|
||||
.with_context(|| format!("loading markets from {:?}", markets_path))?;
|
||||
|
||||
let trades =
|
||||
load_trades(&trades_path).with_context(|| format!("loading trades from {:?}", trades_path))?;
|
||||
|
||||
let mut trade_index: HashMap<String, Vec<usize>> = HashMap::new();
|
||||
for (i, trade) in trades.iter().enumerate() {
|
||||
trade_index.entry(trade.ticker.clone()).or_default().push(i);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
markets,
|
||||
trades,
|
||||
trade_index,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_active_markets(&self, at: DateTime<Utc>) -> Vec<&MarketData> {
|
||||
self.markets
|
||||
.values()
|
||||
.filter(|m| at >= m.open_time && at < m.close_time)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_trades_for_market(&self, ticker: &str, from: DateTime<Utc>, to: DateTime<Utc>) -> Vec<&TradeData> {
|
||||
self.trade_index
|
||||
.get(ticker)
|
||||
.map(|indices| {
|
||||
indices
|
||||
.iter()
|
||||
.filter_map(|&i| {
|
||||
let trade = &self.trades[i];
|
||||
if trade.timestamp >= from && trade.timestamp < to {
|
||||
Some(trade)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn get_current_price(&self, ticker: &str, at: DateTime<Utc>) -> Option<Decimal> {
|
||||
self.trade_index.get(ticker).and_then(|indices| {
|
||||
indices
|
||||
.iter()
|
||||
.filter_map(|&i| {
|
||||
let trade = &self.trades[i];
|
||||
if trade.timestamp <= at {
|
||||
Some(trade)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.last()
|
||||
.map(|t| t.price)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_price_history(
|
||||
&self,
|
||||
ticker: &str,
|
||||
from: DateTime<Utc>,
|
||||
to: DateTime<Utc>,
|
||||
) -> Vec<PricePoint> {
|
||||
self.get_trades_for_market(ticker, from, to)
|
||||
.into_iter()
|
||||
.map(|t| PricePoint {
|
||||
timestamp: t.timestamp,
|
||||
yes_price: t.price,
|
||||
volume: t.volume,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_volume_24h(&self, ticker: &str, at: DateTime<Utc>) -> u64 {
|
||||
let from = at - chrono::Duration::hours(24);
|
||||
self.get_trades_for_market(ticker, from, at)
|
||||
.iter()
|
||||
.map(|t| t.volume)
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn get_order_flow_24h(&self, ticker: &str, at: DateTime<Utc>) -> (u64, u64) {
|
||||
let from = at - chrono::Duration::hours(24);
|
||||
let trades = self.get_trades_for_market(ticker, from, at);
|
||||
let buy_vol: u64 = trades.iter().filter(|t| t.taker_side == Side::Yes).map(|t| t.volume).sum();
|
||||
let sell_vol: u64 = trades.iter().filter(|t| t.taker_side == Side::No).map(|t| t.volume).sum();
|
||||
(buy_vol, sell_vol)
|
||||
}
|
||||
|
||||
pub fn get_resolutions(&self, at: DateTime<Utc>) -> Vec<(&MarketData, MarketResult)> {
|
||||
self.markets
|
||||
.values()
|
||||
.filter_map(|m| {
|
||||
if m.close_time <= at {
|
||||
m.result.map(|r| (m, r))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn get_resolution_at(&self, ticker: &str, at: DateTime<Utc>) -> Option<MarketResult> {
|
||||
self.markets.get(ticker).and_then(|m| {
|
||||
if m.close_time <= at {
|
||||
m.result
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn load_markets(path: &Path) -> Result<HashMap<String, MarketData>> {
|
||||
let mut reader = ReaderBuilder::new()
|
||||
.has_headers(true)
|
||||
.flexible(true)
|
||||
.from_path(path)?;
|
||||
|
||||
let mut markets = HashMap::new();
|
||||
|
||||
for result in reader.deserialize() {
|
||||
let record: CsvMarket = result?;
|
||||
let result = record.result.as_ref().and_then(|r| match r.to_lowercase().as_str() {
|
||||
"yes" => Some(MarketResult::Yes),
|
||||
"no" => Some(MarketResult::No),
|
||||
"cancelled" | "canceled" => Some(MarketResult::Cancelled),
|
||||
"" => None,
|
||||
_ => None,
|
||||
});
|
||||
|
||||
markets.insert(
|
||||
record.ticker.clone(),
|
||||
MarketData {
|
||||
ticker: record.ticker,
|
||||
title: record.title,
|
||||
category: record.category,
|
||||
open_time: record.open_time,
|
||||
close_time: record.close_time,
|
||||
result,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
Ok(markets)
|
||||
}
|
||||
|
||||
fn load_trades(path: &Path) -> Result<Vec<TradeData>> {
|
||||
let mut reader = ReaderBuilder::new()
|
||||
.has_headers(true)
|
||||
.flexible(true)
|
||||
.from_path(path)?;
|
||||
|
||||
let mut trades = Vec::new();
|
||||
|
||||
for result in reader.deserialize() {
|
||||
let record: CsvTrade = result?;
|
||||
let side = match record.taker_side.to_lowercase().as_str() {
|
||||
"yes" | "buy" => Side::Yes,
|
||||
"no" | "sell" => Side::No,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
trades.push(TradeData {
|
||||
timestamp: record.timestamp,
|
||||
ticker: record.ticker,
|
||||
price: Decimal::try_from(record.price / 100.0).unwrap_or(Decimal::ZERO),
|
||||
volume: record.volume,
|
||||
taker_side: side,
|
||||
});
|
||||
}
|
||||
|
||||
trades.sort_by_key(|t| t.timestamp);
|
||||
|
||||
Ok(trades)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn create_test_data() -> TempDir {
|
||||
let dir = TempDir::new().unwrap();
|
||||
|
||||
let markets_csv = r#"ticker,title,category,open_time,close_time,result
|
||||
TEST-MKT-1,Test Market 1,politics,2024-01-01 00:00:00,2024-01-15 00:00:00,yes
|
||||
TEST-MKT-2,Test Market 2,economics,2024-01-01 00:00:00,2024-01-20 00:00:00,no
|
||||
"#;
|
||||
let mut f = std::fs::File::create(dir.path().join("markets.csv")).unwrap();
|
||||
f.write_all(markets_csv.as_bytes()).unwrap();
|
||||
|
||||
let trades_csv = r#"timestamp,ticker,price,volume,taker_side
|
||||
2024-01-05 12:00:00,TEST-MKT-1,55,100,yes
|
||||
2024-01-05 13:00:00,TEST-MKT-1,57,50,yes
|
||||
2024-01-06 10:00:00,TEST-MKT-2,45,200,no
|
||||
"#;
|
||||
let mut f = std::fs::File::create(dir.path().join("trades.csv")).unwrap();
|
||||
f.write_all(trades_csv.as_bytes()).unwrap();
|
||||
|
||||
dir
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_historical_data() {
|
||||
let dir = create_test_data();
|
||||
let data = HistoricalData::load(dir.path()).unwrap();
|
||||
|
||||
assert_eq!(data.markets.len(), 2);
|
||||
assert_eq!(data.trades.len(), 3);
|
||||
}
|
||||
}
|
||||
3
src/data/mod.rs
Normal file
3
src/data/mod.rs
Normal file
@ -0,0 +1,3 @@
|
||||
mod loader;
|
||||
|
||||
pub use loader::HistoricalData;
|
||||
325
src/execution.rs
Normal file
325
src/execution.rs
Normal file
@ -0,0 +1,325 @@
|
||||
use crate::data::HistoricalData;
|
||||
use crate::types::{ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side, Signal, TradingContext};
|
||||
use rust_decimal::Decimal;
|
||||
use rust_decimal::prelude::ToPrimitive;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PositionSizingConfig {
|
||||
pub kelly_fraction: f64,
|
||||
pub max_position_pct: f64,
|
||||
pub min_position_size: u64,
|
||||
pub max_position_size: u64,
|
||||
}
|
||||
|
||||
impl Default for PositionSizingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
kelly_fraction: 0.25,
|
||||
max_position_pct: 0.25,
|
||||
min_position_size: 10,
|
||||
max_position_size: 1000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PositionSizingConfig {
|
||||
pub fn conservative() -> Self {
|
||||
Self {
|
||||
kelly_fraction: 0.1,
|
||||
max_position_pct: 0.1,
|
||||
min_position_size: 10,
|
||||
max_position_size: 500,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn aggressive() -> Self {
|
||||
Self {
|
||||
kelly_fraction: 0.5,
|
||||
max_position_pct: 0.4,
|
||||
min_position_size: 10,
|
||||
max_position_size: 2000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// maps scoring edge [-inf, +inf] to win probability [0, 1]
|
||||
/// tanh squashes extreme values smoothly; +1)/2 shifts from [-1,1] to [0,1]
|
||||
fn edge_to_win_probability(edge: f64) -> f64 {
|
||||
(1.0 + edge.tanh()) / 2.0
|
||||
}
|
||||
|
||||
fn kelly_size(
|
||||
edge: f64,
|
||||
price: f64,
|
||||
bankroll: f64,
|
||||
config: &PositionSizingConfig,
|
||||
) -> u64 {
|
||||
if edge.abs() < 0.01 || price <= 0.0 || price >= 1.0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let win_prob = edge_to_win_probability(edge);
|
||||
let odds = (1.0 - price) / price;
|
||||
|
||||
if odds <= 0.0 {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let kelly = (odds * win_prob - (1.0 - win_prob)) / odds;
|
||||
let safe_kelly = (kelly * config.kelly_fraction).max(0.0);
|
||||
let position_value = bankroll * safe_kelly.min(config.max_position_pct);
|
||||
let shares = (position_value / price).floor() as u64;
|
||||
|
||||
shares.max(config.min_position_size).min(config.max_position_size)
|
||||
}
|
||||
|
||||
pub struct Executor {
|
||||
data: Arc<HistoricalData>,
|
||||
slippage_bps: u32,
|
||||
max_position_size: u64,
|
||||
sizing_config: PositionSizingConfig,
|
||||
exit_config: ExitConfig,
|
||||
}
|
||||
|
||||
impl Executor {
|
||||
pub fn new(data: Arc<HistoricalData>, slippage_bps: u32, max_position_size: u64) -> Self {
|
||||
Self {
|
||||
data,
|
||||
slippage_bps,
|
||||
max_position_size,
|
||||
sizing_config: PositionSizingConfig::default(),
|
||||
exit_config: ExitConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_sizing_config(mut self, config: PositionSizingConfig) -> Self {
|
||||
self.sizing_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_exit_config(mut self, config: ExitConfig) -> Self {
|
||||
self.exit_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn generate_exit_signals(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidate_scores: &std::collections::HashMap<String, f64>,
|
||||
) -> Vec<ExitSignal> {
|
||||
let mut exits = Vec::new();
|
||||
|
||||
for (ticker, position) in &context.portfolio.positions {
|
||||
let current_price = match self.data.get_current_price(ticker, context.timestamp) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let effective_price = match position.side {
|
||||
Side::Yes => current_price,
|
||||
Side::No => Decimal::ONE - current_price,
|
||||
};
|
||||
|
||||
let entry_price_f64 = position.avg_entry_price.to_f64().unwrap_or(0.5);
|
||||
let current_price_f64 = effective_price.to_f64().unwrap_or(0.5);
|
||||
|
||||
if entry_price_f64 <= 0.0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let pnl_pct = (current_price_f64 - entry_price_f64) / entry_price_f64;
|
||||
|
||||
if pnl_pct >= self.exit_config.take_profit_pct {
|
||||
exits.push(ExitSignal {
|
||||
ticker: ticker.clone(),
|
||||
reason: ExitReason::TakeProfit { pnl_pct },
|
||||
current_price,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if pnl_pct <= -self.exit_config.stop_loss_pct {
|
||||
exits.push(ExitSignal {
|
||||
ticker: ticker.clone(),
|
||||
reason: ExitReason::StopLoss { pnl_pct },
|
||||
current_price,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
let hours_held = (context.timestamp - position.entry_time).num_hours();
|
||||
if hours_held >= self.exit_config.max_hold_hours {
|
||||
exits.push(ExitSignal {
|
||||
ticker: ticker.clone(),
|
||||
reason: ExitReason::TimeStop { hours_held },
|
||||
current_price,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(&new_score) = candidate_scores.get(ticker) {
|
||||
if new_score < self.exit_config.score_reversal_threshold {
|
||||
exits.push(ExitSignal {
|
||||
ticker: ticker.clone(),
|
||||
reason: ExitReason::ScoreReversal { new_score },
|
||||
current_price,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
exits
|
||||
}
|
||||
|
||||
pub fn generate_signals(
|
||||
&self,
|
||||
candidates: &[MarketCandidate],
|
||||
context: &TradingContext,
|
||||
) -> Vec<Signal> {
|
||||
candidates
|
||||
.iter()
|
||||
.filter_map(|c| self.candidate_to_signal(c, context))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn candidate_to_signal(
|
||||
&self,
|
||||
candidate: &MarketCandidate,
|
||||
context: &TradingContext,
|
||||
) -> Option<Signal> {
|
||||
let current_position = context.portfolio.get_position(&candidate.ticker);
|
||||
let current_qty = current_position.map(|p| p.quantity).unwrap_or(0);
|
||||
|
||||
if current_qty >= self.max_position_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
let yes_price = candidate.current_yes_price.to_f64().unwrap_or(0.5);
|
||||
|
||||
// positive score = bullish signal, so buy the cheaper side (better risk/reward)
|
||||
// negative score = bearish signal, so buy against the expensive side
|
||||
let side = if candidate.final_score > 0.0 {
|
||||
if yes_price < 0.5 { Side::Yes } else { Side::No }
|
||||
} else if candidate.final_score < 0.0 {
|
||||
if yes_price > 0.5 { Side::No } else { Side::Yes }
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let price = match side {
|
||||
Side::Yes => candidate.current_yes_price,
|
||||
Side::No => candidate.current_no_price,
|
||||
};
|
||||
|
||||
let available_cash = context.portfolio.cash.to_f64().unwrap_or(0.0);
|
||||
let price_f64 = price.to_f64().unwrap_or(0.5);
|
||||
|
||||
if price_f64 <= 0.0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let kelly_qty = kelly_size(
|
||||
candidate.final_score,
|
||||
price_f64,
|
||||
available_cash,
|
||||
&self.sizing_config,
|
||||
);
|
||||
|
||||
let max_affordable = (available_cash / price_f64) as u64;
|
||||
let quantity = kelly_qty
|
||||
.min(max_affordable)
|
||||
.min(self.max_position_size - current_qty);
|
||||
|
||||
if quantity < self.sizing_config.min_position_size {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Signal {
|
||||
ticker: candidate.ticker.clone(),
|
||||
side,
|
||||
quantity,
|
||||
limit_price: Some(price),
|
||||
reason: format!(
|
||||
"score={:.3}, side={:?}, price={:.2}",
|
||||
candidate.final_score, side, price_f64
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn execute_signal(
|
||||
&self,
|
||||
signal: &Signal,
|
||||
context: &TradingContext,
|
||||
) -> Option<Fill> {
|
||||
let market_price = self.data.get_current_price(&signal.ticker, context.timestamp)?;
|
||||
|
||||
let effective_price = match signal.side {
|
||||
Side::Yes => market_price,
|
||||
Side::No => Decimal::ONE - market_price,
|
||||
};
|
||||
|
||||
let slippage = Decimal::from(self.slippage_bps) / Decimal::from(10000);
|
||||
let fill_price = effective_price * (Decimal::ONE + slippage);
|
||||
|
||||
if let Some(limit) = signal.limit_price {
|
||||
if fill_price > limit * (Decimal::ONE + slippage * Decimal::from(2)) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let cost = fill_price * Decimal::from(signal.quantity);
|
||||
if cost > context.portfolio.cash {
|
||||
let affordable = (context.portfolio.cash / fill_price)
|
||||
.to_u64()
|
||||
.unwrap_or(0);
|
||||
if affordable == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
return Some(Fill {
|
||||
ticker: signal.ticker.clone(),
|
||||
side: signal.side,
|
||||
quantity: affordable,
|
||||
price: fill_price,
|
||||
timestamp: context.timestamp,
|
||||
});
|
||||
}
|
||||
|
||||
Some(Fill {
|
||||
ticker: signal.ticker.clone(),
|
||||
side: signal.side,
|
||||
quantity: signal.quantity,
|
||||
price: fill_price,
|
||||
timestamp: context.timestamp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn simple_signal_generator(
|
||||
candidates: &[MarketCandidate],
|
||||
context: &TradingContext,
|
||||
position_size: u64,
|
||||
) -> Vec<Signal> {
|
||||
candidates
|
||||
.iter()
|
||||
.filter(|c| c.final_score > 0.0)
|
||||
.filter(|c| !context.portfolio.has_position(&c.ticker))
|
||||
.map(|c| {
|
||||
let yes_price = c.current_yes_price.to_f64().unwrap_or(0.5);
|
||||
let (side, price) = if yes_price < 0.5 {
|
||||
(Side::Yes, c.current_yes_price)
|
||||
} else {
|
||||
(Side::No, c.current_no_price)
|
||||
};
|
||||
|
||||
Signal {
|
||||
ticker: c.ticker.clone(),
|
||||
side,
|
||||
quantity: position_size,
|
||||
limit_price: Some(price),
|
||||
reason: format!("simple: score={:.3}", c.final_score),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
216
src/main.rs
Normal file
216
src/main.rs
Normal file
@ -0,0 +1,216 @@
|
||||
mod backtest;
|
||||
mod data;
|
||||
mod execution;
|
||||
mod metrics;
|
||||
mod pipeline;
|
||||
mod types;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use backtest::{Backtester, RandomBaseline};
|
||||
use chrono::{DateTime, NaiveDate, TimeZone, Utc};
|
||||
use clap::{Parser, Subcommand};
|
||||
use data::HistoricalData;
|
||||
use execution::{Executor, PositionSizingConfig};
|
||||
use rust_decimal::Decimal;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
use types::{BacktestConfig, ExitConfig};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "kalshi-backtest")]
|
||||
#[command(about = "backtesting framework for kalshi prediction markets")]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
Run {
|
||||
#[arg(short, long, default_value = "data")]
|
||||
data_dir: PathBuf,
|
||||
|
||||
#[arg(long)]
|
||||
start: String,
|
||||
|
||||
#[arg(long)]
|
||||
end: String,
|
||||
|
||||
#[arg(long, default_value = "10000")]
|
||||
capital: f64,
|
||||
|
||||
#[arg(long, default_value = "100")]
|
||||
max_position: u64,
|
||||
|
||||
#[arg(long, default_value = "5")]
|
||||
max_positions: usize,
|
||||
|
||||
#[arg(long, default_value = "1")]
|
||||
interval_hours: i64,
|
||||
|
||||
#[arg(long, default_value = "results")]
|
||||
output_dir: PathBuf,
|
||||
|
||||
#[arg(long)]
|
||||
compare_random: bool,
|
||||
|
||||
#[arg(long, default_value = "0.25")]
|
||||
kelly_fraction: f64,
|
||||
|
||||
#[arg(long, default_value = "0.25")]
|
||||
max_position_pct: f64,
|
||||
|
||||
#[arg(long, default_value = "0.20")]
|
||||
take_profit: f64,
|
||||
|
||||
#[arg(long, default_value = "0.15")]
|
||||
stop_loss: f64,
|
||||
|
||||
#[arg(long, default_value = "72")]
|
||||
max_hold_hours: i64,
|
||||
},
|
||||
|
||||
Summary {
|
||||
#[arg(short, long)]
|
||||
results_file: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
fn parse_date(s: &str) -> Result<DateTime<Utc>> {
|
||||
if let Ok(dt) = DateTime::parse_from_rfc3339(s) {
|
||||
return Ok(dt.with_timezone(&Utc));
|
||||
}
|
||||
|
||||
if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
|
||||
return Ok(Utc.from_utc_datetime(&date.and_hms_opt(0, 0, 0).unwrap()));
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("could not parse date: {}", s))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "kalshi_backtest=info".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
Commands::Run {
|
||||
data_dir,
|
||||
start,
|
||||
end,
|
||||
capital,
|
||||
max_position,
|
||||
max_positions,
|
||||
interval_hours,
|
||||
output_dir,
|
||||
compare_random,
|
||||
kelly_fraction,
|
||||
max_position_pct,
|
||||
take_profit,
|
||||
stop_loss,
|
||||
max_hold_hours,
|
||||
} => {
|
||||
let start_time = parse_date(&start).context("parsing start date")?;
|
||||
let end_time = parse_date(&end).context("parsing end date")?;
|
||||
|
||||
info!(
|
||||
data_dir = %data_dir.display(),
|
||||
start = %start_time,
|
||||
end = %end_time,
|
||||
capital = capital,
|
||||
"loading data"
|
||||
);
|
||||
|
||||
let data = Arc::new(
|
||||
HistoricalData::load(&data_dir).context("loading historical data")?,
|
||||
);
|
||||
|
||||
info!(
|
||||
markets = data.markets.len(),
|
||||
trades = data.trades.len(),
|
||||
"data loaded"
|
||||
);
|
||||
|
||||
let config = BacktestConfig {
|
||||
start_time,
|
||||
end_time,
|
||||
interval: chrono::Duration::hours(interval_hours),
|
||||
initial_capital: Decimal::try_from(capital).unwrap(),
|
||||
max_position_size: max_position,
|
||||
max_positions,
|
||||
};
|
||||
|
||||
let sizing_config = PositionSizingConfig {
|
||||
kelly_fraction,
|
||||
max_position_pct,
|
||||
min_position_size: 10,
|
||||
max_position_size: max_position,
|
||||
};
|
||||
|
||||
let exit_config = ExitConfig {
|
||||
take_profit_pct: take_profit,
|
||||
stop_loss_pct: stop_loss,
|
||||
max_hold_hours,
|
||||
score_reversal_threshold: -0.3,
|
||||
};
|
||||
|
||||
let backtester = Backtester::with_configs(config.clone(), data.clone(), sizing_config, exit_config);
|
||||
let result = backtester.run().await;
|
||||
|
||||
println!("{}", result.summary());
|
||||
|
||||
std::fs::create_dir_all(&output_dir)?;
|
||||
let result_path = output_dir.join("backtest_result.json");
|
||||
let json = serde_json::to_string_pretty(&result)?;
|
||||
std::fs::write(&result_path, json)?;
|
||||
info!(path = %result_path.display(), "results saved");
|
||||
|
||||
if compare_random {
|
||||
println!("\n--- random baseline ---\n");
|
||||
let baseline = RandomBaseline::new(config, data);
|
||||
let baseline_result = baseline.run().await;
|
||||
println!("{}", baseline_result.summary());
|
||||
|
||||
let baseline_path = output_dir.join("baseline_result.json");
|
||||
let json = serde_json::to_string_pretty(&baseline_result)?;
|
||||
std::fs::write(&baseline_path, json)?;
|
||||
|
||||
println!("\n--- comparison ---\n");
|
||||
println!(
|
||||
"strategy return: {:.2}% vs baseline: {:.2}%",
|
||||
result.total_return_pct, baseline_result.total_return_pct
|
||||
);
|
||||
println!(
|
||||
"strategy sharpe: {:.3} vs baseline: {:.3}",
|
||||
result.sharpe_ratio, baseline_result.sharpe_ratio
|
||||
);
|
||||
println!(
|
||||
"strategy win rate: {:.1}% vs baseline: {:.1}%",
|
||||
result.win_rate, baseline_result.win_rate
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Commands::Summary { results_file } => {
|
||||
let content = std::fs::read_to_string(&results_file)
|
||||
.context("reading results file")?;
|
||||
let result: metrics::BacktestResult =
|
||||
serde_json::from_str(&content).context("parsing results")?;
|
||||
|
||||
println!("{}", result.summary());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
280
src/metrics.rs
Normal file
280
src/metrics.rs
Normal file
@ -0,0 +1,280 @@
|
||||
use crate::types::{Portfolio, Trade, TradeType};
|
||||
use chrono::{DateTime, Utc};
|
||||
use rust_decimal::Decimal;
|
||||
use rust_decimal::prelude::ToPrimitive;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BacktestResult {
|
||||
pub total_return: f64,
|
||||
pub total_return_pct: f64,
|
||||
pub sharpe_ratio: f64,
|
||||
pub max_drawdown: f64,
|
||||
pub max_drawdown_pct: f64,
|
||||
pub win_rate: f64,
|
||||
pub total_trades: usize,
|
||||
pub winning_trades: usize,
|
||||
pub losing_trades: usize,
|
||||
pub avg_trade_pnl: f64,
|
||||
pub avg_hold_time_hours: f64,
|
||||
pub trades_per_day: f64,
|
||||
pub return_by_category: HashMap<String, f64>,
|
||||
pub equity_curve: Vec<EquityPoint>,
|
||||
pub trade_log: Vec<TradeRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EquityPoint {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub equity: f64,
|
||||
pub cash: f64,
|
||||
pub positions_value: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TradeRecord {
|
||||
pub ticker: String,
|
||||
pub entry_time: DateTime<Utc>,
|
||||
pub exit_time: Option<DateTime<Utc>>,
|
||||
pub side: String,
|
||||
pub quantity: u64,
|
||||
pub entry_price: f64,
|
||||
pub exit_price: Option<f64>,
|
||||
pub pnl: Option<f64>,
|
||||
pub category: String,
|
||||
}
|
||||
|
||||
pub struct MetricsCollector {
|
||||
initial_capital: Decimal,
|
||||
equity_curve: Vec<EquityPoint>,
|
||||
trade_records: HashMap<String, TradeRecord>,
|
||||
closed_trades: Vec<TradeRecord>,
|
||||
daily_returns: Vec<f64>,
|
||||
last_equity: f64,
|
||||
peak_equity: f64,
|
||||
max_drawdown: f64,
|
||||
}
|
||||
|
||||
impl MetricsCollector {
|
||||
pub fn new(initial_capital: Decimal) -> Self {
|
||||
let capital = initial_capital.to_f64().unwrap_or(10000.0);
|
||||
Self {
|
||||
initial_capital,
|
||||
equity_curve: Vec::new(),
|
||||
trade_records: HashMap::new(),
|
||||
closed_trades: Vec::new(),
|
||||
daily_returns: Vec::new(),
|
||||
last_equity: capital,
|
||||
peak_equity: capital,
|
||||
max_drawdown: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record(
|
||||
&mut self,
|
||||
timestamp: DateTime<Utc>,
|
||||
portfolio: &Portfolio,
|
||||
market_prices: &HashMap<String, Decimal>,
|
||||
) {
|
||||
let positions_value = portfolio
|
||||
.positions
|
||||
.values()
|
||||
.map(|p| {
|
||||
let price = market_prices
|
||||
.get(&p.ticker)
|
||||
.copied()
|
||||
.unwrap_or(p.avg_entry_price);
|
||||
(price * Decimal::from(p.quantity)).to_f64().unwrap_or(0.0)
|
||||
})
|
||||
.sum();
|
||||
|
||||
let cash = portfolio.cash.to_f64().unwrap_or(0.0);
|
||||
let equity = cash + positions_value;
|
||||
|
||||
if equity > self.peak_equity {
|
||||
self.peak_equity = equity;
|
||||
}
|
||||
|
||||
let drawdown = (self.peak_equity - equity) / self.peak_equity;
|
||||
if drawdown > self.max_drawdown {
|
||||
self.max_drawdown = drawdown;
|
||||
}
|
||||
|
||||
if self.last_equity > 0.0 {
|
||||
let daily_return = (equity - self.last_equity) / self.last_equity;
|
||||
self.daily_returns.push(daily_return);
|
||||
}
|
||||
self.last_equity = equity;
|
||||
|
||||
self.equity_curve.push(EquityPoint {
|
||||
timestamp,
|
||||
equity,
|
||||
cash,
|
||||
positions_value,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn record_trade(&mut self, trade: &Trade, category: &str) {
|
||||
match trade.trade_type {
|
||||
TradeType::Open => {
|
||||
let record = TradeRecord {
|
||||
ticker: trade.ticker.clone(),
|
||||
entry_time: trade.timestamp,
|
||||
exit_time: None,
|
||||
side: format!("{:?}", trade.side),
|
||||
quantity: trade.quantity,
|
||||
entry_price: trade.price.to_f64().unwrap_or(0.0),
|
||||
exit_price: None,
|
||||
pnl: None,
|
||||
category: category.to_string(),
|
||||
};
|
||||
self.trade_records.insert(trade.ticker.clone(), record);
|
||||
}
|
||||
TradeType::Close | TradeType::Resolution => {
|
||||
if let Some(mut record) = self.trade_records.remove(&trade.ticker) {
|
||||
let exit_price = trade.price.to_f64().unwrap_or(0.0);
|
||||
let entry_cost = record.entry_price * record.quantity as f64;
|
||||
let exit_value = exit_price * record.quantity as f64;
|
||||
let pnl = exit_value - entry_cost;
|
||||
|
||||
record.exit_time = Some(trade.timestamp);
|
||||
record.exit_price = Some(exit_price);
|
||||
record.pnl = Some(pnl);
|
||||
|
||||
self.closed_trades.push(record);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn finalize(self) -> BacktestResult {
|
||||
let initial = self.initial_capital.to_f64().unwrap_or(10000.0);
|
||||
let final_equity = self.equity_curve.last().map(|e| e.equity).unwrap_or(initial);
|
||||
let total_return = final_equity - initial;
|
||||
let total_return_pct = total_return / initial * 100.0;
|
||||
|
||||
let sharpe_ratio = if self.daily_returns.len() > 1 {
|
||||
let mean: f64 = self.daily_returns.iter().sum::<f64>() / self.daily_returns.len() as f64;
|
||||
let variance: f64 = self
|
||||
.daily_returns
|
||||
.iter()
|
||||
.map(|r| (r - mean).powi(2))
|
||||
.sum::<f64>()
|
||||
/ (self.daily_returns.len() - 1) as f64;
|
||||
let std_dev = variance.sqrt();
|
||||
if std_dev > 0.0 {
|
||||
(mean / std_dev) * (252.0_f64).sqrt()
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let winning_trades = self.closed_trades.iter().filter(|t| t.pnl.unwrap_or(0.0) > 0.0).count();
|
||||
let losing_trades = self.closed_trades.iter().filter(|t| t.pnl.unwrap_or(0.0) < 0.0).count();
|
||||
let total_trades = self.closed_trades.len();
|
||||
|
||||
let win_rate = if total_trades > 0 {
|
||||
winning_trades as f64 / total_trades as f64 * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let avg_trade_pnl = if total_trades > 0 {
|
||||
self.closed_trades.iter().filter_map(|t| t.pnl).sum::<f64>() / total_trades as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let avg_hold_time_hours = if total_trades > 0 {
|
||||
self.closed_trades
|
||||
.iter()
|
||||
.filter_map(|t| {
|
||||
t.exit_time.map(|exit| (exit - t.entry_time).num_hours() as f64)
|
||||
})
|
||||
.sum::<f64>()
|
||||
/ total_trades as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let duration_days = if self.equity_curve.len() >= 2 {
|
||||
let start = self.equity_curve.first().unwrap().timestamp;
|
||||
let end = self.equity_curve.last().unwrap().timestamp;
|
||||
(end - start).num_days().max(1) as f64
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let trades_per_day = total_trades as f64 / duration_days;
|
||||
|
||||
let mut return_by_category: HashMap<String, f64> = HashMap::new();
|
||||
for trade in &self.closed_trades {
|
||||
*return_by_category.entry(trade.category.clone()).or_insert(0.0) +=
|
||||
trade.pnl.unwrap_or(0.0);
|
||||
}
|
||||
|
||||
BacktestResult {
|
||||
total_return,
|
||||
total_return_pct,
|
||||
sharpe_ratio,
|
||||
max_drawdown: self.max_drawdown * 100.0,
|
||||
max_drawdown_pct: self.max_drawdown * 100.0,
|
||||
win_rate,
|
||||
total_trades,
|
||||
winning_trades,
|
||||
losing_trades,
|
||||
avg_trade_pnl,
|
||||
avg_hold_time_hours,
|
||||
trades_per_day,
|
||||
return_by_category,
|
||||
equity_curve: self.equity_curve,
|
||||
trade_log: self.closed_trades,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BacktestResult {
|
||||
pub fn summary(&self) -> String {
|
||||
format!(
|
||||
r#"
|
||||
backtest results
|
||||
================
|
||||
|
||||
performance
|
||||
-----------
|
||||
total return: ${:.2} ({:.2}%)
|
||||
sharpe ratio: {:.3}
|
||||
max drawdown: {:.2}%
|
||||
|
||||
trades
|
||||
------
|
||||
total trades: {}
|
||||
win rate: {:.1}%
|
||||
avg trade pnl: ${:.2}
|
||||
avg hold time: {:.1} hours
|
||||
trades per day: {:.2}
|
||||
|
||||
by category
|
||||
-----------
|
||||
{}
|
||||
"#,
|
||||
self.total_return,
|
||||
self.total_return_pct,
|
||||
self.sharpe_ratio,
|
||||
self.max_drawdown_pct,
|
||||
self.total_trades,
|
||||
self.win_rate,
|
||||
self.avg_trade_pnl,
|
||||
self.avg_hold_time_hours,
|
||||
self.trades_per_day,
|
||||
self.return_by_category
|
||||
.iter()
|
||||
.map(|(k, v)| format!(" {}: ${:.2}", k, v))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
)
|
||||
}
|
||||
}
|
||||
259
src/pipeline/correlation_scorer.rs
Normal file
259
src/pipeline/correlation_scorer.rs
Normal file
@ -0,0 +1,259 @@
|
||||
//! Cross-market correlation scorer
|
||||
//!
|
||||
//! Uses lead-lag relationships between related markets to generate signals.
|
||||
//! When a related market moves, we expect similar movements in correlated markets.
|
||||
|
||||
use crate::pipeline::Scorer;
|
||||
use crate::types::{MarketCandidate, TradingContext};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// correlation entry between two markets
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CorrelationEntry {
|
||||
pub ticker_a: String,
|
||||
pub ticker_b: String,
|
||||
pub correlation: f64,
|
||||
pub lag_hours: i64,
|
||||
}
|
||||
|
||||
/// cross-market correlation scorer
|
||||
/// uses precomputed correlations to generate signals based on related market movements
|
||||
pub struct CorrelationScorer {
|
||||
correlations: Arc<RwLock<HashMap<String, Vec<CorrelationEntry>>>>,
|
||||
lookback_hours: i64,
|
||||
min_correlation: f64,
|
||||
}
|
||||
|
||||
impl CorrelationScorer {
|
||||
pub fn new(lookback_hours: i64, min_correlation: f64) -> Self {
|
||||
Self {
|
||||
correlations: Arc::new(RwLock::new(HashMap::new())),
|
||||
lookback_hours,
|
||||
min_correlation,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(24, 0.5)
|
||||
}
|
||||
|
||||
pub fn load_correlations(&self, correlations: Vec<CorrelationEntry>) {
|
||||
let mut map = self.correlations.write().unwrap();
|
||||
map.clear();
|
||||
|
||||
for entry in correlations {
|
||||
map.entry(entry.ticker_a.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(entry.clone());
|
||||
map.entry(entry.ticker_b.clone())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(CorrelationEntry {
|
||||
ticker_a: entry.ticker_b.clone(),
|
||||
ticker_b: entry.ticker_a.clone(),
|
||||
correlation: entry.correlation,
|
||||
lag_hours: -entry.lag_hours,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_category_correlations(&self, categories: &[&str]) {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
for (i, cat_a) in categories.iter().enumerate() {
|
||||
for cat_b in categories.iter().skip(i + 1) {
|
||||
entries.push(CorrelationEntry {
|
||||
ticker_a: cat_a.to_string(),
|
||||
ticker_b: cat_b.to_string(),
|
||||
correlation: 0.3,
|
||||
lag_hours: 0,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.load_correlations(entries);
|
||||
}
|
||||
|
||||
fn get_related_signals(
|
||||
&self,
|
||||
ticker: &str,
|
||||
all_candidates: &[MarketCandidate],
|
||||
) -> f64 {
|
||||
let correlations = self.correlations.read().unwrap();
|
||||
let Some(related) = correlations.get(ticker) else {
|
||||
return 0.0;
|
||||
};
|
||||
|
||||
let candidate_map: HashMap<&str, &MarketCandidate> = all_candidates
|
||||
.iter()
|
||||
.map(|c| (c.ticker.as_str(), c))
|
||||
.collect();
|
||||
|
||||
let mut weighted_signal = 0.0;
|
||||
let mut total_weight = 0.0;
|
||||
|
||||
for entry in related {
|
||||
if entry.correlation.abs() < self.min_correlation {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(related_candidate) = candidate_map.get(entry.ticker_b.as_str()) {
|
||||
let related_momentum = related_candidate
|
||||
.scores
|
||||
.get("momentum")
|
||||
.copied()
|
||||
.unwrap_or(0.0);
|
||||
|
||||
let signal = related_momentum * entry.correlation;
|
||||
weighted_signal += signal * entry.correlation.abs();
|
||||
total_weight += entry.correlation.abs();
|
||||
}
|
||||
}
|
||||
|
||||
if total_weight > 0.0 {
|
||||
weighted_signal / total_weight
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_category_correlation(
|
||||
candidate: &MarketCandidate,
|
||||
all_candidates: &[MarketCandidate],
|
||||
) -> f64 {
|
||||
let same_category: Vec<&MarketCandidate> = all_candidates
|
||||
.iter()
|
||||
.filter(|c| c.category == candidate.category && c.ticker != candidate.ticker)
|
||||
.collect();
|
||||
|
||||
if same_category.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let avg_momentum: f64 = same_category
|
||||
.iter()
|
||||
.filter_map(|c| c.scores.get("momentum").copied())
|
||||
.sum::<f64>()
|
||||
/ same_category.len() as f64;
|
||||
|
||||
avg_momentum * 0.3
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for CorrelationScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"CorrelationScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let related_signal = self.get_related_signals(&c.ticker, candidates);
|
||||
let category_signal = Self::calculate_category_correlation(c, candidates);
|
||||
let combined = related_signal * 0.7 + category_signal * 0.3;
|
||||
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("cross_market".to_string(), combined);
|
||||
scored.scores.insert("related_signal".to_string(), related_signal);
|
||||
scored.scores.insert("category_signal".to_string(), category_signal);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
for key in ["cross_market", "related_signal", "category_signal"] {
|
||||
if let Some(score) = scored.scores.get(key) {
|
||||
candidate.scores.insert(key.to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// granger causality calculator (simplified version)
|
||||
/// full implementation would use statistical tests
|
||||
pub fn calculate_granger_causality(
|
||||
prices_a: &[f64],
|
||||
prices_b: &[f64],
|
||||
max_lag: i64,
|
||||
) -> Option<(f64, i64)> {
|
||||
if prices_a.len() < 20 || prices_b.len() < 20 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let min_len = prices_a.len().min(prices_b.len());
|
||||
let prices_a = &prices_a[..min_len];
|
||||
let prices_b = &prices_b[..min_len];
|
||||
|
||||
let returns_a: Vec<f64> = prices_a
|
||||
.windows(2)
|
||||
.map(|w| if w[0] > 0.0 { (w[1] / w[0]).ln() } else { 0.0 })
|
||||
.collect();
|
||||
let returns_b: Vec<f64> = prices_b
|
||||
.windows(2)
|
||||
.map(|w| if w[0] > 0.0 { (w[1] / w[0]).ln() } else { 0.0 })
|
||||
.collect();
|
||||
|
||||
let mut best_corr: f64 = 0.0;
|
||||
let mut best_lag: i64 = 0;
|
||||
|
||||
for lag in -max_lag..=max_lag {
|
||||
let (a_slice, b_slice) = if lag >= 0 {
|
||||
let l = lag as usize;
|
||||
if l >= returns_a.len() || l >= returns_b.len() {
|
||||
continue;
|
||||
}
|
||||
(&returns_a[l..], &returns_b[..returns_b.len() - l])
|
||||
} else {
|
||||
let l = (-lag) as usize;
|
||||
if l >= returns_a.len() || l >= returns_b.len() {
|
||||
continue;
|
||||
}
|
||||
(&returns_a[..returns_a.len() - l], &returns_b[l..])
|
||||
};
|
||||
|
||||
if a_slice.len() < 10 || b_slice.len() < 10 || a_slice.len() != b_slice.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let n = a_slice.len() as f64;
|
||||
let mean_a: f64 = a_slice.iter().sum::<f64>() / n;
|
||||
let mean_b: f64 = b_slice.iter().sum::<f64>() / n;
|
||||
|
||||
let cov: f64 = a_slice
|
||||
.iter()
|
||||
.zip(b_slice.iter())
|
||||
.map(|(a, b)| (a - mean_a) * (b - mean_b))
|
||||
.sum::<f64>()
|
||||
/ n;
|
||||
|
||||
let std_a: f64 = (a_slice.iter().map(|a| (a - mean_a).powi(2)).sum::<f64>() / n).sqrt();
|
||||
let std_b: f64 = (b_slice.iter().map(|b| (b - mean_b).powi(2)).sum::<f64>() / n).sqrt();
|
||||
|
||||
if std_a > 0.0 && std_b > 0.0 {
|
||||
let corr = cov / (std_a * std_b);
|
||||
if corr.abs() > best_corr.abs() {
|
||||
best_corr = corr;
|
||||
best_lag = lag;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if best_corr.abs() > 0.1 {
|
||||
Some((best_corr, best_lag))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
182
src/pipeline/filters.rs
Normal file
182
src/pipeline/filters.rs
Normal file
@ -0,0 +1,182 @@
|
||||
use crate::pipeline::{Filter, FilterResult};
|
||||
use crate::types::{MarketCandidate, TradingContext};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Duration;
|
||||
use std::collections::HashSet;
|
||||
|
||||
pub struct LiquidityFilter {
|
||||
min_volume_24h: u64,
|
||||
}
|
||||
|
||||
impl LiquidityFilter {
|
||||
pub fn new(min_volume_24h: u64) -> Self {
|
||||
Self { min_volume_24h }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Filter for LiquidityFilter {
|
||||
fn name(&self) -> &'static str {
|
||||
"LiquidityFilter"
|
||||
}
|
||||
|
||||
async fn filter(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: Vec<MarketCandidate>,
|
||||
) -> Result<FilterResult, String> {
|
||||
let (kept, removed): (Vec<_>, Vec<_>) = candidates
|
||||
.into_iter()
|
||||
.partition(|c| c.volume_24h >= self.min_volume_24h);
|
||||
|
||||
Ok(FilterResult { kept, removed })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TimeToCloseFilter {
|
||||
min_hours: i64,
|
||||
max_hours: Option<i64>,
|
||||
}
|
||||
|
||||
impl TimeToCloseFilter {
|
||||
pub fn new(min_hours: i64, max_hours: Option<i64>) -> Self {
|
||||
Self { min_hours, max_hours }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Filter for TimeToCloseFilter {
|
||||
fn name(&self) -> &'static str {
|
||||
"TimeToCloseFilter"
|
||||
}
|
||||
|
||||
async fn filter(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: Vec<MarketCandidate>,
|
||||
) -> Result<FilterResult, String> {
|
||||
let min_duration = Duration::hours(self.min_hours);
|
||||
let max_duration = self.max_hours.map(Duration::hours);
|
||||
|
||||
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||
let ttc = c.time_to_close(context.timestamp);
|
||||
let above_min = ttc >= min_duration;
|
||||
let below_max = max_duration.map(|max| ttc <= max).unwrap_or(true);
|
||||
above_min && below_max
|
||||
});
|
||||
|
||||
Ok(FilterResult { kept, removed })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AlreadyPositionedFilter {
|
||||
max_position_per_market: u64,
|
||||
}
|
||||
|
||||
impl AlreadyPositionedFilter {
|
||||
pub fn new(max_position_per_market: u64) -> Self {
|
||||
Self {
|
||||
max_position_per_market,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Filter for AlreadyPositionedFilter {
|
||||
fn name(&self) -> &'static str {
|
||||
"AlreadyPositionedFilter"
|
||||
}
|
||||
|
||||
async fn filter(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: Vec<MarketCandidate>,
|
||||
) -> Result<FilterResult, String> {
|
||||
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||
context
|
||||
.portfolio
|
||||
.get_position(&c.ticker)
|
||||
.map(|p| p.quantity < self.max_position_per_market)
|
||||
.unwrap_or(true)
|
||||
});
|
||||
|
||||
Ok(FilterResult { kept, removed })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CategoryFilter {
|
||||
whitelist: Option<HashSet<String>>,
|
||||
blacklist: HashSet<String>,
|
||||
}
|
||||
|
||||
impl CategoryFilter {
|
||||
pub fn whitelist(categories: Vec<String>) -> Self {
|
||||
Self {
|
||||
whitelist: Some(categories.into_iter().collect()),
|
||||
blacklist: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn blacklist(categories: Vec<String>) -> Self {
|
||||
Self {
|
||||
whitelist: None,
|
||||
blacklist: categories.into_iter().collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Filter for CategoryFilter {
|
||||
fn name(&self) -> &'static str {
|
||||
"CategoryFilter"
|
||||
}
|
||||
|
||||
async fn filter(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: Vec<MarketCandidate>,
|
||||
) -> Result<FilterResult, String> {
|
||||
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||
let in_whitelist = self
|
||||
.whitelist
|
||||
.as_ref()
|
||||
.map(|w| w.contains(&c.category))
|
||||
.unwrap_or(true);
|
||||
let not_blacklisted = !self.blacklist.contains(&c.category);
|
||||
in_whitelist && not_blacklisted
|
||||
});
|
||||
|
||||
Ok(FilterResult { kept, removed })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PriceRangeFilter {
|
||||
min_price: f64,
|
||||
max_price: f64,
|
||||
}
|
||||
|
||||
impl PriceRangeFilter {
|
||||
pub fn new(min_price: f64, max_price: f64) -> Self {
|
||||
Self { min_price, max_price }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Filter for PriceRangeFilter {
|
||||
fn name(&self) -> &'static str {
|
||||
"PriceRangeFilter"
|
||||
}
|
||||
|
||||
async fn filter(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: Vec<MarketCandidate>,
|
||||
) -> Result<FilterResult, String> {
|
||||
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||
let price = c.current_yes_price.to_string().parse::<f64>().unwrap_or(0.5);
|
||||
price >= self.min_price && price <= self.max_price
|
||||
});
|
||||
|
||||
Ok(FilterResult { kept, removed })
|
||||
}
|
||||
}
|
||||
272
src/pipeline/ml_scorer.rs
Normal file
272
src/pipeline/ml_scorer.rs
Normal file
@ -0,0 +1,272 @@
|
||||
//! ML-based scorer using ONNX models
|
||||
//!
|
||||
//! This module provides ML-based scoring using pre-trained ONNX models.
|
||||
//! Models are trained separately using the Python scripts in scripts/
|
||||
//!
|
||||
//! Enable with the `ml` feature:
|
||||
//! ```toml
|
||||
//! kalshi-backtest = { features = ["ml"] }
|
||||
//! ```
|
||||
|
||||
use crate::pipeline::Scorer;
|
||||
use crate::types::{MarketCandidate, TradingContext};
|
||||
use async_trait::async_trait;
|
||||
use std::path::Path;
|
||||
|
||||
#[cfg(feature = "ml")]
|
||||
use {
|
||||
ndarray::{Array1, Array2},
|
||||
ort::{session::Session, value::Value},
|
||||
std::sync::Arc,
|
||||
};
|
||||
|
||||
/// ML ensemble scorer that combines multiple ONNX models
|
||||
///
|
||||
/// Models:
|
||||
/// - LSTM: sequence model for price history
|
||||
/// - MLP: feedforward on hand-crafted features
|
||||
#[cfg(feature = "ml")]
|
||||
pub struct MLEnsembleScorer {
|
||||
lstm_session: Option<Arc<Session>>,
|
||||
mlp_session: Option<Arc<Session>>,
|
||||
ensemble_weights: Vec<f64>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "ml")]
|
||||
impl MLEnsembleScorer {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
lstm_session: None,
|
||||
mlp_session: None,
|
||||
ensemble_weights: vec![0.5, 0.5],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_models(model_dir: &Path) -> Result<Self, String> {
|
||||
let lstm_path = model_dir.join("lstm.onnx");
|
||||
let mlp_path = model_dir.join("mlp.onnx");
|
||||
|
||||
let lstm_session = if lstm_path.exists() {
|
||||
Some(Arc::new(
|
||||
Session::builder()
|
||||
.map_err(|e| format!("failed to create session builder: {}", e))?
|
||||
.commit_from_file(&lstm_path)
|
||||
.map_err(|e| format!("failed to load LSTM model: {}", e))?,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mlp_session = if mlp_path.exists() {
|
||||
Some(Arc::new(
|
||||
Session::builder()
|
||||
.map_err(|e| format!("failed to create session builder: {}", e))?
|
||||
.commit_from_file(&mlp_path)
|
||||
.map_err(|e| format!("failed to load MLP model: {}", e))?,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
lstm_session,
|
||||
mlp_session,
|
||||
ensemble_weights: vec![0.5, 0.5],
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_weights(mut self, weights: Vec<f64>) -> Self {
|
||||
self.ensemble_weights = weights;
|
||||
self
|
||||
}
|
||||
|
||||
fn extract_features(candidate: &MarketCandidate) -> Vec<f64> {
|
||||
vec![
|
||||
candidate.scores.get("momentum").copied().unwrap_or(0.0),
|
||||
candidate.scores.get("mean_reversion").copied().unwrap_or(0.0),
|
||||
candidate.scores.get("volume").copied().unwrap_or(0.0),
|
||||
candidate.scores.get("time_decay").copied().unwrap_or(0.0),
|
||||
candidate.scores.get("order_flow").copied().unwrap_or(0.0),
|
||||
candidate.scores.get("bollinger_reversion").copied().unwrap_or(0.0),
|
||||
candidate.scores.get("mtf_momentum").copied().unwrap_or(0.0),
|
||||
]
|
||||
}
|
||||
|
||||
fn extract_price_sequence(candidate: &MarketCandidate, max_len: usize) -> Vec<f64> {
|
||||
use rust_decimal::prelude::ToPrimitive;
|
||||
|
||||
let prices: Vec<f64> = candidate
|
||||
.price_history
|
||||
.iter()
|
||||
.rev()
|
||||
.take(max_len)
|
||||
.filter_map(|p| p.yes_price.to_f64())
|
||||
.collect();
|
||||
|
||||
let mut sequence = vec![0.0; max_len];
|
||||
for (i, &price) in prices.iter().enumerate() {
|
||||
if i < max_len {
|
||||
sequence[max_len - 1 - i] = price;
|
||||
}
|
||||
}
|
||||
|
||||
if prices.len() >= 2 {
|
||||
let mut log_returns = Vec::with_capacity(max_len);
|
||||
for i in 1..sequence.len() {
|
||||
if sequence[i - 1] > 0.0 && sequence[i] > 0.0 {
|
||||
log_returns.push((sequence[i] / sequence[i - 1]).ln());
|
||||
} else {
|
||||
log_returns.push(0.0);
|
||||
}
|
||||
}
|
||||
log_returns.insert(0, 0.0);
|
||||
log_returns
|
||||
} else {
|
||||
sequence
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_lstm(&self, sequence: &[f64]) -> f64 {
|
||||
let Some(session) = &self.lstm_session else {
|
||||
return 0.0;
|
||||
};
|
||||
|
||||
let input = Array2::from_shape_vec((1, sequence.len()), sequence.to_vec())
|
||||
.expect("invalid shape");
|
||||
|
||||
match session.run(ort::inputs!["input" => input.view()]) {
|
||||
Ok(outputs) => {
|
||||
if let Some(output) = outputs.get("output") {
|
||||
if let Ok(tensor) = output.try_extract_tensor::<f32>() {
|
||||
return tensor.view().iter().next().copied().unwrap_or(0.0) as f64;
|
||||
}
|
||||
}
|
||||
0.0
|
||||
}
|
||||
Err(_) => 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn predict_mlp(&self, features: &[f64]) -> f64 {
|
||||
let Some(session) = &self.mlp_session else {
|
||||
return 0.0;
|
||||
};
|
||||
|
||||
let input = Array1::from_vec(features.to_vec());
|
||||
let input_2d = input.insert_axis(ndarray::Axis(0));
|
||||
|
||||
match session.run(ort::inputs!["input" => input_2d.view()]) {
|
||||
Ok(outputs) => {
|
||||
if let Some(output) = outputs.get("output") {
|
||||
if let Ok(tensor) = output.try_extract_tensor::<f32>() {
|
||||
return tensor.view().iter().next().copied().unwrap_or(0.0) as f64;
|
||||
}
|
||||
}
|
||||
0.0
|
||||
}
|
||||
Err(_) => 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
fn predict(&self, candidate: &MarketCandidate) -> f64 {
|
||||
let sequence = Self::extract_price_sequence(candidate, 24);
|
||||
let features = Self::extract_features(candidate);
|
||||
|
||||
let lstm_pred = self.predict_lstm(&sequence);
|
||||
let mlp_pred = self.predict_mlp(&features);
|
||||
|
||||
let mut ensemble = 0.0;
|
||||
let mut total_weight = 0.0;
|
||||
|
||||
if self.lstm_session.is_some() && self.ensemble_weights.len() > 0 {
|
||||
ensemble += lstm_pred * self.ensemble_weights[0];
|
||||
total_weight += self.ensemble_weights[0];
|
||||
}
|
||||
|
||||
if self.mlp_session.is_some() && self.ensemble_weights.len() > 1 {
|
||||
ensemble += mlp_pred * self.ensemble_weights[1];
|
||||
total_weight += self.ensemble_weights[1];
|
||||
}
|
||||
|
||||
if total_weight > 0.0 {
|
||||
ensemble / total_weight
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ml")]
|
||||
#[async_trait]
|
||||
impl Scorer for MLEnsembleScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"MLEnsembleScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let ml_score = self.predict(c);
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("ml_ensemble".to_string(), ml_score);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
if let Some(score) = scored.scores.get("ml_ensemble") {
|
||||
candidate.scores.insert("ml_ensemble".to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// stub scorer when ML feature is disabled
|
||||
#[cfg(not(feature = "ml"))]
|
||||
pub struct MLEnsembleScorer;
|
||||
|
||||
#[cfg(not(feature = "ml"))]
|
||||
impl MLEnsembleScorer {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn load_models(_model_dir: &Path) -> Result<Self, String> {
|
||||
Ok(Self)
|
||||
}
|
||||
|
||||
pub fn with_weights(self, _weights: Vec<f64>) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ml"))]
|
||||
#[async_trait]
|
||||
impl Scorer for MLEnsembleScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"MLEnsembleScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
Ok(candidates.iter().map(|c| MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
}).collect())
|
||||
}
|
||||
|
||||
fn update(&self, _candidate: &mut MarketCandidate, _scored: MarketCandidate) {}
|
||||
}
|
||||
230
src/pipeline/mod.rs
Normal file
230
src/pipeline/mod.rs
Normal file
@ -0,0 +1,230 @@
|
||||
mod correlation_scorer;
|
||||
mod filters;
|
||||
mod ml_scorer;
|
||||
mod scorers;
|
||||
mod selector;
|
||||
mod sources;
|
||||
|
||||
pub use correlation_scorer::*;
|
||||
pub use filters::*;
|
||||
pub use ml_scorer::*;
|
||||
pub use scorers::*;
|
||||
pub use selector::*;
|
||||
pub use sources::*;
|
||||
|
||||
use crate::types::{MarketCandidate, TradingContext};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
|
||||
pub struct PipelineResult {
|
||||
pub retrieved_candidates: Vec<MarketCandidate>,
|
||||
pub filtered_candidates: Vec<MarketCandidate>,
|
||||
pub selected_candidates: Vec<MarketCandidate>,
|
||||
pub context: Arc<TradingContext>,
|
||||
}
|
||||
|
||||
pub struct FilterResult {
|
||||
pub kept: Vec<MarketCandidate>,
|
||||
pub removed: Vec<MarketCandidate>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Source: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
fn enable(&self, _context: &TradingContext) -> bool {
|
||||
true
|
||||
}
|
||||
async fn get_candidates(&self, context: &TradingContext) -> Result<Vec<MarketCandidate>, String>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Filter: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
fn enable(&self, _context: &TradingContext) -> bool {
|
||||
true
|
||||
}
|
||||
async fn filter(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: Vec<MarketCandidate>,
|
||||
) -> Result<FilterResult, String>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Scorer: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
fn enable(&self, _context: &TradingContext) -> bool {
|
||||
true
|
||||
}
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String>;
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate);
|
||||
|
||||
fn update_all(&self, candidates: &mut [MarketCandidate], scored: Vec<MarketCandidate>) {
|
||||
for (c, s) in candidates.iter_mut().zip(scored) {
|
||||
self.update(c, s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Selector: Send + Sync {
|
||||
fn name(&self) -> &'static str;
|
||||
fn enable(&self, _context: &TradingContext) -> bool {
|
||||
true
|
||||
}
|
||||
fn select(&self, context: &TradingContext, candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate>;
|
||||
}
|
||||
|
||||
pub struct TradingPipeline {
|
||||
sources: Vec<Box<dyn Source>>,
|
||||
filters: Vec<Box<dyn Filter>>,
|
||||
scorers: Vec<Box<dyn Scorer>>,
|
||||
selector: Box<dyn Selector>,
|
||||
result_size: usize,
|
||||
}
|
||||
|
||||
impl TradingPipeline {
|
||||
pub fn new(
|
||||
sources: Vec<Box<dyn Source>>,
|
||||
filters: Vec<Box<dyn Filter>>,
|
||||
scorers: Vec<Box<dyn Scorer>>,
|
||||
selector: Box<dyn Selector>,
|
||||
result_size: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
sources,
|
||||
filters,
|
||||
scorers,
|
||||
selector,
|
||||
result_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn execute(&self, context: TradingContext) -> PipelineResult {
|
||||
let request_id = context.request_id().to_string();
|
||||
|
||||
let candidates = self.fetch_candidates(&context).await;
|
||||
info!(
|
||||
request_id = %request_id,
|
||||
candidates = candidates.len(),
|
||||
"fetched candidates"
|
||||
);
|
||||
|
||||
let (kept, filtered) = self.filter(&context, candidates.clone()).await;
|
||||
info!(
|
||||
request_id = %request_id,
|
||||
kept = kept.len(),
|
||||
filtered = filtered.len(),
|
||||
"filtered candidates"
|
||||
);
|
||||
|
||||
let scored = self.score(&context, kept).await;
|
||||
|
||||
let mut selected = self.select(&context, scored);
|
||||
selected.truncate(self.result_size);
|
||||
|
||||
info!(
|
||||
request_id = %request_id,
|
||||
selected = selected.len(),
|
||||
"selected candidates"
|
||||
);
|
||||
|
||||
PipelineResult {
|
||||
retrieved_candidates: candidates,
|
||||
filtered_candidates: filtered,
|
||||
selected_candidates: selected,
|
||||
context: Arc::new(context),
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_candidates(&self, context: &TradingContext) -> Vec<MarketCandidate> {
|
||||
let mut all_candidates = Vec::new();
|
||||
|
||||
for source in self.sources.iter().filter(|s| s.enable(context)) {
|
||||
match source.get_candidates(context).await {
|
||||
Ok(mut candidates) => {
|
||||
info!(
|
||||
source = source.name(),
|
||||
count = candidates.len(),
|
||||
"source returned candidates"
|
||||
);
|
||||
all_candidates.append(&mut candidates);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(source = source.name(), error = %e, "source failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
all_candidates
|
||||
}
|
||||
|
||||
async fn filter(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
mut candidates: Vec<MarketCandidate>,
|
||||
) -> (Vec<MarketCandidate>, Vec<MarketCandidate>) {
|
||||
let mut all_removed = Vec::new();
|
||||
|
||||
for filter in self.filters.iter().filter(|f| f.enable(context)) {
|
||||
let backup = candidates.clone();
|
||||
match filter.filter(context, candidates).await {
|
||||
Ok(result) => {
|
||||
info!(
|
||||
filter = filter.name(),
|
||||
kept = result.kept.len(),
|
||||
removed = result.removed.len(),
|
||||
"filter applied"
|
||||
);
|
||||
candidates = result.kept;
|
||||
all_removed.extend(result.removed);
|
||||
}
|
||||
Err(e) => {
|
||||
error!(filter = filter.name(), error = %e, "filter failed");
|
||||
candidates = backup;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(candidates, all_removed)
|
||||
}
|
||||
|
||||
async fn score(&self, context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
|
||||
let expected_len = candidates.len();
|
||||
|
||||
for scorer in self.scorers.iter().filter(|s| s.enable(context)) {
|
||||
match scorer.score(context, &candidates).await {
|
||||
Ok(scored) => {
|
||||
if scored.len() == expected_len {
|
||||
scorer.update_all(&mut candidates, scored);
|
||||
} else {
|
||||
error!(
|
||||
scorer = scorer.name(),
|
||||
expected = expected_len,
|
||||
got = scored.len(),
|
||||
"scorer returned wrong number of candidates"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(scorer = scorer.name(), error = %e, "scorer failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
candidates
|
||||
}
|
||||
|
||||
fn select(&self, context: &TradingContext, candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
|
||||
if self.selector.enable(context) {
|
||||
self.selector.select(context, candidates)
|
||||
} else {
|
||||
candidates
|
||||
}
|
||||
}
|
||||
}
|
||||
978
src/pipeline/scorers.rs
Normal file
978
src/pipeline/scorers.rs
Normal file
@ -0,0 +1,978 @@
|
||||
use crate::pipeline::Scorer;
|
||||
use crate::types::{MarketCandidate, TradingContext};
|
||||
use async_trait::async_trait;
|
||||
use rust_decimal::prelude::ToPrimitive;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
/// rolling statistics for z-score normalization
|
||||
/// tracks mean and std deviation over a sliding window
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RollingStats {
|
||||
values: Vec<f64>,
|
||||
max_size: usize,
|
||||
sum: f64,
|
||||
sum_sq: f64,
|
||||
}
|
||||
|
||||
impl RollingStats {
|
||||
pub fn new(max_size: usize) -> Self {
|
||||
Self {
|
||||
values: Vec::with_capacity(max_size),
|
||||
max_size,
|
||||
sum: 0.0,
|
||||
sum_sq: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push(&mut self, value: f64) {
|
||||
if !value.is_finite() {
|
||||
return;
|
||||
}
|
||||
|
||||
if self.values.len() >= self.max_size {
|
||||
let old = self.values.remove(0);
|
||||
self.sum -= old;
|
||||
self.sum_sq -= old * old;
|
||||
}
|
||||
|
||||
self.values.push(value);
|
||||
self.sum += value;
|
||||
self.sum_sq += value * value;
|
||||
}
|
||||
|
||||
pub fn push_batch(&mut self, values: &[f64]) {
|
||||
for &v in values {
|
||||
self.push(v);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mean(&self) -> f64 {
|
||||
if self.values.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
self.sum / self.values.len() as f64
|
||||
}
|
||||
}
|
||||
|
||||
pub fn std(&self) -> f64 {
|
||||
let n = self.values.len();
|
||||
if n < 2 {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
let mean = self.mean();
|
||||
let variance = (self.sum_sq / n as f64) - (mean * mean);
|
||||
variance.max(0.0).sqrt()
|
||||
}
|
||||
|
||||
pub fn normalize(&self, value: f64) -> f64 {
|
||||
let std = self.std().max(0.001);
|
||||
(value - self.mean()) / std
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.values.len()
|
||||
}
|
||||
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.values.len() >= self.max_size / 4
|
||||
}
|
||||
}
|
||||
|
||||
/// wrapper that normalizes any scorer's output to z-scores
|
||||
pub struct NormalizedScorer<S> {
|
||||
inner: S,
|
||||
score_key: String,
|
||||
stats: Arc<Mutex<RollingStats>>,
|
||||
}
|
||||
|
||||
impl<S> NormalizedScorer<S> {
|
||||
pub fn new(inner: S, score_key: &str, history_size: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
score_key: score_key.to_string(),
|
||||
stats: Arc::new(Mutex::new(RollingStats::new(history_size))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<S: Scorer + Send + Sync> Scorer for NormalizedScorer<S> {
|
||||
fn name(&self) -> &'static str {
|
||||
self.inner.name()
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let raw_scored = self.inner.score(context, candidates).await?;
|
||||
|
||||
let raw_scores: Vec<f64> = raw_scored
|
||||
.iter()
|
||||
.filter_map(|c| c.scores.get(&self.score_key).copied())
|
||||
.collect();
|
||||
|
||||
{
|
||||
let mut stats = self.stats.lock().unwrap();
|
||||
stats.push_batch(&raw_scores);
|
||||
}
|
||||
|
||||
let stats = self.stats.lock().unwrap();
|
||||
let normalized = raw_scored
|
||||
.into_iter()
|
||||
.map(|mut c| {
|
||||
if let Some(&raw) = c.scores.get(&self.score_key) {
|
||||
let z = if stats.is_ready() {
|
||||
stats.normalize(raw)
|
||||
} else {
|
||||
raw
|
||||
};
|
||||
c.scores.insert(self.score_key.clone(), z);
|
||||
}
|
||||
c
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(normalized)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
self.inner.update(candidate, scored);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MomentumScorer {
|
||||
lookback_hours: i64,
|
||||
}
|
||||
|
||||
impl MomentumScorer {
|
||||
pub fn new(lookback_hours: i64) -> Self {
|
||||
Self { lookback_hours }
|
||||
}
|
||||
|
||||
fn calculate_momentum(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
|
||||
let lookback_start = now - chrono::Duration::hours(lookback_hours);
|
||||
let relevant_history: Vec<_> = candidate
|
||||
.price_history
|
||||
.iter()
|
||||
.filter(|p| p.timestamp >= lookback_start)
|
||||
.collect();
|
||||
|
||||
if relevant_history.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let first = relevant_history.first().unwrap().yes_price.to_f64().unwrap_or(0.5);
|
||||
let last = relevant_history.last().unwrap().yes_price.to_f64().unwrap_or(0.5);
|
||||
|
||||
last - first
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for MomentumScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"MomentumScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let momentum = Self::calculate_momentum(c, context.timestamp, self.lookback_hours);
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("momentum".to_string(), momentum);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
if let Some(score) = scored.scores.get("momentum") {
|
||||
candidate.scores.insert("momentum".to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// multi-timeframe momentum scorer
|
||||
/// looks at multiple windows and detects divergence between short and long term
|
||||
pub struct MultiTimeframeMomentumScorer {
|
||||
windows: Vec<i64>,
|
||||
}
|
||||
|
||||
impl MultiTimeframeMomentumScorer {
|
||||
pub fn new(windows: Vec<i64>) -> Self {
|
||||
Self { windows }
|
||||
}
|
||||
|
||||
pub fn default_windows() -> Self {
|
||||
Self::new(vec![1, 4, 12, 24])
|
||||
}
|
||||
|
||||
fn calculate_momentum_for_window(
|
||||
candidate: &MarketCandidate,
|
||||
now: chrono::DateTime<chrono::Utc>,
|
||||
hours: i64,
|
||||
) -> f64 {
|
||||
let lookback_start = now - chrono::Duration::hours(hours);
|
||||
let relevant_history: Vec<_> = candidate
|
||||
.price_history
|
||||
.iter()
|
||||
.filter(|p| p.timestamp >= lookback_start)
|
||||
.collect();
|
||||
|
||||
if relevant_history.len() < 2 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let first = relevant_history.first().unwrap().yes_price.to_f64().unwrap_or(0.5);
|
||||
let last = relevant_history.last().unwrap().yes_price.to_f64().unwrap_or(0.5);
|
||||
|
||||
last - first
|
||||
}
|
||||
|
||||
fn calculate_score(
|
||||
candidate: &MarketCandidate,
|
||||
now: chrono::DateTime<chrono::Utc>,
|
||||
windows: &[i64],
|
||||
) -> (f64, f64, f64) {
|
||||
let momentums: Vec<f64> = windows
|
||||
.iter()
|
||||
.map(|&w| Self::calculate_momentum_for_window(candidate, now, w))
|
||||
.collect();
|
||||
|
||||
if momentums.is_empty() {
|
||||
return (0.0, 0.0, 1.0);
|
||||
}
|
||||
|
||||
let avg_momentum = momentums.iter().sum::<f64>() / momentums.len() as f64;
|
||||
|
||||
let signs: Vec<i32> = momentums.iter().map(|&m| if m > 0.0 { 1 } else if m < 0.0 { -1 } else { 0 }).collect();
|
||||
let all_same_sign = signs.iter().all(|&s| s == signs[0]) && signs[0] != 0;
|
||||
let alignment = if all_same_sign { 1.0 } else { 0.5 };
|
||||
|
||||
let short_avg = if momentums.len() >= 2 {
|
||||
momentums[..momentums.len() / 2].iter().sum::<f64>() / (momentums.len() / 2) as f64
|
||||
} else {
|
||||
momentums[0]
|
||||
};
|
||||
let long_avg = if momentums.len() >= 2 {
|
||||
momentums[momentums.len() / 2..].iter().sum::<f64>() / (momentums.len() - momentums.len() / 2) as f64
|
||||
} else {
|
||||
momentums[0]
|
||||
};
|
||||
|
||||
let divergence = if (short_avg > 0.0 && long_avg < 0.0) || (short_avg < 0.0 && long_avg > 0.0) {
|
||||
(short_avg - long_avg).abs()
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
(avg_momentum * alignment, divergence, alignment)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for MultiTimeframeMomentumScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"MultiTimeframeMomentumScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let (momentum, divergence, alignment) = Self::calculate_score(c, context.timestamp, &self.windows);
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("mtf_momentum".to_string(), momentum);
|
||||
scored.scores.insert("mtf_divergence".to_string(), divergence);
|
||||
scored.scores.insert("mtf_alignment".to_string(), alignment);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
for key in ["mtf_momentum", "mtf_divergence", "mtf_alignment"] {
|
||||
if let Some(score) = scored.scores.get(key) {
|
||||
candidate.scores.insert(key.to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MeanReversionScorer {
|
||||
lookback_hours: i64,
|
||||
}
|
||||
|
||||
impl MeanReversionScorer {
|
||||
pub fn new(lookback_hours: i64) -> Self {
|
||||
Self { lookback_hours }
|
||||
}
|
||||
|
||||
fn calculate_deviation(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
|
||||
let lookback_start = now - chrono::Duration::hours(lookback_hours);
|
||||
let prices: Vec<f64> = candidate
|
||||
.price_history
|
||||
.iter()
|
||||
.filter(|p| p.timestamp >= lookback_start)
|
||||
.filter_map(|p| p.yes_price.to_f64())
|
||||
.collect();
|
||||
|
||||
if prices.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean: f64 = prices.iter().sum::<f64>() / prices.len() as f64;
|
||||
let current = candidate.current_yes_price.to_f64().unwrap_or(0.5);
|
||||
let deviation = current - mean;
|
||||
|
||||
-deviation
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for MeanReversionScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"MeanReversionScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let reversion = Self::calculate_deviation(c, context.timestamp, self.lookback_hours);
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("mean_reversion".to_string(), reversion);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
if let Some(score) = scored.scores.get("mean_reversion") {
|
||||
candidate.scores.insert("mean_reversion".to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// bollinger bands mean reversion scorer
|
||||
/// triggers when price touches statistical extremes (upper/lower bands)
|
||||
pub struct BollingerMeanReversionScorer {
|
||||
lookback_hours: i64,
|
||||
num_std: f64,
|
||||
}
|
||||
|
||||
impl BollingerMeanReversionScorer {
|
||||
pub fn new(lookback_hours: i64, num_std: f64) -> Self {
|
||||
Self { lookback_hours, num_std }
|
||||
}
|
||||
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(24, 2.0)
|
||||
}
|
||||
|
||||
fn calculate_bands(
|
||||
candidate: &MarketCandidate,
|
||||
now: chrono::DateTime<chrono::Utc>,
|
||||
lookback_hours: i64,
|
||||
) -> Option<(f64, f64, f64)> {
|
||||
let lookback_start = now - chrono::Duration::hours(lookback_hours);
|
||||
let prices: Vec<f64> = candidate
|
||||
.price_history
|
||||
.iter()
|
||||
.filter(|p| p.timestamp >= lookback_start)
|
||||
.filter_map(|p| p.yes_price.to_f64())
|
||||
.collect();
|
||||
|
||||
if prices.len() < 5 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mean: f64 = prices.iter().sum::<f64>() / prices.len() as f64;
|
||||
let variance: f64 = prices.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / prices.len() as f64;
|
||||
let std = variance.sqrt();
|
||||
|
||||
Some((mean, std, *prices.last().unwrap_or(&mean)))
|
||||
}
|
||||
|
||||
fn calculate_score(
|
||||
candidate: &MarketCandidate,
|
||||
now: chrono::DateTime<chrono::Utc>,
|
||||
lookback_hours: i64,
|
||||
num_std: f64,
|
||||
) -> (f64, f64) {
|
||||
let (mean, std, current) = match Self::calculate_bands(candidate, now, lookback_hours) {
|
||||
Some(v) => v,
|
||||
None => return (0.0, 0.0),
|
||||
};
|
||||
|
||||
let upper_band = mean + num_std * std;
|
||||
let lower_band = mean - num_std * std;
|
||||
let band_width = upper_band - lower_band;
|
||||
|
||||
if band_width < 0.001 {
|
||||
return (0.0, 0.0);
|
||||
}
|
||||
|
||||
let position = (current - lower_band) / band_width;
|
||||
|
||||
let score = if current >= upper_band {
|
||||
-(current - upper_band) / std.max(0.001)
|
||||
} else if current <= lower_band {
|
||||
(lower_band - current) / std.max(0.001)
|
||||
} else if current > mean {
|
||||
-(position - 0.5) * 0.5
|
||||
} else {
|
||||
(0.5 - position) * 0.5
|
||||
};
|
||||
|
||||
let band_position = (position * 2.0 - 1.0).clamp(-1.0, 1.0);
|
||||
|
||||
(score.clamp(-2.0, 2.0), band_position)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for BollingerMeanReversionScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"BollingerMeanReversionScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let (score, band_pos) = Self::calculate_score(c, context.timestamp, self.lookback_hours, self.num_std);
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("bollinger_reversion".to_string(), score);
|
||||
scored.scores.insert("bollinger_position".to_string(), band_pos);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
for key in ["bollinger_reversion", "bollinger_position"] {
|
||||
if let Some(score) = scored.scores.get(key) {
|
||||
candidate.scores.insert(key.to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VolumeScorer {
|
||||
lookback_hours: i64,
|
||||
}
|
||||
|
||||
impl VolumeScorer {
|
||||
pub fn new(lookback_hours: i64) -> Self {
|
||||
Self { lookback_hours }
|
||||
}
|
||||
|
||||
fn calculate_volume_score(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
|
||||
let lookback_start = now - chrono::Duration::hours(lookback_hours);
|
||||
let recent_volume: u64 = candidate
|
||||
.price_history
|
||||
.iter()
|
||||
.filter(|p| p.timestamp >= lookback_start)
|
||||
.map(|p| p.volume)
|
||||
.sum();
|
||||
|
||||
if candidate.total_volume == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let avg_hourly_volume = candidate.total_volume as f64
|
||||
/ ((now - candidate.open_time).num_hours().max(1) as f64);
|
||||
let recent_hourly_volume = recent_volume as f64 / lookback_hours.max(1) as f64;
|
||||
|
||||
if avg_hourly_volume > 0.0 {
|
||||
(recent_hourly_volume / avg_hourly_volume).ln().max(-2.0).min(2.0)
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for VolumeScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"VolumeScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let volume = Self::calculate_volume_score(c, context.timestamp, self.lookback_hours);
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("volume".to_string(), volume);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
if let Some(score) = scored.scores.get("volume") {
|
||||
candidate.scores.insert("volume".to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TimeDecayScorer;
|
||||
|
||||
impl TimeDecayScorer {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
fn calculate_time_decay(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>) -> f64 {
|
||||
let ttc = candidate.time_to_close(now);
|
||||
let hours_remaining = ttc.num_hours() as f64;
|
||||
|
||||
if hours_remaining <= 0.0 {
|
||||
return -1.0;
|
||||
}
|
||||
|
||||
let decay = 1.0 - (1.0 / (hours_remaining / 24.0 + 1.0));
|
||||
decay.min(1.0).max(0.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TimeDecayScorer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// order flow imbalance scorer
|
||||
/// measures buying vs selling pressure using taker_side from trades
|
||||
pub struct OrderFlowScorer;
|
||||
|
||||
impl OrderFlowScorer {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
fn calculate_imbalance(candidate: &MarketCandidate) -> f64 {
|
||||
let buy_vol = candidate.buy_volume_24h as f64;
|
||||
let sell_vol = candidate.sell_volume_24h as f64;
|
||||
let total = buy_vol + sell_vol;
|
||||
|
||||
if total == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
(buy_vol - sell_vol) / total
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OrderFlowScorer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for OrderFlowScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"OrderFlowScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let imbalance = Self::calculate_imbalance(c);
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("order_flow".to_string(), imbalance);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
if let Some(score) = scored.scores.get("order_flow") {
|
||||
candidate.scores.insert("order_flow".to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for TimeDecayScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"TimeDecayScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let time_decay = Self::calculate_time_decay(c, context.timestamp);
|
||||
let mut scored = MarketCandidate {
|
||||
scores: c.scores.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
scored.scores.insert("time_decay".to_string(), time_decay);
|
||||
scored
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
if let Some(score) = scored.scores.get("time_decay") {
|
||||
candidate.scores.insert("time_decay".to_string(), *score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct WeightedScorer {
|
||||
weights: Vec<(String, f64)>,
|
||||
}
|
||||
|
||||
impl WeightedScorer {
|
||||
pub fn new(weights: Vec<(String, f64)>) -> Self {
|
||||
Self { weights }
|
||||
}
|
||||
|
||||
pub fn default_weights() -> Self {
|
||||
Self::new(vec![
|
||||
("momentum".to_string(), 0.4),
|
||||
("mean_reversion".to_string(), 0.3),
|
||||
("volume".to_string(), 0.2),
|
||||
("time_decay".to_string(), 0.1),
|
||||
])
|
||||
}
|
||||
|
||||
fn compute_weighted_score(&self, candidate: &MarketCandidate) -> f64 {
|
||||
self.weights
|
||||
.iter()
|
||||
.map(|(name, weight)| {
|
||||
candidate.scores.get(name).copied().unwrap_or(0.0) * weight
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for WeightedScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"WeightedScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let weighted_score = self.compute_weighted_score(c);
|
||||
MarketCandidate {
|
||||
final_score: weighted_score,
|
||||
..Default::default()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
candidate.final_score = scored.final_score;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ScorerWeights {
|
||||
pub momentum: f64,
|
||||
pub mean_reversion: f64,
|
||||
pub volume: f64,
|
||||
pub time_decay: f64,
|
||||
pub order_flow: f64,
|
||||
pub bollinger: f64,
|
||||
pub mtf_momentum: f64,
|
||||
}
|
||||
|
||||
impl Default for ScorerWeights {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
momentum: 0.2,
|
||||
mean_reversion: 0.2,
|
||||
volume: 0.15,
|
||||
time_decay: 0.1,
|
||||
order_flow: 0.15,
|
||||
bollinger: 0.1,
|
||||
mtf_momentum: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScorerWeights {
|
||||
pub fn politics() -> Self {
|
||||
Self {
|
||||
momentum: 0.35,
|
||||
mean_reversion: 0.1,
|
||||
volume: 0.1,
|
||||
time_decay: 0.1,
|
||||
order_flow: 0.15,
|
||||
bollinger: 0.05,
|
||||
mtf_momentum: 0.15,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn weather() -> Self {
|
||||
Self {
|
||||
momentum: 0.1,
|
||||
mean_reversion: 0.35,
|
||||
volume: 0.1,
|
||||
time_decay: 0.15,
|
||||
order_flow: 0.1,
|
||||
bollinger: 0.15,
|
||||
mtf_momentum: 0.05,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sports() -> Self {
|
||||
Self {
|
||||
momentum: 0.2,
|
||||
mean_reversion: 0.1,
|
||||
volume: 0.15,
|
||||
time_decay: 0.1,
|
||||
order_flow: 0.3,
|
||||
bollinger: 0.05,
|
||||
mtf_momentum: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn economics() -> Self {
|
||||
Self {
|
||||
momentum: 0.25,
|
||||
mean_reversion: 0.2,
|
||||
volume: 0.15,
|
||||
time_decay: 0.1,
|
||||
order_flow: 0.15,
|
||||
bollinger: 0.1,
|
||||
mtf_momentum: 0.05,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_score(&self, candidate: &MarketCandidate) -> f64 {
|
||||
let get_score = |key: &str| candidate.scores.get(key).copied().unwrap_or(0.0);
|
||||
|
||||
self.momentum * get_score("momentum")
|
||||
+ self.mean_reversion * get_score("mean_reversion")
|
||||
+ self.volume * get_score("volume")
|
||||
+ self.time_decay * get_score("time_decay")
|
||||
+ self.order_flow * get_score("order_flow")
|
||||
+ self.bollinger * get_score("bollinger_reversion")
|
||||
+ self.mtf_momentum * get_score("mtf_momentum")
|
||||
}
|
||||
}
|
||||
|
||||
/// category-aware weighted scorer
|
||||
/// applies different weights based on market category
|
||||
pub struct CategoryWeightedScorer {
|
||||
category_weights: std::collections::HashMap<String, ScorerWeights>,
|
||||
default_weights: ScorerWeights,
|
||||
}
|
||||
|
||||
impl CategoryWeightedScorer {
|
||||
pub fn new(
|
||||
category_weights: std::collections::HashMap<String, ScorerWeights>,
|
||||
default_weights: ScorerWeights,
|
||||
) -> Self {
|
||||
Self { category_weights, default_weights }
|
||||
}
|
||||
|
||||
pub fn with_defaults() -> Self {
|
||||
let mut category_weights = std::collections::HashMap::new();
|
||||
category_weights.insert("politics".to_string(), ScorerWeights::politics());
|
||||
category_weights.insert("weather".to_string(), ScorerWeights::weather());
|
||||
category_weights.insert("sports".to_string(), ScorerWeights::sports());
|
||||
category_weights.insert("economics".to_string(), ScorerWeights::economics());
|
||||
category_weights.insert("financial".to_string(), ScorerWeights::economics());
|
||||
|
||||
Self {
|
||||
category_weights,
|
||||
default_weights: ScorerWeights::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_weights(&self, category: &str) -> &ScorerWeights {
|
||||
let lower = category.to_lowercase();
|
||||
self.category_weights.get(&lower).unwrap_or(&self.default_weights)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for CategoryWeightedScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"CategoryWeightedScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let weights = self.get_weights(&c.category);
|
||||
let weighted_score = weights.compute_score(c);
|
||||
MarketCandidate {
|
||||
final_score: weighted_score,
|
||||
..Default::default()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
candidate.final_score = scored.final_score;
|
||||
}
|
||||
}
|
||||
|
||||
/// ensemble scorer that combines multiple models with dynamic weighting
|
||||
/// weights can be updated based on recent accuracy
|
||||
pub struct EnsembleScorer {
|
||||
model_weights: std::sync::Arc<std::sync::Mutex<Vec<f64>>>,
|
||||
model_keys: Vec<String>,
|
||||
}
|
||||
|
||||
impl EnsembleScorer {
|
||||
pub fn new(model_keys: Vec<String>, initial_weights: Vec<f64>) -> Self {
|
||||
assert_eq!(model_keys.len(), initial_weights.len());
|
||||
Self {
|
||||
model_weights: std::sync::Arc::new(std::sync::Mutex::new(initial_weights)),
|
||||
model_keys,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_ensemble() -> Self {
|
||||
Self::new(
|
||||
vec![
|
||||
"momentum".to_string(),
|
||||
"mean_reversion".to_string(),
|
||||
"bollinger_reversion".to_string(),
|
||||
"order_flow".to_string(),
|
||||
"mtf_momentum".to_string(),
|
||||
],
|
||||
vec![0.25, 0.2, 0.2, 0.2, 0.15],
|
||||
)
|
||||
}
|
||||
|
||||
pub fn update_weights(&self, new_weights: Vec<f64>) {
|
||||
let mut weights = self.model_weights.lock().unwrap();
|
||||
*weights = new_weights;
|
||||
}
|
||||
|
||||
fn compute_score(&self, candidate: &MarketCandidate) -> f64 {
|
||||
let weights = self.model_weights.lock().unwrap();
|
||||
self.model_keys
|
||||
.iter()
|
||||
.zip(weights.iter())
|
||||
.map(|(key, weight)| {
|
||||
candidate.scores.get(key).copied().unwrap_or(0.0) * weight
|
||||
})
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Scorer for EnsembleScorer {
|
||||
fn name(&self) -> &'static str {
|
||||
"EnsembleScorer"
|
||||
}
|
||||
|
||||
async fn score(
|
||||
&self,
|
||||
_context: &TradingContext,
|
||||
candidates: &[MarketCandidate],
|
||||
) -> Result<Vec<MarketCandidate>, String> {
|
||||
let scored = candidates
|
||||
.iter()
|
||||
.map(|c| {
|
||||
let ensemble_score = self.compute_score(c);
|
||||
MarketCandidate {
|
||||
final_score: ensemble_score,
|
||||
..Default::default()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||
candidate.final_score = scored.final_score;
|
||||
}
|
||||
}
|
||||
|
||||
64
src/pipeline/selector.rs
Normal file
64
src/pipeline/selector.rs
Normal file
@ -0,0 +1,64 @@
|
||||
use crate::pipeline::Selector;
|
||||
use crate::types::{MarketCandidate, TradingContext};
|
||||
|
||||
pub struct TopKSelector {
|
||||
k: usize,
|
||||
}
|
||||
|
||||
impl TopKSelector {
|
||||
pub fn new(k: usize) -> Self {
|
||||
Self { k }
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for TopKSelector {
|
||||
fn name(&self) -> &'static str {
|
||||
"TopKSelector"
|
||||
}
|
||||
|
||||
fn select(&self, _context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
|
||||
candidates.sort_by(|a, b| {
|
||||
b.final_score
|
||||
.partial_cmp(&a.final_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
candidates.truncate(self.k);
|
||||
candidates
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ThresholdSelector {
|
||||
min_score: f64,
|
||||
max_candidates: Option<usize>,
|
||||
}
|
||||
|
||||
impl ThresholdSelector {
|
||||
pub fn new(min_score: f64, max_candidates: Option<usize>) -> Self {
|
||||
Self {
|
||||
min_score,
|
||||
max_candidates,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Selector for ThresholdSelector {
|
||||
fn name(&self) -> &'static str {
|
||||
"ThresholdSelector"
|
||||
}
|
||||
|
||||
fn select(&self, _context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
|
||||
candidates.retain(|c| c.final_score >= self.min_score);
|
||||
candidates.sort_by(|a, b| {
|
||||
b.final_score
|
||||
.partial_cmp(&a.final_score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
if let Some(max) = self.max_candidates {
|
||||
candidates.truncate(max);
|
||||
}
|
||||
|
||||
candidates
|
||||
}
|
||||
}
|
||||
69
src/pipeline/sources.rs
Normal file
69
src/pipeline/sources.rs
Normal file
@ -0,0 +1,69 @@
|
||||
use crate::data::HistoricalData;
|
||||
use crate::pipeline::Source;
|
||||
use crate::types::{MarketCandidate, TradingContext};
|
||||
use async_trait::async_trait;
|
||||
use rust_decimal::Decimal;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct HistoricalMarketSource {
|
||||
data: Arc<HistoricalData>,
|
||||
lookback_hours: i64,
|
||||
}
|
||||
|
||||
impl HistoricalMarketSource {
|
||||
pub fn new(data: Arc<HistoricalData>, lookback_hours: i64) -> Self {
|
||||
Self { data, lookback_hours }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Source for HistoricalMarketSource {
|
||||
fn name(&self) -> &'static str {
|
||||
"HistoricalMarketSource"
|
||||
}
|
||||
|
||||
async fn get_candidates(&self, context: &TradingContext) -> Result<Vec<MarketCandidate>, String> {
|
||||
let now = context.timestamp;
|
||||
let active_markets = self.data.get_active_markets(now);
|
||||
|
||||
let candidates: Vec<MarketCandidate> = active_markets
|
||||
.into_iter()
|
||||
.filter_map(|market| {
|
||||
let current_price = self.data.get_current_price(&market.ticker, now)?;
|
||||
let lookback_start = now - chrono::Duration::hours(self.lookback_hours);
|
||||
let price_history = self.data.get_price_history(&market.ticker, lookback_start, now);
|
||||
let volume_24h = self.data.get_volume_24h(&market.ticker, now);
|
||||
|
||||
let total_volume: u64 = self
|
||||
.data
|
||||
.get_trades_for_market(&market.ticker, market.open_time, now)
|
||||
.iter()
|
||||
.map(|t| t.volume)
|
||||
.sum();
|
||||
|
||||
let (buy_volume_24h, sell_volume_24h) = self.data.get_order_flow_24h(&market.ticker, now);
|
||||
|
||||
Some(MarketCandidate {
|
||||
ticker: market.ticker.clone(),
|
||||
title: market.title.clone(),
|
||||
category: market.category.clone(),
|
||||
current_yes_price: current_price,
|
||||
current_no_price: Decimal::ONE - current_price,
|
||||
volume_24h,
|
||||
total_volume,
|
||||
buy_volume_24h,
|
||||
sell_volume_24h,
|
||||
open_time: market.open_time,
|
||||
close_time: market.close_time,
|
||||
result: market.result,
|
||||
price_history,
|
||||
scores: HashMap::new(),
|
||||
final_score: 0.0,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(candidates)
|
||||
}
|
||||
}
|
||||
343
src/types.rs
Normal file
343
src/types.rs
Normal file
@ -0,0 +1,343 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use rust_decimal::Decimal;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Side {
|
||||
Yes,
|
||||
No,
|
||||
}
|
||||
|
||||
impl Side {
|
||||
pub fn opposite(&self) -> Self {
|
||||
match self {
|
||||
Side::Yes => Side::No,
|
||||
Side::No => Side::Yes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum MarketResult {
|
||||
Yes,
|
||||
No,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MarketCandidate {
|
||||
pub ticker: String,
|
||||
pub title: String,
|
||||
pub category: String,
|
||||
pub current_yes_price: Decimal,
|
||||
pub current_no_price: Decimal,
|
||||
pub volume_24h: u64,
|
||||
pub total_volume: u64,
|
||||
pub buy_volume_24h: u64,
|
||||
pub sell_volume_24h: u64,
|
||||
pub open_time: DateTime<Utc>,
|
||||
pub close_time: DateTime<Utc>,
|
||||
pub result: Option<MarketResult>,
|
||||
pub price_history: Vec<PricePoint>,
|
||||
|
||||
pub scores: HashMap<String, f64>,
|
||||
pub final_score: f64,
|
||||
}
|
||||
|
||||
impl MarketCandidate {
|
||||
pub fn time_to_close(&self, now: DateTime<Utc>) -> chrono::Duration {
|
||||
self.close_time - now
|
||||
}
|
||||
|
||||
pub fn is_open(&self, now: DateTime<Utc>) -> bool {
|
||||
now >= self.open_time && now < self.close_time
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MarketCandidate {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
ticker: String::new(),
|
||||
title: String::new(),
|
||||
category: String::new(),
|
||||
current_yes_price: Decimal::ZERO,
|
||||
current_no_price: Decimal::ZERO,
|
||||
volume_24h: 0,
|
||||
total_volume: 0,
|
||||
buy_volume_24h: 0,
|
||||
sell_volume_24h: 0,
|
||||
open_time: Utc::now(),
|
||||
close_time: Utc::now(),
|
||||
result: None,
|
||||
price_history: Vec::new(),
|
||||
scores: HashMap::new(),
|
||||
final_score: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PricePoint {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub yes_price: Decimal,
|
||||
pub volume: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TradingContext {
|
||||
pub request_id: String,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub portfolio: Portfolio,
|
||||
pub trading_history: Vec<Trade>,
|
||||
}
|
||||
|
||||
impl TradingContext {
|
||||
pub fn new(initial_capital: Decimal, start_time: DateTime<Utc>) -> Self {
|
||||
Self {
|
||||
request_id: uuid::Uuid::new_v4().to_string(),
|
||||
timestamp: start_time,
|
||||
portfolio: Portfolio::new(initial_capital),
|
||||
trading_history: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn request_id(&self) -> &str {
|
||||
&self.request_id
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Portfolio {
|
||||
pub positions: HashMap<String, Position>,
|
||||
pub cash: Decimal,
|
||||
pub initial_capital: Decimal,
|
||||
}
|
||||
|
||||
impl Portfolio {
|
||||
pub fn new(initial_capital: Decimal) -> Self {
|
||||
Self {
|
||||
positions: HashMap::new(),
|
||||
cash: initial_capital,
|
||||
initial_capital,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn total_value(&self, market_prices: &HashMap<String, Decimal>) -> Decimal {
|
||||
let position_value: Decimal = self
|
||||
.positions
|
||||
.values()
|
||||
.map(|p| {
|
||||
let price = market_prices.get(&p.ticker).copied().unwrap_or(p.avg_entry_price);
|
||||
let effective_price = match p.side {
|
||||
Side::Yes => price,
|
||||
Side::No => Decimal::ONE - price,
|
||||
};
|
||||
effective_price * Decimal::from(p.quantity)
|
||||
})
|
||||
.sum();
|
||||
self.cash + position_value
|
||||
}
|
||||
|
||||
pub fn has_position(&self, ticker: &str) -> bool {
|
||||
self.positions.contains_key(ticker)
|
||||
}
|
||||
|
||||
pub fn get_position(&self, ticker: &str) -> Option<&Position> {
|
||||
self.positions.get(ticker)
|
||||
}
|
||||
|
||||
pub fn apply_fill(&mut self, fill: &Fill) {
|
||||
let cost = fill.price * Decimal::from(fill.quantity);
|
||||
|
||||
match fill.side {
|
||||
Side::Yes | Side::No => {
|
||||
self.cash -= cost;
|
||||
let position = self.positions.entry(fill.ticker.clone()).or_insert_with(|| {
|
||||
Position {
|
||||
ticker: fill.ticker.clone(),
|
||||
side: fill.side,
|
||||
quantity: 0,
|
||||
avg_entry_price: Decimal::ZERO,
|
||||
entry_time: fill.timestamp,
|
||||
}
|
||||
});
|
||||
|
||||
let total_cost =
|
||||
position.avg_entry_price * Decimal::from(position.quantity) + cost;
|
||||
position.quantity += fill.quantity;
|
||||
if position.quantity > 0 {
|
||||
position.avg_entry_price = total_cost / Decimal::from(position.quantity);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve_position(&mut self, ticker: &str, result: MarketResult) -> Option<Decimal> {
|
||||
let position = self.positions.remove(ticker)?;
|
||||
|
||||
let payout = match (result, position.side) {
|
||||
(MarketResult::Yes, Side::Yes) | (MarketResult::No, Side::No) => {
|
||||
Decimal::from(position.quantity)
|
||||
}
|
||||
(MarketResult::Cancelled, _) => {
|
||||
position.avg_entry_price * Decimal::from(position.quantity)
|
||||
}
|
||||
_ => Decimal::ZERO,
|
||||
};
|
||||
|
||||
self.cash += payout;
|
||||
|
||||
let cost = position.avg_entry_price * Decimal::from(position.quantity);
|
||||
Some(payout - cost)
|
||||
}
|
||||
|
||||
pub fn close_position(&mut self, ticker: &str, exit_price: Decimal) -> Option<Decimal> {
|
||||
let position = self.positions.remove(ticker)?;
|
||||
|
||||
let effective_exit_price = match position.side {
|
||||
Side::Yes => exit_price,
|
||||
Side::No => Decimal::ONE - exit_price,
|
||||
};
|
||||
|
||||
let exit_value = effective_exit_price * Decimal::from(position.quantity);
|
||||
self.cash += exit_value;
|
||||
|
||||
let cost = position.avg_entry_price * Decimal::from(position.quantity);
|
||||
Some(exit_value - cost)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Position {
|
||||
pub ticker: String,
|
||||
pub side: Side,
|
||||
pub quantity: u64,
|
||||
pub avg_entry_price: Decimal,
|
||||
pub entry_time: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl Position {
|
||||
pub fn cost_basis(&self) -> Decimal {
|
||||
self.avg_entry_price * Decimal::from(self.quantity)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Trade {
|
||||
pub ticker: String,
|
||||
pub side: Side,
|
||||
pub quantity: u64,
|
||||
pub price: Decimal,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub trade_type: TradeType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum TradeType {
|
||||
Open,
|
||||
Close,
|
||||
Resolution,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ExitReason {
|
||||
Resolution(MarketResult),
|
||||
TakeProfit { pnl_pct: f64 },
|
||||
StopLoss { pnl_pct: f64 },
|
||||
TimeStop { hours_held: i64 },
|
||||
ScoreReversal { new_score: f64 },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExitSignal {
|
||||
pub ticker: String,
|
||||
pub reason: ExitReason,
|
||||
pub current_price: Decimal,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExitConfig {
|
||||
pub take_profit_pct: f64,
|
||||
pub stop_loss_pct: f64,
|
||||
pub max_hold_hours: i64,
|
||||
pub score_reversal_threshold: f64,
|
||||
}
|
||||
|
||||
impl Default for ExitConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
take_profit_pct: 0.20,
|
||||
stop_loss_pct: 0.15,
|
||||
max_hold_hours: 72,
|
||||
score_reversal_threshold: -0.3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExitConfig {
|
||||
pub fn conservative() -> Self {
|
||||
Self {
|
||||
take_profit_pct: 0.15,
|
||||
stop_loss_pct: 0.10,
|
||||
max_hold_hours: 48,
|
||||
score_reversal_threshold: -0.2,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn aggressive() -> Self {
|
||||
Self {
|
||||
take_profit_pct: 0.30,
|
||||
stop_loss_pct: 0.20,
|
||||
max_hold_hours: 120,
|
||||
score_reversal_threshold: -0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Fill {
|
||||
pub ticker: String,
|
||||
pub side: Side,
|
||||
pub quantity: u64,
|
||||
pub price: Decimal,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Signal {
|
||||
pub ticker: String,
|
||||
pub side: Side,
|
||||
pub quantity: u64,
|
||||
pub limit_price: Option<Decimal>,
|
||||
pub reason: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MarketData {
|
||||
pub ticker: String,
|
||||
pub title: String,
|
||||
pub category: String,
|
||||
pub open_time: DateTime<Utc>,
|
||||
pub close_time: DateTime<Utc>,
|
||||
pub result: Option<MarketResult>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TradeData {
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub ticker: String,
|
||||
pub price: Decimal,
|
||||
pub volume: u64,
|
||||
pub taker_side: Side,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BacktestConfig {
|
||||
pub start_time: DateTime<Utc>,
|
||||
pub end_time: DateTime<Utc>,
|
||||
pub interval: chrono::Duration,
|
||||
pub initial_capital: Decimal,
|
||||
pub max_position_size: u64,
|
||||
pub max_positions: usize,
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user