feat: initial commit with code quality refactoring
kalshi prediction market backtesting framework with: - trading pipeline (sources, filters, scorers, selectors) - position sizing with kelly criterion - multiple scoring strategies (momentum, mean reversion, etc) - random baseline for comparison refactoring includes: - extract shared resolve_closed_positions() function - reduce RandomBaseline::run() nesting with helper functions - move MarketCandidate Default impl to types.rs - add explanatory comments to complex logic
This commit is contained in:
commit
025322219c
5
.gitignore
vendored
Normal file
5
.gitignore
vendored
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
/target
|
||||||
|
/data/*.csv
|
||||||
|
/data/*.parquet
|
||||||
|
/results/*.json
|
||||||
|
Cargo.lock
|
||||||
30
Cargo.toml
Normal file
30
Cargo.toml
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
[package]
|
||||||
|
name = "kalshi-backtest"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
async-trait = "0.1"
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
csv = "1.3"
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
anyhow = "1"
|
||||||
|
thiserror = "1"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
uuid = { version = "1", features = ["v4"] }
|
||||||
|
rust_decimal = { version = "1", features = ["serde"] }
|
||||||
|
rust_decimal_macros = "1"
|
||||||
|
|
||||||
|
ort = { version = "2.0.0-rc.11", optional = true }
|
||||||
|
ndarray = { version = "0.16", optional = true }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
ml = ["ort", "ndarray"]
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tempfile = "3"
|
||||||
236
README.md
Normal file
236
README.md
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
kalshi-backtest
|
||||||
|
===
|
||||||
|
|
||||||
|
quant-level backtesting framework for kalshi prediction markets, using a candidate pipeline architecture.
|
||||||
|
|
||||||
|
|
||||||
|
features
|
||||||
|
---
|
||||||
|
|
||||||
|
- **multi-timeframe momentum** - detects divergence between short and long-term trends
|
||||||
|
- **bollinger bands mean reversion** - signals when price touches statistical extremes
|
||||||
|
- **order flow analysis** - tracks buying vs selling pressure via taker_side
|
||||||
|
- **kelly criterion position sizing** - dynamic sizing based on edge and win probability
|
||||||
|
- **exit signals** - take profit, stop loss, time stops, and score reversal triggers
|
||||||
|
- **category-aware weighting** - different strategies for politics, weather, sports, etc.
|
||||||
|
- **ensemble scoring** - combine multiple models with dynamic weighting
|
||||||
|
- **cross-market correlations** - lead-lag relationships between related markets
|
||||||
|
- **ML ensemble (optional)** - LSTM + MLP models via ONNX runtime
|
||||||
|
|
||||||
|
|
||||||
|
architecture
|
||||||
|
---
|
||||||
|
|
||||||
|
```
|
||||||
|
Historical Data (CSV)
|
||||||
|
|
|
||||||
|
v
|
||||||
|
+------------------+
|
||||||
|
| Backtest Loop | <- simulates time progression
|
||||||
|
+------------------+
|
||||||
|
|
|
||||||
|
v
|
||||||
|
+------------------+
|
||||||
|
| Candidate Pipeline |
|
||||||
|
+------------------+
|
||||||
|
| |
|
||||||
|
v v
|
||||||
|
Sources Filters -> Scorers -> Selector
|
||||||
|
|
|
||||||
|
v
|
||||||
|
+------------------+
|
||||||
|
| Trade Executor | <- kelly sizing, exit signals
|
||||||
|
+------------------+
|
||||||
|
|
|
||||||
|
v
|
||||||
|
+------------------+
|
||||||
|
| P&L Tracker | <- tracks positions, returns
|
||||||
|
+------------------+
|
||||||
|
|
|
||||||
|
v
|
||||||
|
Performance Metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
data format
|
||||||
|
---
|
||||||
|
|
||||||
|
fetch data from kalshi API using the included script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/fetch_kalshi_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
or download from https://www.deltabase.tech/
|
||||||
|
|
||||||
|
**markets.csv**:
|
||||||
|
```csv
|
||||||
|
ticker,title,category,open_time,close_time,result,status,yes_bid,yes_ask,volume,open_interest
|
||||||
|
PRES-2024-DEM,Will Democrats win?,politics,2024-01-01 00:00:00,2024-11-06 00:00:00,no,finalized,45,47,10000,5000
|
||||||
|
```
|
||||||
|
|
||||||
|
**trades.csv**:
|
||||||
|
```csv
|
||||||
|
timestamp,ticker,price,volume,taker_side
|
||||||
|
2024-01-05 12:00:00,PRES-2024-DEM,45,100,yes
|
||||||
|
2024-01-05 13:00:00,PRES-2024-DEM,46,50,no
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
usage
|
||||||
|
---
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# build
|
||||||
|
cargo build --release
|
||||||
|
|
||||||
|
# run backtest with quant features
|
||||||
|
cargo run --release -- run \
|
||||||
|
--data-dir data \
|
||||||
|
--start 2024-01-01 \
|
||||||
|
--end 2024-06-01 \
|
||||||
|
--capital 10000 \
|
||||||
|
--max-position 500 \
|
||||||
|
--max-positions 10 \
|
||||||
|
--kelly-fraction 0.25 \
|
||||||
|
--max-position-pct 0.25 \
|
||||||
|
--take-profit 0.20 \
|
||||||
|
--stop-loss 0.15 \
|
||||||
|
--max-hold-hours 72 \
|
||||||
|
--compare-random
|
||||||
|
|
||||||
|
# view results
|
||||||
|
cargo run --release -- summary --results-file results/backtest_result.json
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
cli options
|
||||||
|
---
|
||||||
|
|
||||||
|
| option | default | description |
|
||||||
|
|--------|---------|-------------|
|
||||||
|
| --data-dir | data | directory with markets.csv and trades.csv |
|
||||||
|
| --start | required | backtest start date |
|
||||||
|
| --end | required | backtest end date |
|
||||||
|
| --capital | 10000 | initial capital |
|
||||||
|
| --max-position | 100 | max shares per position |
|
||||||
|
| --max-positions | 5 | max concurrent positions |
|
||||||
|
| --kelly-fraction | 0.25 | fraction of kelly criterion (0.1=conservative, 1.0=full) |
|
||||||
|
| --max-position-pct | 0.25 | max % of capital per position |
|
||||||
|
| --take-profit | 0.20 | take profit threshold (20% gain) |
|
||||||
|
| --stop-loss | 0.15 | stop loss threshold (15% loss) |
|
||||||
|
| --max-hold-hours | 72 | time stop in hours |
|
||||||
|
| --compare-random | false | compare vs random baseline |
|
||||||
|
|
||||||
|
|
||||||
|
scorers
|
||||||
|
---
|
||||||
|
|
||||||
|
**basic scorers**:
|
||||||
|
- `MomentumScorer` - price change over lookback period
|
||||||
|
- `MeanReversionScorer` - deviation from historical mean
|
||||||
|
- `VolumeScorer` - unusual volume detection
|
||||||
|
- `TimeDecayScorer` - prefer markets with more time to close
|
||||||
|
|
||||||
|
**quant scorers**:
|
||||||
|
- `MultiTimeframeMomentumScorer` - analyzes 1h, 4h, 12h, 24h windows, detects divergence
|
||||||
|
- `BollingerMeanReversionScorer` - triggers at upper/lower band touches (2 std)
|
||||||
|
- `OrderFlowScorer` - buy/sell imbalance from taker_side
|
||||||
|
- `CategoryWeightedScorer` - different weights per category
|
||||||
|
- `EnsembleScorer` - combines models with dynamic weights
|
||||||
|
- `CorrelationScorer` - cross-market lead-lag signals
|
||||||
|
|
||||||
|
**ml scorers** (requires `ml` feature):
|
||||||
|
- `MLEnsembleScorer` - LSTM + MLP via ONNX
|
||||||
|
|
||||||
|
|
||||||
|
position sizing
|
||||||
|
---
|
||||||
|
|
||||||
|
uses kelly criterion with safety multiplier:
|
||||||
|
|
||||||
|
```
|
||||||
|
kelly = (odds * win_prob - (1 - win_prob)) / odds
|
||||||
|
safe_kelly = kelly * kelly_fraction
|
||||||
|
position = min(bankroll * safe_kelly, max_position_pct * bankroll)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
exit signals
|
||||||
|
---
|
||||||
|
|
||||||
|
positions can exit via:
|
||||||
|
1. **resolution** - market resolves yes/no
|
||||||
|
2. **take profit** - pnl exceeds threshold
|
||||||
|
3. **stop loss** - pnl below threshold
|
||||||
|
4. **time stop** - held too long (capital rotation)
|
||||||
|
5. **score reversal** - strategy flips bearish
|
||||||
|
|
||||||
|
|
||||||
|
ml training (optional)
|
||||||
|
---
|
||||||
|
|
||||||
|
train ML models using pytorch, then export to ONNX:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# install dependencies
|
||||||
|
pip install torch pandas numpy
|
||||||
|
|
||||||
|
# train models
|
||||||
|
python scripts/train_ml_models.py \
|
||||||
|
--data data/trades.csv \
|
||||||
|
--markets data/markets.csv \
|
||||||
|
--output models/ \
|
||||||
|
--epochs 50
|
||||||
|
|
||||||
|
# enable ml feature
|
||||||
|
cargo build --release --features ml
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
metrics
|
||||||
|
---
|
||||||
|
|
||||||
|
- total return ($ and %)
|
||||||
|
- sharpe ratio (annualized)
|
||||||
|
- max drawdown
|
||||||
|
- win rate
|
||||||
|
- average trade P&L
|
||||||
|
- average hold time
|
||||||
|
- trades per day
|
||||||
|
- return by category
|
||||||
|
|
||||||
|
|
||||||
|
extending
|
||||||
|
---
|
||||||
|
|
||||||
|
add custom scorers by implementing the `Scorer` trait:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
pub struct MyScorer;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for MyScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"MyScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
// compute scores...
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
if let Some(score) = scored.scores.get("my_score") {
|
||||||
|
candidate.scores.insert("my_score".to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
then add to the pipeline in `backtest.rs`.
|
||||||
0
data/.gitkeep
Normal file
0
data/.gitkeep
Normal file
1
data/fetch_state.json
Normal file
1
data/fetch_state.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"markets_cursor": "CgsI-rDDywYQkKOiMRI5S1hNVkVTUE9SVFNNVUxUSUdBTUVFWFRFTkRFRC1TMjAyNTBDMDMzMDBBRkYyLTkxNTVFNjFERTk3", "markets_count": 25000, "trades_cursor": null, "trades_count": 0, "markets_done": false, "trades_done": false}
|
||||||
254
scripts/fetch_kalshi_data.py
Executable file
254
scripts/fetch_kalshi_data.py
Executable file
@ -0,0 +1,254 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Fetch historical trade and market data from Kalshi's public API.
|
||||||
|
No authentication required for public endpoints.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Incremental saves (writes batches to disk)
|
||||||
|
- Resume capability (tracks cursor position)
|
||||||
|
- Retry logic with exponential backoff
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import csv
|
||||||
|
import time
|
||||||
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
BASE_URL = "https://api.elections.kalshi.com/trade-api/v2"
|
||||||
|
STATE_FILE = "fetch_state.json"
|
||||||
|
|
||||||
|
def fetch_json(url: str, max_retries: int = 5) -> dict:
|
||||||
|
"""Fetch JSON from URL with retries and exponential backoff."""
|
||||||
|
req = urllib.request.Request(url, headers={"Accept": "application/json"})
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req, timeout=30) as resp:
|
||||||
|
return json.loads(resp.read().decode())
|
||||||
|
except (urllib.error.HTTPError, urllib.error.URLError) as e:
|
||||||
|
wait = 2 ** attempt
|
||||||
|
print(f" attempt {attempt + 1}/{max_retries} failed: {e}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
print(f" retrying in {wait}s...")
|
||||||
|
time.sleep(wait)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
wait = 2 ** attempt
|
||||||
|
print(f" unexpected error: {e}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
print(f" retrying in {wait}s...")
|
||||||
|
time.sleep(wait)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_state(output_dir: Path) -> dict:
|
||||||
|
"""Load saved state for resuming."""
|
||||||
|
state_path = output_dir / STATE_FILE
|
||||||
|
if state_path.exists():
|
||||||
|
with open(state_path) as f:
|
||||||
|
return json.load(f)
|
||||||
|
return {"markets_cursor": None, "markets_count": 0,
|
||||||
|
"trades_cursor": None, "trades_count": 0,
|
||||||
|
"markets_done": False, "trades_done": False}
|
||||||
|
|
||||||
|
def save_state(output_dir: Path, state: dict):
|
||||||
|
"""Save state for resuming."""
|
||||||
|
state_path = output_dir / STATE_FILE
|
||||||
|
with open(state_path, "w") as f:
|
||||||
|
json.dump(state, f)
|
||||||
|
|
||||||
|
def append_markets_csv(markets: list, output_path: Path, write_header: bool):
|
||||||
|
"""Append markets to CSV."""
|
||||||
|
mode = "w" if write_header else "a"
|
||||||
|
with open(output_path, mode, newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
if write_header:
|
||||||
|
writer.writerow(["ticker", "title", "category", "open_time",
|
||||||
|
"close_time", "result", "status", "yes_bid",
|
||||||
|
"yes_ask", "volume", "open_interest"])
|
||||||
|
|
||||||
|
for m in markets:
|
||||||
|
result = ""
|
||||||
|
if m.get("result") == "yes":
|
||||||
|
result = "yes"
|
||||||
|
elif m.get("result") == "no":
|
||||||
|
result = "no"
|
||||||
|
elif m.get("status") == "finalized" and m.get("result"):
|
||||||
|
result = m.get("result")
|
||||||
|
|
||||||
|
writer.writerow([
|
||||||
|
m.get("ticker", ""),
|
||||||
|
m.get("title", ""),
|
||||||
|
m.get("category", ""),
|
||||||
|
m.get("open_time", ""),
|
||||||
|
m.get("close_time", m.get("expiration_time", "")),
|
||||||
|
result,
|
||||||
|
m.get("status", ""),
|
||||||
|
m.get("yes_bid", ""),
|
||||||
|
m.get("yes_ask", ""),
|
||||||
|
m.get("volume", ""),
|
||||||
|
m.get("open_interest", ""),
|
||||||
|
])
|
||||||
|
|
||||||
|
def append_trades_csv(trades: list, output_path: Path, write_header: bool):
|
||||||
|
"""Append trades to CSV."""
|
||||||
|
mode = "w" if write_header else "a"
|
||||||
|
with open(output_path, mode, newline="") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
if write_header:
|
||||||
|
writer.writerow(["timestamp", "ticker", "price", "volume", "taker_side"])
|
||||||
|
|
||||||
|
for t in trades:
|
||||||
|
price = t.get("yes_price", t.get("price", 50))
|
||||||
|
taker_side = t.get("taker_side", "")
|
||||||
|
if not taker_side:
|
||||||
|
taker_side = "yes" if t.get("is_taker_side_yes", True) else "no"
|
||||||
|
|
||||||
|
writer.writerow([
|
||||||
|
t.get("created_time", t.get("ts", "")),
|
||||||
|
t.get("ticker", t.get("market_ticker", "")),
|
||||||
|
price,
|
||||||
|
t.get("count", t.get("volume", 1)),
|
||||||
|
taker_side,
|
||||||
|
])
|
||||||
|
|
||||||
|
def fetch_markets_incremental(output_dir: Path, state: dict) -> int:
|
||||||
|
"""Fetch markets incrementally with state tracking."""
|
||||||
|
output_path = output_dir / "markets.csv"
|
||||||
|
cursor = state["markets_cursor"]
|
||||||
|
total = state["markets_count"]
|
||||||
|
write_header = total == 0
|
||||||
|
|
||||||
|
print(f"Resuming from {total} markets...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
url = f"{BASE_URL}/markets?limit=1000"
|
||||||
|
if cursor:
|
||||||
|
url += f"&cursor={cursor}"
|
||||||
|
|
||||||
|
print(f"Fetching markets... ({total:,} so far)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = fetch_json(url)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error fetching markets: {e}")
|
||||||
|
print(f"Progress saved. Run again to resume from {total:,} markets.")
|
||||||
|
return total
|
||||||
|
|
||||||
|
batch = data.get("markets", [])
|
||||||
|
if batch:
|
||||||
|
append_markets_csv(batch, output_path, write_header)
|
||||||
|
write_header = False
|
||||||
|
total += len(batch)
|
||||||
|
|
||||||
|
cursor = data.get("cursor")
|
||||||
|
state["markets_cursor"] = cursor
|
||||||
|
state["markets_count"] = total
|
||||||
|
save_state(output_dir, state)
|
||||||
|
|
||||||
|
if not cursor:
|
||||||
|
state["markets_done"] = True
|
||||||
|
save_state(output_dir, state)
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
def fetch_trades_incremental(output_dir: Path, state: dict, limit: int) -> int:
|
||||||
|
"""Fetch trades incrementally with state tracking."""
|
||||||
|
output_path = output_dir / "trades.csv"
|
||||||
|
cursor = state["trades_cursor"]
|
||||||
|
total = state["trades_count"]
|
||||||
|
write_header = total == 0
|
||||||
|
|
||||||
|
print(f"Resuming from {total} trades...")
|
||||||
|
|
||||||
|
while total < limit:
|
||||||
|
url = f"{BASE_URL}/markets/trades?limit=1000"
|
||||||
|
if cursor:
|
||||||
|
url += f"&cursor={cursor}"
|
||||||
|
|
||||||
|
print(f"Fetching trades... ({total:,}/{limit:,})")
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = fetch_json(url)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error fetching trades: {e}")
|
||||||
|
print(f"Progress saved. Run again to resume from {total:,} trades.")
|
||||||
|
return total
|
||||||
|
|
||||||
|
batch = data.get("trades", [])
|
||||||
|
if not batch:
|
||||||
|
break
|
||||||
|
|
||||||
|
append_trades_csv(batch, output_path, write_header)
|
||||||
|
write_header = False
|
||||||
|
total += len(batch)
|
||||||
|
|
||||||
|
cursor = data.get("cursor")
|
||||||
|
state["trades_cursor"] = cursor
|
||||||
|
state["trades_count"] = total
|
||||||
|
save_state(output_dir, state)
|
||||||
|
|
||||||
|
if not cursor:
|
||||||
|
state["trades_done"] = True
|
||||||
|
save_state(output_dir, state)
|
||||||
|
break
|
||||||
|
|
||||||
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
return total
|
||||||
|
|
||||||
|
def main():
|
||||||
|
output_dir = Path("/mnt/work/kalshi-data")
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("Kalshi Data Fetcher (with resume)")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
state = load_state(output_dir)
|
||||||
|
|
||||||
|
# fetch markets
|
||||||
|
if not state["markets_done"]:
|
||||||
|
print("\n[1/2] Fetching markets...")
|
||||||
|
markets_count = fetch_markets_incremental(output_dir, state)
|
||||||
|
if state["markets_done"]:
|
||||||
|
print(f"Markets complete: {markets_count:,}")
|
||||||
|
else:
|
||||||
|
print(f"Markets paused at: {markets_count:,}")
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
print(f"\n[1/2] Markets already complete: {state['markets_count']:,}")
|
||||||
|
|
||||||
|
# fetch trades
|
||||||
|
if not state["trades_done"]:
|
||||||
|
print("\n[2/2] Fetching trades...")
|
||||||
|
trades_count = fetch_trades_incremental(output_dir, state, limit=1000000)
|
||||||
|
if state["trades_done"]:
|
||||||
|
print(f"Trades complete: {trades_count:,}")
|
||||||
|
else:
|
||||||
|
print(f"Trades paused at: {trades_count:,}")
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
print(f"\n[2/2] Trades already complete: {state['trades_count']:,}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Done!")
|
||||||
|
print(f"Markets: {state['markets_count']:,}")
|
||||||
|
print(f"Trades: {state['trades_count']:,}")
|
||||||
|
print(f"Output: {output_dir}")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
# clear state for next run
|
||||||
|
(output_dir / STATE_FILE).unlink(missing_ok=True)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
280
scripts/train_ml_models.py
Normal file
280
scripts/train_ml_models.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Train ML models for the kalshi backtest framework.
|
||||||
|
|
||||||
|
Models:
|
||||||
|
- LSTM: learns patterns from price history sequences
|
||||||
|
- MLP: learns optimal combination of hand-crafted features
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/train_ml_models.py --data data/trades.csv --output models/
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
HAS_TORCH = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_TORCH = False
|
||||||
|
print("warning: pytorch not installed. run: pip install torch")
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Train ML models for kalshi backtest")
|
||||||
|
parser.add_argument("--data", type=Path, default=Path("data/trades.csv"))
|
||||||
|
parser.add_argument("--markets", type=Path, default=Path("data/markets.csv"))
|
||||||
|
parser.add_argument("--output", type=Path, default=Path("models"))
|
||||||
|
parser.add_argument("--epochs", type=int, default=50)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=64)
|
||||||
|
parser.add_argument("--seq-len", type=int, default=24)
|
||||||
|
parser.add_argument("--train-split", type=float, default=0.8)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
class LSTMPredictor(nn.Module):
|
||||||
|
def __init__(self, input_size=1, hidden_size=128, num_layers=2, dropout=0.2):
|
||||||
|
super().__init__()
|
||||||
|
self.lstm = nn.LSTM(
|
||||||
|
input_size=input_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
self.fc = nn.Linear(hidden_size, 1)
|
||||||
|
self.tanh = nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
lstm_out, _ = self.lstm(x)
|
||||||
|
last_output = lstm_out[:, -1, :]
|
||||||
|
out = self.fc(last_output)
|
||||||
|
return self.tanh(out)
|
||||||
|
|
||||||
|
class MLPPredictor(nn.Module):
|
||||||
|
def __init__(self, input_size=7, hidden_sizes=[64, 32]):
|
||||||
|
super().__init__()
|
||||||
|
layers = []
|
||||||
|
prev_size = input_size
|
||||||
|
for h in hidden_sizes:
|
||||||
|
layers.append(nn.Linear(prev_size, h))
|
||||||
|
layers.append(nn.ReLU())
|
||||||
|
layers.append(nn.Dropout(0.2))
|
||||||
|
prev_size = h
|
||||||
|
layers.append(nn.Linear(prev_size, 1))
|
||||||
|
layers.append(nn.Tanh())
|
||||||
|
self.net = nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.net(x)
|
||||||
|
|
||||||
|
def load_data(trades_path: Path, markets_path: Path, seq_len: int):
|
||||||
|
print(f"loading trades from {trades_path}...")
|
||||||
|
trades = pd.read_csv(trades_path)
|
||||||
|
trades["timestamp"] = pd.to_datetime(trades["timestamp"])
|
||||||
|
trades = trades.sort_values(["ticker", "timestamp"])
|
||||||
|
|
||||||
|
print(f"loading markets from {markets_path}...")
|
||||||
|
markets = pd.read_csv(markets_path)
|
||||||
|
markets["close_time"] = pd.to_datetime(markets["close_time"])
|
||||||
|
|
||||||
|
result_map = dict(zip(markets["ticker"], markets["result"]))
|
||||||
|
|
||||||
|
sequences = []
|
||||||
|
features = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
for ticker, group in trades.groupby("ticker"):
|
||||||
|
result = result_map.get(ticker)
|
||||||
|
if result not in ["yes", "no"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
label = 1.0 if result == "yes" else -1.0
|
||||||
|
|
||||||
|
prices = group["price"].values / 100.0
|
||||||
|
volumes = group["volume"].values
|
||||||
|
taker_sides = (group["taker_side"] == "yes").astype(float).values
|
||||||
|
|
||||||
|
if len(prices) < seq_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for i in range(seq_len, len(prices)):
|
||||||
|
seq = prices[i - seq_len : i]
|
||||||
|
log_returns = np.diff(np.log(np.clip(seq, 1e-6, 1.0)))
|
||||||
|
|
||||||
|
if len(log_returns) == seq_len - 1:
|
||||||
|
log_returns = np.pad(log_returns, (1, 0), mode="constant")
|
||||||
|
|
||||||
|
sequences.append(log_returns)
|
||||||
|
|
||||||
|
curr_price = prices[i - 1]
|
||||||
|
momentum = prices[i - 1] - prices[i - seq_len] if len(prices) > seq_len else 0
|
||||||
|
mean_price = np.mean(prices[i - seq_len : i])
|
||||||
|
mean_reversion = mean_price - curr_price
|
||||||
|
vol_sum = np.sum(volumes[i - seq_len : i])
|
||||||
|
buy_vol = np.sum(volumes[i - seq_len : i] * taker_sides[i - seq_len : i])
|
||||||
|
sell_vol = vol_sum - buy_vol
|
||||||
|
order_flow = (buy_vol - sell_vol) / max(vol_sum, 1)
|
||||||
|
|
||||||
|
feat = [
|
||||||
|
momentum,
|
||||||
|
mean_reversion,
|
||||||
|
np.log1p(vol_sum),
|
||||||
|
order_flow,
|
||||||
|
curr_price,
|
||||||
|
np.std(log_returns) if len(log_returns) > 1 else 0,
|
||||||
|
len(group) / 1000.0,
|
||||||
|
]
|
||||||
|
features.append(feat)
|
||||||
|
labels.append(label)
|
||||||
|
|
||||||
|
print(f"created {len(sequences)} training samples")
|
||||||
|
return np.array(sequences), np.array(features), np.array(labels)
|
||||||
|
|
||||||
|
def train_lstm(sequences, labels, args):
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Training LSTM")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
n = len(sequences)
|
||||||
|
split = int(n * args.train_split)
|
||||||
|
|
||||||
|
X_train = torch.tensor(sequences[:split], dtype=torch.float32).unsqueeze(-1)
|
||||||
|
y_train = torch.tensor(labels[:split], dtype=torch.float32).unsqueeze(-1)
|
||||||
|
X_test = torch.tensor(sequences[split:], dtype=torch.float32).unsqueeze(-1)
|
||||||
|
y_test = torch.tensor(labels[split:], dtype=torch.float32).unsqueeze(-1)
|
||||||
|
|
||||||
|
train_dataset = TensorDataset(X_train, y_train)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
model = LSTMPredictor(input_size=1, hidden_size=128, num_layers=2)
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0
|
||||||
|
for X_batch, y_batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(X_batch)
|
||||||
|
loss = criterion(output, y_batch)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
if (epoch + 1) % 10 == 0:
|
||||||
|
model.set_mode_to_inference()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_pred = model(X_train)
|
||||||
|
test_pred = model(X_test)
|
||||||
|
train_acc = ((train_pred > 0) == (y_train > 0)).float().mean()
|
||||||
|
test_acc = ((test_pred > 0) == (y_test > 0)).float().mean()
|
||||||
|
print(f"epoch {epoch + 1}/{args.epochs}: loss={total_loss/len(train_loader):.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def train_mlp(features, labels, args):
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Training MLP")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
|
||||||
|
|
||||||
|
n = len(features)
|
||||||
|
split = int(n * args.train_split)
|
||||||
|
|
||||||
|
X_train = torch.tensor(features[:split], dtype=torch.float32)
|
||||||
|
y_train = torch.tensor(labels[:split], dtype=torch.float32).unsqueeze(-1)
|
||||||
|
X_test = torch.tensor(features[split:], dtype=torch.float32)
|
||||||
|
y_test = torch.tensor(labels[split:], dtype=torch.float32).unsqueeze(-1)
|
||||||
|
|
||||||
|
train_dataset = TensorDataset(X_train, y_train)
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
|
||||||
|
|
||||||
|
model = MLPPredictor(input_size=features.shape[1])
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
model.train()
|
||||||
|
total_loss = 0
|
||||||
|
for X_batch, y_batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
output = model(X_batch)
|
||||||
|
loss = criterion(output, y_batch)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
if (epoch + 1) % 10 == 0:
|
||||||
|
model.set_mode_to_inference()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_pred = model(X_train)
|
||||||
|
test_pred = model(X_test)
|
||||||
|
train_acc = ((train_pred > 0) == (y_train > 0)).float().mean()
|
||||||
|
test_acc = ((test_pred > 0) == (y_test > 0)).float().mean()
|
||||||
|
print(f"epoch {epoch + 1}/{args.epochs}: loss={total_loss/len(train_loader):.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def export_onnx(model, output_path: Path, input_shape, input_name="input", output_name="output"):
|
||||||
|
model.set_mode_to_inference()
|
||||||
|
dummy_input = torch.randn(*input_shape)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
dummy_input,
|
||||||
|
output_path,
|
||||||
|
input_names=[input_name],
|
||||||
|
output_names=[output_name],
|
||||||
|
dynamic_axes={
|
||||||
|
input_name: {0: "batch_size"},
|
||||||
|
output_name: {0: "batch_size"},
|
||||||
|
},
|
||||||
|
opset_version=14,
|
||||||
|
)
|
||||||
|
print(f"exported to {output_path}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
if not HAS_TORCH:
|
||||||
|
print("error: pytorch required for training. install with: pip install torch")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if not args.data.exists():
|
||||||
|
print(f"error: data file not found: {args.data}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if not args.markets.exists():
|
||||||
|
print(f"error: markets file not found: {args.markets}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
args.output.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
sequences, features, labels = load_data(args.data, args.markets, args.seq_len)
|
||||||
|
|
||||||
|
if len(sequences) < 100:
|
||||||
|
print(f"error: not enough training data ({len(sequences)} samples)")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
lstm_model = train_lstm(sequences, labels, args)
|
||||||
|
export_onnx(lstm_model, args.output / "lstm.onnx", (1, args.seq_len, 1))
|
||||||
|
|
||||||
|
mlp_model = train_mlp(features, labels, args)
|
||||||
|
export_onnx(mlp_model, args.output / "mlp.onnx", (1, features.shape[1]))
|
||||||
|
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Training complete!")
|
||||||
|
print(f"Models saved to: {args.output}")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
445
src/backtest.rs
Normal file
445
src/backtest.rs
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
use crate::data::HistoricalData;
|
||||||
|
use crate::execution::{Executor, PositionSizingConfig};
|
||||||
|
use crate::metrics::{BacktestResult, MetricsCollector};
|
||||||
|
use crate::pipeline::{
|
||||||
|
AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer, Filter,
|
||||||
|
HistoricalMarketSource, LiquidityFilter, MeanReversionScorer, MomentumScorer,
|
||||||
|
MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer, Selector, Source, TimeDecayScorer,
|
||||||
|
TimeToCloseFilter, TopKSelector, TradingPipeline, VolumeScorer,
|
||||||
|
};
|
||||||
|
use crate::types::{
|
||||||
|
BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType,
|
||||||
|
TradingContext,
|
||||||
|
};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use rust_decimal::Decimal;
|
||||||
|
use rust_decimal::prelude::ToPrimitive;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
/// resolves any positions in markets that have closed
|
||||||
|
/// returns list of (ticker, result, pnl) for logging purposes
|
||||||
|
fn resolve_closed_positions(
|
||||||
|
portfolio: &mut Portfolio,
|
||||||
|
data: &HistoricalData,
|
||||||
|
resolved: &mut HashSet<String>,
|
||||||
|
at: DateTime<Utc>,
|
||||||
|
history: &mut Vec<Trade>,
|
||||||
|
metrics: &mut MetricsCollector,
|
||||||
|
) -> Vec<(String, MarketResult, Option<Decimal>)> {
|
||||||
|
let tickers: Vec<String> = portfolio.positions.keys().cloned().collect();
|
||||||
|
let mut resolutions = Vec::new();
|
||||||
|
|
||||||
|
for ticker in tickers {
|
||||||
|
if resolved.contains(&ticker) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(result) = data.get_resolution_at(&ticker, at) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
resolved.insert(ticker.clone());
|
||||||
|
let Some(pos) = portfolio.positions.get(&ticker).cloned() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let pnl = portfolio.resolve_position(&ticker, result);
|
||||||
|
|
||||||
|
let exit_price = match result {
|
||||||
|
MarketResult::Yes => match pos.side {
|
||||||
|
Side::Yes => Decimal::ONE,
|
||||||
|
Side::No => Decimal::ZERO,
|
||||||
|
},
|
||||||
|
MarketResult::No => match pos.side {
|
||||||
|
Side::Yes => Decimal::ZERO,
|
||||||
|
Side::No => Decimal::ONE,
|
||||||
|
},
|
||||||
|
MarketResult::Cancelled => pos.avg_entry_price,
|
||||||
|
};
|
||||||
|
|
||||||
|
let category = data
|
||||||
|
.markets
|
||||||
|
.get(&ticker)
|
||||||
|
.map(|m| m.category.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let trade = Trade {
|
||||||
|
ticker: ticker.clone(),
|
||||||
|
side: pos.side,
|
||||||
|
quantity: pos.quantity,
|
||||||
|
price: exit_price,
|
||||||
|
timestamp: at,
|
||||||
|
trade_type: TradeType::Resolution,
|
||||||
|
};
|
||||||
|
|
||||||
|
history.push(trade.clone());
|
||||||
|
metrics.record_trade(&trade, &category);
|
||||||
|
resolutions.push((ticker, result, pnl));
|
||||||
|
}
|
||||||
|
|
||||||
|
resolutions
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Backtester {
|
||||||
|
config: BacktestConfig,
|
||||||
|
data: Arc<HistoricalData>,
|
||||||
|
pipeline: TradingPipeline,
|
||||||
|
executor: Executor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Backtester {
|
||||||
|
pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self {
|
||||||
|
let pipeline = Self::build_default_pipeline(data.clone(), &config);
|
||||||
|
let executor = Executor::new(data.clone(), 10, config.max_position_size);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
data,
|
||||||
|
pipeline,
|
||||||
|
executor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_configs(
|
||||||
|
config: BacktestConfig,
|
||||||
|
data: Arc<HistoricalData>,
|
||||||
|
sizing_config: PositionSizingConfig,
|
||||||
|
exit_config: ExitConfig,
|
||||||
|
) -> Self {
|
||||||
|
let pipeline = Self::build_default_pipeline(data.clone(), &config);
|
||||||
|
let executor = Executor::new(data.clone(), 10, config.max_position_size)
|
||||||
|
.with_sizing_config(sizing_config)
|
||||||
|
.with_exit_config(exit_config);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
data,
|
||||||
|
pipeline,
|
||||||
|
executor,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self {
|
||||||
|
self.pipeline = pipeline;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_default_pipeline(data: Arc<HistoricalData>, config: &BacktestConfig) -> TradingPipeline {
|
||||||
|
let sources: Vec<Box<dyn Source>> = vec![
|
||||||
|
Box::new(HistoricalMarketSource::new(data, 24)),
|
||||||
|
];
|
||||||
|
|
||||||
|
let filters: Vec<Box<dyn Filter>> = vec![
|
||||||
|
Box::new(LiquidityFilter::new(100)),
|
||||||
|
Box::new(TimeToCloseFilter::new(2, Some(720))),
|
||||||
|
Box::new(AlreadyPositionedFilter::new(config.max_position_size)),
|
||||||
|
];
|
||||||
|
|
||||||
|
let scorers: Vec<Box<dyn Scorer>> = vec![
|
||||||
|
Box::new(MomentumScorer::new(6)),
|
||||||
|
Box::new(MultiTimeframeMomentumScorer::default_windows()),
|
||||||
|
Box::new(MeanReversionScorer::new(24)),
|
||||||
|
Box::new(BollingerMeanReversionScorer::default_config()),
|
||||||
|
Box::new(VolumeScorer::new(6)),
|
||||||
|
Box::new(OrderFlowScorer::new()),
|
||||||
|
Box::new(TimeDecayScorer::new()),
|
||||||
|
Box::new(CategoryWeightedScorer::with_defaults()),
|
||||||
|
];
|
||||||
|
|
||||||
|
let selector: Box<dyn Selector> = Box::new(TopKSelector::new(config.max_positions));
|
||||||
|
|
||||||
|
TradingPipeline::new(sources, filters, scorers, selector, config.max_positions)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(&self) -> BacktestResult {
|
||||||
|
let mut context = TradingContext::new(self.config.initial_capital, self.config.start_time);
|
||||||
|
let mut metrics = MetricsCollector::new(self.config.initial_capital);
|
||||||
|
let mut resolved_markets: HashSet<String> = HashSet::new();
|
||||||
|
|
||||||
|
let mut current_time = self.config.start_time;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
start = %self.config.start_time,
|
||||||
|
end = %self.config.end_time,
|
||||||
|
interval_hours = self.config.interval.num_hours(),
|
||||||
|
"starting backtest"
|
||||||
|
);
|
||||||
|
|
||||||
|
while current_time < self.config.end_time {
|
||||||
|
context.timestamp = current_time;
|
||||||
|
context.request_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
|
||||||
|
let resolutions = resolve_closed_positions(
|
||||||
|
&mut context.portfolio,
|
||||||
|
&self.data,
|
||||||
|
&mut resolved_markets,
|
||||||
|
current_time,
|
||||||
|
&mut context.trading_history,
|
||||||
|
&mut metrics,
|
||||||
|
);
|
||||||
|
for (ticker, result, pnl) in resolutions {
|
||||||
|
info!(ticker = %ticker, result = ?result, pnl = ?pnl, "market resolved");
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = self.pipeline.execute(context.clone()).await;
|
||||||
|
|
||||||
|
let candidate_scores: std::collections::HashMap<String, f64> = result
|
||||||
|
.selected_candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| (c.ticker.clone(), c.final_score))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let exit_signals = self.executor.generate_exit_signals(&context, &candidate_scores);
|
||||||
|
for exit in exit_signals {
|
||||||
|
if let Some(position) = context.portfolio.positions.get(&exit.ticker).cloned() {
|
||||||
|
let pnl = context.portfolio.close_position(&exit.ticker, exit.current_price);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
ticker = %exit.ticker,
|
||||||
|
reason = ?exit.reason,
|
||||||
|
pnl = ?pnl,
|
||||||
|
"exit triggered"
|
||||||
|
);
|
||||||
|
|
||||||
|
let category = self
|
||||||
|
.data
|
||||||
|
.markets
|
||||||
|
.get(&exit.ticker)
|
||||||
|
.map(|m| m.category.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let exit_price = match position.side {
|
||||||
|
crate::types::Side::Yes => exit.current_price,
|
||||||
|
crate::types::Side::No => Decimal::ONE - exit.current_price,
|
||||||
|
};
|
||||||
|
|
||||||
|
let trade = Trade {
|
||||||
|
ticker: exit.ticker.clone(),
|
||||||
|
side: position.side,
|
||||||
|
quantity: position.quantity,
|
||||||
|
price: exit_price,
|
||||||
|
timestamp: current_time,
|
||||||
|
trade_type: TradeType::Close,
|
||||||
|
};
|
||||||
|
|
||||||
|
context.trading_history.push(trade.clone());
|
||||||
|
metrics.record_trade(&trade, &category);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let signals = self.executor.generate_signals(&result.selected_candidates, &context);
|
||||||
|
|
||||||
|
for signal in signals {
|
||||||
|
if let Some(fill) = self.executor.execute_signal(&signal, &context) {
|
||||||
|
info!(
|
||||||
|
ticker = %fill.ticker,
|
||||||
|
side = ?fill.side,
|
||||||
|
quantity = fill.quantity,
|
||||||
|
price = %fill.price,
|
||||||
|
"executed trade"
|
||||||
|
);
|
||||||
|
|
||||||
|
context.portfolio.apply_fill(&fill);
|
||||||
|
|
||||||
|
let category = self
|
||||||
|
.data
|
||||||
|
.markets
|
||||||
|
.get(&fill.ticker)
|
||||||
|
.map(|m| m.category.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
let trade = Trade {
|
||||||
|
ticker: fill.ticker.clone(),
|
||||||
|
side: fill.side,
|
||||||
|
quantity: fill.quantity,
|
||||||
|
price: fill.price,
|
||||||
|
timestamp: fill.timestamp,
|
||||||
|
trade_type: TradeType::Open,
|
||||||
|
};
|
||||||
|
|
||||||
|
context.trading_history.push(trade.clone());
|
||||||
|
metrics.record_trade(&trade, &category);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let market_prices = self.get_current_prices(current_time);
|
||||||
|
metrics.record(current_time, &context.portfolio, &market_prices);
|
||||||
|
|
||||||
|
current_time = current_time + self.config.interval;
|
||||||
|
}
|
||||||
|
|
||||||
|
let resolutions = resolve_closed_positions(
|
||||||
|
&mut context.portfolio,
|
||||||
|
&self.data,
|
||||||
|
&mut resolved_markets,
|
||||||
|
self.config.end_time,
|
||||||
|
&mut context.trading_history,
|
||||||
|
&mut metrics,
|
||||||
|
);
|
||||||
|
for (ticker, result, pnl) in resolutions {
|
||||||
|
info!(ticker = %ticker, result = ?result, pnl = ?pnl, "market resolved");
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(
|
||||||
|
trades = context.trading_history.len(),
|
||||||
|
positions = context.portfolio.positions.len(),
|
||||||
|
cash = %context.portfolio.cash,
|
||||||
|
"backtest complete"
|
||||||
|
);
|
||||||
|
|
||||||
|
metrics.finalize()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_current_prices(&self, at: DateTime<Utc>) -> HashMap<String, Decimal> {
|
||||||
|
self.data
|
||||||
|
.markets
|
||||||
|
.keys()
|
||||||
|
.filter_map(|ticker| {
|
||||||
|
self.data
|
||||||
|
.get_current_price(ticker, at)
|
||||||
|
.map(|p| (ticker.clone(), p))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RandomBaseline {
|
||||||
|
config: BacktestConfig,
|
||||||
|
data: Arc<HistoricalData>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RandomBaseline {
|
||||||
|
pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self {
|
||||||
|
Self { config, data }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(&self) -> BacktestResult {
|
||||||
|
let mut context = TradingContext::new(self.config.initial_capital, self.config.start_time);
|
||||||
|
let mut metrics = MetricsCollector::new(self.config.initial_capital);
|
||||||
|
let mut resolved_markets: HashSet<String> = HashSet::new();
|
||||||
|
let mut rng_state: u64 = 42;
|
||||||
|
|
||||||
|
let mut current_time = self.config.start_time;
|
||||||
|
|
||||||
|
while current_time < self.config.end_time {
|
||||||
|
context.timestamp = current_time;
|
||||||
|
|
||||||
|
resolve_closed_positions(
|
||||||
|
&mut context.portfolio,
|
||||||
|
&self.data,
|
||||||
|
&mut resolved_markets,
|
||||||
|
current_time,
|
||||||
|
&mut context.trading_history,
|
||||||
|
&mut metrics,
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(fill) = self.try_random_trade(&context, current_time, &mut rng_state) {
|
||||||
|
let category = self
|
||||||
|
.data
|
||||||
|
.markets
|
||||||
|
.get(&fill.ticker)
|
||||||
|
.map(|m| m.category.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
context.portfolio.apply_fill(&fill);
|
||||||
|
|
||||||
|
let trade = Trade {
|
||||||
|
ticker: fill.ticker.clone(),
|
||||||
|
side: fill.side,
|
||||||
|
quantity: fill.quantity,
|
||||||
|
price: fill.price,
|
||||||
|
timestamp: current_time,
|
||||||
|
trade_type: TradeType::Open,
|
||||||
|
};
|
||||||
|
|
||||||
|
context.trading_history.push(trade.clone());
|
||||||
|
metrics.record_trade(&trade, &category);
|
||||||
|
}
|
||||||
|
|
||||||
|
let market_prices = self.get_current_prices(current_time);
|
||||||
|
metrics.record(current_time, &context.portfolio, &market_prices);
|
||||||
|
|
||||||
|
current_time = current_time + self.config.interval;
|
||||||
|
}
|
||||||
|
|
||||||
|
resolve_closed_positions(
|
||||||
|
&mut context.portfolio,
|
||||||
|
&self.data,
|
||||||
|
&mut resolved_markets,
|
||||||
|
self.config.end_time,
|
||||||
|
&mut context.trading_history,
|
||||||
|
&mut metrics,
|
||||||
|
);
|
||||||
|
|
||||||
|
metrics.finalize()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn try_random_trade(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
at: DateTime<Utc>,
|
||||||
|
rng_state: &mut u64,
|
||||||
|
) -> Option<Fill> {
|
||||||
|
if context.portfolio.positions.len() >= self.config.max_positions {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let active_markets = self.data.get_active_markets(at);
|
||||||
|
let unpositioned: Vec<_> = active_markets
|
||||||
|
.iter()
|
||||||
|
.filter(|m| !context.portfolio.has_position(&m.ticker))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if unpositioned.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
*rng_state = lcg_next(*rng_state);
|
||||||
|
let idx = (*rng_state as usize) % unpositioned.len();
|
||||||
|
let market = unpositioned[idx];
|
||||||
|
|
||||||
|
let price = self.data.get_current_price(&market.ticker, at)?;
|
||||||
|
let side = if *rng_state % 2 == 0 { Side::Yes } else { Side::No };
|
||||||
|
|
||||||
|
let effective_price = match side {
|
||||||
|
Side::Yes => price,
|
||||||
|
Side::No => Decimal::ONE - price,
|
||||||
|
};
|
||||||
|
|
||||||
|
let quantity = self
|
||||||
|
.config
|
||||||
|
.max_position_size
|
||||||
|
.min((context.portfolio.cash / effective_price).to_u64().unwrap_or(0));
|
||||||
|
|
||||||
|
if quantity == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(Fill {
|
||||||
|
ticker: market.ticker.clone(),
|
||||||
|
side,
|
||||||
|
quantity,
|
||||||
|
price: effective_price,
|
||||||
|
timestamp: at,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_current_prices(&self, at: DateTime<Utc>) -> HashMap<String, Decimal> {
|
||||||
|
self.data
|
||||||
|
.markets
|
||||||
|
.keys()
|
||||||
|
.filter_map(|ticker| {
|
||||||
|
self.data
|
||||||
|
.get_current_price(ticker, at)
|
||||||
|
.map(|p| (ticker.clone(), p))
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// linear congruential generator for deterministic random baseline
|
||||||
|
fn lcg_next(state: u64) -> u64 {
|
||||||
|
state.wrapping_mul(1103515245).wrapping_add(12345)
|
||||||
|
}
|
||||||
290
src/data/loader.rs
Normal file
290
src/data/loader.rs
Normal file
@ -0,0 +1,290 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use csv::ReaderBuilder;
|
||||||
|
use rust_decimal::Decimal;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use crate::types::{MarketData, MarketResult, PricePoint, Side, TradeData};
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct CsvMarket {
|
||||||
|
ticker: String,
|
||||||
|
title: String,
|
||||||
|
category: String,
|
||||||
|
#[serde(with = "flexible_datetime")]
|
||||||
|
open_time: DateTime<Utc>,
|
||||||
|
#[serde(with = "flexible_datetime")]
|
||||||
|
close_time: DateTime<Utc>,
|
||||||
|
result: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct CsvTrade {
|
||||||
|
#[serde(with = "flexible_datetime")]
|
||||||
|
timestamp: DateTime<Utc>,
|
||||||
|
ticker: String,
|
||||||
|
price: f64,
|
||||||
|
volume: u64,
|
||||||
|
taker_side: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
mod flexible_datetime {
|
||||||
|
use chrono::{DateTime, NaiveDateTime, Utc};
|
||||||
|
use serde::{self, Deserialize, Deserializer};
|
||||||
|
|
||||||
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let s = String::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
if let Ok(dt) = DateTime::parse_from_rfc3339(&s) {
|
||||||
|
return Ok(dt.with_timezone(&Utc));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(dt) = NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S") {
|
||||||
|
return Ok(dt.and_utc());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(ts) = s.parse::<i64>() {
|
||||||
|
return DateTime::from_timestamp(ts, 0)
|
||||||
|
.ok_or_else(|| serde::de::Error::custom("invalid timestamp"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(serde::de::Error::custom(format!(
|
||||||
|
"could not parse datetime: {}",
|
||||||
|
s
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct HistoricalData {
|
||||||
|
pub markets: HashMap<String, MarketData>,
|
||||||
|
pub trades: Vec<TradeData>,
|
||||||
|
trade_index: HashMap<String, Vec<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HistoricalData {
|
||||||
|
pub fn load(data_dir: &Path) -> Result<Self> {
|
||||||
|
let markets_path = data_dir.join("markets.csv");
|
||||||
|
let trades_path = data_dir.join("trades.csv");
|
||||||
|
|
||||||
|
let markets = load_markets(&markets_path)
|
||||||
|
.with_context(|| format!("loading markets from {:?}", markets_path))?;
|
||||||
|
|
||||||
|
let trades =
|
||||||
|
load_trades(&trades_path).with_context(|| format!("loading trades from {:?}", trades_path))?;
|
||||||
|
|
||||||
|
let mut trade_index: HashMap<String, Vec<usize>> = HashMap::new();
|
||||||
|
for (i, trade) in trades.iter().enumerate() {
|
||||||
|
trade_index.entry(trade.ticker.clone()).or_default().push(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
markets,
|
||||||
|
trades,
|
||||||
|
trade_index,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_active_markets(&self, at: DateTime<Utc>) -> Vec<&MarketData> {
|
||||||
|
self.markets
|
||||||
|
.values()
|
||||||
|
.filter(|m| at >= m.open_time && at < m.close_time)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_trades_for_market(&self, ticker: &str, from: DateTime<Utc>, to: DateTime<Utc>) -> Vec<&TradeData> {
|
||||||
|
self.trade_index
|
||||||
|
.get(ticker)
|
||||||
|
.map(|indices| {
|
||||||
|
indices
|
||||||
|
.iter()
|
||||||
|
.filter_map(|&i| {
|
||||||
|
let trade = &self.trades[i];
|
||||||
|
if trade.timestamp >= from && trade.timestamp < to {
|
||||||
|
Some(trade)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_current_price(&self, ticker: &str, at: DateTime<Utc>) -> Option<Decimal> {
|
||||||
|
self.trade_index.get(ticker).and_then(|indices| {
|
||||||
|
indices
|
||||||
|
.iter()
|
||||||
|
.filter_map(|&i| {
|
||||||
|
let trade = &self.trades[i];
|
||||||
|
if trade.timestamp <= at {
|
||||||
|
Some(trade)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.last()
|
||||||
|
.map(|t| t.price)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_price_history(
|
||||||
|
&self,
|
||||||
|
ticker: &str,
|
||||||
|
from: DateTime<Utc>,
|
||||||
|
to: DateTime<Utc>,
|
||||||
|
) -> Vec<PricePoint> {
|
||||||
|
self.get_trades_for_market(ticker, from, to)
|
||||||
|
.into_iter()
|
||||||
|
.map(|t| PricePoint {
|
||||||
|
timestamp: t.timestamp,
|
||||||
|
yes_price: t.price,
|
||||||
|
volume: t.volume,
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_volume_24h(&self, ticker: &str, at: DateTime<Utc>) -> u64 {
|
||||||
|
let from = at - chrono::Duration::hours(24);
|
||||||
|
self.get_trades_for_market(ticker, from, at)
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.volume)
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_order_flow_24h(&self, ticker: &str, at: DateTime<Utc>) -> (u64, u64) {
|
||||||
|
let from = at - chrono::Duration::hours(24);
|
||||||
|
let trades = self.get_trades_for_market(ticker, from, at);
|
||||||
|
let buy_vol: u64 = trades.iter().filter(|t| t.taker_side == Side::Yes).map(|t| t.volume).sum();
|
||||||
|
let sell_vol: u64 = trades.iter().filter(|t| t.taker_side == Side::No).map(|t| t.volume).sum();
|
||||||
|
(buy_vol, sell_vol)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_resolutions(&self, at: DateTime<Utc>) -> Vec<(&MarketData, MarketResult)> {
|
||||||
|
self.markets
|
||||||
|
.values()
|
||||||
|
.filter_map(|m| {
|
||||||
|
if m.close_time <= at {
|
||||||
|
m.result.map(|r| (m, r))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_resolution_at(&self, ticker: &str, at: DateTime<Utc>) -> Option<MarketResult> {
|
||||||
|
self.markets.get(ticker).and_then(|m| {
|
||||||
|
if m.close_time <= at {
|
||||||
|
m.result
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_markets(path: &Path) -> Result<HashMap<String, MarketData>> {
|
||||||
|
let mut reader = ReaderBuilder::new()
|
||||||
|
.has_headers(true)
|
||||||
|
.flexible(true)
|
||||||
|
.from_path(path)?;
|
||||||
|
|
||||||
|
let mut markets = HashMap::new();
|
||||||
|
|
||||||
|
for result in reader.deserialize() {
|
||||||
|
let record: CsvMarket = result?;
|
||||||
|
let result = record.result.as_ref().and_then(|r| match r.to_lowercase().as_str() {
|
||||||
|
"yes" => Some(MarketResult::Yes),
|
||||||
|
"no" => Some(MarketResult::No),
|
||||||
|
"cancelled" | "canceled" => Some(MarketResult::Cancelled),
|
||||||
|
"" => None,
|
||||||
|
_ => None,
|
||||||
|
});
|
||||||
|
|
||||||
|
markets.insert(
|
||||||
|
record.ticker.clone(),
|
||||||
|
MarketData {
|
||||||
|
ticker: record.ticker,
|
||||||
|
title: record.title,
|
||||||
|
category: record.category,
|
||||||
|
open_time: record.open_time,
|
||||||
|
close_time: record.close_time,
|
||||||
|
result,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(markets)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_trades(path: &Path) -> Result<Vec<TradeData>> {
|
||||||
|
let mut reader = ReaderBuilder::new()
|
||||||
|
.has_headers(true)
|
||||||
|
.flexible(true)
|
||||||
|
.from_path(path)?;
|
||||||
|
|
||||||
|
let mut trades = Vec::new();
|
||||||
|
|
||||||
|
for result in reader.deserialize() {
|
||||||
|
let record: CsvTrade = result?;
|
||||||
|
let side = match record.taker_side.to_lowercase().as_str() {
|
||||||
|
"yes" | "buy" => Side::Yes,
|
||||||
|
"no" | "sell" => Side::No,
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
trades.push(TradeData {
|
||||||
|
timestamp: record.timestamp,
|
||||||
|
ticker: record.ticker,
|
||||||
|
price: Decimal::try_from(record.price / 100.0).unwrap_or(Decimal::ZERO),
|
||||||
|
volume: record.volume,
|
||||||
|
taker_side: side,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
trades.sort_by_key(|t| t.timestamp);
|
||||||
|
|
||||||
|
Ok(trades)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::io::Write;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn create_test_data() -> TempDir {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
let markets_csv = r#"ticker,title,category,open_time,close_time,result
|
||||||
|
TEST-MKT-1,Test Market 1,politics,2024-01-01 00:00:00,2024-01-15 00:00:00,yes
|
||||||
|
TEST-MKT-2,Test Market 2,economics,2024-01-01 00:00:00,2024-01-20 00:00:00,no
|
||||||
|
"#;
|
||||||
|
let mut f = std::fs::File::create(dir.path().join("markets.csv")).unwrap();
|
||||||
|
f.write_all(markets_csv.as_bytes()).unwrap();
|
||||||
|
|
||||||
|
let trades_csv = r#"timestamp,ticker,price,volume,taker_side
|
||||||
|
2024-01-05 12:00:00,TEST-MKT-1,55,100,yes
|
||||||
|
2024-01-05 13:00:00,TEST-MKT-1,57,50,yes
|
||||||
|
2024-01-06 10:00:00,TEST-MKT-2,45,200,no
|
||||||
|
"#;
|
||||||
|
let mut f = std::fs::File::create(dir.path().join("trades.csv")).unwrap();
|
||||||
|
f.write_all(trades_csv.as_bytes()).unwrap();
|
||||||
|
|
||||||
|
dir
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_load_historical_data() {
|
||||||
|
let dir = create_test_data();
|
||||||
|
let data = HistoricalData::load(dir.path()).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(data.markets.len(), 2);
|
||||||
|
assert_eq!(data.trades.len(), 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
3
src/data/mod.rs
Normal file
3
src/data/mod.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
mod loader;
|
||||||
|
|
||||||
|
pub use loader::HistoricalData;
|
||||||
325
src/execution.rs
Normal file
325
src/execution.rs
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
use crate::data::HistoricalData;
|
||||||
|
use crate::types::{ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side, Signal, TradingContext};
|
||||||
|
use rust_decimal::Decimal;
|
||||||
|
use rust_decimal::prelude::ToPrimitive;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PositionSizingConfig {
|
||||||
|
pub kelly_fraction: f64,
|
||||||
|
pub max_position_pct: f64,
|
||||||
|
pub min_position_size: u64,
|
||||||
|
pub max_position_size: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for PositionSizingConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
kelly_fraction: 0.25,
|
||||||
|
max_position_pct: 0.25,
|
||||||
|
min_position_size: 10,
|
||||||
|
max_position_size: 1000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PositionSizingConfig {
|
||||||
|
pub fn conservative() -> Self {
|
||||||
|
Self {
|
||||||
|
kelly_fraction: 0.1,
|
||||||
|
max_position_pct: 0.1,
|
||||||
|
min_position_size: 10,
|
||||||
|
max_position_size: 500,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn aggressive() -> Self {
|
||||||
|
Self {
|
||||||
|
kelly_fraction: 0.5,
|
||||||
|
max_position_pct: 0.4,
|
||||||
|
min_position_size: 10,
|
||||||
|
max_position_size: 2000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// maps scoring edge [-inf, +inf] to win probability [0, 1]
|
||||||
|
/// tanh squashes extreme values smoothly; +1)/2 shifts from [-1,1] to [0,1]
|
||||||
|
fn edge_to_win_probability(edge: f64) -> f64 {
|
||||||
|
(1.0 + edge.tanh()) / 2.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn kelly_size(
|
||||||
|
edge: f64,
|
||||||
|
price: f64,
|
||||||
|
bankroll: f64,
|
||||||
|
config: &PositionSizingConfig,
|
||||||
|
) -> u64 {
|
||||||
|
if edge.abs() < 0.01 || price <= 0.0 || price >= 1.0 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let win_prob = edge_to_win_probability(edge);
|
||||||
|
let odds = (1.0 - price) / price;
|
||||||
|
|
||||||
|
if odds <= 0.0 {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let kelly = (odds * win_prob - (1.0 - win_prob)) / odds;
|
||||||
|
let safe_kelly = (kelly * config.kelly_fraction).max(0.0);
|
||||||
|
let position_value = bankroll * safe_kelly.min(config.max_position_pct);
|
||||||
|
let shares = (position_value / price).floor() as u64;
|
||||||
|
|
||||||
|
shares.max(config.min_position_size).min(config.max_position_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Executor {
|
||||||
|
data: Arc<HistoricalData>,
|
||||||
|
slippage_bps: u32,
|
||||||
|
max_position_size: u64,
|
||||||
|
sizing_config: PositionSizingConfig,
|
||||||
|
exit_config: ExitConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Executor {
|
||||||
|
pub fn new(data: Arc<HistoricalData>, slippage_bps: u32, max_position_size: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
data,
|
||||||
|
slippage_bps,
|
||||||
|
max_position_size,
|
||||||
|
sizing_config: PositionSizingConfig::default(),
|
||||||
|
exit_config: ExitConfig::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_sizing_config(mut self, config: PositionSizingConfig) -> Self {
|
||||||
|
self.sizing_config = config;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_exit_config(mut self, config: ExitConfig) -> Self {
|
||||||
|
self.exit_config = config;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_exit_signals(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidate_scores: &std::collections::HashMap<String, f64>,
|
||||||
|
) -> Vec<ExitSignal> {
|
||||||
|
let mut exits = Vec::new();
|
||||||
|
|
||||||
|
for (ticker, position) in &context.portfolio.positions {
|
||||||
|
let current_price = match self.data.get_current_price(ticker, context.timestamp) {
|
||||||
|
Some(p) => p,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let effective_price = match position.side {
|
||||||
|
Side::Yes => current_price,
|
||||||
|
Side::No => Decimal::ONE - current_price,
|
||||||
|
};
|
||||||
|
|
||||||
|
let entry_price_f64 = position.avg_entry_price.to_f64().unwrap_or(0.5);
|
||||||
|
let current_price_f64 = effective_price.to_f64().unwrap_or(0.5);
|
||||||
|
|
||||||
|
if entry_price_f64 <= 0.0 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let pnl_pct = (current_price_f64 - entry_price_f64) / entry_price_f64;
|
||||||
|
|
||||||
|
if pnl_pct >= self.exit_config.take_profit_pct {
|
||||||
|
exits.push(ExitSignal {
|
||||||
|
ticker: ticker.clone(),
|
||||||
|
reason: ExitReason::TakeProfit { pnl_pct },
|
||||||
|
current_price,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if pnl_pct <= -self.exit_config.stop_loss_pct {
|
||||||
|
exits.push(ExitSignal {
|
||||||
|
ticker: ticker.clone(),
|
||||||
|
reason: ExitReason::StopLoss { pnl_pct },
|
||||||
|
current_price,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let hours_held = (context.timestamp - position.entry_time).num_hours();
|
||||||
|
if hours_held >= self.exit_config.max_hold_hours {
|
||||||
|
exits.push(ExitSignal {
|
||||||
|
ticker: ticker.clone(),
|
||||||
|
reason: ExitReason::TimeStop { hours_held },
|
||||||
|
current_price,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(&new_score) = candidate_scores.get(ticker) {
|
||||||
|
if new_score < self.exit_config.score_reversal_threshold {
|
||||||
|
exits.push(ExitSignal {
|
||||||
|
ticker: ticker.clone(),
|
||||||
|
reason: ExitReason::ScoreReversal { new_score },
|
||||||
|
current_price,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
exits
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_signals(
|
||||||
|
&self,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
context: &TradingContext,
|
||||||
|
) -> Vec<Signal> {
|
||||||
|
candidates
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| self.candidate_to_signal(c, context))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn candidate_to_signal(
|
||||||
|
&self,
|
||||||
|
candidate: &MarketCandidate,
|
||||||
|
context: &TradingContext,
|
||||||
|
) -> Option<Signal> {
|
||||||
|
let current_position = context.portfolio.get_position(&candidate.ticker);
|
||||||
|
let current_qty = current_position.map(|p| p.quantity).unwrap_or(0);
|
||||||
|
|
||||||
|
if current_qty >= self.max_position_size {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let yes_price = candidate.current_yes_price.to_f64().unwrap_or(0.5);
|
||||||
|
|
||||||
|
// positive score = bullish signal, so buy the cheaper side (better risk/reward)
|
||||||
|
// negative score = bearish signal, so buy against the expensive side
|
||||||
|
let side = if candidate.final_score > 0.0 {
|
||||||
|
if yes_price < 0.5 { Side::Yes } else { Side::No }
|
||||||
|
} else if candidate.final_score < 0.0 {
|
||||||
|
if yes_price > 0.5 { Side::No } else { Side::Yes }
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let price = match side {
|
||||||
|
Side::Yes => candidate.current_yes_price,
|
||||||
|
Side::No => candidate.current_no_price,
|
||||||
|
};
|
||||||
|
|
||||||
|
let available_cash = context.portfolio.cash.to_f64().unwrap_or(0.0);
|
||||||
|
let price_f64 = price.to_f64().unwrap_or(0.5);
|
||||||
|
|
||||||
|
if price_f64 <= 0.0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let kelly_qty = kelly_size(
|
||||||
|
candidate.final_score,
|
||||||
|
price_f64,
|
||||||
|
available_cash,
|
||||||
|
&self.sizing_config,
|
||||||
|
);
|
||||||
|
|
||||||
|
let max_affordable = (available_cash / price_f64) as u64;
|
||||||
|
let quantity = kelly_qty
|
||||||
|
.min(max_affordable)
|
||||||
|
.min(self.max_position_size - current_qty);
|
||||||
|
|
||||||
|
if quantity < self.sizing_config.min_position_size {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(Signal {
|
||||||
|
ticker: candidate.ticker.clone(),
|
||||||
|
side,
|
||||||
|
quantity,
|
||||||
|
limit_price: Some(price),
|
||||||
|
reason: format!(
|
||||||
|
"score={:.3}, side={:?}, price={:.2}",
|
||||||
|
candidate.final_score, side, price_f64
|
||||||
|
),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute_signal(
|
||||||
|
&self,
|
||||||
|
signal: &Signal,
|
||||||
|
context: &TradingContext,
|
||||||
|
) -> Option<Fill> {
|
||||||
|
let market_price = self.data.get_current_price(&signal.ticker, context.timestamp)?;
|
||||||
|
|
||||||
|
let effective_price = match signal.side {
|
||||||
|
Side::Yes => market_price,
|
||||||
|
Side::No => Decimal::ONE - market_price,
|
||||||
|
};
|
||||||
|
|
||||||
|
let slippage = Decimal::from(self.slippage_bps) / Decimal::from(10000);
|
||||||
|
let fill_price = effective_price * (Decimal::ONE + slippage);
|
||||||
|
|
||||||
|
if let Some(limit) = signal.limit_price {
|
||||||
|
if fill_price > limit * (Decimal::ONE + slippage * Decimal::from(2)) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let cost = fill_price * Decimal::from(signal.quantity);
|
||||||
|
if cost > context.portfolio.cash {
|
||||||
|
let affordable = (context.portfolio.cash / fill_price)
|
||||||
|
.to_u64()
|
||||||
|
.unwrap_or(0);
|
||||||
|
if affordable == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Some(Fill {
|
||||||
|
ticker: signal.ticker.clone(),
|
||||||
|
side: signal.side,
|
||||||
|
quantity: affordable,
|
||||||
|
price: fill_price,
|
||||||
|
timestamp: context.timestamp,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(Fill {
|
||||||
|
ticker: signal.ticker.clone(),
|
||||||
|
side: signal.side,
|
||||||
|
quantity: signal.quantity,
|
||||||
|
price: fill_price,
|
||||||
|
timestamp: context.timestamp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn simple_signal_generator(
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
context: &TradingContext,
|
||||||
|
position_size: u64,
|
||||||
|
) -> Vec<Signal> {
|
||||||
|
candidates
|
||||||
|
.iter()
|
||||||
|
.filter(|c| c.final_score > 0.0)
|
||||||
|
.filter(|c| !context.portfolio.has_position(&c.ticker))
|
||||||
|
.map(|c| {
|
||||||
|
let yes_price = c.current_yes_price.to_f64().unwrap_or(0.5);
|
||||||
|
let (side, price) = if yes_price < 0.5 {
|
||||||
|
(Side::Yes, c.current_yes_price)
|
||||||
|
} else {
|
||||||
|
(Side::No, c.current_no_price)
|
||||||
|
};
|
||||||
|
|
||||||
|
Signal {
|
||||||
|
ticker: c.ticker.clone(),
|
||||||
|
side,
|
||||||
|
quantity: position_size,
|
||||||
|
limit_price: Some(price),
|
||||||
|
reason: format!("simple: score={:.3}", c.final_score),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
216
src/main.rs
Normal file
216
src/main.rs
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
mod backtest;
|
||||||
|
mod data;
|
||||||
|
mod execution;
|
||||||
|
mod metrics;
|
||||||
|
mod pipeline;
|
||||||
|
mod types;
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use backtest::{Backtester, RandomBaseline};
|
||||||
|
use chrono::{DateTime, NaiveDate, TimeZone, Utc};
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
use data::HistoricalData;
|
||||||
|
use execution::{Executor, PositionSizingConfig};
|
||||||
|
use rust_decimal::Decimal;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::info;
|
||||||
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||||
|
use types::{BacktestConfig, ExitConfig};
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
#[command(name = "kalshi-backtest")]
|
||||||
|
#[command(about = "backtesting framework for kalshi prediction markets")]
|
||||||
|
struct Cli {
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Commands,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand)]
|
||||||
|
enum Commands {
|
||||||
|
Run {
|
||||||
|
#[arg(short, long, default_value = "data")]
|
||||||
|
data_dir: PathBuf,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
start: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
end: String,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "10000")]
|
||||||
|
capital: f64,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "100")]
|
||||||
|
max_position: u64,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "5")]
|
||||||
|
max_positions: usize,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "1")]
|
||||||
|
interval_hours: i64,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "results")]
|
||||||
|
output_dir: PathBuf,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
compare_random: bool,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "0.25")]
|
||||||
|
kelly_fraction: f64,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "0.25")]
|
||||||
|
max_position_pct: f64,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "0.20")]
|
||||||
|
take_profit: f64,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "0.15")]
|
||||||
|
stop_loss: f64,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "72")]
|
||||||
|
max_hold_hours: i64,
|
||||||
|
},
|
||||||
|
|
||||||
|
Summary {
|
||||||
|
#[arg(short, long)]
|
||||||
|
results_file: PathBuf,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_date(s: &str) -> Result<DateTime<Utc>> {
|
||||||
|
if let Ok(dt) = DateTime::parse_from_rfc3339(s) {
|
||||||
|
return Ok(dt.with_timezone(&Utc));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
|
||||||
|
return Ok(Utc.from_utc_datetime(&date.and_hms_opt(0, 0, 0).unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow::anyhow!("could not parse date: {}", s))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(
|
||||||
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| "kalshi_backtest=info".into()),
|
||||||
|
)
|
||||||
|
.with(tracing_subscriber::fmt::layer())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let cli = Cli::parse();
|
||||||
|
|
||||||
|
match cli.command {
|
||||||
|
Commands::Run {
|
||||||
|
data_dir,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
capital,
|
||||||
|
max_position,
|
||||||
|
max_positions,
|
||||||
|
interval_hours,
|
||||||
|
output_dir,
|
||||||
|
compare_random,
|
||||||
|
kelly_fraction,
|
||||||
|
max_position_pct,
|
||||||
|
take_profit,
|
||||||
|
stop_loss,
|
||||||
|
max_hold_hours,
|
||||||
|
} => {
|
||||||
|
let start_time = parse_date(&start).context("parsing start date")?;
|
||||||
|
let end_time = parse_date(&end).context("parsing end date")?;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
data_dir = %data_dir.display(),
|
||||||
|
start = %start_time,
|
||||||
|
end = %end_time,
|
||||||
|
capital = capital,
|
||||||
|
"loading data"
|
||||||
|
);
|
||||||
|
|
||||||
|
let data = Arc::new(
|
||||||
|
HistoricalData::load(&data_dir).context("loading historical data")?,
|
||||||
|
);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
markets = data.markets.len(),
|
||||||
|
trades = data.trades.len(),
|
||||||
|
"data loaded"
|
||||||
|
);
|
||||||
|
|
||||||
|
let config = BacktestConfig {
|
||||||
|
start_time,
|
||||||
|
end_time,
|
||||||
|
interval: chrono::Duration::hours(interval_hours),
|
||||||
|
initial_capital: Decimal::try_from(capital).unwrap(),
|
||||||
|
max_position_size: max_position,
|
||||||
|
max_positions,
|
||||||
|
};
|
||||||
|
|
||||||
|
let sizing_config = PositionSizingConfig {
|
||||||
|
kelly_fraction,
|
||||||
|
max_position_pct,
|
||||||
|
min_position_size: 10,
|
||||||
|
max_position_size: max_position,
|
||||||
|
};
|
||||||
|
|
||||||
|
let exit_config = ExitConfig {
|
||||||
|
take_profit_pct: take_profit,
|
||||||
|
stop_loss_pct: stop_loss,
|
||||||
|
max_hold_hours,
|
||||||
|
score_reversal_threshold: -0.3,
|
||||||
|
};
|
||||||
|
|
||||||
|
let backtester = Backtester::with_configs(config.clone(), data.clone(), sizing_config, exit_config);
|
||||||
|
let result = backtester.run().await;
|
||||||
|
|
||||||
|
println!("{}", result.summary());
|
||||||
|
|
||||||
|
std::fs::create_dir_all(&output_dir)?;
|
||||||
|
let result_path = output_dir.join("backtest_result.json");
|
||||||
|
let json = serde_json::to_string_pretty(&result)?;
|
||||||
|
std::fs::write(&result_path, json)?;
|
||||||
|
info!(path = %result_path.display(), "results saved");
|
||||||
|
|
||||||
|
if compare_random {
|
||||||
|
println!("\n--- random baseline ---\n");
|
||||||
|
let baseline = RandomBaseline::new(config, data);
|
||||||
|
let baseline_result = baseline.run().await;
|
||||||
|
println!("{}", baseline_result.summary());
|
||||||
|
|
||||||
|
let baseline_path = output_dir.join("baseline_result.json");
|
||||||
|
let json = serde_json::to_string_pretty(&baseline_result)?;
|
||||||
|
std::fs::write(&baseline_path, json)?;
|
||||||
|
|
||||||
|
println!("\n--- comparison ---\n");
|
||||||
|
println!(
|
||||||
|
"strategy return: {:.2}% vs baseline: {:.2}%",
|
||||||
|
result.total_return_pct, baseline_result.total_return_pct
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"strategy sharpe: {:.3} vs baseline: {:.3}",
|
||||||
|
result.sharpe_ratio, baseline_result.sharpe_ratio
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"strategy win rate: {:.1}% vs baseline: {:.1}%",
|
||||||
|
result.win_rate, baseline_result.win_rate
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Commands::Summary { results_file } => {
|
||||||
|
let content = std::fs::read_to_string(&results_file)
|
||||||
|
.context("reading results file")?;
|
||||||
|
let result: metrics::BacktestResult =
|
||||||
|
serde_json::from_str(&content).context("parsing results")?;
|
||||||
|
|
||||||
|
println!("{}", result.summary());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
280
src/metrics.rs
Normal file
280
src/metrics.rs
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
use crate::types::{Portfolio, Trade, TradeType};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use rust_decimal::Decimal;
|
||||||
|
use rust_decimal::prelude::ToPrimitive;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct BacktestResult {
|
||||||
|
pub total_return: f64,
|
||||||
|
pub total_return_pct: f64,
|
||||||
|
pub sharpe_ratio: f64,
|
||||||
|
pub max_drawdown: f64,
|
||||||
|
pub max_drawdown_pct: f64,
|
||||||
|
pub win_rate: f64,
|
||||||
|
pub total_trades: usize,
|
||||||
|
pub winning_trades: usize,
|
||||||
|
pub losing_trades: usize,
|
||||||
|
pub avg_trade_pnl: f64,
|
||||||
|
pub avg_hold_time_hours: f64,
|
||||||
|
pub trades_per_day: f64,
|
||||||
|
pub return_by_category: HashMap<String, f64>,
|
||||||
|
pub equity_curve: Vec<EquityPoint>,
|
||||||
|
pub trade_log: Vec<TradeRecord>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct EquityPoint {
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
pub equity: f64,
|
||||||
|
pub cash: f64,
|
||||||
|
pub positions_value: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TradeRecord {
|
||||||
|
pub ticker: String,
|
||||||
|
pub entry_time: DateTime<Utc>,
|
||||||
|
pub exit_time: Option<DateTime<Utc>>,
|
||||||
|
pub side: String,
|
||||||
|
pub quantity: u64,
|
||||||
|
pub entry_price: f64,
|
||||||
|
pub exit_price: Option<f64>,
|
||||||
|
pub pnl: Option<f64>,
|
||||||
|
pub category: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MetricsCollector {
|
||||||
|
initial_capital: Decimal,
|
||||||
|
equity_curve: Vec<EquityPoint>,
|
||||||
|
trade_records: HashMap<String, TradeRecord>,
|
||||||
|
closed_trades: Vec<TradeRecord>,
|
||||||
|
daily_returns: Vec<f64>,
|
||||||
|
last_equity: f64,
|
||||||
|
peak_equity: f64,
|
||||||
|
max_drawdown: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MetricsCollector {
|
||||||
|
pub fn new(initial_capital: Decimal) -> Self {
|
||||||
|
let capital = initial_capital.to_f64().unwrap_or(10000.0);
|
||||||
|
Self {
|
||||||
|
initial_capital,
|
||||||
|
equity_curve: Vec::new(),
|
||||||
|
trade_records: HashMap::new(),
|
||||||
|
closed_trades: Vec::new(),
|
||||||
|
daily_returns: Vec::new(),
|
||||||
|
last_equity: capital,
|
||||||
|
peak_equity: capital,
|
||||||
|
max_drawdown: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn record(
|
||||||
|
&mut self,
|
||||||
|
timestamp: DateTime<Utc>,
|
||||||
|
portfolio: &Portfolio,
|
||||||
|
market_prices: &HashMap<String, Decimal>,
|
||||||
|
) {
|
||||||
|
let positions_value = portfolio
|
||||||
|
.positions
|
||||||
|
.values()
|
||||||
|
.map(|p| {
|
||||||
|
let price = market_prices
|
||||||
|
.get(&p.ticker)
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(p.avg_entry_price);
|
||||||
|
(price * Decimal::from(p.quantity)).to_f64().unwrap_or(0.0)
|
||||||
|
})
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
let cash = portfolio.cash.to_f64().unwrap_or(0.0);
|
||||||
|
let equity = cash + positions_value;
|
||||||
|
|
||||||
|
if equity > self.peak_equity {
|
||||||
|
self.peak_equity = equity;
|
||||||
|
}
|
||||||
|
|
||||||
|
let drawdown = (self.peak_equity - equity) / self.peak_equity;
|
||||||
|
if drawdown > self.max_drawdown {
|
||||||
|
self.max_drawdown = drawdown;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.last_equity > 0.0 {
|
||||||
|
let daily_return = (equity - self.last_equity) / self.last_equity;
|
||||||
|
self.daily_returns.push(daily_return);
|
||||||
|
}
|
||||||
|
self.last_equity = equity;
|
||||||
|
|
||||||
|
self.equity_curve.push(EquityPoint {
|
||||||
|
timestamp,
|
||||||
|
equity,
|
||||||
|
cash,
|
||||||
|
positions_value,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn record_trade(&mut self, trade: &Trade, category: &str) {
|
||||||
|
match trade.trade_type {
|
||||||
|
TradeType::Open => {
|
||||||
|
let record = TradeRecord {
|
||||||
|
ticker: trade.ticker.clone(),
|
||||||
|
entry_time: trade.timestamp,
|
||||||
|
exit_time: None,
|
||||||
|
side: format!("{:?}", trade.side),
|
||||||
|
quantity: trade.quantity,
|
||||||
|
entry_price: trade.price.to_f64().unwrap_or(0.0),
|
||||||
|
exit_price: None,
|
||||||
|
pnl: None,
|
||||||
|
category: category.to_string(),
|
||||||
|
};
|
||||||
|
self.trade_records.insert(trade.ticker.clone(), record);
|
||||||
|
}
|
||||||
|
TradeType::Close | TradeType::Resolution => {
|
||||||
|
if let Some(mut record) = self.trade_records.remove(&trade.ticker) {
|
||||||
|
let exit_price = trade.price.to_f64().unwrap_or(0.0);
|
||||||
|
let entry_cost = record.entry_price * record.quantity as f64;
|
||||||
|
let exit_value = exit_price * record.quantity as f64;
|
||||||
|
let pnl = exit_value - entry_cost;
|
||||||
|
|
||||||
|
record.exit_time = Some(trade.timestamp);
|
||||||
|
record.exit_price = Some(exit_price);
|
||||||
|
record.pnl = Some(pnl);
|
||||||
|
|
||||||
|
self.closed_trades.push(record);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn finalize(self) -> BacktestResult {
|
||||||
|
let initial = self.initial_capital.to_f64().unwrap_or(10000.0);
|
||||||
|
let final_equity = self.equity_curve.last().map(|e| e.equity).unwrap_or(initial);
|
||||||
|
let total_return = final_equity - initial;
|
||||||
|
let total_return_pct = total_return / initial * 100.0;
|
||||||
|
|
||||||
|
let sharpe_ratio = if self.daily_returns.len() > 1 {
|
||||||
|
let mean: f64 = self.daily_returns.iter().sum::<f64>() / self.daily_returns.len() as f64;
|
||||||
|
let variance: f64 = self
|
||||||
|
.daily_returns
|
||||||
|
.iter()
|
||||||
|
.map(|r| (r - mean).powi(2))
|
||||||
|
.sum::<f64>()
|
||||||
|
/ (self.daily_returns.len() - 1) as f64;
|
||||||
|
let std_dev = variance.sqrt();
|
||||||
|
if std_dev > 0.0 {
|
||||||
|
(mean / std_dev) * (252.0_f64).sqrt()
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let winning_trades = self.closed_trades.iter().filter(|t| t.pnl.unwrap_or(0.0) > 0.0).count();
|
||||||
|
let losing_trades = self.closed_trades.iter().filter(|t| t.pnl.unwrap_or(0.0) < 0.0).count();
|
||||||
|
let total_trades = self.closed_trades.len();
|
||||||
|
|
||||||
|
let win_rate = if total_trades > 0 {
|
||||||
|
winning_trades as f64 / total_trades as f64 * 100.0
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let avg_trade_pnl = if total_trades > 0 {
|
||||||
|
self.closed_trades.iter().filter_map(|t| t.pnl).sum::<f64>() / total_trades as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let avg_hold_time_hours = if total_trades > 0 {
|
||||||
|
self.closed_trades
|
||||||
|
.iter()
|
||||||
|
.filter_map(|t| {
|
||||||
|
t.exit_time.map(|exit| (exit - t.entry_time).num_hours() as f64)
|
||||||
|
})
|
||||||
|
.sum::<f64>()
|
||||||
|
/ total_trades as f64
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let duration_days = if self.equity_curve.len() >= 2 {
|
||||||
|
let start = self.equity_curve.first().unwrap().timestamp;
|
||||||
|
let end = self.equity_curve.last().unwrap().timestamp;
|
||||||
|
(end - start).num_days().max(1) as f64
|
||||||
|
} else {
|
||||||
|
1.0
|
||||||
|
};
|
||||||
|
|
||||||
|
let trades_per_day = total_trades as f64 / duration_days;
|
||||||
|
|
||||||
|
let mut return_by_category: HashMap<String, f64> = HashMap::new();
|
||||||
|
for trade in &self.closed_trades {
|
||||||
|
*return_by_category.entry(trade.category.clone()).or_insert(0.0) +=
|
||||||
|
trade.pnl.unwrap_or(0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
BacktestResult {
|
||||||
|
total_return,
|
||||||
|
total_return_pct,
|
||||||
|
sharpe_ratio,
|
||||||
|
max_drawdown: self.max_drawdown * 100.0,
|
||||||
|
max_drawdown_pct: self.max_drawdown * 100.0,
|
||||||
|
win_rate,
|
||||||
|
total_trades,
|
||||||
|
winning_trades,
|
||||||
|
losing_trades,
|
||||||
|
avg_trade_pnl,
|
||||||
|
avg_hold_time_hours,
|
||||||
|
trades_per_day,
|
||||||
|
return_by_category,
|
||||||
|
equity_curve: self.equity_curve,
|
||||||
|
trade_log: self.closed_trades,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BacktestResult {
|
||||||
|
pub fn summary(&self) -> String {
|
||||||
|
format!(
|
||||||
|
r#"
|
||||||
|
backtest results
|
||||||
|
================
|
||||||
|
|
||||||
|
performance
|
||||||
|
-----------
|
||||||
|
total return: ${:.2} ({:.2}%)
|
||||||
|
sharpe ratio: {:.3}
|
||||||
|
max drawdown: {:.2}%
|
||||||
|
|
||||||
|
trades
|
||||||
|
------
|
||||||
|
total trades: {}
|
||||||
|
win rate: {:.1}%
|
||||||
|
avg trade pnl: ${:.2}
|
||||||
|
avg hold time: {:.1} hours
|
||||||
|
trades per day: {:.2}
|
||||||
|
|
||||||
|
by category
|
||||||
|
-----------
|
||||||
|
{}
|
||||||
|
"#,
|
||||||
|
self.total_return,
|
||||||
|
self.total_return_pct,
|
||||||
|
self.sharpe_ratio,
|
||||||
|
self.max_drawdown_pct,
|
||||||
|
self.total_trades,
|
||||||
|
self.win_rate,
|
||||||
|
self.avg_trade_pnl,
|
||||||
|
self.avg_hold_time_hours,
|
||||||
|
self.trades_per_day,
|
||||||
|
self.return_by_category
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| format!(" {}: ${:.2}", k, v))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
259
src/pipeline/correlation_scorer.rs
Normal file
259
src/pipeline/correlation_scorer.rs
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
//! Cross-market correlation scorer
|
||||||
|
//!
|
||||||
|
//! Uses lead-lag relationships between related markets to generate signals.
|
||||||
|
//! When a related market moves, we expect similar movements in correlated markets.
|
||||||
|
|
||||||
|
use crate::pipeline::Scorer;
|
||||||
|
use crate::types::{MarketCandidate, TradingContext};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
|
/// correlation entry between two markets
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CorrelationEntry {
|
||||||
|
pub ticker_a: String,
|
||||||
|
pub ticker_b: String,
|
||||||
|
pub correlation: f64,
|
||||||
|
pub lag_hours: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// cross-market correlation scorer
|
||||||
|
/// uses precomputed correlations to generate signals based on related market movements
|
||||||
|
pub struct CorrelationScorer {
|
||||||
|
correlations: Arc<RwLock<HashMap<String, Vec<CorrelationEntry>>>>,
|
||||||
|
lookback_hours: i64,
|
||||||
|
min_correlation: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CorrelationScorer {
|
||||||
|
pub fn new(lookback_hours: i64, min_correlation: f64) -> Self {
|
||||||
|
Self {
|
||||||
|
correlations: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
lookback_hours,
|
||||||
|
min_correlation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_config() -> Self {
|
||||||
|
Self::new(24, 0.5)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_correlations(&self, correlations: Vec<CorrelationEntry>) {
|
||||||
|
let mut map = self.correlations.write().unwrap();
|
||||||
|
map.clear();
|
||||||
|
|
||||||
|
for entry in correlations {
|
||||||
|
map.entry(entry.ticker_a.clone())
|
||||||
|
.or_insert_with(Vec::new)
|
||||||
|
.push(entry.clone());
|
||||||
|
map.entry(entry.ticker_b.clone())
|
||||||
|
.or_insert_with(Vec::new)
|
||||||
|
.push(CorrelationEntry {
|
||||||
|
ticker_a: entry.ticker_b.clone(),
|
||||||
|
ticker_b: entry.ticker_a.clone(),
|
||||||
|
correlation: entry.correlation,
|
||||||
|
lag_hours: -entry.lag_hours,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_category_correlations(&self, categories: &[&str]) {
|
||||||
|
let mut entries = Vec::new();
|
||||||
|
|
||||||
|
for (i, cat_a) in categories.iter().enumerate() {
|
||||||
|
for cat_b in categories.iter().skip(i + 1) {
|
||||||
|
entries.push(CorrelationEntry {
|
||||||
|
ticker_a: cat_a.to_string(),
|
||||||
|
ticker_b: cat_b.to_string(),
|
||||||
|
correlation: 0.3,
|
||||||
|
lag_hours: 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.load_correlations(entries);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_related_signals(
|
||||||
|
&self,
|
||||||
|
ticker: &str,
|
||||||
|
all_candidates: &[MarketCandidate],
|
||||||
|
) -> f64 {
|
||||||
|
let correlations = self.correlations.read().unwrap();
|
||||||
|
let Some(related) = correlations.get(ticker) else {
|
||||||
|
return 0.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
let candidate_map: HashMap<&str, &MarketCandidate> = all_candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| (c.ticker.as_str(), c))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut weighted_signal = 0.0;
|
||||||
|
let mut total_weight = 0.0;
|
||||||
|
|
||||||
|
for entry in related {
|
||||||
|
if entry.correlation.abs() < self.min_correlation {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(related_candidate) = candidate_map.get(entry.ticker_b.as_str()) {
|
||||||
|
let related_momentum = related_candidate
|
||||||
|
.scores
|
||||||
|
.get("momentum")
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
|
||||||
|
let signal = related_momentum * entry.correlation;
|
||||||
|
weighted_signal += signal * entry.correlation.abs();
|
||||||
|
total_weight += entry.correlation.abs();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if total_weight > 0.0 {
|
||||||
|
weighted_signal / total_weight
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_category_correlation(
|
||||||
|
candidate: &MarketCandidate,
|
||||||
|
all_candidates: &[MarketCandidate],
|
||||||
|
) -> f64 {
|
||||||
|
let same_category: Vec<&MarketCandidate> = all_candidates
|
||||||
|
.iter()
|
||||||
|
.filter(|c| c.category == candidate.category && c.ticker != candidate.ticker)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if same_category.is_empty() {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let avg_momentum: f64 = same_category
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c.scores.get("momentum").copied())
|
||||||
|
.sum::<f64>()
|
||||||
|
/ same_category.len() as f64;
|
||||||
|
|
||||||
|
avg_momentum * 0.3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for CorrelationScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"CorrelationScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let related_signal = self.get_related_signals(&c.ticker, candidates);
|
||||||
|
let category_signal = Self::calculate_category_correlation(c, candidates);
|
||||||
|
let combined = related_signal * 0.7 + category_signal * 0.3;
|
||||||
|
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("cross_market".to_string(), combined);
|
||||||
|
scored.scores.insert("related_signal".to_string(), related_signal);
|
||||||
|
scored.scores.insert("category_signal".to_string(), category_signal);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
for key in ["cross_market", "related_signal", "category_signal"] {
|
||||||
|
if let Some(score) = scored.scores.get(key) {
|
||||||
|
candidate.scores.insert(key.to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// granger causality calculator (simplified version)
|
||||||
|
/// full implementation would use statistical tests
|
||||||
|
pub fn calculate_granger_causality(
|
||||||
|
prices_a: &[f64],
|
||||||
|
prices_b: &[f64],
|
||||||
|
max_lag: i64,
|
||||||
|
) -> Option<(f64, i64)> {
|
||||||
|
if prices_a.len() < 20 || prices_b.len() < 20 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let min_len = prices_a.len().min(prices_b.len());
|
||||||
|
let prices_a = &prices_a[..min_len];
|
||||||
|
let prices_b = &prices_b[..min_len];
|
||||||
|
|
||||||
|
let returns_a: Vec<f64> = prices_a
|
||||||
|
.windows(2)
|
||||||
|
.map(|w| if w[0] > 0.0 { (w[1] / w[0]).ln() } else { 0.0 })
|
||||||
|
.collect();
|
||||||
|
let returns_b: Vec<f64> = prices_b
|
||||||
|
.windows(2)
|
||||||
|
.map(|w| if w[0] > 0.0 { (w[1] / w[0]).ln() } else { 0.0 })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut best_corr: f64 = 0.0;
|
||||||
|
let mut best_lag: i64 = 0;
|
||||||
|
|
||||||
|
for lag in -max_lag..=max_lag {
|
||||||
|
let (a_slice, b_slice) = if lag >= 0 {
|
||||||
|
let l = lag as usize;
|
||||||
|
if l >= returns_a.len() || l >= returns_b.len() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
(&returns_a[l..], &returns_b[..returns_b.len() - l])
|
||||||
|
} else {
|
||||||
|
let l = (-lag) as usize;
|
||||||
|
if l >= returns_a.len() || l >= returns_b.len() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
(&returns_a[..returns_a.len() - l], &returns_b[l..])
|
||||||
|
};
|
||||||
|
|
||||||
|
if a_slice.len() < 10 || b_slice.len() < 10 || a_slice.len() != b_slice.len() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let n = a_slice.len() as f64;
|
||||||
|
let mean_a: f64 = a_slice.iter().sum::<f64>() / n;
|
||||||
|
let mean_b: f64 = b_slice.iter().sum::<f64>() / n;
|
||||||
|
|
||||||
|
let cov: f64 = a_slice
|
||||||
|
.iter()
|
||||||
|
.zip(b_slice.iter())
|
||||||
|
.map(|(a, b)| (a - mean_a) * (b - mean_b))
|
||||||
|
.sum::<f64>()
|
||||||
|
/ n;
|
||||||
|
|
||||||
|
let std_a: f64 = (a_slice.iter().map(|a| (a - mean_a).powi(2)).sum::<f64>() / n).sqrt();
|
||||||
|
let std_b: f64 = (b_slice.iter().map(|b| (b - mean_b).powi(2)).sum::<f64>() / n).sqrt();
|
||||||
|
|
||||||
|
if std_a > 0.0 && std_b > 0.0 {
|
||||||
|
let corr = cov / (std_a * std_b);
|
||||||
|
if corr.abs() > best_corr.abs() {
|
||||||
|
best_corr = corr;
|
||||||
|
best_lag = lag;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if best_corr.abs() > 0.1 {
|
||||||
|
Some((best_corr, best_lag))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
182
src/pipeline/filters.rs
Normal file
182
src/pipeline/filters.rs
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
use crate::pipeline::{Filter, FilterResult};
|
||||||
|
use crate::types::{MarketCandidate, TradingContext};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use chrono::Duration;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
pub struct LiquidityFilter {
|
||||||
|
min_volume_24h: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LiquidityFilter {
|
||||||
|
pub fn new(min_volume_24h: u64) -> Self {
|
||||||
|
Self { min_volume_24h }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter for LiquidityFilter {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"LiquidityFilter"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: Vec<MarketCandidate>,
|
||||||
|
) -> Result<FilterResult, String> {
|
||||||
|
let (kept, removed): (Vec<_>, Vec<_>) = candidates
|
||||||
|
.into_iter()
|
||||||
|
.partition(|c| c.volume_24h >= self.min_volume_24h);
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TimeToCloseFilter {
|
||||||
|
min_hours: i64,
|
||||||
|
max_hours: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TimeToCloseFilter {
|
||||||
|
pub fn new(min_hours: i64, max_hours: Option<i64>) -> Self {
|
||||||
|
Self { min_hours, max_hours }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter for TimeToCloseFilter {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"TimeToCloseFilter"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: Vec<MarketCandidate>,
|
||||||
|
) -> Result<FilterResult, String> {
|
||||||
|
let min_duration = Duration::hours(self.min_hours);
|
||||||
|
let max_duration = self.max_hours.map(Duration::hours);
|
||||||
|
|
||||||
|
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||||
|
let ttc = c.time_to_close(context.timestamp);
|
||||||
|
let above_min = ttc >= min_duration;
|
||||||
|
let below_max = max_duration.map(|max| ttc <= max).unwrap_or(true);
|
||||||
|
above_min && below_max
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AlreadyPositionedFilter {
|
||||||
|
max_position_per_market: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AlreadyPositionedFilter {
|
||||||
|
pub fn new(max_position_per_market: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
max_position_per_market,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter for AlreadyPositionedFilter {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"AlreadyPositionedFilter"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: Vec<MarketCandidate>,
|
||||||
|
) -> Result<FilterResult, String> {
|
||||||
|
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||||
|
context
|
||||||
|
.portfolio
|
||||||
|
.get_position(&c.ticker)
|
||||||
|
.map(|p| p.quantity < self.max_position_per_market)
|
||||||
|
.unwrap_or(true)
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct CategoryFilter {
|
||||||
|
whitelist: Option<HashSet<String>>,
|
||||||
|
blacklist: HashSet<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CategoryFilter {
|
||||||
|
pub fn whitelist(categories: Vec<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
whitelist: Some(categories.into_iter().collect()),
|
||||||
|
blacklist: HashSet::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn blacklist(categories: Vec<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
whitelist: None,
|
||||||
|
blacklist: categories.into_iter().collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter for CategoryFilter {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"CategoryFilter"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: Vec<MarketCandidate>,
|
||||||
|
) -> Result<FilterResult, String> {
|
||||||
|
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||||
|
let in_whitelist = self
|
||||||
|
.whitelist
|
||||||
|
.as_ref()
|
||||||
|
.map(|w| w.contains(&c.category))
|
||||||
|
.unwrap_or(true);
|
||||||
|
let not_blacklisted = !self.blacklist.contains(&c.category);
|
||||||
|
in_whitelist && not_blacklisted
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PriceRangeFilter {
|
||||||
|
min_price: f64,
|
||||||
|
max_price: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PriceRangeFilter {
|
||||||
|
pub fn new(min_price: f64, max_price: f64) -> Self {
|
||||||
|
Self { min_price, max_price }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Filter for PriceRangeFilter {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"PriceRangeFilter"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: Vec<MarketCandidate>,
|
||||||
|
) -> Result<FilterResult, String> {
|
||||||
|
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
|
||||||
|
let price = c.current_yes_price.to_string().parse::<f64>().unwrap_or(0.5);
|
||||||
|
price >= self.min_price && price <= self.max_price
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(FilterResult { kept, removed })
|
||||||
|
}
|
||||||
|
}
|
||||||
272
src/pipeline/ml_scorer.rs
Normal file
272
src/pipeline/ml_scorer.rs
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
//! ML-based scorer using ONNX models
|
||||||
|
//!
|
||||||
|
//! This module provides ML-based scoring using pre-trained ONNX models.
|
||||||
|
//! Models are trained separately using the Python scripts in scripts/
|
||||||
|
//!
|
||||||
|
//! Enable with the `ml` feature:
|
||||||
|
//! ```toml
|
||||||
|
//! kalshi-backtest = { features = ["ml"] }
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use crate::pipeline::Scorer;
|
||||||
|
use crate::types::{MarketCandidate, TradingContext};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
use {
|
||||||
|
ndarray::{Array1, Array2},
|
||||||
|
ort::{session::Session, value::Value},
|
||||||
|
std::sync::Arc,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// ML ensemble scorer that combines multiple ONNX models
|
||||||
|
///
|
||||||
|
/// Models:
|
||||||
|
/// - LSTM: sequence model for price history
|
||||||
|
/// - MLP: feedforward on hand-crafted features
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
pub struct MLEnsembleScorer {
|
||||||
|
lstm_session: Option<Arc<Session>>,
|
||||||
|
mlp_session: Option<Arc<Session>>,
|
||||||
|
ensemble_weights: Vec<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
impl MLEnsembleScorer {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
lstm_session: None,
|
||||||
|
mlp_session: None,
|
||||||
|
ensemble_weights: vec![0.5, 0.5],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_models(model_dir: &Path) -> Result<Self, String> {
|
||||||
|
let lstm_path = model_dir.join("lstm.onnx");
|
||||||
|
let mlp_path = model_dir.join("mlp.onnx");
|
||||||
|
|
||||||
|
let lstm_session = if lstm_path.exists() {
|
||||||
|
Some(Arc::new(
|
||||||
|
Session::builder()
|
||||||
|
.map_err(|e| format!("failed to create session builder: {}", e))?
|
||||||
|
.commit_from_file(&lstm_path)
|
||||||
|
.map_err(|e| format!("failed to load LSTM model: {}", e))?,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let mlp_session = if mlp_path.exists() {
|
||||||
|
Some(Arc::new(
|
||||||
|
Session::builder()
|
||||||
|
.map_err(|e| format!("failed to create session builder: {}", e))?
|
||||||
|
.commit_from_file(&mlp_path)
|
||||||
|
.map_err(|e| format!("failed to load MLP model: {}", e))?,
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
lstm_session,
|
||||||
|
mlp_session,
|
||||||
|
ensemble_weights: vec![0.5, 0.5],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_weights(mut self, weights: Vec<f64>) -> Self {
|
||||||
|
self.ensemble_weights = weights;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_features(candidate: &MarketCandidate) -> Vec<f64> {
|
||||||
|
vec![
|
||||||
|
candidate.scores.get("momentum").copied().unwrap_or(0.0),
|
||||||
|
candidate.scores.get("mean_reversion").copied().unwrap_or(0.0),
|
||||||
|
candidate.scores.get("volume").copied().unwrap_or(0.0),
|
||||||
|
candidate.scores.get("time_decay").copied().unwrap_or(0.0),
|
||||||
|
candidate.scores.get("order_flow").copied().unwrap_or(0.0),
|
||||||
|
candidate.scores.get("bollinger_reversion").copied().unwrap_or(0.0),
|
||||||
|
candidate.scores.get("mtf_momentum").copied().unwrap_or(0.0),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_price_sequence(candidate: &MarketCandidate, max_len: usize) -> Vec<f64> {
|
||||||
|
use rust_decimal::prelude::ToPrimitive;
|
||||||
|
|
||||||
|
let prices: Vec<f64> = candidate
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.take(max_len)
|
||||||
|
.filter_map(|p| p.yes_price.to_f64())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut sequence = vec![0.0; max_len];
|
||||||
|
for (i, &price) in prices.iter().enumerate() {
|
||||||
|
if i < max_len {
|
||||||
|
sequence[max_len - 1 - i] = price;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prices.len() >= 2 {
|
||||||
|
let mut log_returns = Vec::with_capacity(max_len);
|
||||||
|
for i in 1..sequence.len() {
|
||||||
|
if sequence[i - 1] > 0.0 && sequence[i] > 0.0 {
|
||||||
|
log_returns.push((sequence[i] / sequence[i - 1]).ln());
|
||||||
|
} else {
|
||||||
|
log_returns.push(0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log_returns.insert(0, 0.0);
|
||||||
|
log_returns
|
||||||
|
} else {
|
||||||
|
sequence
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_lstm(&self, sequence: &[f64]) -> f64 {
|
||||||
|
let Some(session) = &self.lstm_session else {
|
||||||
|
return 0.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = Array2::from_shape_vec((1, sequence.len()), sequence.to_vec())
|
||||||
|
.expect("invalid shape");
|
||||||
|
|
||||||
|
match session.run(ort::inputs!["input" => input.view()]) {
|
||||||
|
Ok(outputs) => {
|
||||||
|
if let Some(output) = outputs.get("output") {
|
||||||
|
if let Ok(tensor) = output.try_extract_tensor::<f32>() {
|
||||||
|
return tensor.view().iter().next().copied().unwrap_or(0.0) as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
Err(_) => 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict_mlp(&self, features: &[f64]) -> f64 {
|
||||||
|
let Some(session) = &self.mlp_session else {
|
||||||
|
return 0.0;
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = Array1::from_vec(features.to_vec());
|
||||||
|
let input_2d = input.insert_axis(ndarray::Axis(0));
|
||||||
|
|
||||||
|
match session.run(ort::inputs!["input" => input_2d.view()]) {
|
||||||
|
Ok(outputs) => {
|
||||||
|
if let Some(output) = outputs.get("output") {
|
||||||
|
if let Ok(tensor) = output.try_extract_tensor::<f32>() {
|
||||||
|
return tensor.view().iter().next().copied().unwrap_or(0.0) as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
Err(_) => 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn predict(&self, candidate: &MarketCandidate) -> f64 {
|
||||||
|
let sequence = Self::extract_price_sequence(candidate, 24);
|
||||||
|
let features = Self::extract_features(candidate);
|
||||||
|
|
||||||
|
let lstm_pred = self.predict_lstm(&sequence);
|
||||||
|
let mlp_pred = self.predict_mlp(&features);
|
||||||
|
|
||||||
|
let mut ensemble = 0.0;
|
||||||
|
let mut total_weight = 0.0;
|
||||||
|
|
||||||
|
if self.lstm_session.is_some() && self.ensemble_weights.len() > 0 {
|
||||||
|
ensemble += lstm_pred * self.ensemble_weights[0];
|
||||||
|
total_weight += self.ensemble_weights[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.mlp_session.is_some() && self.ensemble_weights.len() > 1 {
|
||||||
|
ensemble += mlp_pred * self.ensemble_weights[1];
|
||||||
|
total_weight += self.ensemble_weights[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
if total_weight > 0.0 {
|
||||||
|
ensemble / total_weight
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "ml")]
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for MLEnsembleScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"MLEnsembleScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let ml_score = self.predict(c);
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("ml_ensemble".to_string(), ml_score);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
if let Some(score) = scored.scores.get("ml_ensemble") {
|
||||||
|
candidate.scores.insert("ml_ensemble".to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// stub scorer when ML feature is disabled
|
||||||
|
#[cfg(not(feature = "ml"))]
|
||||||
|
pub struct MLEnsembleScorer;
|
||||||
|
|
||||||
|
#[cfg(not(feature = "ml"))]
|
||||||
|
impl MLEnsembleScorer {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_models(_model_dir: &Path) -> Result<Self, String> {
|
||||||
|
Ok(Self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_weights(self, _weights: Vec<f64>) -> Self {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "ml"))]
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for MLEnsembleScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"MLEnsembleScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
Ok(candidates.iter().map(|c| MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
}).collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, _candidate: &mut MarketCandidate, _scored: MarketCandidate) {}
|
||||||
|
}
|
||||||
230
src/pipeline/mod.rs
Normal file
230
src/pipeline/mod.rs
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
mod correlation_scorer;
|
||||||
|
mod filters;
|
||||||
|
mod ml_scorer;
|
||||||
|
mod scorers;
|
||||||
|
mod selector;
|
||||||
|
mod sources;
|
||||||
|
|
||||||
|
pub use correlation_scorer::*;
|
||||||
|
pub use filters::*;
|
||||||
|
pub use ml_scorer::*;
|
||||||
|
pub use scorers::*;
|
||||||
|
pub use selector::*;
|
||||||
|
pub use sources::*;
|
||||||
|
|
||||||
|
use crate::types::{MarketCandidate, TradingContext};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::{error, info};
|
||||||
|
|
||||||
|
pub struct PipelineResult {
|
||||||
|
pub retrieved_candidates: Vec<MarketCandidate>,
|
||||||
|
pub filtered_candidates: Vec<MarketCandidate>,
|
||||||
|
pub selected_candidates: Vec<MarketCandidate>,
|
||||||
|
pub context: Arc<TradingContext>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FilterResult {
|
||||||
|
pub kept: Vec<MarketCandidate>,
|
||||||
|
pub removed: Vec<MarketCandidate>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Source: Send + Sync {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
fn enable(&self, _context: &TradingContext) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
async fn get_candidates(&self, context: &TradingContext) -> Result<Vec<MarketCandidate>, String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Filter: Send + Sync {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
fn enable(&self, _context: &TradingContext) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: Vec<MarketCandidate>,
|
||||||
|
) -> Result<FilterResult, String>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Scorer: Send + Sync {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
fn enable(&self, _context: &TradingContext) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String>;
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate);
|
||||||
|
|
||||||
|
fn update_all(&self, candidates: &mut [MarketCandidate], scored: Vec<MarketCandidate>) {
|
||||||
|
for (c, s) in candidates.iter_mut().zip(scored) {
|
||||||
|
self.update(c, s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Selector: Send + Sync {
|
||||||
|
fn name(&self) -> &'static str;
|
||||||
|
fn enable(&self, _context: &TradingContext) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn select(&self, context: &TradingContext, candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TradingPipeline {
|
||||||
|
sources: Vec<Box<dyn Source>>,
|
||||||
|
filters: Vec<Box<dyn Filter>>,
|
||||||
|
scorers: Vec<Box<dyn Scorer>>,
|
||||||
|
selector: Box<dyn Selector>,
|
||||||
|
result_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TradingPipeline {
|
||||||
|
pub fn new(
|
||||||
|
sources: Vec<Box<dyn Source>>,
|
||||||
|
filters: Vec<Box<dyn Filter>>,
|
||||||
|
scorers: Vec<Box<dyn Scorer>>,
|
||||||
|
selector: Box<dyn Selector>,
|
||||||
|
result_size: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
sources,
|
||||||
|
filters,
|
||||||
|
scorers,
|
||||||
|
selector,
|
||||||
|
result_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn execute(&self, context: TradingContext) -> PipelineResult {
|
||||||
|
let request_id = context.request_id().to_string();
|
||||||
|
|
||||||
|
let candidates = self.fetch_candidates(&context).await;
|
||||||
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
candidates = candidates.len(),
|
||||||
|
"fetched candidates"
|
||||||
|
);
|
||||||
|
|
||||||
|
let (kept, filtered) = self.filter(&context, candidates.clone()).await;
|
||||||
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
kept = kept.len(),
|
||||||
|
filtered = filtered.len(),
|
||||||
|
"filtered candidates"
|
||||||
|
);
|
||||||
|
|
||||||
|
let scored = self.score(&context, kept).await;
|
||||||
|
|
||||||
|
let mut selected = self.select(&context, scored);
|
||||||
|
selected.truncate(self.result_size);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
request_id = %request_id,
|
||||||
|
selected = selected.len(),
|
||||||
|
"selected candidates"
|
||||||
|
);
|
||||||
|
|
||||||
|
PipelineResult {
|
||||||
|
retrieved_candidates: candidates,
|
||||||
|
filtered_candidates: filtered,
|
||||||
|
selected_candidates: selected,
|
||||||
|
context: Arc::new(context),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_candidates(&self, context: &TradingContext) -> Vec<MarketCandidate> {
|
||||||
|
let mut all_candidates = Vec::new();
|
||||||
|
|
||||||
|
for source in self.sources.iter().filter(|s| s.enable(context)) {
|
||||||
|
match source.get_candidates(context).await {
|
||||||
|
Ok(mut candidates) => {
|
||||||
|
info!(
|
||||||
|
source = source.name(),
|
||||||
|
count = candidates.len(),
|
||||||
|
"source returned candidates"
|
||||||
|
);
|
||||||
|
all_candidates.append(&mut candidates);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(source = source.name(), error = %e, "source failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
all_candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn filter(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
mut candidates: Vec<MarketCandidate>,
|
||||||
|
) -> (Vec<MarketCandidate>, Vec<MarketCandidate>) {
|
||||||
|
let mut all_removed = Vec::new();
|
||||||
|
|
||||||
|
for filter in self.filters.iter().filter(|f| f.enable(context)) {
|
||||||
|
let backup = candidates.clone();
|
||||||
|
match filter.filter(context, candidates).await {
|
||||||
|
Ok(result) => {
|
||||||
|
info!(
|
||||||
|
filter = filter.name(),
|
||||||
|
kept = result.kept.len(),
|
||||||
|
removed = result.removed.len(),
|
||||||
|
"filter applied"
|
||||||
|
);
|
||||||
|
candidates = result.kept;
|
||||||
|
all_removed.extend(result.removed);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(filter = filter.name(), error = %e, "filter failed");
|
||||||
|
candidates = backup;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(candidates, all_removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(&self, context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
|
||||||
|
let expected_len = candidates.len();
|
||||||
|
|
||||||
|
for scorer in self.scorers.iter().filter(|s| s.enable(context)) {
|
||||||
|
match scorer.score(context, &candidates).await {
|
||||||
|
Ok(scored) => {
|
||||||
|
if scored.len() == expected_len {
|
||||||
|
scorer.update_all(&mut candidates, scored);
|
||||||
|
} else {
|
||||||
|
error!(
|
||||||
|
scorer = scorer.name(),
|
||||||
|
expected = expected_len,
|
||||||
|
got = scored.len(),
|
||||||
|
"scorer returned wrong number of candidates"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!(scorer = scorer.name(), error = %e, "scorer failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
fn select(&self, context: &TradingContext, candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
|
||||||
|
if self.selector.enable(context) {
|
||||||
|
self.selector.select(context, candidates)
|
||||||
|
} else {
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
978
src/pipeline/scorers.rs
Normal file
978
src/pipeline/scorers.rs
Normal file
@ -0,0 +1,978 @@
|
|||||||
|
use crate::pipeline::Scorer;
|
||||||
|
use crate::types::{MarketCandidate, TradingContext};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use rust_decimal::prelude::ToPrimitive;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
/// rolling statistics for z-score normalization
|
||||||
|
/// tracks mean and std deviation over a sliding window
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RollingStats {
|
||||||
|
values: Vec<f64>,
|
||||||
|
max_size: usize,
|
||||||
|
sum: f64,
|
||||||
|
sum_sq: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RollingStats {
|
||||||
|
pub fn new(max_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
values: Vec::with_capacity(max_size),
|
||||||
|
max_size,
|
||||||
|
sum: 0.0,
|
||||||
|
sum_sq: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, value: f64) {
|
||||||
|
if !value.is_finite() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.values.len() >= self.max_size {
|
||||||
|
let old = self.values.remove(0);
|
||||||
|
self.sum -= old;
|
||||||
|
self.sum_sq -= old * old;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.values.push(value);
|
||||||
|
self.sum += value;
|
||||||
|
self.sum_sq += value * value;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push_batch(&mut self, values: &[f64]) {
|
||||||
|
for &v in values {
|
||||||
|
self.push(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mean(&self) -> f64 {
|
||||||
|
if self.values.is_empty() {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
self.sum / self.values.len() as f64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn std(&self) -> f64 {
|
||||||
|
let n = self.values.len();
|
||||||
|
if n < 2 {
|
||||||
|
return 1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mean = self.mean();
|
||||||
|
let variance = (self.sum_sq / n as f64) - (mean * mean);
|
||||||
|
variance.max(0.0).sqrt()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize(&self, value: f64) -> f64 {
|
||||||
|
let std = self.std().max(0.001);
|
||||||
|
(value - self.mean()) / std
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.values.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_ready(&self) -> bool {
|
||||||
|
self.values.len() >= self.max_size / 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// wrapper that normalizes any scorer's output to z-scores
|
||||||
|
pub struct NormalizedScorer<S> {
|
||||||
|
inner: S,
|
||||||
|
score_key: String,
|
||||||
|
stats: Arc<Mutex<RollingStats>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> NormalizedScorer<S> {
|
||||||
|
pub fn new(inner: S, score_key: &str, history_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
score_key: score_key.to_string(),
|
||||||
|
stats: Arc::new(Mutex::new(RollingStats::new(history_size))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl<S: Scorer + Send + Sync> Scorer for NormalizedScorer<S> {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
self.inner.name()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let raw_scored = self.inner.score(context, candidates).await?;
|
||||||
|
|
||||||
|
let raw_scores: Vec<f64> = raw_scored
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| c.scores.get(&self.score_key).copied())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut stats = self.stats.lock().unwrap();
|
||||||
|
stats.push_batch(&raw_scores);
|
||||||
|
}
|
||||||
|
|
||||||
|
let stats = self.stats.lock().unwrap();
|
||||||
|
let normalized = raw_scored
|
||||||
|
.into_iter()
|
||||||
|
.map(|mut c| {
|
||||||
|
if let Some(&raw) = c.scores.get(&self.score_key) {
|
||||||
|
let z = if stats.is_ready() {
|
||||||
|
stats.normalize(raw)
|
||||||
|
} else {
|
||||||
|
raw
|
||||||
|
};
|
||||||
|
c.scores.insert(self.score_key.clone(), z);
|
||||||
|
}
|
||||||
|
c
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(normalized)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
self.inner.update(candidate, scored);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MomentumScorer {
|
||||||
|
lookback_hours: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MomentumScorer {
|
||||||
|
pub fn new(lookback_hours: i64) -> Self {
|
||||||
|
Self { lookback_hours }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_momentum(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
|
||||||
|
let lookback_start = now - chrono::Duration::hours(lookback_hours);
|
||||||
|
let relevant_history: Vec<_> = candidate
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
|
.filter(|p| p.timestamp >= lookback_start)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if relevant_history.len() < 2 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let first = relevant_history.first().unwrap().yes_price.to_f64().unwrap_or(0.5);
|
||||||
|
let last = relevant_history.last().unwrap().yes_price.to_f64().unwrap_or(0.5);
|
||||||
|
|
||||||
|
last - first
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for MomentumScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"MomentumScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let momentum = Self::calculate_momentum(c, context.timestamp, self.lookback_hours);
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("momentum".to_string(), momentum);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
if let Some(score) = scored.scores.get("momentum") {
|
||||||
|
candidate.scores.insert("momentum".to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// multi-timeframe momentum scorer
|
||||||
|
/// looks at multiple windows and detects divergence between short and long term
|
||||||
|
pub struct MultiTimeframeMomentumScorer {
|
||||||
|
windows: Vec<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiTimeframeMomentumScorer {
|
||||||
|
pub fn new(windows: Vec<i64>) -> Self {
|
||||||
|
Self { windows }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_windows() -> Self {
|
||||||
|
Self::new(vec![1, 4, 12, 24])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_momentum_for_window(
|
||||||
|
candidate: &MarketCandidate,
|
||||||
|
now: chrono::DateTime<chrono::Utc>,
|
||||||
|
hours: i64,
|
||||||
|
) -> f64 {
|
||||||
|
let lookback_start = now - chrono::Duration::hours(hours);
|
||||||
|
let relevant_history: Vec<_> = candidate
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
|
.filter(|p| p.timestamp >= lookback_start)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if relevant_history.len() < 2 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let first = relevant_history.first().unwrap().yes_price.to_f64().unwrap_or(0.5);
|
||||||
|
let last = relevant_history.last().unwrap().yes_price.to_f64().unwrap_or(0.5);
|
||||||
|
|
||||||
|
last - first
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_score(
|
||||||
|
candidate: &MarketCandidate,
|
||||||
|
now: chrono::DateTime<chrono::Utc>,
|
||||||
|
windows: &[i64],
|
||||||
|
) -> (f64, f64, f64) {
|
||||||
|
let momentums: Vec<f64> = windows
|
||||||
|
.iter()
|
||||||
|
.map(|&w| Self::calculate_momentum_for_window(candidate, now, w))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if momentums.is_empty() {
|
||||||
|
return (0.0, 0.0, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let avg_momentum = momentums.iter().sum::<f64>() / momentums.len() as f64;
|
||||||
|
|
||||||
|
let signs: Vec<i32> = momentums.iter().map(|&m| if m > 0.0 { 1 } else if m < 0.0 { -1 } else { 0 }).collect();
|
||||||
|
let all_same_sign = signs.iter().all(|&s| s == signs[0]) && signs[0] != 0;
|
||||||
|
let alignment = if all_same_sign { 1.0 } else { 0.5 };
|
||||||
|
|
||||||
|
let short_avg = if momentums.len() >= 2 {
|
||||||
|
momentums[..momentums.len() / 2].iter().sum::<f64>() / (momentums.len() / 2) as f64
|
||||||
|
} else {
|
||||||
|
momentums[0]
|
||||||
|
};
|
||||||
|
let long_avg = if momentums.len() >= 2 {
|
||||||
|
momentums[momentums.len() / 2..].iter().sum::<f64>() / (momentums.len() - momentums.len() / 2) as f64
|
||||||
|
} else {
|
||||||
|
momentums[0]
|
||||||
|
};
|
||||||
|
|
||||||
|
let divergence = if (short_avg > 0.0 && long_avg < 0.0) || (short_avg < 0.0 && long_avg > 0.0) {
|
||||||
|
(short_avg - long_avg).abs()
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
};
|
||||||
|
|
||||||
|
(avg_momentum * alignment, divergence, alignment)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for MultiTimeframeMomentumScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"MultiTimeframeMomentumScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let (momentum, divergence, alignment) = Self::calculate_score(c, context.timestamp, &self.windows);
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("mtf_momentum".to_string(), momentum);
|
||||||
|
scored.scores.insert("mtf_divergence".to_string(), divergence);
|
||||||
|
scored.scores.insert("mtf_alignment".to_string(), alignment);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
for key in ["mtf_momentum", "mtf_divergence", "mtf_alignment"] {
|
||||||
|
if let Some(score) = scored.scores.get(key) {
|
||||||
|
candidate.scores.insert(key.to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct MeanReversionScorer {
|
||||||
|
lookback_hours: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MeanReversionScorer {
|
||||||
|
pub fn new(lookback_hours: i64) -> Self {
|
||||||
|
Self { lookback_hours }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_deviation(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
|
||||||
|
let lookback_start = now - chrono::Duration::hours(lookback_hours);
|
||||||
|
let prices: Vec<f64> = candidate
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
|
.filter(|p| p.timestamp >= lookback_start)
|
||||||
|
.filter_map(|p| p.yes_price.to_f64())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if prices.is_empty() {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mean: f64 = prices.iter().sum::<f64>() / prices.len() as f64;
|
||||||
|
let current = candidate.current_yes_price.to_f64().unwrap_or(0.5);
|
||||||
|
let deviation = current - mean;
|
||||||
|
|
||||||
|
-deviation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for MeanReversionScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"MeanReversionScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let reversion = Self::calculate_deviation(c, context.timestamp, self.lookback_hours);
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("mean_reversion".to_string(), reversion);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
if let Some(score) = scored.scores.get("mean_reversion") {
|
||||||
|
candidate.scores.insert("mean_reversion".to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// bollinger bands mean reversion scorer
|
||||||
|
/// triggers when price touches statistical extremes (upper/lower bands)
|
||||||
|
pub struct BollingerMeanReversionScorer {
|
||||||
|
lookback_hours: i64,
|
||||||
|
num_std: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BollingerMeanReversionScorer {
|
||||||
|
pub fn new(lookback_hours: i64, num_std: f64) -> Self {
|
||||||
|
Self { lookback_hours, num_std }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_config() -> Self {
|
||||||
|
Self::new(24, 2.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_bands(
|
||||||
|
candidate: &MarketCandidate,
|
||||||
|
now: chrono::DateTime<chrono::Utc>,
|
||||||
|
lookback_hours: i64,
|
||||||
|
) -> Option<(f64, f64, f64)> {
|
||||||
|
let lookback_start = now - chrono::Duration::hours(lookback_hours);
|
||||||
|
let prices: Vec<f64> = candidate
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
|
.filter(|p| p.timestamp >= lookback_start)
|
||||||
|
.filter_map(|p| p.yes_price.to_f64())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if prices.len() < 5 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mean: f64 = prices.iter().sum::<f64>() / prices.len() as f64;
|
||||||
|
let variance: f64 = prices.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / prices.len() as f64;
|
||||||
|
let std = variance.sqrt();
|
||||||
|
|
||||||
|
Some((mean, std, *prices.last().unwrap_or(&mean)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_score(
|
||||||
|
candidate: &MarketCandidate,
|
||||||
|
now: chrono::DateTime<chrono::Utc>,
|
||||||
|
lookback_hours: i64,
|
||||||
|
num_std: f64,
|
||||||
|
) -> (f64, f64) {
|
||||||
|
let (mean, std, current) = match Self::calculate_bands(candidate, now, lookback_hours) {
|
||||||
|
Some(v) => v,
|
||||||
|
None => return (0.0, 0.0),
|
||||||
|
};
|
||||||
|
|
||||||
|
let upper_band = mean + num_std * std;
|
||||||
|
let lower_band = mean - num_std * std;
|
||||||
|
let band_width = upper_band - lower_band;
|
||||||
|
|
||||||
|
if band_width < 0.001 {
|
||||||
|
return (0.0, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let position = (current - lower_band) / band_width;
|
||||||
|
|
||||||
|
let score = if current >= upper_band {
|
||||||
|
-(current - upper_band) / std.max(0.001)
|
||||||
|
} else if current <= lower_band {
|
||||||
|
(lower_band - current) / std.max(0.001)
|
||||||
|
} else if current > mean {
|
||||||
|
-(position - 0.5) * 0.5
|
||||||
|
} else {
|
||||||
|
(0.5 - position) * 0.5
|
||||||
|
};
|
||||||
|
|
||||||
|
let band_position = (position * 2.0 - 1.0).clamp(-1.0, 1.0);
|
||||||
|
|
||||||
|
(score.clamp(-2.0, 2.0), band_position)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for BollingerMeanReversionScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"BollingerMeanReversionScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let (score, band_pos) = Self::calculate_score(c, context.timestamp, self.lookback_hours, self.num_std);
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("bollinger_reversion".to_string(), score);
|
||||||
|
scored.scores.insert("bollinger_position".to_string(), band_pos);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
for key in ["bollinger_reversion", "bollinger_position"] {
|
||||||
|
if let Some(score) = scored.scores.get(key) {
|
||||||
|
candidate.scores.insert(key.to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct VolumeScorer {
|
||||||
|
lookback_hours: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VolumeScorer {
|
||||||
|
pub fn new(lookback_hours: i64) -> Self {
|
||||||
|
Self { lookback_hours }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_volume_score(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
|
||||||
|
let lookback_start = now - chrono::Duration::hours(lookback_hours);
|
||||||
|
let recent_volume: u64 = candidate
|
||||||
|
.price_history
|
||||||
|
.iter()
|
||||||
|
.filter(|p| p.timestamp >= lookback_start)
|
||||||
|
.map(|p| p.volume)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
if candidate.total_volume == 0 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let avg_hourly_volume = candidate.total_volume as f64
|
||||||
|
/ ((now - candidate.open_time).num_hours().max(1) as f64);
|
||||||
|
let recent_hourly_volume = recent_volume as f64 / lookback_hours.max(1) as f64;
|
||||||
|
|
||||||
|
if avg_hourly_volume > 0.0 {
|
||||||
|
(recent_hourly_volume / avg_hourly_volume).ln().max(-2.0).min(2.0)
|
||||||
|
} else {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for VolumeScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"VolumeScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let volume = Self::calculate_volume_score(c, context.timestamp, self.lookback_hours);
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("volume".to_string(), volume);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
if let Some(score) = scored.scores.get("volume") {
|
||||||
|
candidate.scores.insert("volume".to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TimeDecayScorer;
|
||||||
|
|
||||||
|
impl TimeDecayScorer {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_time_decay(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>) -> f64 {
|
||||||
|
let ttc = candidate.time_to_close(now);
|
||||||
|
let hours_remaining = ttc.num_hours() as f64;
|
||||||
|
|
||||||
|
if hours_remaining <= 0.0 {
|
||||||
|
return -1.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let decay = 1.0 - (1.0 / (hours_remaining / 24.0 + 1.0));
|
||||||
|
decay.min(1.0).max(0.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for TimeDecayScorer {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// order flow imbalance scorer
|
||||||
|
/// measures buying vs selling pressure using taker_side from trades
|
||||||
|
pub struct OrderFlowScorer;
|
||||||
|
|
||||||
|
impl OrderFlowScorer {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calculate_imbalance(candidate: &MarketCandidate) -> f64 {
|
||||||
|
let buy_vol = candidate.buy_volume_24h as f64;
|
||||||
|
let sell_vol = candidate.sell_volume_24h as f64;
|
||||||
|
let total = buy_vol + sell_vol;
|
||||||
|
|
||||||
|
if total == 0.0 {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
(buy_vol - sell_vol) / total
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for OrderFlowScorer {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for OrderFlowScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"OrderFlowScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let imbalance = Self::calculate_imbalance(c);
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("order_flow".to_string(), imbalance);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
if let Some(score) = scored.scores.get("order_flow") {
|
||||||
|
candidate.scores.insert("order_flow".to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for TimeDecayScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"TimeDecayScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let time_decay = Self::calculate_time_decay(c, context.timestamp);
|
||||||
|
let mut scored = MarketCandidate {
|
||||||
|
scores: c.scores.clone(),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
scored.scores.insert("time_decay".to_string(), time_decay);
|
||||||
|
scored
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
if let Some(score) = scored.scores.get("time_decay") {
|
||||||
|
candidate.scores.insert("time_decay".to_string(), *score);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WeightedScorer {
|
||||||
|
weights: Vec<(String, f64)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WeightedScorer {
|
||||||
|
pub fn new(weights: Vec<(String, f64)>) -> Self {
|
||||||
|
Self { weights }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_weights() -> Self {
|
||||||
|
Self::new(vec![
|
||||||
|
("momentum".to_string(), 0.4),
|
||||||
|
("mean_reversion".to_string(), 0.3),
|
||||||
|
("volume".to_string(), 0.2),
|
||||||
|
("time_decay".to_string(), 0.1),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_weighted_score(&self, candidate: &MarketCandidate) -> f64 {
|
||||||
|
self.weights
|
||||||
|
.iter()
|
||||||
|
.map(|(name, weight)| {
|
||||||
|
candidate.scores.get(name).copied().unwrap_or(0.0) * weight
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for WeightedScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"WeightedScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let weighted_score = self.compute_weighted_score(c);
|
||||||
|
MarketCandidate {
|
||||||
|
final_score: weighted_score,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
candidate.final_score = scored.final_score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct ScorerWeights {
|
||||||
|
pub momentum: f64,
|
||||||
|
pub mean_reversion: f64,
|
||||||
|
pub volume: f64,
|
||||||
|
pub time_decay: f64,
|
||||||
|
pub order_flow: f64,
|
||||||
|
pub bollinger: f64,
|
||||||
|
pub mtf_momentum: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ScorerWeights {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
momentum: 0.2,
|
||||||
|
mean_reversion: 0.2,
|
||||||
|
volume: 0.15,
|
||||||
|
time_decay: 0.1,
|
||||||
|
order_flow: 0.15,
|
||||||
|
bollinger: 0.1,
|
||||||
|
mtf_momentum: 0.1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScorerWeights {
|
||||||
|
pub fn politics() -> Self {
|
||||||
|
Self {
|
||||||
|
momentum: 0.35,
|
||||||
|
mean_reversion: 0.1,
|
||||||
|
volume: 0.1,
|
||||||
|
time_decay: 0.1,
|
||||||
|
order_flow: 0.15,
|
||||||
|
bollinger: 0.05,
|
||||||
|
mtf_momentum: 0.15,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn weather() -> Self {
|
||||||
|
Self {
|
||||||
|
momentum: 0.1,
|
||||||
|
mean_reversion: 0.35,
|
||||||
|
volume: 0.1,
|
||||||
|
time_decay: 0.15,
|
||||||
|
order_flow: 0.1,
|
||||||
|
bollinger: 0.15,
|
||||||
|
mtf_momentum: 0.05,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn sports() -> Self {
|
||||||
|
Self {
|
||||||
|
momentum: 0.2,
|
||||||
|
mean_reversion: 0.1,
|
||||||
|
volume: 0.15,
|
||||||
|
time_decay: 0.1,
|
||||||
|
order_flow: 0.3,
|
||||||
|
bollinger: 0.05,
|
||||||
|
mtf_momentum: 0.1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn economics() -> Self {
|
||||||
|
Self {
|
||||||
|
momentum: 0.25,
|
||||||
|
mean_reversion: 0.2,
|
||||||
|
volume: 0.15,
|
||||||
|
time_decay: 0.1,
|
||||||
|
order_flow: 0.15,
|
||||||
|
bollinger: 0.1,
|
||||||
|
mtf_momentum: 0.05,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_score(&self, candidate: &MarketCandidate) -> f64 {
|
||||||
|
let get_score = |key: &str| candidate.scores.get(key).copied().unwrap_or(0.0);
|
||||||
|
|
||||||
|
self.momentum * get_score("momentum")
|
||||||
|
+ self.mean_reversion * get_score("mean_reversion")
|
||||||
|
+ self.volume * get_score("volume")
|
||||||
|
+ self.time_decay * get_score("time_decay")
|
||||||
|
+ self.order_flow * get_score("order_flow")
|
||||||
|
+ self.bollinger * get_score("bollinger_reversion")
|
||||||
|
+ self.mtf_momentum * get_score("mtf_momentum")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// category-aware weighted scorer
|
||||||
|
/// applies different weights based on market category
|
||||||
|
pub struct CategoryWeightedScorer {
|
||||||
|
category_weights: std::collections::HashMap<String, ScorerWeights>,
|
||||||
|
default_weights: ScorerWeights,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CategoryWeightedScorer {
|
||||||
|
pub fn new(
|
||||||
|
category_weights: std::collections::HashMap<String, ScorerWeights>,
|
||||||
|
default_weights: ScorerWeights,
|
||||||
|
) -> Self {
|
||||||
|
Self { category_weights, default_weights }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_defaults() -> Self {
|
||||||
|
let mut category_weights = std::collections::HashMap::new();
|
||||||
|
category_weights.insert("politics".to_string(), ScorerWeights::politics());
|
||||||
|
category_weights.insert("weather".to_string(), ScorerWeights::weather());
|
||||||
|
category_weights.insert("sports".to_string(), ScorerWeights::sports());
|
||||||
|
category_weights.insert("economics".to_string(), ScorerWeights::economics());
|
||||||
|
category_weights.insert("financial".to_string(), ScorerWeights::economics());
|
||||||
|
|
||||||
|
Self {
|
||||||
|
category_weights,
|
||||||
|
default_weights: ScorerWeights::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_weights(&self, category: &str) -> &ScorerWeights {
|
||||||
|
let lower = category.to_lowercase();
|
||||||
|
self.category_weights.get(&lower).unwrap_or(&self.default_weights)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for CategoryWeightedScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"CategoryWeightedScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let weights = self.get_weights(&c.category);
|
||||||
|
let weighted_score = weights.compute_score(c);
|
||||||
|
MarketCandidate {
|
||||||
|
final_score: weighted_score,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
candidate.final_score = scored.final_score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// ensemble scorer that combines multiple models with dynamic weighting
|
||||||
|
/// weights can be updated based on recent accuracy
|
||||||
|
pub struct EnsembleScorer {
|
||||||
|
model_weights: std::sync::Arc<std::sync::Mutex<Vec<f64>>>,
|
||||||
|
model_keys: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EnsembleScorer {
|
||||||
|
pub fn new(model_keys: Vec<String>, initial_weights: Vec<f64>) -> Self {
|
||||||
|
assert_eq!(model_keys.len(), initial_weights.len());
|
||||||
|
Self {
|
||||||
|
model_weights: std::sync::Arc::new(std::sync::Mutex::new(initial_weights)),
|
||||||
|
model_keys,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default_ensemble() -> Self {
|
||||||
|
Self::new(
|
||||||
|
vec![
|
||||||
|
"momentum".to_string(),
|
||||||
|
"mean_reversion".to_string(),
|
||||||
|
"bollinger_reversion".to_string(),
|
||||||
|
"order_flow".to_string(),
|
||||||
|
"mtf_momentum".to_string(),
|
||||||
|
],
|
||||||
|
vec![0.25, 0.2, 0.2, 0.2, 0.15],
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_weights(&self, new_weights: Vec<f64>) {
|
||||||
|
let mut weights = self.model_weights.lock().unwrap();
|
||||||
|
*weights = new_weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compute_score(&self, candidate: &MarketCandidate) -> f64 {
|
||||||
|
let weights = self.model_weights.lock().unwrap();
|
||||||
|
self.model_keys
|
||||||
|
.iter()
|
||||||
|
.zip(weights.iter())
|
||||||
|
.map(|(key, weight)| {
|
||||||
|
candidate.scores.get(key).copied().unwrap_or(0.0) * weight
|
||||||
|
})
|
||||||
|
.sum()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Scorer for EnsembleScorer {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"EnsembleScorer"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn score(
|
||||||
|
&self,
|
||||||
|
_context: &TradingContext,
|
||||||
|
candidates: &[MarketCandidate],
|
||||||
|
) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let scored = candidates
|
||||||
|
.iter()
|
||||||
|
.map(|c| {
|
||||||
|
let ensemble_score = self.compute_score(c);
|
||||||
|
MarketCandidate {
|
||||||
|
final_score: ensemble_score,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
|
||||||
|
candidate.final_score = scored.final_score;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
64
src/pipeline/selector.rs
Normal file
64
src/pipeline/selector.rs
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
use crate::pipeline::Selector;
|
||||||
|
use crate::types::{MarketCandidate, TradingContext};
|
||||||
|
|
||||||
|
pub struct TopKSelector {
|
||||||
|
k: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TopKSelector {
|
||||||
|
pub fn new(k: usize) -> Self {
|
||||||
|
Self { k }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Selector for TopKSelector {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"TopKSelector"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn select(&self, _context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
|
||||||
|
candidates.sort_by(|a, b| {
|
||||||
|
b.final_score
|
||||||
|
.partial_cmp(&a.final_score)
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
candidates.truncate(self.k);
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ThresholdSelector {
|
||||||
|
min_score: f64,
|
||||||
|
max_candidates: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThresholdSelector {
|
||||||
|
pub fn new(min_score: f64, max_candidates: Option<usize>) -> Self {
|
||||||
|
Self {
|
||||||
|
min_score,
|
||||||
|
max_candidates,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Selector for ThresholdSelector {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"ThresholdSelector"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn select(&self, _context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
|
||||||
|
candidates.retain(|c| c.final_score >= self.min_score);
|
||||||
|
candidates.sort_by(|a, b| {
|
||||||
|
b.final_score
|
||||||
|
.partial_cmp(&a.final_score)
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(max) = self.max_candidates {
|
||||||
|
candidates.truncate(max);
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates
|
||||||
|
}
|
||||||
|
}
|
||||||
69
src/pipeline/sources.rs
Normal file
69
src/pipeline/sources.rs
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
use crate::data::HistoricalData;
|
||||||
|
use crate::pipeline::Source;
|
||||||
|
use crate::types::{MarketCandidate, TradingContext};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use rust_decimal::Decimal;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub struct HistoricalMarketSource {
|
||||||
|
data: Arc<HistoricalData>,
|
||||||
|
lookback_hours: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HistoricalMarketSource {
|
||||||
|
pub fn new(data: Arc<HistoricalData>, lookback_hours: i64) -> Self {
|
||||||
|
Self { data, lookback_hours }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Source for HistoricalMarketSource {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"HistoricalMarketSource"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_candidates(&self, context: &TradingContext) -> Result<Vec<MarketCandidate>, String> {
|
||||||
|
let now = context.timestamp;
|
||||||
|
let active_markets = self.data.get_active_markets(now);
|
||||||
|
|
||||||
|
let candidates: Vec<MarketCandidate> = active_markets
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|market| {
|
||||||
|
let current_price = self.data.get_current_price(&market.ticker, now)?;
|
||||||
|
let lookback_start = now - chrono::Duration::hours(self.lookback_hours);
|
||||||
|
let price_history = self.data.get_price_history(&market.ticker, lookback_start, now);
|
||||||
|
let volume_24h = self.data.get_volume_24h(&market.ticker, now);
|
||||||
|
|
||||||
|
let total_volume: u64 = self
|
||||||
|
.data
|
||||||
|
.get_trades_for_market(&market.ticker, market.open_time, now)
|
||||||
|
.iter()
|
||||||
|
.map(|t| t.volume)
|
||||||
|
.sum();
|
||||||
|
|
||||||
|
let (buy_volume_24h, sell_volume_24h) = self.data.get_order_flow_24h(&market.ticker, now);
|
||||||
|
|
||||||
|
Some(MarketCandidate {
|
||||||
|
ticker: market.ticker.clone(),
|
||||||
|
title: market.title.clone(),
|
||||||
|
category: market.category.clone(),
|
||||||
|
current_yes_price: current_price,
|
||||||
|
current_no_price: Decimal::ONE - current_price,
|
||||||
|
volume_24h,
|
||||||
|
total_volume,
|
||||||
|
buy_volume_24h,
|
||||||
|
sell_volume_24h,
|
||||||
|
open_time: market.open_time,
|
||||||
|
close_time: market.close_time,
|
||||||
|
result: market.result,
|
||||||
|
price_history,
|
||||||
|
scores: HashMap::new(),
|
||||||
|
final_score: 0.0,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(candidates)
|
||||||
|
}
|
||||||
|
}
|
||||||
343
src/types.rs
Normal file
343
src/types.rs
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use rust_decimal::Decimal;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum Side {
|
||||||
|
Yes,
|
||||||
|
No,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Side {
|
||||||
|
pub fn opposite(&self) -> Self {
|
||||||
|
match self {
|
||||||
|
Side::Yes => Side::No,
|
||||||
|
Side::No => Side::Yes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum MarketResult {
|
||||||
|
Yes,
|
||||||
|
No,
|
||||||
|
Cancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MarketCandidate {
|
||||||
|
pub ticker: String,
|
||||||
|
pub title: String,
|
||||||
|
pub category: String,
|
||||||
|
pub current_yes_price: Decimal,
|
||||||
|
pub current_no_price: Decimal,
|
||||||
|
pub volume_24h: u64,
|
||||||
|
pub total_volume: u64,
|
||||||
|
pub buy_volume_24h: u64,
|
||||||
|
pub sell_volume_24h: u64,
|
||||||
|
pub open_time: DateTime<Utc>,
|
||||||
|
pub close_time: DateTime<Utc>,
|
||||||
|
pub result: Option<MarketResult>,
|
||||||
|
pub price_history: Vec<PricePoint>,
|
||||||
|
|
||||||
|
pub scores: HashMap<String, f64>,
|
||||||
|
pub final_score: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MarketCandidate {
|
||||||
|
pub fn time_to_close(&self, now: DateTime<Utc>) -> chrono::Duration {
|
||||||
|
self.close_time - now
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_open(&self, now: DateTime<Utc>) -> bool {
|
||||||
|
now >= self.open_time && now < self.close_time
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MarketCandidate {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
ticker: String::new(),
|
||||||
|
title: String::new(),
|
||||||
|
category: String::new(),
|
||||||
|
current_yes_price: Decimal::ZERO,
|
||||||
|
current_no_price: Decimal::ZERO,
|
||||||
|
volume_24h: 0,
|
||||||
|
total_volume: 0,
|
||||||
|
buy_volume_24h: 0,
|
||||||
|
sell_volume_24h: 0,
|
||||||
|
open_time: Utc::now(),
|
||||||
|
close_time: Utc::now(),
|
||||||
|
result: None,
|
||||||
|
price_history: Vec::new(),
|
||||||
|
scores: HashMap::new(),
|
||||||
|
final_score: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PricePoint {
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
pub yes_price: Decimal,
|
||||||
|
pub volume: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TradingContext {
|
||||||
|
pub request_id: String,
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
pub portfolio: Portfolio,
|
||||||
|
pub trading_history: Vec<Trade>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TradingContext {
|
||||||
|
pub fn new(initial_capital: Decimal, start_time: DateTime<Utc>) -> Self {
|
||||||
|
Self {
|
||||||
|
request_id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
timestamp: start_time,
|
||||||
|
portfolio: Portfolio::new(initial_capital),
|
||||||
|
trading_history: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn request_id(&self) -> &str {
|
||||||
|
&self.request_id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Portfolio {
|
||||||
|
pub positions: HashMap<String, Position>,
|
||||||
|
pub cash: Decimal,
|
||||||
|
pub initial_capital: Decimal,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Portfolio {
|
||||||
|
pub fn new(initial_capital: Decimal) -> Self {
|
||||||
|
Self {
|
||||||
|
positions: HashMap::new(),
|
||||||
|
cash: initial_capital,
|
||||||
|
initial_capital,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn total_value(&self, market_prices: &HashMap<String, Decimal>) -> Decimal {
|
||||||
|
let position_value: Decimal = self
|
||||||
|
.positions
|
||||||
|
.values()
|
||||||
|
.map(|p| {
|
||||||
|
let price = market_prices.get(&p.ticker).copied().unwrap_or(p.avg_entry_price);
|
||||||
|
let effective_price = match p.side {
|
||||||
|
Side::Yes => price,
|
||||||
|
Side::No => Decimal::ONE - price,
|
||||||
|
};
|
||||||
|
effective_price * Decimal::from(p.quantity)
|
||||||
|
})
|
||||||
|
.sum();
|
||||||
|
self.cash + position_value
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn has_position(&self, ticker: &str) -> bool {
|
||||||
|
self.positions.contains_key(ticker)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_position(&self, ticker: &str) -> Option<&Position> {
|
||||||
|
self.positions.get(ticker)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn apply_fill(&mut self, fill: &Fill) {
|
||||||
|
let cost = fill.price * Decimal::from(fill.quantity);
|
||||||
|
|
||||||
|
match fill.side {
|
||||||
|
Side::Yes | Side::No => {
|
||||||
|
self.cash -= cost;
|
||||||
|
let position = self.positions.entry(fill.ticker.clone()).or_insert_with(|| {
|
||||||
|
Position {
|
||||||
|
ticker: fill.ticker.clone(),
|
||||||
|
side: fill.side,
|
||||||
|
quantity: 0,
|
||||||
|
avg_entry_price: Decimal::ZERO,
|
||||||
|
entry_time: fill.timestamp,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let total_cost =
|
||||||
|
position.avg_entry_price * Decimal::from(position.quantity) + cost;
|
||||||
|
position.quantity += fill.quantity;
|
||||||
|
if position.quantity > 0 {
|
||||||
|
position.avg_entry_price = total_cost / Decimal::from(position.quantity);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn resolve_position(&mut self, ticker: &str, result: MarketResult) -> Option<Decimal> {
|
||||||
|
let position = self.positions.remove(ticker)?;
|
||||||
|
|
||||||
|
let payout = match (result, position.side) {
|
||||||
|
(MarketResult::Yes, Side::Yes) | (MarketResult::No, Side::No) => {
|
||||||
|
Decimal::from(position.quantity)
|
||||||
|
}
|
||||||
|
(MarketResult::Cancelled, _) => {
|
||||||
|
position.avg_entry_price * Decimal::from(position.quantity)
|
||||||
|
}
|
||||||
|
_ => Decimal::ZERO,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.cash += payout;
|
||||||
|
|
||||||
|
let cost = position.avg_entry_price * Decimal::from(position.quantity);
|
||||||
|
Some(payout - cost)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn close_position(&mut self, ticker: &str, exit_price: Decimal) -> Option<Decimal> {
|
||||||
|
let position = self.positions.remove(ticker)?;
|
||||||
|
|
||||||
|
let effective_exit_price = match position.side {
|
||||||
|
Side::Yes => exit_price,
|
||||||
|
Side::No => Decimal::ONE - exit_price,
|
||||||
|
};
|
||||||
|
|
||||||
|
let exit_value = effective_exit_price * Decimal::from(position.quantity);
|
||||||
|
self.cash += exit_value;
|
||||||
|
|
||||||
|
let cost = position.avg_entry_price * Decimal::from(position.quantity);
|
||||||
|
Some(exit_value - cost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Position {
|
||||||
|
pub ticker: String,
|
||||||
|
pub side: Side,
|
||||||
|
pub quantity: u64,
|
||||||
|
pub avg_entry_price: Decimal,
|
||||||
|
pub entry_time: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Position {
|
||||||
|
pub fn cost_basis(&self) -> Decimal {
|
||||||
|
self.avg_entry_price * Decimal::from(self.quantity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Trade {
|
||||||
|
pub ticker: String,
|
||||||
|
pub side: Side,
|
||||||
|
pub quantity: u64,
|
||||||
|
pub price: Decimal,
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
pub trade_type: TradeType,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub enum TradeType {
|
||||||
|
Open,
|
||||||
|
Close,
|
||||||
|
Resolution,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub enum ExitReason {
|
||||||
|
Resolution(MarketResult),
|
||||||
|
TakeProfit { pnl_pct: f64 },
|
||||||
|
StopLoss { pnl_pct: f64 },
|
||||||
|
TimeStop { hours_held: i64 },
|
||||||
|
ScoreReversal { new_score: f64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ExitSignal {
|
||||||
|
pub ticker: String,
|
||||||
|
pub reason: ExitReason,
|
||||||
|
pub current_price: Decimal,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ExitConfig {
|
||||||
|
pub take_profit_pct: f64,
|
||||||
|
pub stop_loss_pct: f64,
|
||||||
|
pub max_hold_hours: i64,
|
||||||
|
pub score_reversal_threshold: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ExitConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
take_profit_pct: 0.20,
|
||||||
|
stop_loss_pct: 0.15,
|
||||||
|
max_hold_hours: 72,
|
||||||
|
score_reversal_threshold: -0.3,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExitConfig {
|
||||||
|
pub fn conservative() -> Self {
|
||||||
|
Self {
|
||||||
|
take_profit_pct: 0.15,
|
||||||
|
stop_loss_pct: 0.10,
|
||||||
|
max_hold_hours: 48,
|
||||||
|
score_reversal_threshold: -0.2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn aggressive() -> Self {
|
||||||
|
Self {
|
||||||
|
take_profit_pct: 0.30,
|
||||||
|
stop_loss_pct: 0.20,
|
||||||
|
max_hold_hours: 120,
|
||||||
|
score_reversal_threshold: -0.5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Fill {
|
||||||
|
pub ticker: String,
|
||||||
|
pub side: Side,
|
||||||
|
pub quantity: u64,
|
||||||
|
pub price: Decimal,
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Signal {
|
||||||
|
pub ticker: String,
|
||||||
|
pub side: Side,
|
||||||
|
pub quantity: u64,
|
||||||
|
pub limit_price: Option<Decimal>,
|
||||||
|
pub reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MarketData {
|
||||||
|
pub ticker: String,
|
||||||
|
pub title: String,
|
||||||
|
pub category: String,
|
||||||
|
pub open_time: DateTime<Utc>,
|
||||||
|
pub close_time: DateTime<Utc>,
|
||||||
|
pub result: Option<MarketResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct TradeData {
|
||||||
|
pub timestamp: DateTime<Utc>,
|
||||||
|
pub ticker: String,
|
||||||
|
pub price: Decimal,
|
||||||
|
pub volume: u64,
|
||||||
|
pub taker_side: Side,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct BacktestConfig {
|
||||||
|
pub start_time: DateTime<Utc>,
|
||||||
|
pub end_time: DateTime<Utc>,
|
||||||
|
pub interval: chrono::Duration,
|
||||||
|
pub initial_capital: Decimal,
|
||||||
|
pub max_position_size: u64,
|
||||||
|
pub max_positions: usize,
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user