feat: initial commit with code quality refactoring

kalshi prediction market backtesting framework with:
- trading pipeline (sources, filters, scorers, selectors)
- position sizing with kelly criterion
- multiple scoring strategies (momentum, mean reversion, etc)
- random baseline for comparison

refactoring includes:
- extract shared resolve_closed_positions() function
- reduce RandomBaseline::run() nesting with helper functions
- move MarketCandidate Default impl to types.rs
- add explanatory comments to complex logic
This commit is contained in:
Nicholai Vogel 2026-01-21 09:32:12 -07:00
commit 025322219c
21 changed files with 4762 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
/target
/data/*.csv
/data/*.parquet
/results/*.json
Cargo.lock

30
Cargo.toml Normal file
View File

@ -0,0 +1,30 @@
[package]
name = "kalshi-backtest"
version = "0.1.0"
edition = "2021"
[dependencies]
tokio = { version = "1", features = ["full"] }
async-trait = "0.1"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
csv = "1.3"
chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4", features = ["derive"] }
anyhow = "1"
thiserror = "1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1", features = ["v4"] }
rust_decimal = { version = "1", features = ["serde"] }
rust_decimal_macros = "1"
ort = { version = "2.0.0-rc.11", optional = true }
ndarray = { version = "0.16", optional = true }
[features]
default = []
ml = ["ort", "ndarray"]
[dev-dependencies]
tempfile = "3"

236
README.md Normal file
View File

@ -0,0 +1,236 @@
kalshi-backtest
===
quant-level backtesting framework for kalshi prediction markets, using a candidate pipeline architecture.
features
---
- **multi-timeframe momentum** - detects divergence between short and long-term trends
- **bollinger bands mean reversion** - signals when price touches statistical extremes
- **order flow analysis** - tracks buying vs selling pressure via taker_side
- **kelly criterion position sizing** - dynamic sizing based on edge and win probability
- **exit signals** - take profit, stop loss, time stops, and score reversal triggers
- **category-aware weighting** - different strategies for politics, weather, sports, etc.
- **ensemble scoring** - combine multiple models with dynamic weighting
- **cross-market correlations** - lead-lag relationships between related markets
- **ML ensemble (optional)** - LSTM + MLP models via ONNX runtime
architecture
---
```
Historical Data (CSV)
|
v
+------------------+
| Backtest Loop | <- simulates time progression
+------------------+
|
v
+------------------+
| Candidate Pipeline |
+------------------+
| |
v v
Sources Filters -> Scorers -> Selector
|
v
+------------------+
| Trade Executor | <- kelly sizing, exit signals
+------------------+
|
v
+------------------+
| P&L Tracker | <- tracks positions, returns
+------------------+
|
v
Performance Metrics
```
data format
---
fetch data from kalshi API using the included script:
```bash
python scripts/fetch_kalshi_data.py
```
or download from https://www.deltabase.tech/
**markets.csv**:
```csv
ticker,title,category,open_time,close_time,result,status,yes_bid,yes_ask,volume,open_interest
PRES-2024-DEM,Will Democrats win?,politics,2024-01-01 00:00:00,2024-11-06 00:00:00,no,finalized,45,47,10000,5000
```
**trades.csv**:
```csv
timestamp,ticker,price,volume,taker_side
2024-01-05 12:00:00,PRES-2024-DEM,45,100,yes
2024-01-05 13:00:00,PRES-2024-DEM,46,50,no
```
usage
---
```bash
# build
cargo build --release
# run backtest with quant features
cargo run --release -- run \
--data-dir data \
--start 2024-01-01 \
--end 2024-06-01 \
--capital 10000 \
--max-position 500 \
--max-positions 10 \
--kelly-fraction 0.25 \
--max-position-pct 0.25 \
--take-profit 0.20 \
--stop-loss 0.15 \
--max-hold-hours 72 \
--compare-random
# view results
cargo run --release -- summary --results-file results/backtest_result.json
```
cli options
---
| option | default | description |
|--------|---------|-------------|
| --data-dir | data | directory with markets.csv and trades.csv |
| --start | required | backtest start date |
| --end | required | backtest end date |
| --capital | 10000 | initial capital |
| --max-position | 100 | max shares per position |
| --max-positions | 5 | max concurrent positions |
| --kelly-fraction | 0.25 | fraction of kelly criterion (0.1=conservative, 1.0=full) |
| --max-position-pct | 0.25 | max % of capital per position |
| --take-profit | 0.20 | take profit threshold (20% gain) |
| --stop-loss | 0.15 | stop loss threshold (15% loss) |
| --max-hold-hours | 72 | time stop in hours |
| --compare-random | false | compare vs random baseline |
scorers
---
**basic scorers**:
- `MomentumScorer` - price change over lookback period
- `MeanReversionScorer` - deviation from historical mean
- `VolumeScorer` - unusual volume detection
- `TimeDecayScorer` - prefer markets with more time to close
**quant scorers**:
- `MultiTimeframeMomentumScorer` - analyzes 1h, 4h, 12h, 24h windows, detects divergence
- `BollingerMeanReversionScorer` - triggers at upper/lower band touches (2 std)
- `OrderFlowScorer` - buy/sell imbalance from taker_side
- `CategoryWeightedScorer` - different weights per category
- `EnsembleScorer` - combines models with dynamic weights
- `CorrelationScorer` - cross-market lead-lag signals
**ml scorers** (requires `ml` feature):
- `MLEnsembleScorer` - LSTM + MLP via ONNX
position sizing
---
uses kelly criterion with safety multiplier:
```
kelly = (odds * win_prob - (1 - win_prob)) / odds
safe_kelly = kelly * kelly_fraction
position = min(bankroll * safe_kelly, max_position_pct * bankroll)
```
exit signals
---
positions can exit via:
1. **resolution** - market resolves yes/no
2. **take profit** - pnl exceeds threshold
3. **stop loss** - pnl below threshold
4. **time stop** - held too long (capital rotation)
5. **score reversal** - strategy flips bearish
ml training (optional)
---
train ML models using pytorch, then export to ONNX:
```bash
# install dependencies
pip install torch pandas numpy
# train models
python scripts/train_ml_models.py \
--data data/trades.csv \
--markets data/markets.csv \
--output models/ \
--epochs 50
# enable ml feature
cargo build --release --features ml
```
metrics
---
- total return ($ and %)
- sharpe ratio (annualized)
- max drawdown
- win rate
- average trade P&L
- average hold time
- trades per day
- return by category
extending
---
add custom scorers by implementing the `Scorer` trait:
```rust
use async_trait::async_trait;
pub struct MyScorer;
#[async_trait]
impl Scorer for MyScorer {
fn name(&self) -> &'static str {
"MyScorer"
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
// compute scores...
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
if let Some(score) = scored.scores.get("my_score") {
candidate.scores.insert("my_score".to_string(), *score);
}
}
}
```
then add to the pipeline in `backtest.rs`.

0
data/.gitkeep Normal file
View File

1
data/fetch_state.json Normal file
View File

@ -0,0 +1 @@
{"markets_cursor": "CgsI-rDDywYQkKOiMRI5S1hNVkVTUE9SVFNNVUxUSUdBTUVFWFRFTkRFRC1TMjAyNTBDMDMzMDBBRkYyLTkxNTVFNjFERTk3", "markets_count": 25000, "trades_cursor": null, "trades_count": 0, "markets_done": false, "trades_done": false}

254
scripts/fetch_kalshi_data.py Executable file
View File

@ -0,0 +1,254 @@
#!/usr/bin/env python3
"""
Fetch historical trade and market data from Kalshi's public API.
No authentication required for public endpoints.
Features:
- Incremental saves (writes batches to disk)
- Resume capability (tracks cursor position)
- Retry logic with exponential backoff
"""
import json
import csv
import time
import urllib.request
import urllib.error
from datetime import datetime
from pathlib import Path
BASE_URL = "https://api.elections.kalshi.com/trade-api/v2"
STATE_FILE = "fetch_state.json"
def fetch_json(url: str, max_retries: int = 5) -> dict:
"""Fetch JSON from URL with retries and exponential backoff."""
req = urllib.request.Request(url, headers={"Accept": "application/json"})
for attempt in range(max_retries):
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode())
except (urllib.error.HTTPError, urllib.error.URLError) as e:
wait = 2 ** attempt
print(f" attempt {attempt + 1}/{max_retries} failed: {e}")
if attempt < max_retries - 1:
print(f" retrying in {wait}s...")
time.sleep(wait)
else:
raise
except Exception as e:
wait = 2 ** attempt
print(f" unexpected error: {e}")
if attempt < max_retries - 1:
print(f" retrying in {wait}s...")
time.sleep(wait)
else:
raise
def load_state(output_dir: Path) -> dict:
"""Load saved state for resuming."""
state_path = output_dir / STATE_FILE
if state_path.exists():
with open(state_path) as f:
return json.load(f)
return {"markets_cursor": None, "markets_count": 0,
"trades_cursor": None, "trades_count": 0,
"markets_done": False, "trades_done": False}
def save_state(output_dir: Path, state: dict):
"""Save state for resuming."""
state_path = output_dir / STATE_FILE
with open(state_path, "w") as f:
json.dump(state, f)
def append_markets_csv(markets: list, output_path: Path, write_header: bool):
"""Append markets to CSV."""
mode = "w" if write_header else "a"
with open(output_path, mode, newline="") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["ticker", "title", "category", "open_time",
"close_time", "result", "status", "yes_bid",
"yes_ask", "volume", "open_interest"])
for m in markets:
result = ""
if m.get("result") == "yes":
result = "yes"
elif m.get("result") == "no":
result = "no"
elif m.get("status") == "finalized" and m.get("result"):
result = m.get("result")
writer.writerow([
m.get("ticker", ""),
m.get("title", ""),
m.get("category", ""),
m.get("open_time", ""),
m.get("close_time", m.get("expiration_time", "")),
result,
m.get("status", ""),
m.get("yes_bid", ""),
m.get("yes_ask", ""),
m.get("volume", ""),
m.get("open_interest", ""),
])
def append_trades_csv(trades: list, output_path: Path, write_header: bool):
"""Append trades to CSV."""
mode = "w" if write_header else "a"
with open(output_path, mode, newline="") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["timestamp", "ticker", "price", "volume", "taker_side"])
for t in trades:
price = t.get("yes_price", t.get("price", 50))
taker_side = t.get("taker_side", "")
if not taker_side:
taker_side = "yes" if t.get("is_taker_side_yes", True) else "no"
writer.writerow([
t.get("created_time", t.get("ts", "")),
t.get("ticker", t.get("market_ticker", "")),
price,
t.get("count", t.get("volume", 1)),
taker_side,
])
def fetch_markets_incremental(output_dir: Path, state: dict) -> int:
"""Fetch markets incrementally with state tracking."""
output_path = output_dir / "markets.csv"
cursor = state["markets_cursor"]
total = state["markets_count"]
write_header = total == 0
print(f"Resuming from {total} markets...")
while True:
url = f"{BASE_URL}/markets?limit=1000"
if cursor:
url += f"&cursor={cursor}"
print(f"Fetching markets... ({total:,} so far)")
try:
data = fetch_json(url)
except Exception as e:
print(f"Error fetching markets: {e}")
print(f"Progress saved. Run again to resume from {total:,} markets.")
return total
batch = data.get("markets", [])
if batch:
append_markets_csv(batch, output_path, write_header)
write_header = False
total += len(batch)
cursor = data.get("cursor")
state["markets_cursor"] = cursor
state["markets_count"] = total
save_state(output_dir, state)
if not cursor:
state["markets_done"] = True
save_state(output_dir, state)
break
time.sleep(0.3)
return total
def fetch_trades_incremental(output_dir: Path, state: dict, limit: int) -> int:
"""Fetch trades incrementally with state tracking."""
output_path = output_dir / "trades.csv"
cursor = state["trades_cursor"]
total = state["trades_count"]
write_header = total == 0
print(f"Resuming from {total} trades...")
while total < limit:
url = f"{BASE_URL}/markets/trades?limit=1000"
if cursor:
url += f"&cursor={cursor}"
print(f"Fetching trades... ({total:,}/{limit:,})")
try:
data = fetch_json(url)
except Exception as e:
print(f"Error fetching trades: {e}")
print(f"Progress saved. Run again to resume from {total:,} trades.")
return total
batch = data.get("trades", [])
if not batch:
break
append_trades_csv(batch, output_path, write_header)
write_header = False
total += len(batch)
cursor = data.get("cursor")
state["trades_cursor"] = cursor
state["trades_count"] = total
save_state(output_dir, state)
if not cursor:
state["trades_done"] = True
save_state(output_dir, state)
break
time.sleep(0.3)
return total
def main():
output_dir = Path("/mnt/work/kalshi-data")
output_dir.mkdir(exist_ok=True)
print("=" * 50)
print("Kalshi Data Fetcher (with resume)")
print("=" * 50)
state = load_state(output_dir)
# fetch markets
if not state["markets_done"]:
print("\n[1/2] Fetching markets...")
markets_count = fetch_markets_incremental(output_dir, state)
if state["markets_done"]:
print(f"Markets complete: {markets_count:,}")
else:
print(f"Markets paused at: {markets_count:,}")
return 1
else:
print(f"\n[1/2] Markets already complete: {state['markets_count']:,}")
# fetch trades
if not state["trades_done"]:
print("\n[2/2] Fetching trades...")
trades_count = fetch_trades_incremental(output_dir, state, limit=1000000)
if state["trades_done"]:
print(f"Trades complete: {trades_count:,}")
else:
print(f"Trades paused at: {trades_count:,}")
return 1
else:
print(f"\n[2/2] Trades already complete: {state['trades_count']:,}")
print("\n" + "=" * 50)
print("Done!")
print(f"Markets: {state['markets_count']:,}")
print(f"Trades: {state['trades_count']:,}")
print(f"Output: {output_dir}")
print("=" * 50)
# clear state for next run
(output_dir / STATE_FILE).unlink(missing_ok=True)
return 0
if __name__ == "__main__":
exit(main())

280
scripts/train_ml_models.py Normal file
View File

@ -0,0 +1,280 @@
#!/usr/bin/env python3
"""
Train ML models for the kalshi backtest framework.
Models:
- LSTM: learns patterns from price history sequences
- MLP: learns optimal combination of hand-crafted features
Usage:
python scripts/train_ml_models.py --data data/trades.csv --output models/
"""
import argparse
import json
import numpy as np
import pandas as pd
from pathlib import Path
try:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
print("warning: pytorch not installed. run: pip install torch")
def parse_args():
parser = argparse.ArgumentParser(description="Train ML models for kalshi backtest")
parser.add_argument("--data", type=Path, default=Path("data/trades.csv"))
parser.add_argument("--markets", type=Path, default=Path("data/markets.csv"))
parser.add_argument("--output", type=Path, default=Path("models"))
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--seq-len", type=int, default=24)
parser.add_argument("--train-split", type=float, default=0.8)
return parser.parse_args()
class LSTMPredictor(nn.Module):
def __init__(self, input_size=1, hidden_size=128, num_layers=2, dropout=0.2):
super().__init__()
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
batch_first=True,
)
self.fc = nn.Linear(hidden_size, 1)
self.tanh = nn.Tanh()
def forward(self, x):
lstm_out, _ = self.lstm(x)
last_output = lstm_out[:, -1, :]
out = self.fc(last_output)
return self.tanh(out)
class MLPPredictor(nn.Module):
def __init__(self, input_size=7, hidden_sizes=[64, 32]):
super().__init__()
layers = []
prev_size = input_size
for h in hidden_sizes:
layers.append(nn.Linear(prev_size, h))
layers.append(nn.ReLU())
layers.append(nn.Dropout(0.2))
prev_size = h
layers.append(nn.Linear(prev_size, 1))
layers.append(nn.Tanh())
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def load_data(trades_path: Path, markets_path: Path, seq_len: int):
print(f"loading trades from {trades_path}...")
trades = pd.read_csv(trades_path)
trades["timestamp"] = pd.to_datetime(trades["timestamp"])
trades = trades.sort_values(["ticker", "timestamp"])
print(f"loading markets from {markets_path}...")
markets = pd.read_csv(markets_path)
markets["close_time"] = pd.to_datetime(markets["close_time"])
result_map = dict(zip(markets["ticker"], markets["result"]))
sequences = []
features = []
labels = []
for ticker, group in trades.groupby("ticker"):
result = result_map.get(ticker)
if result not in ["yes", "no"]:
continue
label = 1.0 if result == "yes" else -1.0
prices = group["price"].values / 100.0
volumes = group["volume"].values
taker_sides = (group["taker_side"] == "yes").astype(float).values
if len(prices) < seq_len:
continue
for i in range(seq_len, len(prices)):
seq = prices[i - seq_len : i]
log_returns = np.diff(np.log(np.clip(seq, 1e-6, 1.0)))
if len(log_returns) == seq_len - 1:
log_returns = np.pad(log_returns, (1, 0), mode="constant")
sequences.append(log_returns)
curr_price = prices[i - 1]
momentum = prices[i - 1] - prices[i - seq_len] if len(prices) > seq_len else 0
mean_price = np.mean(prices[i - seq_len : i])
mean_reversion = mean_price - curr_price
vol_sum = np.sum(volumes[i - seq_len : i])
buy_vol = np.sum(volumes[i - seq_len : i] * taker_sides[i - seq_len : i])
sell_vol = vol_sum - buy_vol
order_flow = (buy_vol - sell_vol) / max(vol_sum, 1)
feat = [
momentum,
mean_reversion,
np.log1p(vol_sum),
order_flow,
curr_price,
np.std(log_returns) if len(log_returns) > 1 else 0,
len(group) / 1000.0,
]
features.append(feat)
labels.append(label)
print(f"created {len(sequences)} training samples")
return np.array(sequences), np.array(features), np.array(labels)
def train_lstm(sequences, labels, args):
print("\n" + "=" * 50)
print("Training LSTM")
print("=" * 50)
n = len(sequences)
split = int(n * args.train_split)
X_train = torch.tensor(sequences[:split], dtype=torch.float32).unsqueeze(-1)
y_train = torch.tensor(labels[:split], dtype=torch.float32).unsqueeze(-1)
X_test = torch.tensor(sequences[split:], dtype=torch.float32).unsqueeze(-1)
y_test = torch.tensor(labels[split:], dtype=torch.float32).unsqueeze(-1)
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
model = LSTMPredictor(input_size=1, hidden_size=128, num_layers=2)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(args.epochs):
model.train()
total_loss = 0
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
output = model(X_batch)
loss = criterion(output, y_batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
model.set_mode_to_inference()
with torch.no_grad():
train_pred = model(X_train)
test_pred = model(X_test)
train_acc = ((train_pred > 0) == (y_train > 0)).float().mean()
test_acc = ((test_pred > 0) == (y_test > 0)).float().mean()
print(f"epoch {epoch + 1}/{args.epochs}: loss={total_loss/len(train_loader):.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
return model
def train_mlp(features, labels, args):
print("\n" + "=" * 50)
print("Training MLP")
print("=" * 50)
features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-8)
n = len(features)
split = int(n * args.train_split)
X_train = torch.tensor(features[:split], dtype=torch.float32)
y_train = torch.tensor(labels[:split], dtype=torch.float32).unsqueeze(-1)
X_test = torch.tensor(features[split:], dtype=torch.float32)
y_test = torch.tensor(labels[split:], dtype=torch.float32).unsqueeze(-1)
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
model = MLPPredictor(input_size=features.shape[1])
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(args.epochs):
model.train()
total_loss = 0
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
output = model(X_batch)
loss = criterion(output, y_batch)
loss.backward()
optimizer.step()
total_loss += loss.item()
if (epoch + 1) % 10 == 0:
model.set_mode_to_inference()
with torch.no_grad():
train_pred = model(X_train)
test_pred = model(X_test)
train_acc = ((train_pred > 0) == (y_train > 0)).float().mean()
test_acc = ((test_pred > 0) == (y_test > 0)).float().mean()
print(f"epoch {epoch + 1}/{args.epochs}: loss={total_loss/len(train_loader):.4f}, train_acc={train_acc:.3f}, test_acc={test_acc:.3f}")
return model
def export_onnx(model, output_path: Path, input_shape, input_name="input", output_name="output"):
model.set_mode_to_inference()
dummy_input = torch.randn(*input_shape)
torch.onnx.export(
model,
dummy_input,
output_path,
input_names=[input_name],
output_names=[output_name],
dynamic_axes={
input_name: {0: "batch_size"},
output_name: {0: "batch_size"},
},
opset_version=14,
)
print(f"exported to {output_path}")
def main():
args = parse_args()
if not HAS_TORCH:
print("error: pytorch required for training. install with: pip install torch")
return 1
if not args.data.exists():
print(f"error: data file not found: {args.data}")
return 1
if not args.markets.exists():
print(f"error: markets file not found: {args.markets}")
return 1
args.output.mkdir(parents=True, exist_ok=True)
sequences, features, labels = load_data(args.data, args.markets, args.seq_len)
if len(sequences) < 100:
print(f"error: not enough training data ({len(sequences)} samples)")
return 1
lstm_model = train_lstm(sequences, labels, args)
export_onnx(lstm_model, args.output / "lstm.onnx", (1, args.seq_len, 1))
mlp_model = train_mlp(features, labels, args)
export_onnx(mlp_model, args.output / "mlp.onnx", (1, features.shape[1]))
print("\n" + "=" * 50)
print("Training complete!")
print(f"Models saved to: {args.output}")
print("=" * 50)
return 0
if __name__ == "__main__":
exit(main())

445
src/backtest.rs Normal file
View File

@ -0,0 +1,445 @@
use crate::data::HistoricalData;
use crate::execution::{Executor, PositionSizingConfig};
use crate::metrics::{BacktestResult, MetricsCollector};
use crate::pipeline::{
AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer, Filter,
HistoricalMarketSource, LiquidityFilter, MeanReversionScorer, MomentumScorer,
MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer, Selector, Source, TimeDecayScorer,
TimeToCloseFilter, TopKSelector, TradingPipeline, VolumeScorer,
};
use crate::types::{
BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType,
TradingContext,
};
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tracing::info;
/// resolves any positions in markets that have closed
/// returns list of (ticker, result, pnl) for logging purposes
fn resolve_closed_positions(
portfolio: &mut Portfolio,
data: &HistoricalData,
resolved: &mut HashSet<String>,
at: DateTime<Utc>,
history: &mut Vec<Trade>,
metrics: &mut MetricsCollector,
) -> Vec<(String, MarketResult, Option<Decimal>)> {
let tickers: Vec<String> = portfolio.positions.keys().cloned().collect();
let mut resolutions = Vec::new();
for ticker in tickers {
if resolved.contains(&ticker) {
continue;
}
let Some(result) = data.get_resolution_at(&ticker, at) else {
continue;
};
resolved.insert(ticker.clone());
let Some(pos) = portfolio.positions.get(&ticker).cloned() else {
continue;
};
let pnl = portfolio.resolve_position(&ticker, result);
let exit_price = match result {
MarketResult::Yes => match pos.side {
Side::Yes => Decimal::ONE,
Side::No => Decimal::ZERO,
},
MarketResult::No => match pos.side {
Side::Yes => Decimal::ZERO,
Side::No => Decimal::ONE,
},
MarketResult::Cancelled => pos.avg_entry_price,
};
let category = data
.markets
.get(&ticker)
.map(|m| m.category.clone())
.unwrap_or_default();
let trade = Trade {
ticker: ticker.clone(),
side: pos.side,
quantity: pos.quantity,
price: exit_price,
timestamp: at,
trade_type: TradeType::Resolution,
};
history.push(trade.clone());
metrics.record_trade(&trade, &category);
resolutions.push((ticker, result, pnl));
}
resolutions
}
pub struct Backtester {
config: BacktestConfig,
data: Arc<HistoricalData>,
pipeline: TradingPipeline,
executor: Executor,
}
impl Backtester {
pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self {
let pipeline = Self::build_default_pipeline(data.clone(), &config);
let executor = Executor::new(data.clone(), 10, config.max_position_size);
Self {
config,
data,
pipeline,
executor,
}
}
pub fn with_configs(
config: BacktestConfig,
data: Arc<HistoricalData>,
sizing_config: PositionSizingConfig,
exit_config: ExitConfig,
) -> Self {
let pipeline = Self::build_default_pipeline(data.clone(), &config);
let executor = Executor::new(data.clone(), 10, config.max_position_size)
.with_sizing_config(sizing_config)
.with_exit_config(exit_config);
Self {
config,
data,
pipeline,
executor,
}
}
pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self {
self.pipeline = pipeline;
self
}
fn build_default_pipeline(data: Arc<HistoricalData>, config: &BacktestConfig) -> TradingPipeline {
let sources: Vec<Box<dyn Source>> = vec![
Box::new(HistoricalMarketSource::new(data, 24)),
];
let filters: Vec<Box<dyn Filter>> = vec![
Box::new(LiquidityFilter::new(100)),
Box::new(TimeToCloseFilter::new(2, Some(720))),
Box::new(AlreadyPositionedFilter::new(config.max_position_size)),
];
let scorers: Vec<Box<dyn Scorer>> = vec![
Box::new(MomentumScorer::new(6)),
Box::new(MultiTimeframeMomentumScorer::default_windows()),
Box::new(MeanReversionScorer::new(24)),
Box::new(BollingerMeanReversionScorer::default_config()),
Box::new(VolumeScorer::new(6)),
Box::new(OrderFlowScorer::new()),
Box::new(TimeDecayScorer::new()),
Box::new(CategoryWeightedScorer::with_defaults()),
];
let selector: Box<dyn Selector> = Box::new(TopKSelector::new(config.max_positions));
TradingPipeline::new(sources, filters, scorers, selector, config.max_positions)
}
pub async fn run(&self) -> BacktestResult {
let mut context = TradingContext::new(self.config.initial_capital, self.config.start_time);
let mut metrics = MetricsCollector::new(self.config.initial_capital);
let mut resolved_markets: HashSet<String> = HashSet::new();
let mut current_time = self.config.start_time;
info!(
start = %self.config.start_time,
end = %self.config.end_time,
interval_hours = self.config.interval.num_hours(),
"starting backtest"
);
while current_time < self.config.end_time {
context.timestamp = current_time;
context.request_id = uuid::Uuid::new_v4().to_string();
let resolutions = resolve_closed_positions(
&mut context.portfolio,
&self.data,
&mut resolved_markets,
current_time,
&mut context.trading_history,
&mut metrics,
);
for (ticker, result, pnl) in resolutions {
info!(ticker = %ticker, result = ?result, pnl = ?pnl, "market resolved");
}
let result = self.pipeline.execute(context.clone()).await;
let candidate_scores: std::collections::HashMap<String, f64> = result
.selected_candidates
.iter()
.map(|c| (c.ticker.clone(), c.final_score))
.collect();
let exit_signals = self.executor.generate_exit_signals(&context, &candidate_scores);
for exit in exit_signals {
if let Some(position) = context.portfolio.positions.get(&exit.ticker).cloned() {
let pnl = context.portfolio.close_position(&exit.ticker, exit.current_price);
info!(
ticker = %exit.ticker,
reason = ?exit.reason,
pnl = ?pnl,
"exit triggered"
);
let category = self
.data
.markets
.get(&exit.ticker)
.map(|m| m.category.clone())
.unwrap_or_default();
let exit_price = match position.side {
crate::types::Side::Yes => exit.current_price,
crate::types::Side::No => Decimal::ONE - exit.current_price,
};
let trade = Trade {
ticker: exit.ticker.clone(),
side: position.side,
quantity: position.quantity,
price: exit_price,
timestamp: current_time,
trade_type: TradeType::Close,
};
context.trading_history.push(trade.clone());
metrics.record_trade(&trade, &category);
}
}
let signals = self.executor.generate_signals(&result.selected_candidates, &context);
for signal in signals {
if let Some(fill) = self.executor.execute_signal(&signal, &context) {
info!(
ticker = %fill.ticker,
side = ?fill.side,
quantity = fill.quantity,
price = %fill.price,
"executed trade"
);
context.portfolio.apply_fill(&fill);
let category = self
.data
.markets
.get(&fill.ticker)
.map(|m| m.category.clone())
.unwrap_or_default();
let trade = Trade {
ticker: fill.ticker.clone(),
side: fill.side,
quantity: fill.quantity,
price: fill.price,
timestamp: fill.timestamp,
trade_type: TradeType::Open,
};
context.trading_history.push(trade.clone());
metrics.record_trade(&trade, &category);
}
}
let market_prices = self.get_current_prices(current_time);
metrics.record(current_time, &context.portfolio, &market_prices);
current_time = current_time + self.config.interval;
}
let resolutions = resolve_closed_positions(
&mut context.portfolio,
&self.data,
&mut resolved_markets,
self.config.end_time,
&mut context.trading_history,
&mut metrics,
);
for (ticker, result, pnl) in resolutions {
info!(ticker = %ticker, result = ?result, pnl = ?pnl, "market resolved");
}
info!(
trades = context.trading_history.len(),
positions = context.portfolio.positions.len(),
cash = %context.portfolio.cash,
"backtest complete"
);
metrics.finalize()
}
fn get_current_prices(&self, at: DateTime<Utc>) -> HashMap<String, Decimal> {
self.data
.markets
.keys()
.filter_map(|ticker| {
self.data
.get_current_price(ticker, at)
.map(|p| (ticker.clone(), p))
})
.collect()
}
}
pub struct RandomBaseline {
config: BacktestConfig,
data: Arc<HistoricalData>,
}
impl RandomBaseline {
pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self {
Self { config, data }
}
pub async fn run(&self) -> BacktestResult {
let mut context = TradingContext::new(self.config.initial_capital, self.config.start_time);
let mut metrics = MetricsCollector::new(self.config.initial_capital);
let mut resolved_markets: HashSet<String> = HashSet::new();
let mut rng_state: u64 = 42;
let mut current_time = self.config.start_time;
while current_time < self.config.end_time {
context.timestamp = current_time;
resolve_closed_positions(
&mut context.portfolio,
&self.data,
&mut resolved_markets,
current_time,
&mut context.trading_history,
&mut metrics,
);
if let Some(fill) = self.try_random_trade(&context, current_time, &mut rng_state) {
let category = self
.data
.markets
.get(&fill.ticker)
.map(|m| m.category.clone())
.unwrap_or_default();
context.portfolio.apply_fill(&fill);
let trade = Trade {
ticker: fill.ticker.clone(),
side: fill.side,
quantity: fill.quantity,
price: fill.price,
timestamp: current_time,
trade_type: TradeType::Open,
};
context.trading_history.push(trade.clone());
metrics.record_trade(&trade, &category);
}
let market_prices = self.get_current_prices(current_time);
metrics.record(current_time, &context.portfolio, &market_prices);
current_time = current_time + self.config.interval;
}
resolve_closed_positions(
&mut context.portfolio,
&self.data,
&mut resolved_markets,
self.config.end_time,
&mut context.trading_history,
&mut metrics,
);
metrics.finalize()
}
fn try_random_trade(
&self,
context: &TradingContext,
at: DateTime<Utc>,
rng_state: &mut u64,
) -> Option<Fill> {
if context.portfolio.positions.len() >= self.config.max_positions {
return None;
}
let active_markets = self.data.get_active_markets(at);
let unpositioned: Vec<_> = active_markets
.iter()
.filter(|m| !context.portfolio.has_position(&m.ticker))
.collect();
if unpositioned.is_empty() {
return None;
}
*rng_state = lcg_next(*rng_state);
let idx = (*rng_state as usize) % unpositioned.len();
let market = unpositioned[idx];
let price = self.data.get_current_price(&market.ticker, at)?;
let side = if *rng_state % 2 == 0 { Side::Yes } else { Side::No };
let effective_price = match side {
Side::Yes => price,
Side::No => Decimal::ONE - price,
};
let quantity = self
.config
.max_position_size
.min((context.portfolio.cash / effective_price).to_u64().unwrap_or(0));
if quantity == 0 {
return None;
}
Some(Fill {
ticker: market.ticker.clone(),
side,
quantity,
price: effective_price,
timestamp: at,
})
}
fn get_current_prices(&self, at: DateTime<Utc>) -> HashMap<String, Decimal> {
self.data
.markets
.keys()
.filter_map(|ticker| {
self.data
.get_current_price(ticker, at)
.map(|p| (ticker.clone(), p))
})
.collect()
}
}
/// linear congruential generator for deterministic random baseline
fn lcg_next(state: u64) -> u64 {
state.wrapping_mul(1103515245).wrapping_add(12345)
}

290
src/data/loader.rs Normal file
View File

@ -0,0 +1,290 @@
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use csv::ReaderBuilder;
use rust_decimal::Decimal;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::Path;
use crate::types::{MarketData, MarketResult, PricePoint, Side, TradeData};
#[derive(Debug, Deserialize)]
struct CsvMarket {
ticker: String,
title: String,
category: String,
#[serde(with = "flexible_datetime")]
open_time: DateTime<Utc>,
#[serde(with = "flexible_datetime")]
close_time: DateTime<Utc>,
result: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CsvTrade {
#[serde(with = "flexible_datetime")]
timestamp: DateTime<Utc>,
ticker: String,
price: f64,
volume: u64,
taker_side: String,
}
mod flexible_datetime {
use chrono::{DateTime, NaiveDateTime, Utc};
use serde::{self, Deserialize, Deserializer};
pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
if let Ok(dt) = DateTime::parse_from_rfc3339(&s) {
return Ok(dt.with_timezone(&Utc));
}
if let Ok(dt) = NaiveDateTime::parse_from_str(&s, "%Y-%m-%d %H:%M:%S") {
return Ok(dt.and_utc());
}
if let Ok(ts) = s.parse::<i64>() {
return DateTime::from_timestamp(ts, 0)
.ok_or_else(|| serde::de::Error::custom("invalid timestamp"));
}
Err(serde::de::Error::custom(format!(
"could not parse datetime: {}",
s
)))
}
}
pub struct HistoricalData {
pub markets: HashMap<String, MarketData>,
pub trades: Vec<TradeData>,
trade_index: HashMap<String, Vec<usize>>,
}
impl HistoricalData {
pub fn load(data_dir: &Path) -> Result<Self> {
let markets_path = data_dir.join("markets.csv");
let trades_path = data_dir.join("trades.csv");
let markets = load_markets(&markets_path)
.with_context(|| format!("loading markets from {:?}", markets_path))?;
let trades =
load_trades(&trades_path).with_context(|| format!("loading trades from {:?}", trades_path))?;
let mut trade_index: HashMap<String, Vec<usize>> = HashMap::new();
for (i, trade) in trades.iter().enumerate() {
trade_index.entry(trade.ticker.clone()).or_default().push(i);
}
Ok(Self {
markets,
trades,
trade_index,
})
}
pub fn get_active_markets(&self, at: DateTime<Utc>) -> Vec<&MarketData> {
self.markets
.values()
.filter(|m| at >= m.open_time && at < m.close_time)
.collect()
}
pub fn get_trades_for_market(&self, ticker: &str, from: DateTime<Utc>, to: DateTime<Utc>) -> Vec<&TradeData> {
self.trade_index
.get(ticker)
.map(|indices| {
indices
.iter()
.filter_map(|&i| {
let trade = &self.trades[i];
if trade.timestamp >= from && trade.timestamp < to {
Some(trade)
} else {
None
}
})
.collect()
})
.unwrap_or_default()
}
pub fn get_current_price(&self, ticker: &str, at: DateTime<Utc>) -> Option<Decimal> {
self.trade_index.get(ticker).and_then(|indices| {
indices
.iter()
.filter_map(|&i| {
let trade = &self.trades[i];
if trade.timestamp <= at {
Some(trade)
} else {
None
}
})
.last()
.map(|t| t.price)
})
}
pub fn get_price_history(
&self,
ticker: &str,
from: DateTime<Utc>,
to: DateTime<Utc>,
) -> Vec<PricePoint> {
self.get_trades_for_market(ticker, from, to)
.into_iter()
.map(|t| PricePoint {
timestamp: t.timestamp,
yes_price: t.price,
volume: t.volume,
})
.collect()
}
pub fn get_volume_24h(&self, ticker: &str, at: DateTime<Utc>) -> u64 {
let from = at - chrono::Duration::hours(24);
self.get_trades_for_market(ticker, from, at)
.iter()
.map(|t| t.volume)
.sum()
}
pub fn get_order_flow_24h(&self, ticker: &str, at: DateTime<Utc>) -> (u64, u64) {
let from = at - chrono::Duration::hours(24);
let trades = self.get_trades_for_market(ticker, from, at);
let buy_vol: u64 = trades.iter().filter(|t| t.taker_side == Side::Yes).map(|t| t.volume).sum();
let sell_vol: u64 = trades.iter().filter(|t| t.taker_side == Side::No).map(|t| t.volume).sum();
(buy_vol, sell_vol)
}
pub fn get_resolutions(&self, at: DateTime<Utc>) -> Vec<(&MarketData, MarketResult)> {
self.markets
.values()
.filter_map(|m| {
if m.close_time <= at {
m.result.map(|r| (m, r))
} else {
None
}
})
.collect()
}
pub fn get_resolution_at(&self, ticker: &str, at: DateTime<Utc>) -> Option<MarketResult> {
self.markets.get(ticker).and_then(|m| {
if m.close_time <= at {
m.result
} else {
None
}
})
}
}
fn load_markets(path: &Path) -> Result<HashMap<String, MarketData>> {
let mut reader = ReaderBuilder::new()
.has_headers(true)
.flexible(true)
.from_path(path)?;
let mut markets = HashMap::new();
for result in reader.deserialize() {
let record: CsvMarket = result?;
let result = record.result.as_ref().and_then(|r| match r.to_lowercase().as_str() {
"yes" => Some(MarketResult::Yes),
"no" => Some(MarketResult::No),
"cancelled" | "canceled" => Some(MarketResult::Cancelled),
"" => None,
_ => None,
});
markets.insert(
record.ticker.clone(),
MarketData {
ticker: record.ticker,
title: record.title,
category: record.category,
open_time: record.open_time,
close_time: record.close_time,
result,
},
);
}
Ok(markets)
}
fn load_trades(path: &Path) -> Result<Vec<TradeData>> {
let mut reader = ReaderBuilder::new()
.has_headers(true)
.flexible(true)
.from_path(path)?;
let mut trades = Vec::new();
for result in reader.deserialize() {
let record: CsvTrade = result?;
let side = match record.taker_side.to_lowercase().as_str() {
"yes" | "buy" => Side::Yes,
"no" | "sell" => Side::No,
_ => continue,
};
trades.push(TradeData {
timestamp: record.timestamp,
ticker: record.ticker,
price: Decimal::try_from(record.price / 100.0).unwrap_or(Decimal::ZERO),
volume: record.volume,
taker_side: side,
});
}
trades.sort_by_key(|t| t.timestamp);
Ok(trades)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn create_test_data() -> TempDir {
let dir = TempDir::new().unwrap();
let markets_csv = r#"ticker,title,category,open_time,close_time,result
TEST-MKT-1,Test Market 1,politics,2024-01-01 00:00:00,2024-01-15 00:00:00,yes
TEST-MKT-2,Test Market 2,economics,2024-01-01 00:00:00,2024-01-20 00:00:00,no
"#;
let mut f = std::fs::File::create(dir.path().join("markets.csv")).unwrap();
f.write_all(markets_csv.as_bytes()).unwrap();
let trades_csv = r#"timestamp,ticker,price,volume,taker_side
2024-01-05 12:00:00,TEST-MKT-1,55,100,yes
2024-01-05 13:00:00,TEST-MKT-1,57,50,yes
2024-01-06 10:00:00,TEST-MKT-2,45,200,no
"#;
let mut f = std::fs::File::create(dir.path().join("trades.csv")).unwrap();
f.write_all(trades_csv.as_bytes()).unwrap();
dir
}
#[test]
fn test_load_historical_data() {
let dir = create_test_data();
let data = HistoricalData::load(dir.path()).unwrap();
assert_eq!(data.markets.len(), 2);
assert_eq!(data.trades.len(), 3);
}
}

3
src/data/mod.rs Normal file
View File

@ -0,0 +1,3 @@
mod loader;
pub use loader::HistoricalData;

325
src/execution.rs Normal file
View File

@ -0,0 +1,325 @@
use crate::data::HistoricalData;
use crate::types::{ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side, Signal, TradingContext};
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct PositionSizingConfig {
pub kelly_fraction: f64,
pub max_position_pct: f64,
pub min_position_size: u64,
pub max_position_size: u64,
}
impl Default for PositionSizingConfig {
fn default() -> Self {
Self {
kelly_fraction: 0.25,
max_position_pct: 0.25,
min_position_size: 10,
max_position_size: 1000,
}
}
}
impl PositionSizingConfig {
pub fn conservative() -> Self {
Self {
kelly_fraction: 0.1,
max_position_pct: 0.1,
min_position_size: 10,
max_position_size: 500,
}
}
pub fn aggressive() -> Self {
Self {
kelly_fraction: 0.5,
max_position_pct: 0.4,
min_position_size: 10,
max_position_size: 2000,
}
}
}
/// maps scoring edge [-inf, +inf] to win probability [0, 1]
/// tanh squashes extreme values smoothly; +1)/2 shifts from [-1,1] to [0,1]
fn edge_to_win_probability(edge: f64) -> f64 {
(1.0 + edge.tanh()) / 2.0
}
fn kelly_size(
edge: f64,
price: f64,
bankroll: f64,
config: &PositionSizingConfig,
) -> u64 {
if edge.abs() < 0.01 || price <= 0.0 || price >= 1.0 {
return 0;
}
let win_prob = edge_to_win_probability(edge);
let odds = (1.0 - price) / price;
if odds <= 0.0 {
return 0;
}
let kelly = (odds * win_prob - (1.0 - win_prob)) / odds;
let safe_kelly = (kelly * config.kelly_fraction).max(0.0);
let position_value = bankroll * safe_kelly.min(config.max_position_pct);
let shares = (position_value / price).floor() as u64;
shares.max(config.min_position_size).min(config.max_position_size)
}
pub struct Executor {
data: Arc<HistoricalData>,
slippage_bps: u32,
max_position_size: u64,
sizing_config: PositionSizingConfig,
exit_config: ExitConfig,
}
impl Executor {
pub fn new(data: Arc<HistoricalData>, slippage_bps: u32, max_position_size: u64) -> Self {
Self {
data,
slippage_bps,
max_position_size,
sizing_config: PositionSizingConfig::default(),
exit_config: ExitConfig::default(),
}
}
pub fn with_sizing_config(mut self, config: PositionSizingConfig) -> Self {
self.sizing_config = config;
self
}
pub fn with_exit_config(mut self, config: ExitConfig) -> Self {
self.exit_config = config;
self
}
pub fn generate_exit_signals(
&self,
context: &TradingContext,
candidate_scores: &std::collections::HashMap<String, f64>,
) -> Vec<ExitSignal> {
let mut exits = Vec::new();
for (ticker, position) in &context.portfolio.positions {
let current_price = match self.data.get_current_price(ticker, context.timestamp) {
Some(p) => p,
None => continue,
};
let effective_price = match position.side {
Side::Yes => current_price,
Side::No => Decimal::ONE - current_price,
};
let entry_price_f64 = position.avg_entry_price.to_f64().unwrap_or(0.5);
let current_price_f64 = effective_price.to_f64().unwrap_or(0.5);
if entry_price_f64 <= 0.0 {
continue;
}
let pnl_pct = (current_price_f64 - entry_price_f64) / entry_price_f64;
if pnl_pct >= self.exit_config.take_profit_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TakeProfit { pnl_pct },
current_price,
});
continue;
}
if pnl_pct <= -self.exit_config.stop_loss_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::StopLoss { pnl_pct },
current_price,
});
continue;
}
let hours_held = (context.timestamp - position.entry_time).num_hours();
if hours_held >= self.exit_config.max_hold_hours {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TimeStop { hours_held },
current_price,
});
continue;
}
if let Some(&new_score) = candidate_scores.get(ticker) {
if new_score < self.exit_config.score_reversal_threshold {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::ScoreReversal { new_score },
current_price,
});
}
}
}
exits
}
pub fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal> {
candidates
.iter()
.filter_map(|c| self.candidate_to_signal(c, context))
.collect()
}
fn candidate_to_signal(
&self,
candidate: &MarketCandidate,
context: &TradingContext,
) -> Option<Signal> {
let current_position = context.portfolio.get_position(&candidate.ticker);
let current_qty = current_position.map(|p| p.quantity).unwrap_or(0);
if current_qty >= self.max_position_size {
return None;
}
let yes_price = candidate.current_yes_price.to_f64().unwrap_or(0.5);
// positive score = bullish signal, so buy the cheaper side (better risk/reward)
// negative score = bearish signal, so buy against the expensive side
let side = if candidate.final_score > 0.0 {
if yes_price < 0.5 { Side::Yes } else { Side::No }
} else if candidate.final_score < 0.0 {
if yes_price > 0.5 { Side::No } else { Side::Yes }
} else {
return None;
};
let price = match side {
Side::Yes => candidate.current_yes_price,
Side::No => candidate.current_no_price,
};
let available_cash = context.portfolio.cash.to_f64().unwrap_or(0.0);
let price_f64 = price.to_f64().unwrap_or(0.5);
if price_f64 <= 0.0 {
return None;
}
let kelly_qty = kelly_size(
candidate.final_score,
price_f64,
available_cash,
&self.sizing_config,
);
let max_affordable = (available_cash / price_f64) as u64;
let quantity = kelly_qty
.min(max_affordable)
.min(self.max_position_size - current_qty);
if quantity < self.sizing_config.min_position_size {
return None;
}
Some(Signal {
ticker: candidate.ticker.clone(),
side,
quantity,
limit_price: Some(price),
reason: format!(
"score={:.3}, side={:?}, price={:.2}",
candidate.final_score, side, price_f64
),
})
}
pub fn execute_signal(
&self,
signal: &Signal,
context: &TradingContext,
) -> Option<Fill> {
let market_price = self.data.get_current_price(&signal.ticker, context.timestamp)?;
let effective_price = match signal.side {
Side::Yes => market_price,
Side::No => Decimal::ONE - market_price,
};
let slippage = Decimal::from(self.slippage_bps) / Decimal::from(10000);
let fill_price = effective_price * (Decimal::ONE + slippage);
if let Some(limit) = signal.limit_price {
if fill_price > limit * (Decimal::ONE + slippage * Decimal::from(2)) {
return None;
}
}
let cost = fill_price * Decimal::from(signal.quantity);
if cost > context.portfolio.cash {
let affordable = (context.portfolio.cash / fill_price)
.to_u64()
.unwrap_or(0);
if affordable == 0 {
return None;
}
return Some(Fill {
ticker: signal.ticker.clone(),
side: signal.side,
quantity: affordable,
price: fill_price,
timestamp: context.timestamp,
});
}
Some(Fill {
ticker: signal.ticker.clone(),
side: signal.side,
quantity: signal.quantity,
price: fill_price,
timestamp: context.timestamp,
})
}
}
pub fn simple_signal_generator(
candidates: &[MarketCandidate],
context: &TradingContext,
position_size: u64,
) -> Vec<Signal> {
candidates
.iter()
.filter(|c| c.final_score > 0.0)
.filter(|c| !context.portfolio.has_position(&c.ticker))
.map(|c| {
let yes_price = c.current_yes_price.to_f64().unwrap_or(0.5);
let (side, price) = if yes_price < 0.5 {
(Side::Yes, c.current_yes_price)
} else {
(Side::No, c.current_no_price)
};
Signal {
ticker: c.ticker.clone(),
side,
quantity: position_size,
limit_price: Some(price),
reason: format!("simple: score={:.3}", c.final_score),
}
})
.collect()
}

216
src/main.rs Normal file
View File

@ -0,0 +1,216 @@
mod backtest;
mod data;
mod execution;
mod metrics;
mod pipeline;
mod types;
use anyhow::{Context, Result};
use backtest::{Backtester, RandomBaseline};
use chrono::{DateTime, NaiveDate, TimeZone, Utc};
use clap::{Parser, Subcommand};
use data::HistoricalData;
use execution::{Executor, PositionSizingConfig};
use rust_decimal::Decimal;
use std::path::PathBuf;
use std::sync::Arc;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use types::{BacktestConfig, ExitConfig};
#[derive(Parser)]
#[command(name = "kalshi-backtest")]
#[command(about = "backtesting framework for kalshi prediction markets")]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Run {
#[arg(short, long, default_value = "data")]
data_dir: PathBuf,
#[arg(long)]
start: String,
#[arg(long)]
end: String,
#[arg(long, default_value = "10000")]
capital: f64,
#[arg(long, default_value = "100")]
max_position: u64,
#[arg(long, default_value = "5")]
max_positions: usize,
#[arg(long, default_value = "1")]
interval_hours: i64,
#[arg(long, default_value = "results")]
output_dir: PathBuf,
#[arg(long)]
compare_random: bool,
#[arg(long, default_value = "0.25")]
kelly_fraction: f64,
#[arg(long, default_value = "0.25")]
max_position_pct: f64,
#[arg(long, default_value = "0.20")]
take_profit: f64,
#[arg(long, default_value = "0.15")]
stop_loss: f64,
#[arg(long, default_value = "72")]
max_hold_hours: i64,
},
Summary {
#[arg(short, long)]
results_file: PathBuf,
},
}
fn parse_date(s: &str) -> Result<DateTime<Utc>> {
if let Ok(dt) = DateTime::parse_from_rfc3339(s) {
return Ok(dt.with_timezone(&Utc));
}
if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
return Ok(Utc.from_utc_datetime(&date.and_hms_opt(0, 0, 0).unwrap()));
}
Err(anyhow::anyhow!("could not parse date: {}", s))
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "kalshi_backtest=info".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let cli = Cli::parse();
match cli.command {
Commands::Run {
data_dir,
start,
end,
capital,
max_position,
max_positions,
interval_hours,
output_dir,
compare_random,
kelly_fraction,
max_position_pct,
take_profit,
stop_loss,
max_hold_hours,
} => {
let start_time = parse_date(&start).context("parsing start date")?;
let end_time = parse_date(&end).context("parsing end date")?;
info!(
data_dir = %data_dir.display(),
start = %start_time,
end = %end_time,
capital = capital,
"loading data"
);
let data = Arc::new(
HistoricalData::load(&data_dir).context("loading historical data")?,
);
info!(
markets = data.markets.len(),
trades = data.trades.len(),
"data loaded"
);
let config = BacktestConfig {
start_time,
end_time,
interval: chrono::Duration::hours(interval_hours),
initial_capital: Decimal::try_from(capital).unwrap(),
max_position_size: max_position,
max_positions,
};
let sizing_config = PositionSizingConfig {
kelly_fraction,
max_position_pct,
min_position_size: 10,
max_position_size: max_position,
};
let exit_config = ExitConfig {
take_profit_pct: take_profit,
stop_loss_pct: stop_loss,
max_hold_hours,
score_reversal_threshold: -0.3,
};
let backtester = Backtester::with_configs(config.clone(), data.clone(), sizing_config, exit_config);
let result = backtester.run().await;
println!("{}", result.summary());
std::fs::create_dir_all(&output_dir)?;
let result_path = output_dir.join("backtest_result.json");
let json = serde_json::to_string_pretty(&result)?;
std::fs::write(&result_path, json)?;
info!(path = %result_path.display(), "results saved");
if compare_random {
println!("\n--- random baseline ---\n");
let baseline = RandomBaseline::new(config, data);
let baseline_result = baseline.run().await;
println!("{}", baseline_result.summary());
let baseline_path = output_dir.join("baseline_result.json");
let json = serde_json::to_string_pretty(&baseline_result)?;
std::fs::write(&baseline_path, json)?;
println!("\n--- comparison ---\n");
println!(
"strategy return: {:.2}% vs baseline: {:.2}%",
result.total_return_pct, baseline_result.total_return_pct
);
println!(
"strategy sharpe: {:.3} vs baseline: {:.3}",
result.sharpe_ratio, baseline_result.sharpe_ratio
);
println!(
"strategy win rate: {:.1}% vs baseline: {:.1}%",
result.win_rate, baseline_result.win_rate
);
}
Ok(())
}
Commands::Summary { results_file } => {
let content = std::fs::read_to_string(&results_file)
.context("reading results file")?;
let result: metrics::BacktestResult =
serde_json::from_str(&content).context("parsing results")?;
println!("{}", result.summary());
Ok(())
}
}
}

280
src/metrics.rs Normal file
View File

@ -0,0 +1,280 @@
use crate::types::{Portfolio, Trade, TradeType};
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BacktestResult {
pub total_return: f64,
pub total_return_pct: f64,
pub sharpe_ratio: f64,
pub max_drawdown: f64,
pub max_drawdown_pct: f64,
pub win_rate: f64,
pub total_trades: usize,
pub winning_trades: usize,
pub losing_trades: usize,
pub avg_trade_pnl: f64,
pub avg_hold_time_hours: f64,
pub trades_per_day: f64,
pub return_by_category: HashMap<String, f64>,
pub equity_curve: Vec<EquityPoint>,
pub trade_log: Vec<TradeRecord>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EquityPoint {
pub timestamp: DateTime<Utc>,
pub equity: f64,
pub cash: f64,
pub positions_value: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TradeRecord {
pub ticker: String,
pub entry_time: DateTime<Utc>,
pub exit_time: Option<DateTime<Utc>>,
pub side: String,
pub quantity: u64,
pub entry_price: f64,
pub exit_price: Option<f64>,
pub pnl: Option<f64>,
pub category: String,
}
pub struct MetricsCollector {
initial_capital: Decimal,
equity_curve: Vec<EquityPoint>,
trade_records: HashMap<String, TradeRecord>,
closed_trades: Vec<TradeRecord>,
daily_returns: Vec<f64>,
last_equity: f64,
peak_equity: f64,
max_drawdown: f64,
}
impl MetricsCollector {
pub fn new(initial_capital: Decimal) -> Self {
let capital = initial_capital.to_f64().unwrap_or(10000.0);
Self {
initial_capital,
equity_curve: Vec::new(),
trade_records: HashMap::new(),
closed_trades: Vec::new(),
daily_returns: Vec::new(),
last_equity: capital,
peak_equity: capital,
max_drawdown: 0.0,
}
}
pub fn record(
&mut self,
timestamp: DateTime<Utc>,
portfolio: &Portfolio,
market_prices: &HashMap<String, Decimal>,
) {
let positions_value = portfolio
.positions
.values()
.map(|p| {
let price = market_prices
.get(&p.ticker)
.copied()
.unwrap_or(p.avg_entry_price);
(price * Decimal::from(p.quantity)).to_f64().unwrap_or(0.0)
})
.sum();
let cash = portfolio.cash.to_f64().unwrap_or(0.0);
let equity = cash + positions_value;
if equity > self.peak_equity {
self.peak_equity = equity;
}
let drawdown = (self.peak_equity - equity) / self.peak_equity;
if drawdown > self.max_drawdown {
self.max_drawdown = drawdown;
}
if self.last_equity > 0.0 {
let daily_return = (equity - self.last_equity) / self.last_equity;
self.daily_returns.push(daily_return);
}
self.last_equity = equity;
self.equity_curve.push(EquityPoint {
timestamp,
equity,
cash,
positions_value,
});
}
pub fn record_trade(&mut self, trade: &Trade, category: &str) {
match trade.trade_type {
TradeType::Open => {
let record = TradeRecord {
ticker: trade.ticker.clone(),
entry_time: trade.timestamp,
exit_time: None,
side: format!("{:?}", trade.side),
quantity: trade.quantity,
entry_price: trade.price.to_f64().unwrap_or(0.0),
exit_price: None,
pnl: None,
category: category.to_string(),
};
self.trade_records.insert(trade.ticker.clone(), record);
}
TradeType::Close | TradeType::Resolution => {
if let Some(mut record) = self.trade_records.remove(&trade.ticker) {
let exit_price = trade.price.to_f64().unwrap_or(0.0);
let entry_cost = record.entry_price * record.quantity as f64;
let exit_value = exit_price * record.quantity as f64;
let pnl = exit_value - entry_cost;
record.exit_time = Some(trade.timestamp);
record.exit_price = Some(exit_price);
record.pnl = Some(pnl);
self.closed_trades.push(record);
}
}
}
}
pub fn finalize(self) -> BacktestResult {
let initial = self.initial_capital.to_f64().unwrap_or(10000.0);
let final_equity = self.equity_curve.last().map(|e| e.equity).unwrap_or(initial);
let total_return = final_equity - initial;
let total_return_pct = total_return / initial * 100.0;
let sharpe_ratio = if self.daily_returns.len() > 1 {
let mean: f64 = self.daily_returns.iter().sum::<f64>() / self.daily_returns.len() as f64;
let variance: f64 = self
.daily_returns
.iter()
.map(|r| (r - mean).powi(2))
.sum::<f64>()
/ (self.daily_returns.len() - 1) as f64;
let std_dev = variance.sqrt();
if std_dev > 0.0 {
(mean / std_dev) * (252.0_f64).sqrt()
} else {
0.0
}
} else {
0.0
};
let winning_trades = self.closed_trades.iter().filter(|t| t.pnl.unwrap_or(0.0) > 0.0).count();
let losing_trades = self.closed_trades.iter().filter(|t| t.pnl.unwrap_or(0.0) < 0.0).count();
let total_trades = self.closed_trades.len();
let win_rate = if total_trades > 0 {
winning_trades as f64 / total_trades as f64 * 100.0
} else {
0.0
};
let avg_trade_pnl = if total_trades > 0 {
self.closed_trades.iter().filter_map(|t| t.pnl).sum::<f64>() / total_trades as f64
} else {
0.0
};
let avg_hold_time_hours = if total_trades > 0 {
self.closed_trades
.iter()
.filter_map(|t| {
t.exit_time.map(|exit| (exit - t.entry_time).num_hours() as f64)
})
.sum::<f64>()
/ total_trades as f64
} else {
0.0
};
let duration_days = if self.equity_curve.len() >= 2 {
let start = self.equity_curve.first().unwrap().timestamp;
let end = self.equity_curve.last().unwrap().timestamp;
(end - start).num_days().max(1) as f64
} else {
1.0
};
let trades_per_day = total_trades as f64 / duration_days;
let mut return_by_category: HashMap<String, f64> = HashMap::new();
for trade in &self.closed_trades {
*return_by_category.entry(trade.category.clone()).or_insert(0.0) +=
trade.pnl.unwrap_or(0.0);
}
BacktestResult {
total_return,
total_return_pct,
sharpe_ratio,
max_drawdown: self.max_drawdown * 100.0,
max_drawdown_pct: self.max_drawdown * 100.0,
win_rate,
total_trades,
winning_trades,
losing_trades,
avg_trade_pnl,
avg_hold_time_hours,
trades_per_day,
return_by_category,
equity_curve: self.equity_curve,
trade_log: self.closed_trades,
}
}
}
impl BacktestResult {
pub fn summary(&self) -> String {
format!(
r#"
backtest results
================
performance
-----------
total return: ${:.2} ({:.2}%)
sharpe ratio: {:.3}
max drawdown: {:.2}%
trades
------
total trades: {}
win rate: {:.1}%
avg trade pnl: ${:.2}
avg hold time: {:.1} hours
trades per day: {:.2}
by category
-----------
{}
"#,
self.total_return,
self.total_return_pct,
self.sharpe_ratio,
self.max_drawdown_pct,
self.total_trades,
self.win_rate,
self.avg_trade_pnl,
self.avg_hold_time_hours,
self.trades_per_day,
self.return_by_category
.iter()
.map(|(k, v)| format!(" {}: ${:.2}", k, v))
.collect::<Vec<_>>()
.join("\n")
)
}
}

View File

@ -0,0 +1,259 @@
//! Cross-market correlation scorer
//!
//! Uses lead-lag relationships between related markets to generate signals.
//! When a related market moves, we expect similar movements in correlated markets.
use crate::pipeline::Scorer;
use crate::types::{MarketCandidate, TradingContext};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
/// correlation entry between two markets
#[derive(Debug, Clone)]
pub struct CorrelationEntry {
pub ticker_a: String,
pub ticker_b: String,
pub correlation: f64,
pub lag_hours: i64,
}
/// cross-market correlation scorer
/// uses precomputed correlations to generate signals based on related market movements
pub struct CorrelationScorer {
correlations: Arc<RwLock<HashMap<String, Vec<CorrelationEntry>>>>,
lookback_hours: i64,
min_correlation: f64,
}
impl CorrelationScorer {
pub fn new(lookback_hours: i64, min_correlation: f64) -> Self {
Self {
correlations: Arc::new(RwLock::new(HashMap::new())),
lookback_hours,
min_correlation,
}
}
pub fn default_config() -> Self {
Self::new(24, 0.5)
}
pub fn load_correlations(&self, correlations: Vec<CorrelationEntry>) {
let mut map = self.correlations.write().unwrap();
map.clear();
for entry in correlations {
map.entry(entry.ticker_a.clone())
.or_insert_with(Vec::new)
.push(entry.clone());
map.entry(entry.ticker_b.clone())
.or_insert_with(Vec::new)
.push(CorrelationEntry {
ticker_a: entry.ticker_b.clone(),
ticker_b: entry.ticker_a.clone(),
correlation: entry.correlation,
lag_hours: -entry.lag_hours,
});
}
}
pub fn add_category_correlations(&self, categories: &[&str]) {
let mut entries = Vec::new();
for (i, cat_a) in categories.iter().enumerate() {
for cat_b in categories.iter().skip(i + 1) {
entries.push(CorrelationEntry {
ticker_a: cat_a.to_string(),
ticker_b: cat_b.to_string(),
correlation: 0.3,
lag_hours: 0,
});
}
}
self.load_correlations(entries);
}
fn get_related_signals(
&self,
ticker: &str,
all_candidates: &[MarketCandidate],
) -> f64 {
let correlations = self.correlations.read().unwrap();
let Some(related) = correlations.get(ticker) else {
return 0.0;
};
let candidate_map: HashMap<&str, &MarketCandidate> = all_candidates
.iter()
.map(|c| (c.ticker.as_str(), c))
.collect();
let mut weighted_signal = 0.0;
let mut total_weight = 0.0;
for entry in related {
if entry.correlation.abs() < self.min_correlation {
continue;
}
if let Some(related_candidate) = candidate_map.get(entry.ticker_b.as_str()) {
let related_momentum = related_candidate
.scores
.get("momentum")
.copied()
.unwrap_or(0.0);
let signal = related_momentum * entry.correlation;
weighted_signal += signal * entry.correlation.abs();
total_weight += entry.correlation.abs();
}
}
if total_weight > 0.0 {
weighted_signal / total_weight
} else {
0.0
}
}
fn calculate_category_correlation(
candidate: &MarketCandidate,
all_candidates: &[MarketCandidate],
) -> f64 {
let same_category: Vec<&MarketCandidate> = all_candidates
.iter()
.filter(|c| c.category == candidate.category && c.ticker != candidate.ticker)
.collect();
if same_category.is_empty() {
return 0.0;
}
let avg_momentum: f64 = same_category
.iter()
.filter_map(|c| c.scores.get("momentum").copied())
.sum::<f64>()
/ same_category.len() as f64;
avg_momentum * 0.3
}
}
#[async_trait]
impl Scorer for CorrelationScorer {
fn name(&self) -> &'static str {
"CorrelationScorer"
}
async fn score(
&self,
_context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let related_signal = self.get_related_signals(&c.ticker, candidates);
let category_signal = Self::calculate_category_correlation(c, candidates);
let combined = related_signal * 0.7 + category_signal * 0.3;
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("cross_market".to_string(), combined);
scored.scores.insert("related_signal".to_string(), related_signal);
scored.scores.insert("category_signal".to_string(), category_signal);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
for key in ["cross_market", "related_signal", "category_signal"] {
if let Some(score) = scored.scores.get(key) {
candidate.scores.insert(key.to_string(), *score);
}
}
}
}
/// granger causality calculator (simplified version)
/// full implementation would use statistical tests
pub fn calculate_granger_causality(
prices_a: &[f64],
prices_b: &[f64],
max_lag: i64,
) -> Option<(f64, i64)> {
if prices_a.len() < 20 || prices_b.len() < 20 {
return None;
}
let min_len = prices_a.len().min(prices_b.len());
let prices_a = &prices_a[..min_len];
let prices_b = &prices_b[..min_len];
let returns_a: Vec<f64> = prices_a
.windows(2)
.map(|w| if w[0] > 0.0 { (w[1] / w[0]).ln() } else { 0.0 })
.collect();
let returns_b: Vec<f64> = prices_b
.windows(2)
.map(|w| if w[0] > 0.0 { (w[1] / w[0]).ln() } else { 0.0 })
.collect();
let mut best_corr: f64 = 0.0;
let mut best_lag: i64 = 0;
for lag in -max_lag..=max_lag {
let (a_slice, b_slice) = if lag >= 0 {
let l = lag as usize;
if l >= returns_a.len() || l >= returns_b.len() {
continue;
}
(&returns_a[l..], &returns_b[..returns_b.len() - l])
} else {
let l = (-lag) as usize;
if l >= returns_a.len() || l >= returns_b.len() {
continue;
}
(&returns_a[..returns_a.len() - l], &returns_b[l..])
};
if a_slice.len() < 10 || b_slice.len() < 10 || a_slice.len() != b_slice.len() {
continue;
}
let n = a_slice.len() as f64;
let mean_a: f64 = a_slice.iter().sum::<f64>() / n;
let mean_b: f64 = b_slice.iter().sum::<f64>() / n;
let cov: f64 = a_slice
.iter()
.zip(b_slice.iter())
.map(|(a, b)| (a - mean_a) * (b - mean_b))
.sum::<f64>()
/ n;
let std_a: f64 = (a_slice.iter().map(|a| (a - mean_a).powi(2)).sum::<f64>() / n).sqrt();
let std_b: f64 = (b_slice.iter().map(|b| (b - mean_b).powi(2)).sum::<f64>() / n).sqrt();
if std_a > 0.0 && std_b > 0.0 {
let corr = cov / (std_a * std_b);
if corr.abs() > best_corr.abs() {
best_corr = corr;
best_lag = lag;
}
}
}
if best_corr.abs() > 0.1 {
Some((best_corr, best_lag))
} else {
None
}
}

182
src/pipeline/filters.rs Normal file
View File

@ -0,0 +1,182 @@
use crate::pipeline::{Filter, FilterResult};
use crate::types::{MarketCandidate, TradingContext};
use async_trait::async_trait;
use chrono::Duration;
use std::collections::HashSet;
pub struct LiquidityFilter {
min_volume_24h: u64,
}
impl LiquidityFilter {
pub fn new(min_volume_24h: u64) -> Self {
Self { min_volume_24h }
}
}
#[async_trait]
impl Filter for LiquidityFilter {
fn name(&self) -> &'static str {
"LiquidityFilter"
}
async fn filter(
&self,
_context: &TradingContext,
candidates: Vec<MarketCandidate>,
) -> Result<FilterResult, String> {
let (kept, removed): (Vec<_>, Vec<_>) = candidates
.into_iter()
.partition(|c| c.volume_24h >= self.min_volume_24h);
Ok(FilterResult { kept, removed })
}
}
pub struct TimeToCloseFilter {
min_hours: i64,
max_hours: Option<i64>,
}
impl TimeToCloseFilter {
pub fn new(min_hours: i64, max_hours: Option<i64>) -> Self {
Self { min_hours, max_hours }
}
}
#[async_trait]
impl Filter for TimeToCloseFilter {
fn name(&self) -> &'static str {
"TimeToCloseFilter"
}
async fn filter(
&self,
context: &TradingContext,
candidates: Vec<MarketCandidate>,
) -> Result<FilterResult, String> {
let min_duration = Duration::hours(self.min_hours);
let max_duration = self.max_hours.map(Duration::hours);
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
let ttc = c.time_to_close(context.timestamp);
let above_min = ttc >= min_duration;
let below_max = max_duration.map(|max| ttc <= max).unwrap_or(true);
above_min && below_max
});
Ok(FilterResult { kept, removed })
}
}
pub struct AlreadyPositionedFilter {
max_position_per_market: u64,
}
impl AlreadyPositionedFilter {
pub fn new(max_position_per_market: u64) -> Self {
Self {
max_position_per_market,
}
}
}
#[async_trait]
impl Filter for AlreadyPositionedFilter {
fn name(&self) -> &'static str {
"AlreadyPositionedFilter"
}
async fn filter(
&self,
context: &TradingContext,
candidates: Vec<MarketCandidate>,
) -> Result<FilterResult, String> {
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
context
.portfolio
.get_position(&c.ticker)
.map(|p| p.quantity < self.max_position_per_market)
.unwrap_or(true)
});
Ok(FilterResult { kept, removed })
}
}
pub struct CategoryFilter {
whitelist: Option<HashSet<String>>,
blacklist: HashSet<String>,
}
impl CategoryFilter {
pub fn whitelist(categories: Vec<String>) -> Self {
Self {
whitelist: Some(categories.into_iter().collect()),
blacklist: HashSet::new(),
}
}
pub fn blacklist(categories: Vec<String>) -> Self {
Self {
whitelist: None,
blacklist: categories.into_iter().collect(),
}
}
}
#[async_trait]
impl Filter for CategoryFilter {
fn name(&self) -> &'static str {
"CategoryFilter"
}
async fn filter(
&self,
_context: &TradingContext,
candidates: Vec<MarketCandidate>,
) -> Result<FilterResult, String> {
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
let in_whitelist = self
.whitelist
.as_ref()
.map(|w| w.contains(&c.category))
.unwrap_or(true);
let not_blacklisted = !self.blacklist.contains(&c.category);
in_whitelist && not_blacklisted
});
Ok(FilterResult { kept, removed })
}
}
pub struct PriceRangeFilter {
min_price: f64,
max_price: f64,
}
impl PriceRangeFilter {
pub fn new(min_price: f64, max_price: f64) -> Self {
Self { min_price, max_price }
}
}
#[async_trait]
impl Filter for PriceRangeFilter {
fn name(&self) -> &'static str {
"PriceRangeFilter"
}
async fn filter(
&self,
_context: &TradingContext,
candidates: Vec<MarketCandidate>,
) -> Result<FilterResult, String> {
let (kept, removed): (Vec<_>, Vec<_>) = candidates.into_iter().partition(|c| {
let price = c.current_yes_price.to_string().parse::<f64>().unwrap_or(0.5);
price >= self.min_price && price <= self.max_price
});
Ok(FilterResult { kept, removed })
}
}

272
src/pipeline/ml_scorer.rs Normal file
View File

@ -0,0 +1,272 @@
//! ML-based scorer using ONNX models
//!
//! This module provides ML-based scoring using pre-trained ONNX models.
//! Models are trained separately using the Python scripts in scripts/
//!
//! Enable with the `ml` feature:
//! ```toml
//! kalshi-backtest = { features = ["ml"] }
//! ```
use crate::pipeline::Scorer;
use crate::types::{MarketCandidate, TradingContext};
use async_trait::async_trait;
use std::path::Path;
#[cfg(feature = "ml")]
use {
ndarray::{Array1, Array2},
ort::{session::Session, value::Value},
std::sync::Arc,
};
/// ML ensemble scorer that combines multiple ONNX models
///
/// Models:
/// - LSTM: sequence model for price history
/// - MLP: feedforward on hand-crafted features
#[cfg(feature = "ml")]
pub struct MLEnsembleScorer {
lstm_session: Option<Arc<Session>>,
mlp_session: Option<Arc<Session>>,
ensemble_weights: Vec<f64>,
}
#[cfg(feature = "ml")]
impl MLEnsembleScorer {
pub fn new() -> Self {
Self {
lstm_session: None,
mlp_session: None,
ensemble_weights: vec![0.5, 0.5],
}
}
pub fn load_models(model_dir: &Path) -> Result<Self, String> {
let lstm_path = model_dir.join("lstm.onnx");
let mlp_path = model_dir.join("mlp.onnx");
let lstm_session = if lstm_path.exists() {
Some(Arc::new(
Session::builder()
.map_err(|e| format!("failed to create session builder: {}", e))?
.commit_from_file(&lstm_path)
.map_err(|e| format!("failed to load LSTM model: {}", e))?,
))
} else {
None
};
let mlp_session = if mlp_path.exists() {
Some(Arc::new(
Session::builder()
.map_err(|e| format!("failed to create session builder: {}", e))?
.commit_from_file(&mlp_path)
.map_err(|e| format!("failed to load MLP model: {}", e))?,
))
} else {
None
};
Ok(Self {
lstm_session,
mlp_session,
ensemble_weights: vec![0.5, 0.5],
})
}
pub fn with_weights(mut self, weights: Vec<f64>) -> Self {
self.ensemble_weights = weights;
self
}
fn extract_features(candidate: &MarketCandidate) -> Vec<f64> {
vec![
candidate.scores.get("momentum").copied().unwrap_or(0.0),
candidate.scores.get("mean_reversion").copied().unwrap_or(0.0),
candidate.scores.get("volume").copied().unwrap_or(0.0),
candidate.scores.get("time_decay").copied().unwrap_or(0.0),
candidate.scores.get("order_flow").copied().unwrap_or(0.0),
candidate.scores.get("bollinger_reversion").copied().unwrap_or(0.0),
candidate.scores.get("mtf_momentum").copied().unwrap_or(0.0),
]
}
fn extract_price_sequence(candidate: &MarketCandidate, max_len: usize) -> Vec<f64> {
use rust_decimal::prelude::ToPrimitive;
let prices: Vec<f64> = candidate
.price_history
.iter()
.rev()
.take(max_len)
.filter_map(|p| p.yes_price.to_f64())
.collect();
let mut sequence = vec![0.0; max_len];
for (i, &price) in prices.iter().enumerate() {
if i < max_len {
sequence[max_len - 1 - i] = price;
}
}
if prices.len() >= 2 {
let mut log_returns = Vec::with_capacity(max_len);
for i in 1..sequence.len() {
if sequence[i - 1] > 0.0 && sequence[i] > 0.0 {
log_returns.push((sequence[i] / sequence[i - 1]).ln());
} else {
log_returns.push(0.0);
}
}
log_returns.insert(0, 0.0);
log_returns
} else {
sequence
}
}
fn predict_lstm(&self, sequence: &[f64]) -> f64 {
let Some(session) = &self.lstm_session else {
return 0.0;
};
let input = Array2::from_shape_vec((1, sequence.len()), sequence.to_vec())
.expect("invalid shape");
match session.run(ort::inputs!["input" => input.view()]) {
Ok(outputs) => {
if let Some(output) = outputs.get("output") {
if let Ok(tensor) = output.try_extract_tensor::<f32>() {
return tensor.view().iter().next().copied().unwrap_or(0.0) as f64;
}
}
0.0
}
Err(_) => 0.0,
}
}
fn predict_mlp(&self, features: &[f64]) -> f64 {
let Some(session) = &self.mlp_session else {
return 0.0;
};
let input = Array1::from_vec(features.to_vec());
let input_2d = input.insert_axis(ndarray::Axis(0));
match session.run(ort::inputs!["input" => input_2d.view()]) {
Ok(outputs) => {
if let Some(output) = outputs.get("output") {
if let Ok(tensor) = output.try_extract_tensor::<f32>() {
return tensor.view().iter().next().copied().unwrap_or(0.0) as f64;
}
}
0.0
}
Err(_) => 0.0,
}
}
fn predict(&self, candidate: &MarketCandidate) -> f64 {
let sequence = Self::extract_price_sequence(candidate, 24);
let features = Self::extract_features(candidate);
let lstm_pred = self.predict_lstm(&sequence);
let mlp_pred = self.predict_mlp(&features);
let mut ensemble = 0.0;
let mut total_weight = 0.0;
if self.lstm_session.is_some() && self.ensemble_weights.len() > 0 {
ensemble += lstm_pred * self.ensemble_weights[0];
total_weight += self.ensemble_weights[0];
}
if self.mlp_session.is_some() && self.ensemble_weights.len() > 1 {
ensemble += mlp_pred * self.ensemble_weights[1];
total_weight += self.ensemble_weights[1];
}
if total_weight > 0.0 {
ensemble / total_weight
} else {
0.0
}
}
}
#[cfg(feature = "ml")]
#[async_trait]
impl Scorer for MLEnsembleScorer {
fn name(&self) -> &'static str {
"MLEnsembleScorer"
}
async fn score(
&self,
_context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let ml_score = self.predict(c);
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("ml_ensemble".to_string(), ml_score);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
if let Some(score) = scored.scores.get("ml_ensemble") {
candidate.scores.insert("ml_ensemble".to_string(), *score);
}
}
}
/// stub scorer when ML feature is disabled
#[cfg(not(feature = "ml"))]
pub struct MLEnsembleScorer;
#[cfg(not(feature = "ml"))]
impl MLEnsembleScorer {
pub fn new() -> Self {
Self
}
pub fn load_models(_model_dir: &Path) -> Result<Self, String> {
Ok(Self)
}
pub fn with_weights(self, _weights: Vec<f64>) -> Self {
self
}
}
#[cfg(not(feature = "ml"))]
#[async_trait]
impl Scorer for MLEnsembleScorer {
fn name(&self) -> &'static str {
"MLEnsembleScorer"
}
async fn score(
&self,
_context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
Ok(candidates.iter().map(|c| MarketCandidate {
scores: c.scores.clone(),
..Default::default()
}).collect())
}
fn update(&self, _candidate: &mut MarketCandidate, _scored: MarketCandidate) {}
}

230
src/pipeline/mod.rs Normal file
View File

@ -0,0 +1,230 @@
mod correlation_scorer;
mod filters;
mod ml_scorer;
mod scorers;
mod selector;
mod sources;
pub use correlation_scorer::*;
pub use filters::*;
pub use ml_scorer::*;
pub use scorers::*;
pub use selector::*;
pub use sources::*;
use crate::types::{MarketCandidate, TradingContext};
use async_trait::async_trait;
use std::sync::Arc;
use tracing::{error, info};
pub struct PipelineResult {
pub retrieved_candidates: Vec<MarketCandidate>,
pub filtered_candidates: Vec<MarketCandidate>,
pub selected_candidates: Vec<MarketCandidate>,
pub context: Arc<TradingContext>,
}
pub struct FilterResult {
pub kept: Vec<MarketCandidate>,
pub removed: Vec<MarketCandidate>,
}
#[async_trait]
pub trait Source: Send + Sync {
fn name(&self) -> &'static str;
fn enable(&self, _context: &TradingContext) -> bool {
true
}
async fn get_candidates(&self, context: &TradingContext) -> Result<Vec<MarketCandidate>, String>;
}
#[async_trait]
pub trait Filter: Send + Sync {
fn name(&self) -> &'static str;
fn enable(&self, _context: &TradingContext) -> bool {
true
}
async fn filter(
&self,
context: &TradingContext,
candidates: Vec<MarketCandidate>,
) -> Result<FilterResult, String>;
}
#[async_trait]
pub trait Scorer: Send + Sync {
fn name(&self) -> &'static str;
fn enable(&self, _context: &TradingContext) -> bool {
true
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String>;
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate);
fn update_all(&self, candidates: &mut [MarketCandidate], scored: Vec<MarketCandidate>) {
for (c, s) in candidates.iter_mut().zip(scored) {
self.update(c, s);
}
}
}
pub trait Selector: Send + Sync {
fn name(&self) -> &'static str;
fn enable(&self, _context: &TradingContext) -> bool {
true
}
fn select(&self, context: &TradingContext, candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate>;
}
pub struct TradingPipeline {
sources: Vec<Box<dyn Source>>,
filters: Vec<Box<dyn Filter>>,
scorers: Vec<Box<dyn Scorer>>,
selector: Box<dyn Selector>,
result_size: usize,
}
impl TradingPipeline {
pub fn new(
sources: Vec<Box<dyn Source>>,
filters: Vec<Box<dyn Filter>>,
scorers: Vec<Box<dyn Scorer>>,
selector: Box<dyn Selector>,
result_size: usize,
) -> Self {
Self {
sources,
filters,
scorers,
selector,
result_size,
}
}
pub async fn execute(&self, context: TradingContext) -> PipelineResult {
let request_id = context.request_id().to_string();
let candidates = self.fetch_candidates(&context).await;
info!(
request_id = %request_id,
candidates = candidates.len(),
"fetched candidates"
);
let (kept, filtered) = self.filter(&context, candidates.clone()).await;
info!(
request_id = %request_id,
kept = kept.len(),
filtered = filtered.len(),
"filtered candidates"
);
let scored = self.score(&context, kept).await;
let mut selected = self.select(&context, scored);
selected.truncate(self.result_size);
info!(
request_id = %request_id,
selected = selected.len(),
"selected candidates"
);
PipelineResult {
retrieved_candidates: candidates,
filtered_candidates: filtered,
selected_candidates: selected,
context: Arc::new(context),
}
}
async fn fetch_candidates(&self, context: &TradingContext) -> Vec<MarketCandidate> {
let mut all_candidates = Vec::new();
for source in self.sources.iter().filter(|s| s.enable(context)) {
match source.get_candidates(context).await {
Ok(mut candidates) => {
info!(
source = source.name(),
count = candidates.len(),
"source returned candidates"
);
all_candidates.append(&mut candidates);
}
Err(e) => {
error!(source = source.name(), error = %e, "source failed");
}
}
}
all_candidates
}
async fn filter(
&self,
context: &TradingContext,
mut candidates: Vec<MarketCandidate>,
) -> (Vec<MarketCandidate>, Vec<MarketCandidate>) {
let mut all_removed = Vec::new();
for filter in self.filters.iter().filter(|f| f.enable(context)) {
let backup = candidates.clone();
match filter.filter(context, candidates).await {
Ok(result) => {
info!(
filter = filter.name(),
kept = result.kept.len(),
removed = result.removed.len(),
"filter applied"
);
candidates = result.kept;
all_removed.extend(result.removed);
}
Err(e) => {
error!(filter = filter.name(), error = %e, "filter failed");
candidates = backup;
}
}
}
(candidates, all_removed)
}
async fn score(&self, context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
let expected_len = candidates.len();
for scorer in self.scorers.iter().filter(|s| s.enable(context)) {
match scorer.score(context, &candidates).await {
Ok(scored) => {
if scored.len() == expected_len {
scorer.update_all(&mut candidates, scored);
} else {
error!(
scorer = scorer.name(),
expected = expected_len,
got = scored.len(),
"scorer returned wrong number of candidates"
);
}
}
Err(e) => {
error!(scorer = scorer.name(), error = %e, "scorer failed");
}
}
}
candidates
}
fn select(&self, context: &TradingContext, candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
if self.selector.enable(context) {
self.selector.select(context, candidates)
} else {
candidates
}
}
}

978
src/pipeline/scorers.rs Normal file
View File

@ -0,0 +1,978 @@
use crate::pipeline::Scorer;
use crate::types::{MarketCandidate, TradingContext};
use async_trait::async_trait;
use rust_decimal::prelude::ToPrimitive;
use std::sync::{Arc, Mutex};
/// rolling statistics for z-score normalization
/// tracks mean and std deviation over a sliding window
#[derive(Debug, Clone)]
pub struct RollingStats {
values: Vec<f64>,
max_size: usize,
sum: f64,
sum_sq: f64,
}
impl RollingStats {
pub fn new(max_size: usize) -> Self {
Self {
values: Vec::with_capacity(max_size),
max_size,
sum: 0.0,
sum_sq: 0.0,
}
}
pub fn push(&mut self, value: f64) {
if !value.is_finite() {
return;
}
if self.values.len() >= self.max_size {
let old = self.values.remove(0);
self.sum -= old;
self.sum_sq -= old * old;
}
self.values.push(value);
self.sum += value;
self.sum_sq += value * value;
}
pub fn push_batch(&mut self, values: &[f64]) {
for &v in values {
self.push(v);
}
}
pub fn mean(&self) -> f64 {
if self.values.is_empty() {
0.0
} else {
self.sum / self.values.len() as f64
}
}
pub fn std(&self) -> f64 {
let n = self.values.len();
if n < 2 {
return 1.0;
}
let mean = self.mean();
let variance = (self.sum_sq / n as f64) - (mean * mean);
variance.max(0.0).sqrt()
}
pub fn normalize(&self, value: f64) -> f64 {
let std = self.std().max(0.001);
(value - self.mean()) / std
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_ready(&self) -> bool {
self.values.len() >= self.max_size / 4
}
}
/// wrapper that normalizes any scorer's output to z-scores
pub struct NormalizedScorer<S> {
inner: S,
score_key: String,
stats: Arc<Mutex<RollingStats>>,
}
impl<S> NormalizedScorer<S> {
pub fn new(inner: S, score_key: &str, history_size: usize) -> Self {
Self {
inner,
score_key: score_key.to_string(),
stats: Arc::new(Mutex::new(RollingStats::new(history_size))),
}
}
}
#[async_trait]
impl<S: Scorer + Send + Sync> Scorer for NormalizedScorer<S> {
fn name(&self) -> &'static str {
self.inner.name()
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let raw_scored = self.inner.score(context, candidates).await?;
let raw_scores: Vec<f64> = raw_scored
.iter()
.filter_map(|c| c.scores.get(&self.score_key).copied())
.collect();
{
let mut stats = self.stats.lock().unwrap();
stats.push_batch(&raw_scores);
}
let stats = self.stats.lock().unwrap();
let normalized = raw_scored
.into_iter()
.map(|mut c| {
if let Some(&raw) = c.scores.get(&self.score_key) {
let z = if stats.is_ready() {
stats.normalize(raw)
} else {
raw
};
c.scores.insert(self.score_key.clone(), z);
}
c
})
.collect();
Ok(normalized)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
self.inner.update(candidate, scored);
}
}
pub struct MomentumScorer {
lookback_hours: i64,
}
impl MomentumScorer {
pub fn new(lookback_hours: i64) -> Self {
Self { lookback_hours }
}
fn calculate_momentum(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
let lookback_start = now - chrono::Duration::hours(lookback_hours);
let relevant_history: Vec<_> = candidate
.price_history
.iter()
.filter(|p| p.timestamp >= lookback_start)
.collect();
if relevant_history.len() < 2 {
return 0.0;
}
let first = relevant_history.first().unwrap().yes_price.to_f64().unwrap_or(0.5);
let last = relevant_history.last().unwrap().yes_price.to_f64().unwrap_or(0.5);
last - first
}
}
#[async_trait]
impl Scorer for MomentumScorer {
fn name(&self) -> &'static str {
"MomentumScorer"
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let momentum = Self::calculate_momentum(c, context.timestamp, self.lookback_hours);
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("momentum".to_string(), momentum);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
if let Some(score) = scored.scores.get("momentum") {
candidate.scores.insert("momentum".to_string(), *score);
}
}
}
/// multi-timeframe momentum scorer
/// looks at multiple windows and detects divergence between short and long term
pub struct MultiTimeframeMomentumScorer {
windows: Vec<i64>,
}
impl MultiTimeframeMomentumScorer {
pub fn new(windows: Vec<i64>) -> Self {
Self { windows }
}
pub fn default_windows() -> Self {
Self::new(vec![1, 4, 12, 24])
}
fn calculate_momentum_for_window(
candidate: &MarketCandidate,
now: chrono::DateTime<chrono::Utc>,
hours: i64,
) -> f64 {
let lookback_start = now - chrono::Duration::hours(hours);
let relevant_history: Vec<_> = candidate
.price_history
.iter()
.filter(|p| p.timestamp >= lookback_start)
.collect();
if relevant_history.len() < 2 {
return 0.0;
}
let first = relevant_history.first().unwrap().yes_price.to_f64().unwrap_or(0.5);
let last = relevant_history.last().unwrap().yes_price.to_f64().unwrap_or(0.5);
last - first
}
fn calculate_score(
candidate: &MarketCandidate,
now: chrono::DateTime<chrono::Utc>,
windows: &[i64],
) -> (f64, f64, f64) {
let momentums: Vec<f64> = windows
.iter()
.map(|&w| Self::calculate_momentum_for_window(candidate, now, w))
.collect();
if momentums.is_empty() {
return (0.0, 0.0, 1.0);
}
let avg_momentum = momentums.iter().sum::<f64>() / momentums.len() as f64;
let signs: Vec<i32> = momentums.iter().map(|&m| if m > 0.0 { 1 } else if m < 0.0 { -1 } else { 0 }).collect();
let all_same_sign = signs.iter().all(|&s| s == signs[0]) && signs[0] != 0;
let alignment = if all_same_sign { 1.0 } else { 0.5 };
let short_avg = if momentums.len() >= 2 {
momentums[..momentums.len() / 2].iter().sum::<f64>() / (momentums.len() / 2) as f64
} else {
momentums[0]
};
let long_avg = if momentums.len() >= 2 {
momentums[momentums.len() / 2..].iter().sum::<f64>() / (momentums.len() - momentums.len() / 2) as f64
} else {
momentums[0]
};
let divergence = if (short_avg > 0.0 && long_avg < 0.0) || (short_avg < 0.0 && long_avg > 0.0) {
(short_avg - long_avg).abs()
} else {
0.0
};
(avg_momentum * alignment, divergence, alignment)
}
}
#[async_trait]
impl Scorer for MultiTimeframeMomentumScorer {
fn name(&self) -> &'static str {
"MultiTimeframeMomentumScorer"
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let (momentum, divergence, alignment) = Self::calculate_score(c, context.timestamp, &self.windows);
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("mtf_momentum".to_string(), momentum);
scored.scores.insert("mtf_divergence".to_string(), divergence);
scored.scores.insert("mtf_alignment".to_string(), alignment);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
for key in ["mtf_momentum", "mtf_divergence", "mtf_alignment"] {
if let Some(score) = scored.scores.get(key) {
candidate.scores.insert(key.to_string(), *score);
}
}
}
}
pub struct MeanReversionScorer {
lookback_hours: i64,
}
impl MeanReversionScorer {
pub fn new(lookback_hours: i64) -> Self {
Self { lookback_hours }
}
fn calculate_deviation(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
let lookback_start = now - chrono::Duration::hours(lookback_hours);
let prices: Vec<f64> = candidate
.price_history
.iter()
.filter(|p| p.timestamp >= lookback_start)
.filter_map(|p| p.yes_price.to_f64())
.collect();
if prices.is_empty() {
return 0.0;
}
let mean: f64 = prices.iter().sum::<f64>() / prices.len() as f64;
let current = candidate.current_yes_price.to_f64().unwrap_or(0.5);
let deviation = current - mean;
-deviation
}
}
#[async_trait]
impl Scorer for MeanReversionScorer {
fn name(&self) -> &'static str {
"MeanReversionScorer"
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let reversion = Self::calculate_deviation(c, context.timestamp, self.lookback_hours);
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("mean_reversion".to_string(), reversion);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
if let Some(score) = scored.scores.get("mean_reversion") {
candidate.scores.insert("mean_reversion".to_string(), *score);
}
}
}
/// bollinger bands mean reversion scorer
/// triggers when price touches statistical extremes (upper/lower bands)
pub struct BollingerMeanReversionScorer {
lookback_hours: i64,
num_std: f64,
}
impl BollingerMeanReversionScorer {
pub fn new(lookback_hours: i64, num_std: f64) -> Self {
Self { lookback_hours, num_std }
}
pub fn default_config() -> Self {
Self::new(24, 2.0)
}
fn calculate_bands(
candidate: &MarketCandidate,
now: chrono::DateTime<chrono::Utc>,
lookback_hours: i64,
) -> Option<(f64, f64, f64)> {
let lookback_start = now - chrono::Duration::hours(lookback_hours);
let prices: Vec<f64> = candidate
.price_history
.iter()
.filter(|p| p.timestamp >= lookback_start)
.filter_map(|p| p.yes_price.to_f64())
.collect();
if prices.len() < 5 {
return None;
}
let mean: f64 = prices.iter().sum::<f64>() / prices.len() as f64;
let variance: f64 = prices.iter().map(|p| (p - mean).powi(2)).sum::<f64>() / prices.len() as f64;
let std = variance.sqrt();
Some((mean, std, *prices.last().unwrap_or(&mean)))
}
fn calculate_score(
candidate: &MarketCandidate,
now: chrono::DateTime<chrono::Utc>,
lookback_hours: i64,
num_std: f64,
) -> (f64, f64) {
let (mean, std, current) = match Self::calculate_bands(candidate, now, lookback_hours) {
Some(v) => v,
None => return (0.0, 0.0),
};
let upper_band = mean + num_std * std;
let lower_band = mean - num_std * std;
let band_width = upper_band - lower_band;
if band_width < 0.001 {
return (0.0, 0.0);
}
let position = (current - lower_band) / band_width;
let score = if current >= upper_band {
-(current - upper_band) / std.max(0.001)
} else if current <= lower_band {
(lower_band - current) / std.max(0.001)
} else if current > mean {
-(position - 0.5) * 0.5
} else {
(0.5 - position) * 0.5
};
let band_position = (position * 2.0 - 1.0).clamp(-1.0, 1.0);
(score.clamp(-2.0, 2.0), band_position)
}
}
#[async_trait]
impl Scorer for BollingerMeanReversionScorer {
fn name(&self) -> &'static str {
"BollingerMeanReversionScorer"
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let (score, band_pos) = Self::calculate_score(c, context.timestamp, self.lookback_hours, self.num_std);
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("bollinger_reversion".to_string(), score);
scored.scores.insert("bollinger_position".to_string(), band_pos);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
for key in ["bollinger_reversion", "bollinger_position"] {
if let Some(score) = scored.scores.get(key) {
candidate.scores.insert(key.to_string(), *score);
}
}
}
}
pub struct VolumeScorer {
lookback_hours: i64,
}
impl VolumeScorer {
pub fn new(lookback_hours: i64) -> Self {
Self { lookback_hours }
}
fn calculate_volume_score(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>, lookback_hours: i64) -> f64 {
let lookback_start = now - chrono::Duration::hours(lookback_hours);
let recent_volume: u64 = candidate
.price_history
.iter()
.filter(|p| p.timestamp >= lookback_start)
.map(|p| p.volume)
.sum();
if candidate.total_volume == 0 {
return 0.0;
}
let avg_hourly_volume = candidate.total_volume as f64
/ ((now - candidate.open_time).num_hours().max(1) as f64);
let recent_hourly_volume = recent_volume as f64 / lookback_hours.max(1) as f64;
if avg_hourly_volume > 0.0 {
(recent_hourly_volume / avg_hourly_volume).ln().max(-2.0).min(2.0)
} else {
0.0
}
}
}
#[async_trait]
impl Scorer for VolumeScorer {
fn name(&self) -> &'static str {
"VolumeScorer"
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let volume = Self::calculate_volume_score(c, context.timestamp, self.lookback_hours);
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("volume".to_string(), volume);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
if let Some(score) = scored.scores.get("volume") {
candidate.scores.insert("volume".to_string(), *score);
}
}
}
pub struct TimeDecayScorer;
impl TimeDecayScorer {
pub fn new() -> Self {
Self
}
fn calculate_time_decay(candidate: &MarketCandidate, now: chrono::DateTime<chrono::Utc>) -> f64 {
let ttc = candidate.time_to_close(now);
let hours_remaining = ttc.num_hours() as f64;
if hours_remaining <= 0.0 {
return -1.0;
}
let decay = 1.0 - (1.0 / (hours_remaining / 24.0 + 1.0));
decay.min(1.0).max(0.0)
}
}
impl Default for TimeDecayScorer {
fn default() -> Self {
Self::new()
}
}
/// order flow imbalance scorer
/// measures buying vs selling pressure using taker_side from trades
pub struct OrderFlowScorer;
impl OrderFlowScorer {
pub fn new() -> Self {
Self
}
fn calculate_imbalance(candidate: &MarketCandidate) -> f64 {
let buy_vol = candidate.buy_volume_24h as f64;
let sell_vol = candidate.sell_volume_24h as f64;
let total = buy_vol + sell_vol;
if total == 0.0 {
return 0.0;
}
(buy_vol - sell_vol) / total
}
}
impl Default for OrderFlowScorer {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Scorer for OrderFlowScorer {
fn name(&self) -> &'static str {
"OrderFlowScorer"
}
async fn score(
&self,
_context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let imbalance = Self::calculate_imbalance(c);
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("order_flow".to_string(), imbalance);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
if let Some(score) = scored.scores.get("order_flow") {
candidate.scores.insert("order_flow".to_string(), *score);
}
}
}
#[async_trait]
impl Scorer for TimeDecayScorer {
fn name(&self) -> &'static str {
"TimeDecayScorer"
}
async fn score(
&self,
context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let time_decay = Self::calculate_time_decay(c, context.timestamp);
let mut scored = MarketCandidate {
scores: c.scores.clone(),
..Default::default()
};
scored.scores.insert("time_decay".to_string(), time_decay);
scored
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
if let Some(score) = scored.scores.get("time_decay") {
candidate.scores.insert("time_decay".to_string(), *score);
}
}
}
pub struct WeightedScorer {
weights: Vec<(String, f64)>,
}
impl WeightedScorer {
pub fn new(weights: Vec<(String, f64)>) -> Self {
Self { weights }
}
pub fn default_weights() -> Self {
Self::new(vec![
("momentum".to_string(), 0.4),
("mean_reversion".to_string(), 0.3),
("volume".to_string(), 0.2),
("time_decay".to_string(), 0.1),
])
}
fn compute_weighted_score(&self, candidate: &MarketCandidate) -> f64 {
self.weights
.iter()
.map(|(name, weight)| {
candidate.scores.get(name).copied().unwrap_or(0.0) * weight
})
.sum()
}
}
#[async_trait]
impl Scorer for WeightedScorer {
fn name(&self) -> &'static str {
"WeightedScorer"
}
async fn score(
&self,
_context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let weighted_score = self.compute_weighted_score(c);
MarketCandidate {
final_score: weighted_score,
..Default::default()
}
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
candidate.final_score = scored.final_score;
}
}
#[derive(Clone)]
pub struct ScorerWeights {
pub momentum: f64,
pub mean_reversion: f64,
pub volume: f64,
pub time_decay: f64,
pub order_flow: f64,
pub bollinger: f64,
pub mtf_momentum: f64,
}
impl Default for ScorerWeights {
fn default() -> Self {
Self {
momentum: 0.2,
mean_reversion: 0.2,
volume: 0.15,
time_decay: 0.1,
order_flow: 0.15,
bollinger: 0.1,
mtf_momentum: 0.1,
}
}
}
impl ScorerWeights {
pub fn politics() -> Self {
Self {
momentum: 0.35,
mean_reversion: 0.1,
volume: 0.1,
time_decay: 0.1,
order_flow: 0.15,
bollinger: 0.05,
mtf_momentum: 0.15,
}
}
pub fn weather() -> Self {
Self {
momentum: 0.1,
mean_reversion: 0.35,
volume: 0.1,
time_decay: 0.15,
order_flow: 0.1,
bollinger: 0.15,
mtf_momentum: 0.05,
}
}
pub fn sports() -> Self {
Self {
momentum: 0.2,
mean_reversion: 0.1,
volume: 0.15,
time_decay: 0.1,
order_flow: 0.3,
bollinger: 0.05,
mtf_momentum: 0.1,
}
}
pub fn economics() -> Self {
Self {
momentum: 0.25,
mean_reversion: 0.2,
volume: 0.15,
time_decay: 0.1,
order_flow: 0.15,
bollinger: 0.1,
mtf_momentum: 0.05,
}
}
fn compute_score(&self, candidate: &MarketCandidate) -> f64 {
let get_score = |key: &str| candidate.scores.get(key).copied().unwrap_or(0.0);
self.momentum * get_score("momentum")
+ self.mean_reversion * get_score("mean_reversion")
+ self.volume * get_score("volume")
+ self.time_decay * get_score("time_decay")
+ self.order_flow * get_score("order_flow")
+ self.bollinger * get_score("bollinger_reversion")
+ self.mtf_momentum * get_score("mtf_momentum")
}
}
/// category-aware weighted scorer
/// applies different weights based on market category
pub struct CategoryWeightedScorer {
category_weights: std::collections::HashMap<String, ScorerWeights>,
default_weights: ScorerWeights,
}
impl CategoryWeightedScorer {
pub fn new(
category_weights: std::collections::HashMap<String, ScorerWeights>,
default_weights: ScorerWeights,
) -> Self {
Self { category_weights, default_weights }
}
pub fn with_defaults() -> Self {
let mut category_weights = std::collections::HashMap::new();
category_weights.insert("politics".to_string(), ScorerWeights::politics());
category_weights.insert("weather".to_string(), ScorerWeights::weather());
category_weights.insert("sports".to_string(), ScorerWeights::sports());
category_weights.insert("economics".to_string(), ScorerWeights::economics());
category_weights.insert("financial".to_string(), ScorerWeights::economics());
Self {
category_weights,
default_weights: ScorerWeights::default(),
}
}
fn get_weights(&self, category: &str) -> &ScorerWeights {
let lower = category.to_lowercase();
self.category_weights.get(&lower).unwrap_or(&self.default_weights)
}
}
#[async_trait]
impl Scorer for CategoryWeightedScorer {
fn name(&self) -> &'static str {
"CategoryWeightedScorer"
}
async fn score(
&self,
_context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let weights = self.get_weights(&c.category);
let weighted_score = weights.compute_score(c);
MarketCandidate {
final_score: weighted_score,
..Default::default()
}
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
candidate.final_score = scored.final_score;
}
}
/// ensemble scorer that combines multiple models with dynamic weighting
/// weights can be updated based on recent accuracy
pub struct EnsembleScorer {
model_weights: std::sync::Arc<std::sync::Mutex<Vec<f64>>>,
model_keys: Vec<String>,
}
impl EnsembleScorer {
pub fn new(model_keys: Vec<String>, initial_weights: Vec<f64>) -> Self {
assert_eq!(model_keys.len(), initial_weights.len());
Self {
model_weights: std::sync::Arc::new(std::sync::Mutex::new(initial_weights)),
model_keys,
}
}
pub fn default_ensemble() -> Self {
Self::new(
vec![
"momentum".to_string(),
"mean_reversion".to_string(),
"bollinger_reversion".to_string(),
"order_flow".to_string(),
"mtf_momentum".to_string(),
],
vec![0.25, 0.2, 0.2, 0.2, 0.15],
)
}
pub fn update_weights(&self, new_weights: Vec<f64>) {
let mut weights = self.model_weights.lock().unwrap();
*weights = new_weights;
}
fn compute_score(&self, candidate: &MarketCandidate) -> f64 {
let weights = self.model_weights.lock().unwrap();
self.model_keys
.iter()
.zip(weights.iter())
.map(|(key, weight)| {
candidate.scores.get(key).copied().unwrap_or(0.0) * weight
})
.sum()
}
}
#[async_trait]
impl Scorer for EnsembleScorer {
fn name(&self) -> &'static str {
"EnsembleScorer"
}
async fn score(
&self,
_context: &TradingContext,
candidates: &[MarketCandidate],
) -> Result<Vec<MarketCandidate>, String> {
let scored = candidates
.iter()
.map(|c| {
let ensemble_score = self.compute_score(c);
MarketCandidate {
final_score: ensemble_score,
..Default::default()
}
})
.collect();
Ok(scored)
}
fn update(&self, candidate: &mut MarketCandidate, scored: MarketCandidate) {
candidate.final_score = scored.final_score;
}
}

64
src/pipeline/selector.rs Normal file
View File

@ -0,0 +1,64 @@
use crate::pipeline::Selector;
use crate::types::{MarketCandidate, TradingContext};
pub struct TopKSelector {
k: usize,
}
impl TopKSelector {
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl Selector for TopKSelector {
fn name(&self) -> &'static str {
"TopKSelector"
}
fn select(&self, _context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
candidates.sort_by(|a, b| {
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(self.k);
candidates
}
}
pub struct ThresholdSelector {
min_score: f64,
max_candidates: Option<usize>,
}
impl ThresholdSelector {
pub fn new(min_score: f64, max_candidates: Option<usize>) -> Self {
Self {
min_score,
max_candidates,
}
}
}
impl Selector for ThresholdSelector {
fn name(&self) -> &'static str {
"ThresholdSelector"
}
fn select(&self, _context: &TradingContext, mut candidates: Vec<MarketCandidate>) -> Vec<MarketCandidate> {
candidates.retain(|c| c.final_score >= self.min_score);
candidates.sort_by(|a, b| {
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if let Some(max) = self.max_candidates {
candidates.truncate(max);
}
candidates
}
}

69
src/pipeline/sources.rs Normal file
View File

@ -0,0 +1,69 @@
use crate::data::HistoricalData;
use crate::pipeline::Source;
use crate::types::{MarketCandidate, TradingContext};
use async_trait::async_trait;
use rust_decimal::Decimal;
use std::collections::HashMap;
use std::sync::Arc;
pub struct HistoricalMarketSource {
data: Arc<HistoricalData>,
lookback_hours: i64,
}
impl HistoricalMarketSource {
pub fn new(data: Arc<HistoricalData>, lookback_hours: i64) -> Self {
Self { data, lookback_hours }
}
}
#[async_trait]
impl Source for HistoricalMarketSource {
fn name(&self) -> &'static str {
"HistoricalMarketSource"
}
async fn get_candidates(&self, context: &TradingContext) -> Result<Vec<MarketCandidate>, String> {
let now = context.timestamp;
let active_markets = self.data.get_active_markets(now);
let candidates: Vec<MarketCandidate> = active_markets
.into_iter()
.filter_map(|market| {
let current_price = self.data.get_current_price(&market.ticker, now)?;
let lookback_start = now - chrono::Duration::hours(self.lookback_hours);
let price_history = self.data.get_price_history(&market.ticker, lookback_start, now);
let volume_24h = self.data.get_volume_24h(&market.ticker, now);
let total_volume: u64 = self
.data
.get_trades_for_market(&market.ticker, market.open_time, now)
.iter()
.map(|t| t.volume)
.sum();
let (buy_volume_24h, sell_volume_24h) = self.data.get_order_flow_24h(&market.ticker, now);
Some(MarketCandidate {
ticker: market.ticker.clone(),
title: market.title.clone(),
category: market.category.clone(),
current_yes_price: current_price,
current_no_price: Decimal::ONE - current_price,
volume_24h,
total_volume,
buy_volume_24h,
sell_volume_24h,
open_time: market.open_time,
close_time: market.close_time,
result: market.result,
price_history,
scores: HashMap::new(),
final_score: 0.0,
})
})
.collect();
Ok(candidates)
}
}

343
src/types.rs Normal file
View File

@ -0,0 +1,343 @@
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Side {
Yes,
No,
}
impl Side {
pub fn opposite(&self) -> Self {
match self {
Side::Yes => Side::No,
Side::No => Side::Yes,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MarketResult {
Yes,
No,
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MarketCandidate {
pub ticker: String,
pub title: String,
pub category: String,
pub current_yes_price: Decimal,
pub current_no_price: Decimal,
pub volume_24h: u64,
pub total_volume: u64,
pub buy_volume_24h: u64,
pub sell_volume_24h: u64,
pub open_time: DateTime<Utc>,
pub close_time: DateTime<Utc>,
pub result: Option<MarketResult>,
pub price_history: Vec<PricePoint>,
pub scores: HashMap<String, f64>,
pub final_score: f64,
}
impl MarketCandidate {
pub fn time_to_close(&self, now: DateTime<Utc>) -> chrono::Duration {
self.close_time - now
}
pub fn is_open(&self, now: DateTime<Utc>) -> bool {
now >= self.open_time && now < self.close_time
}
}
impl Default for MarketCandidate {
fn default() -> Self {
Self {
ticker: String::new(),
title: String::new(),
category: String::new(),
current_yes_price: Decimal::ZERO,
current_no_price: Decimal::ZERO,
volume_24h: 0,
total_volume: 0,
buy_volume_24h: 0,
sell_volume_24h: 0,
open_time: Utc::now(),
close_time: Utc::now(),
result: None,
price_history: Vec::new(),
scores: HashMap::new(),
final_score: 0.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricePoint {
pub timestamp: DateTime<Utc>,
pub yes_price: Decimal,
pub volume: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TradingContext {
pub request_id: String,
pub timestamp: DateTime<Utc>,
pub portfolio: Portfolio,
pub trading_history: Vec<Trade>,
}
impl TradingContext {
pub fn new(initial_capital: Decimal, start_time: DateTime<Utc>) -> Self {
Self {
request_id: uuid::Uuid::new_v4().to_string(),
timestamp: start_time,
portfolio: Portfolio::new(initial_capital),
trading_history: Vec::new(),
}
}
pub fn request_id(&self) -> &str {
&self.request_id
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Portfolio {
pub positions: HashMap<String, Position>,
pub cash: Decimal,
pub initial_capital: Decimal,
}
impl Portfolio {
pub fn new(initial_capital: Decimal) -> Self {
Self {
positions: HashMap::new(),
cash: initial_capital,
initial_capital,
}
}
pub fn total_value(&self, market_prices: &HashMap<String, Decimal>) -> Decimal {
let position_value: Decimal = self
.positions
.values()
.map(|p| {
let price = market_prices.get(&p.ticker).copied().unwrap_or(p.avg_entry_price);
let effective_price = match p.side {
Side::Yes => price,
Side::No => Decimal::ONE - price,
};
effective_price * Decimal::from(p.quantity)
})
.sum();
self.cash + position_value
}
pub fn has_position(&self, ticker: &str) -> bool {
self.positions.contains_key(ticker)
}
pub fn get_position(&self, ticker: &str) -> Option<&Position> {
self.positions.get(ticker)
}
pub fn apply_fill(&mut self, fill: &Fill) {
let cost = fill.price * Decimal::from(fill.quantity);
match fill.side {
Side::Yes | Side::No => {
self.cash -= cost;
let position = self.positions.entry(fill.ticker.clone()).or_insert_with(|| {
Position {
ticker: fill.ticker.clone(),
side: fill.side,
quantity: 0,
avg_entry_price: Decimal::ZERO,
entry_time: fill.timestamp,
}
});
let total_cost =
position.avg_entry_price * Decimal::from(position.quantity) + cost;
position.quantity += fill.quantity;
if position.quantity > 0 {
position.avg_entry_price = total_cost / Decimal::from(position.quantity);
}
}
}
}
pub fn resolve_position(&mut self, ticker: &str, result: MarketResult) -> Option<Decimal> {
let position = self.positions.remove(ticker)?;
let payout = match (result, position.side) {
(MarketResult::Yes, Side::Yes) | (MarketResult::No, Side::No) => {
Decimal::from(position.quantity)
}
(MarketResult::Cancelled, _) => {
position.avg_entry_price * Decimal::from(position.quantity)
}
_ => Decimal::ZERO,
};
self.cash += payout;
let cost = position.avg_entry_price * Decimal::from(position.quantity);
Some(payout - cost)
}
pub fn close_position(&mut self, ticker: &str, exit_price: Decimal) -> Option<Decimal> {
let position = self.positions.remove(ticker)?;
let effective_exit_price = match position.side {
Side::Yes => exit_price,
Side::No => Decimal::ONE - exit_price,
};
let exit_value = effective_exit_price * Decimal::from(position.quantity);
self.cash += exit_value;
let cost = position.avg_entry_price * Decimal::from(position.quantity);
Some(exit_value - cost)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Position {
pub ticker: String,
pub side: Side,
pub quantity: u64,
pub avg_entry_price: Decimal,
pub entry_time: DateTime<Utc>,
}
impl Position {
pub fn cost_basis(&self) -> Decimal {
self.avg_entry_price * Decimal::from(self.quantity)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Trade {
pub ticker: String,
pub side: Side,
pub quantity: u64,
pub price: Decimal,
pub timestamp: DateTime<Utc>,
pub trade_type: TradeType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TradeType {
Open,
Close,
Resolution,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum ExitReason {
Resolution(MarketResult),
TakeProfit { pnl_pct: f64 },
StopLoss { pnl_pct: f64 },
TimeStop { hours_held: i64 },
ScoreReversal { new_score: f64 },
}
#[derive(Debug, Clone)]
pub struct ExitSignal {
pub ticker: String,
pub reason: ExitReason,
pub current_price: Decimal,
}
#[derive(Debug, Clone)]
pub struct ExitConfig {
pub take_profit_pct: f64,
pub stop_loss_pct: f64,
pub max_hold_hours: i64,
pub score_reversal_threshold: f64,
}
impl Default for ExitConfig {
fn default() -> Self {
Self {
take_profit_pct: 0.20,
stop_loss_pct: 0.15,
max_hold_hours: 72,
score_reversal_threshold: -0.3,
}
}
}
impl ExitConfig {
pub fn conservative() -> Self {
Self {
take_profit_pct: 0.15,
stop_loss_pct: 0.10,
max_hold_hours: 48,
score_reversal_threshold: -0.2,
}
}
pub fn aggressive() -> Self {
Self {
take_profit_pct: 0.30,
stop_loss_pct: 0.20,
max_hold_hours: 120,
score_reversal_threshold: -0.5,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Fill {
pub ticker: String,
pub side: Side,
pub quantity: u64,
pub price: Decimal,
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Signal {
pub ticker: String,
pub side: Side,
pub quantity: u64,
pub limit_price: Option<Decimal>,
pub reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MarketData {
pub ticker: String,
pub title: String,
pub category: String,
pub open_time: DateTime<Utc>,
pub close_time: DateTime<Utc>,
pub result: Option<MarketResult>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TradeData {
pub timestamp: DateTime<Utc>,
pub ticker: String,
pub price: Decimal,
pub volume: u64,
pub taker_side: Side,
}
#[derive(Debug, Clone)]
pub struct BacktestConfig {
pub start_time: DateTime<Utc>,
pub end_time: DateTime<Utc>,
pub interval: chrono::Duration,
pub initial_capital: Decimal,
pub max_position_size: u64,
pub max_positions: usize,
}