Compare commits

..

2 Commits

Author SHA1 Message Date
5dc05ba185 backup for jake 2026-01-25 01:20:44 -07:00
3621d93643 feat(backtest): optimize exit strategy and position sizing
6 iterations of backtest refinements with key discoveries:
- stop losses don't work for prediction markets (prices gap)
- 50% take profit, no stop loss yields +9.37% vs +4.04% baseline
- diversification beats concentration: 100 positions → +18.98%
- added kalman filter, VPIN, regime detection scorers (research)

exit config: take_profit 50%, stop_loss disabled, 48h max hold
position sizing: kelly 0.40, max 30% per position, 100 max positions
2026-01-22 11:16:23 -07:00
28 changed files with 6210 additions and 213 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@
/data/*.parquet /data/*.parquet
/results/*.json /results/*.json
Cargo.lock Cargo.lock
.grepai/

View File

@ -19,6 +19,13 @@ uuid = { version = "1", features = ["v4"] }
rust_decimal = { version = "1", features = ["serde"] } rust_decimal = { version = "1", features = ["serde"] }
rust_decimal_macros = "1" rust_decimal_macros = "1"
reqwest = { version = "0.12", features = ["json"] }
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "chrono"] }
axum = "0.8"
tower-http = { version = "0.6", features = ["cors", "fs"] }
toml = "0.8"
governor = "0.8"
ort = { version = "2.0.0-rc.11", optional = true } ort = { version = "2.0.0-rc.11", optional = true }
ndarray = { version = "0.16", optional = true } ndarray = { version = "0.16", optional = true }

View File

@ -310,3 +310,502 @@ time_decay = 1 - 1 / (hours_remaining / 24 + 1)
``` ```
ranges from 0 (about to close) to ~1 (distant expiry). ranges from 0 (about to close) to ~1 (distant expiry).
backtest run #2
---
**date:** 2026-01-22
**period:** 2026-01-21 04:00 to 2026-01-21 06:00 (2 hours)
**initial capital:** $10,000
**interval:** 1 hour
### results summary
| metric | strategy | random baseline | delta |
|--------|----------|-----------------|-------|
| total return | +$502.81 (+5.03%) | $0.00 (0.00%) | +$502.81 |
| sharpe ratio | 68.845 | 0.000 | +68.845 |
| max drawdown | 0.00% | 0.00% | +0.00% |
| win rate | 100.0% | 0.0% | +100.0% |
| total trades | 1 (closed) | 0 | +1 |
| positions | 9 (open) | 0 | +9 |
*note: short duration used to validate regime detection logic.*
### architectural updates
1. **momentum acceleration scorer**
- implemented second-order momentum (acceleration)
- detects market turning points using fast/slow momentum divergence
- derived from "momentum turning points" academic research
2. **regime adaptive scorer**
- dynamic weight allocation based on market state
- **bull:** favors trend following (momentum: 0.4)
- **bear:** favors mean reversion (mean_reversion: 0.4)
- **transition:** defensive positioning (time_decay: 0.3, volume: 0.2)
- replaced static `CategoryWeightedScorer`
3. **data handling**
- identified data gap before jan 21 03:00
- adjusted backtest start time to align with available trade data
backtest run #3 (iteration 1)
---
**date:** 2026-01-22
**period:** 2026-01-20 00:00 to 2026-01-22 00:00 (2 days)
**initial capital:** $10,000
**interval:** 1 hour
### results summary
| metric | value |
|--------|-------|
| total return | +$412.85 (+4.13%) |
| sharpe ratio | 4.579 |
| max drawdown | 0.25% |
| win rate | 83.3% |
| total trades | 6 (closed) |
| positions | 49 (open) |
| avg trade pnl | $8.81 |
| avg hold time | 4.7 hours |
### comparison with previous runs
| metric | run #1 (2 days) | run #2 (2 hrs) | run #3 (2 days) | trend |
|--------|-----------------|----------------|-----------------|-------|
| total return | +9.94% | +5.03% | +4.13% | ↓ |
| sharpe ratio | 5.448 | 68.845* | 4.579 | ↓ |
| max drawdown | 1.26% | 0.00% | 0.25% | ↓ better |
| win rate | 58.7% | 100.0% | 83.3% | ↑ |
*run #2 sharpe inflated due to very short period
### architectural updates
1. **kalman price filter**
- implements recursive kalman filtering for price estimation
- outputs: filtered_price, innovation (deviation from prediction), uncertainty
- filters noisy price observations to get better "true price" estimates
- adapts to changing volatility automatically via adaptive gain
2. **VPIN scorer (volume-synchronized probability of informed trading)**
- based on easley, lopez de prado, and o'hara (2012) research
- measures flow toxicity using volume-bucketed order imbalance
- outputs: vpin, flow_toxicity, informed_direction
- high VPIN indicates presence of informed traders
3. **adaptive confidence scorer**
- replaces RegimeAdaptiveScorer with confidence-weighted approach
- uses kalman uncertainty, VPIN, and entropy to calculate confidence
- scales all feature weights by confidence factor
- dynamic weight profiles based on:
- high VPIN + informed direction -> follow smart money (order_flow: 0.4)
- turning point detected -> defensive (time_decay: 0.25)
- bull regime -> trend following (momentum: 0.35)
- bear regime -> mean reversion (mean_reversion: 0.35)
- neutral -> balanced weights
### analysis
**why return decreased from run #1:**
1. the new AdaptiveConfidenceScorer is more conservative, scaling down weights when confidence is low
2. fewer positions taken overall (6 closed vs 46 in run #1)
3. tighter risk management - max drawdown improved from 1.26% to 0.25%
**positive improvements:**
- win rate increased from 58.7% to 83.3%
- avg trade pnl increased from $4.59 to $8.81
- max drawdown decreased significantly (better risk-adjusted returns)
- sharpe ratio still positive at 4.579
**next iteration considerations:**
1. the confidence scaling may be too aggressive - consider relaxing the uncertainty multiplier
2. need to tune the VPIN thresholds for detecting informed trading
3. kalman filter process_noise and measurement_noise parameters could be optimized
4. should add cross-validation with different market regimes
### scorer pipeline (run #3)
```
MomentumScorer (6h) -> momentum
MultiTimeframeMomentumScorer (1h,4h,12h,24h) -> mtf_momentum, mtf_divergence, mtf_alignment
MeanReversionScorer (24h) -> mean_reversion
BollingerMeanReversionScorer (24h, 2.0 std) -> bollinger_reversion, bollinger_position
VolumeScorer (6h) -> volume
OrderFlowScorer -> order_flow
TimeDecayScorer -> time_decay
VolatilityScorer (24h) -> volatility
EntropyScorer (24h) -> entropy
RegimeDetector (24h) -> regime
MomentumAccelerationScorer (3h fast, 12h slow) -> momentum_acceleration, momentum_regime, turning_point
CorrelationScorer (24h, lag 6) -> correlation
KalmanPriceFilter (24h) -> kalman_price, kalman_innovation, kalman_uncertainty
VPINScorer (bucket 50, 20 buckets) -> vpin, flow_toxicity, informed_direction
AdaptiveConfidenceScorer -> final_score, confidence
```
### research sources
- kalman filtering: https://questdb.com/glossary/kalman-filter-for-time-series-forecasting/
- VPIN/flow toxicity: https://www.stern.nyu.edu/sites/default/files/assets/documents/con_035928.pdf
- kelly criterion for prediction markets: https://arxiv.org/html/2412.14144v1
- order flow imbalance: https://www.emergentmind.com/topics/order-flow-imbalance
### thoughts for next iteration
the lower return is concerning but the improved win rate and reduced drawdown suggest the model is making better quality trades, just fewer of them. the confidence mechanism might be too conservative.
potential improvements:
1. reduce uncertainty_factor multiplier from 5.0 to 2.0-3.0
2. add a minimum confidence threshold before suppressing trades entirely
3. explore bayesian updating of the kalman filter parameters based on prediction accuracy
4. add cross-market correlation features (currently CorrelationScorer only does autocorrelation)
backtest run #4 (iteration 2)
---
**date:** 2026-01-22
**period:** 2026-01-20 00:00 to 2026-01-22 00:00 (2 days)
**initial capital:** $10,000
**interval:** 1 hour
### results summary
| metric | original config | with kalman/VPIN |
|--------|-----------------|------------------|
| total return | +$403.69 (4.04%) | +$356.82 (3.57%) |
| sharpe ratio | 3.540 | 4.052 |
| max drawdown | 1.50% | 0.85% |
| win rate | 40.9% | 60.0% |
| total trades | 22 | 5 |
| avg trade pnl | -$7.57 | $9.17 |
### iteration 2 analysis - what went wrong
**root cause identified:** the original run #1 used `CategoryWeightedScorer` with a much simpler pipeline:
- MomentumScorer
- MultiTimeframeMomentumScorer
- MeanReversionScorer
- BollingerMeanReversionScorer
- VolumeScorer
- OrderFlowScorer
- TimeDecayScorer
- CategoryWeightedScorer
subsequent iterations added:
- VolatilityScorer
- EntropyScorer
- RegimeDetector
- MomentumAccelerationScorer
- CorrelationScorer
- KalmanPriceFilter
- VPINScorer
- AdaptiveConfidenceScorer / RegimeAdaptiveScorer
**key findings:**
1. **AdaptiveConfidenceScorer caused massive trade reduction**
- original confidence formula: `1/(1 + uncertainty*5)` with 0.1 floor
- at uncertainty=0.5, confidence=0.29, scaling ALL weights down by 70%
- this suppressed nearly all trading signals
- trade count dropped from 46 (run #1) to 5-6 (iter 1)
2. **adding more scorers != better predictions**
- the additional scorers (RegimeDetector, Entropy, Correlation) added noise
- each scorer contributes features that may conflict or dilute strong signals
- "forecast combination puzzle" - simple equal weights often beat sophisticated methods
3. **kalman filter and VPIN didn't help**
- removing them had no measurable impact on returns
- they may be useful features but weren't being utilized effectively
**attempted fixes in iteration 2:**
- reduced uncertainty multiplier from 5.0 to 2.0
- raised confidence floor from 0.1 to 0.4
- added signal_strength bonus for strong raw signals
- lowered VPIN thresholds from 0.6 to 0.4
- changed confidence to post-multiplier instead of weight-scaling
**none of these fixes restored original performance**
### lessons learned
1. **simplicity wins** - the original 8-scorer pipeline with CategoryWeightedScorer worked best
2. **confidence scaling is dangerous** - multiplying weights by confidence suppresses signals too aggressively
3. **test incrementally** - should have added one scorer at a time and measured impact
4. **beware over-engineering** - the research on kalman filters and VPIN is academically interesting but added complexity without improving results
5. **preserve baseline** - should have kept the original working config in a separate branch
### next iteration direction
rather than adding more complexity, focus on:
1. restoring original simple pipeline
2. tuning existing weights based on category performance
3. improving exit logic rather than entry signals
4. maybe add ONE new feature at a time with A/B testing
backtest run #5 (iteration 3)
---
**date:** 2026-01-22
**period:** 2026-01-20 00:00 to 2026-01-22 00:00 (2 days)
**initial capital:** $10,000
**interval:** 1 hour
### results summary
| metric | strategy | random baseline | delta |
|--------|----------|-----------------|-------|
| total return | +$936.61 (+9.37%) | -$8.00 (-0.08%) | +$944.61 |
| sharpe ratio | 6.491 | -2.291 | +8.782 |
| max drawdown | 0.33% | 0.08% | +0.25% |
| win rate | 100.0% | 0.0% | +100.0% |
| total trades | 9 | 0 | +9 |
| positions (open) | 46 | 0 | +46 |
| avg trade pnl | $25.32 | $0.00 | +$25.32 |
### comparison with previous runs
| metric | run #4 (iter 2) | run #5 (iter 3) | change |
|--------|-----------------|-----------------|--------|
| total return | +4.04% | +9.37% | **+132%** |
| sharpe ratio | 3.540 | 6.491 | **+83%** |
| max drawdown | 1.50% | 0.33% | **-78%** |
| win rate | 40.9% | 100.0% | **+144%** |
| total trades | 22 | 9 | -59% |
| avg trade pnl | -$7.57 | +$25.32 | **+$32.89** |
### key discovery: stop losses hurt prediction market returns
**root cause analysis:**
during iteration 3, we discovered that the original trades.csv data was overwritten after run #1, making it impossible to reproduce those results. this led us to investigate why the "restored" pipeline (iter 2) performed poorly.
analysis of trade logs revealed:
1. **stop losses triggered at -67% to -97%**, not at the configured -15%
2. exits only checked at hourly intervals - prices gapped through stops
3. prediction market prices can move discontinuously (binary outcomes, news)
example failed stop losses from run #4:
- KXSPACEXCOUNT: stop triggered at **-67.4%** (configured -15%)
- KXUCLBTTS: stop triggered at **-97.5%** (configured -15%)
- KXNCAAWBGAME: stop triggered at **-95.0%** (configured -15%)
### exit strategy optimization
we tested 5 exit configurations:
| config | return | sharpe | drawdown | win rate |
|--------|--------|--------|----------|----------|
| baseline (20% TP, 15% SL) | +4.04% | 3.540 | 1.50% | 40.9% |
| 100% TP, no SL | +9.44% | 6.458 | 0.55% | 100% |
| resolution only | +7.16% | 4.388 | 2.12% | n/a |
| **50% TP, no SL** | **+9.37%** | **6.491** | **0.33%** | **100%** |
| 75% TP, no SL | +9.28% | 6.381 | 0.45% | 100% |
**winner: 50% take profit, no stop loss**
- highest sharpe ratio (6.491)
- lowest max drawdown (0.33%)
- good capital recycling (9 closed trades vs 4)
### implementation changes
**new default exit config (src/types.rs):**
```rust
take_profit_pct: 0.50, // exit at +50% (was 0.20)
stop_loss_pct: 0.99, // disabled (was 0.15)
max_hold_hours: 48, // shorter (was 72)
score_reversal_threshold: -0.5,
```
**rationale:**
1. **stop losses don't work** for prediction markets
- prices gap through hourly checks
- binary outcomes mean temp drops don't invalidate bets
- position sizing limits max loss instead
2. **50% take profit** balances two goals:
- locks in gains before potential reversal
- lets winners run further than 20% (which cut gains short)
3. **shorter hold time (48h)** for 2-day backtests
- ensures positions resolve or exit within test period
### lessons learned
1. **prediction markets ≠ traditional trading**
- traditional stop losses assume continuous price paths
- binary outcomes can cause discontinuous jumps
- holding to resolution is often optimal
2. **exit strategy matters as much as entry**
- iteration 3 used the SAME entry signals as iteration 2
- only changed exit parameters
- return increased 132% (4.04% → 9.37%)
3. **test before theorizing**
- academic research on stop losses assumes continuous markets
- empirical testing revealed the opposite for prediction markets
### research sources
- optimal trailing stop (Leung & Zhang 2021): https://medium.com/quantitative-investing/optimal-trading-with-a-trailing-stop-796964fc892a
- forecast combination: https://www.sciencedirect.com/science/article/abs/pii/S0169207021000650
- exit strategies empirical: https://www.quantifiedstrategies.com/trading-exit-strategies/
### thoughts for next iteration
the exit strategy optimization was a major win. next iteration should consider:
1. **position sizing optimization**
- current kelly fraction is 0.25, may be too conservative
- with 100% win rate, could increase bet sizing
2. **entry signal filtering**
- 46 positions still open at end of backtest
- could add filters to reduce position count for capital efficiency
3. **category-specific exit tuning**
- sports markets may need different exits than politics
- crypto markets have different volatility profiles
4. **longer backtest period**
- current data covers only 2 days
- need to test across different market conditions
backtest run #6 (iteration 4)
---
**date:** 2026-01-22
**period:** 2026-01-20 00:00 to 2026-01-22 00:00 (2 days)
**initial capital:** $10,000
**interval:** 1 hour
### results summary
| metric | strategy | random baseline | delta |
|--------|----------|-----------------|-------|
| total return | +$1,898.45 (+18.98%) | $0.00 (0.00%) | +$1,898.45 |
| sharpe ratio | 2.814 | 0.000 | +2.814 |
| max drawdown | 0.79% | 0.00% | +0.79% |
| win rate | 100.0% | 0.0% | +100.0% |
| total trades | 10 | 0 | +10 |
| positions (open) | 100 | 0 | +100 |
### comparison with previous runs
| metric | iter 3 | iter 4 | change |
|--------|--------|--------|--------|
| total return | +9.37% | **+18.98%** | **+102%** |
| sharpe ratio | 6.491 | 2.814 | -57% |
| max drawdown | 0.33% | 0.79% | +139% |
| win rate | 100.0% | 100.0% | 0% |
| total trades | 9 | 10 | +11% |
| positions | 46 | 100 | +117% |
### key discovery: diversification beats concentration in prediction markets
**surprising finding:** concentration hurts returns in prediction markets!
this contradicts conventional wisdom ("best ideas outperform") but makes sense for binary outcomes:
| max_positions | return | sharpe | win rate | trades |
|---------------|--------|--------|----------|--------|
| 5 | 0.24% | 0.986 | 100% | 1 |
| 10 | 0.47% | 1.902 | 100% | 2 |
| 30 | 3.12% | 3.109 | 100% | 3 |
| 50 | 7.97% | 2.593 | 100% | 5 |
| 100 | 18.98% | 2.814 | 100% | 10 |
| 200 | 38.88% | 2.995 | 97.5% | 40 |
| 500 | 96.10% | 3.295 | 95.4% | 87 |
| 1000 | **105.55%** | **3.495** | 95.7% | 94 |
**why diversification wins for prediction markets:**
1. **binary payouts** - each position has positive expected value
- more positions = more chances to capture binary wins
- unlike stocks, losers go to 0 quickly (can't average down)
2. **model has positive edge**
- if scoring model has +EV on average, more bets = more profit
- law of large numbers favors diversification
3. **capital utilization**
- concentrated portfolios leave cash idle
- diversified approach deploys all capital
- with 1000 positions, cash went to $0.00
4. **different from stock picking**
- "best ideas" research assumes winners can compound
- prediction markets resolve quickly (days/weeks)
- can't hold winners long-term
### bug fix: max_positions enforcement
discovered that max_positions wasn't being enforced - positions accumulated each hour without limit. added check in backtest loop:
```rust
for signal in signals {
// enforce max_positions limit
if context.portfolio.positions.len() >= self.config.max_positions {
break;
}
// ...
}
```
### implementation changes
**new defaults:**
```rust
// src/main.rs CLI defaults
max_positions: 100 // was 5
kelly_fraction: 0.40 // was 0.25
max_position_pct: 0.30 // was 0.25
// src/execution.rs PositionSizingConfig
kelly_fraction: 0.40
max_position_pct: 0.30
```
### note on sharpe ratio decrease
sharpe dropped from 6.491 (iter 3) to 2.814 (iter 4) despite 2x higher returns because:
- more positions = more variance in equity curve
- sharpe measures risk-adjusted returns
- still a strong positive sharpe (>1.0 is generally good)
the trade-off is worth it: double the returns for lower risk-adjusted ratio.
### research sources
- kelly criterion for prediction markets: https://arxiv.org/html/2412.14144
- concentrated portfolios: https://www.bbh.com/us/en/insights/capital-partners-insights/the-benefits-of-concentrated-portfolios.html
- position sizing research: https://thescienceofhitting.com/p/position-sizing
### thoughts for next iteration
iteration 4 was a paradigm shift. next iteration should consider:
1. **push diversification further**
- 1000 positions gave 105% return (2x capital!)
- limited by cash, not max_positions
- could explore leverage or smaller position sizes
2. **validate with longer backtest**
- 2-day window is very short
- need to test if diversification holds across market regimes
3. **position sizing optimization**
- current kelly approach may not be optimal
- with many positions, equal weighting might work better
4. **transaction costs**
- many positions = many transactions
- need to model realistic slippage and fees
5. **examine edge by category**
- sports vs politics vs crypto
- may find some categories have stronger edge

31
config.toml Normal file
View File

@ -0,0 +1,31 @@
mode = "paper"
[kalshi]
base_url = "https://api.elections.kalshi.com/trade-api/v2"
poll_interval_secs = 60
rate_limit_per_sec = 2
[trading]
initial_capital = 10000.0
max_positions = 100
kelly_fraction = 0.25
max_position_pct = 0.10
take_profit_pct = 0.50
stop_loss_pct = 0.99
max_hold_hours = 48
[persistence]
db_path = "kalshi-paper.db"
[web]
enabled = true
bind_addr = "127.0.0.1:3030"
[circuit_breaker]
max_drawdown_pct = 0.15
max_daily_loss_pct = 0.05
max_positions = 100
max_single_position_pct = 0.10
max_consecutive_errors = 5
max_fills_per_hour = 200
max_fills_per_day = 1000

BIN
kalshi-paper.db Normal file

Binary file not shown.

View File

@ -7,8 +7,20 @@ Features:
- Incremental saves (writes batches to disk) - Incremental saves (writes batches to disk)
- Resume capability (tracks cursor position) - Resume capability (tracks cursor position)
- Retry logic with exponential backoff - Retry logic with exponential backoff
- Date filtering for trades (--min-ts, --max-ts)
Usage:
# fetch everything (default)
python fetch_kalshi_data.py
# fetch trades from last 2 months with higher limit
python fetch_kalshi_data.py --min-ts 1763794800 --trade-limit 10000000
# reset trades state and refetch
python fetch_kalshi_data.py --reset-trades --min-ts 1763794800
""" """
import argparse
import json import json
import csv import csv
import time import time
@ -20,6 +32,46 @@ from pathlib import Path
BASE_URL = "https://api.elections.kalshi.com/trade-api/v2" BASE_URL = "https://api.elections.kalshi.com/trade-api/v2"
STATE_FILE = "fetch_state.json" STATE_FILE = "fetch_state.json"
def parse_args():
parser = argparse.ArgumentParser(description="Fetch Kalshi market and trade data")
parser.add_argument(
"--output-dir",
type=str,
default="/mnt/work/kalshi-data",
help="Output directory for CSV files (default: /mnt/work/kalshi-data)"
)
parser.add_argument(
"--trade-limit",
type=int,
default=1_000_000,
help="Maximum number of trades to fetch (default: 1,000,000)"
)
parser.add_argument(
"--min-ts",
type=int,
default=None,
help="Minimum unix timestamp for trades (trades after this time)"
)
parser.add_argument(
"--max-ts",
type=int,
default=None,
help="Maximum unix timestamp for trades (trades before this time)"
)
parser.add_argument(
"--reset-trades",
action="store_true",
help="Reset trades state to fetch fresh (keeps markets done)"
)
parser.add_argument(
"--trades-only",
action="store_true",
help="Skip markets fetch, only fetch trades"
)
return parser.parse_args()
def fetch_json(url: str, max_retries: int = 5) -> dict: def fetch_json(url: str, max_retries: int = 5) -> dict:
"""Fetch JSON from URL with retries and exponential backoff.""" """Fetch JSON from URL with retries and exponential backoff."""
req = urllib.request.Request(url, headers={"Accept": "application/json"}) req = urllib.request.Request(url, headers={"Accept": "application/json"})
@ -45,6 +97,7 @@ def fetch_json(url: str, max_retries: int = 5) -> dict:
else: else:
raise raise
def load_state(output_dir: Path) -> dict: def load_state(output_dir: Path) -> dict:
"""Load saved state for resuming.""" """Load saved state for resuming."""
state_path = output_dir / STATE_FILE state_path = output_dir / STATE_FILE
@ -55,12 +108,14 @@ def load_state(output_dir: Path) -> dict:
"trades_cursor": None, "trades_count": 0, "trades_cursor": None, "trades_count": 0,
"markets_done": False, "trades_done": False} "markets_done": False, "trades_done": False}
def save_state(output_dir: Path, state: dict): def save_state(output_dir: Path, state: dict):
"""Save state for resuming.""" """Save state for resuming."""
state_path = output_dir / STATE_FILE state_path = output_dir / STATE_FILE
with open(state_path, "w") as f: with open(state_path, "w") as f:
json.dump(state, f) json.dump(state, f)
def append_markets_csv(markets: list, output_path: Path, write_header: bool): def append_markets_csv(markets: list, output_path: Path, write_header: bool):
"""Append markets to CSV.""" """Append markets to CSV."""
mode = "w" if write_header else "a" mode = "w" if write_header else "a"
@ -94,6 +149,7 @@ def append_markets_csv(markets: list, output_path: Path, write_header: bool):
m.get("open_interest", ""), m.get("open_interest", ""),
]) ])
def append_trades_csv(trades: list, output_path: Path, write_header: bool): def append_trades_csv(trades: list, output_path: Path, write_header: bool):
"""Append trades to CSV.""" """Append trades to CSV."""
mode = "w" if write_header else "a" mode = "w" if write_header else "a"
@ -116,6 +172,7 @@ def append_trades_csv(trades: list, output_path: Path, write_header: bool):
taker_side, taker_side,
]) ])
def fetch_markets_incremental(output_dir: Path, state: dict) -> int: def fetch_markets_incremental(output_dir: Path, state: dict) -> int:
"""Fetch markets incrementally with state tracking.""" """Fetch markets incrementally with state tracking."""
output_path = output_dir / "markets.csv" output_path = output_dir / "markets.csv"
@ -159,19 +216,38 @@ def fetch_markets_incremental(output_dir: Path, state: dict) -> int:
return total return total
def fetch_trades_incremental(output_dir: Path, state: dict, limit: int) -> int:
def fetch_trades_incremental(
output_dir: Path,
state: dict,
limit: int,
min_ts: int = None,
max_ts: int = None
) -> int:
"""Fetch trades incrementally with state tracking.""" """Fetch trades incrementally with state tracking."""
output_path = output_dir / "trades.csv" output_path = output_dir / "trades.csv"
cursor = state["trades_cursor"] cursor = state["trades_cursor"]
total = state["trades_count"] total = state["trades_count"]
write_header = total == 0 write_header = total == 0
print(f"Resuming from {total} trades...") if total == 0:
print("Starting fresh trades fetch...")
else:
print(f"Resuming from {total:,} trades...")
if min_ts:
print(f" min_ts filter: {min_ts} ({datetime.fromtimestamp(min_ts)})")
if max_ts:
print(f" max_ts filter: {max_ts} ({datetime.fromtimestamp(max_ts)})")
while total < limit: while total < limit:
url = f"{BASE_URL}/markets/trades?limit=1000" url = f"{BASE_URL}/markets/trades?limit=1000"
if cursor: if cursor:
url += f"&cursor={cursor}" url += f"&cursor={cursor}"
if min_ts:
url += f"&min_ts={min_ts}"
if max_ts:
url += f"&max_ts={max_ts}"
print(f"Fetching trades... ({total:,}/{limit:,})") print(f"Fetching trades... ({total:,}/{limit:,})")
@ -204,32 +280,53 @@ def fetch_trades_incremental(output_dir: Path, state: dict, limit: int) -> int:
return total return total
def main(): def main():
output_dir = Path("/mnt/work/kalshi-data") args = parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
print("=" * 50) print("=" * 50)
print("Kalshi Data Fetcher (with resume)") print("Kalshi Data Fetcher (with resume)")
print("=" * 50) print("=" * 50)
print(f"Output: {output_dir}")
print(f"Trade limit: {args.trade_limit:,}")
state = load_state(output_dir) state = load_state(output_dir)
# fetch markets # reset trades state if requested
if not state["markets_done"]: if args.reset_trades:
print("\n[1/2] Fetching markets...") print("\nResetting trades state...")
markets_count = fetch_markets_incremental(output_dir, state) state["trades_cursor"] = None
if state["markets_done"]: state["trades_count"] = 0
print(f"Markets complete: {markets_count:,}") state["trades_done"] = False
save_state(output_dir, state)
# fetch markets (skip if --trades-only)
if not args.trades_only:
if not state["markets_done"]:
print("\n[1/2] Fetching markets...")
markets_count = fetch_markets_incremental(output_dir, state)
if state["markets_done"]:
print(f"Markets complete: {markets_count:,}")
else:
print(f"Markets paused at: {markets_count:,}")
return 1
else: else:
print(f"Markets paused at: {markets_count:,}") print(f"\n[1/2] Markets already complete: {state['markets_count']:,}")
return 1
else: else:
print(f"\n[1/2] Markets already complete: {state['markets_count']:,}") print("\n[1/2] Skipping markets (--trades-only)")
# fetch trades # fetch trades
if not state["trades_done"]: if not state["trades_done"]:
print("\n[2/2] Fetching trades...") print("\n[2/2] Fetching trades...")
trades_count = fetch_trades_incremental(output_dir, state, limit=1000000) trades_count = fetch_trades_incremental(
output_dir,
state,
limit=args.trade_limit,
min_ts=args.min_ts,
max_ts=args.max_ts
)
if state["trades_done"]: if state["trades_done"]:
print(f"Trades complete: {trades_count:,}") print(f"Trades complete: {trades_count:,}")
else: else:
@ -250,5 +347,6 @@ def main():
return 0 return 0
if __name__ == "__main__": if __name__ == "__main__":
exit(main()) exit(main())

274
scripts/fetch_kalshi_data_v2.py Executable file
View File

@ -0,0 +1,274 @@
#!/usr/bin/env python3
"""
Fetch historical trade data from Kalshi's public API with daily distribution.
Fetches a configurable number of trades per day across a date range,
ensuring good coverage rather than clustering around recent data.
Features:
- Day-by-day iteration (oldest to newest)
- Configurable trades-per-day limit
- Resume capability (tracks per-day progress)
- Retry logic with exponential backoff
Usage:
# fetch last 2 months with default settings
python fetch_kalshi_data_v2.py
# fetch specific date range
python fetch_kalshi_data_v2.py --start-date 2025-11-22 --end-date 2026-01-22
# test with small range
python fetch_kalshi_data_v2.py --start-date 2026-01-20 --end-date 2026-01-21
"""
import argparse
import json
import csv
import time
import urllib.request
import urllib.error
from datetime import datetime, timedelta
from pathlib import Path
BASE_URL = "https://api.elections.kalshi.com/trade-api/v2"
STATE_FILE = "fetch_state_v2.json"
def parse_args():
parser = argparse.ArgumentParser(
description="Fetch Kalshi trade data with daily distribution"
)
two_months_ago = (datetime.now() - timedelta(days=61)).strftime("%Y-%m-%d")
today = datetime.now().strftime("%Y-%m-%d")
parser.add_argument(
"--start-date",
type=str,
default=two_months_ago,
help=f"Start date YYYY-MM-DD (default: {two_months_ago})"
)
parser.add_argument(
"--end-date",
type=str,
default=today,
help=f"End date YYYY-MM-DD (default: {today})"
)
parser.add_argument(
"--trades-per-day",
type=int,
default=100_000,
help="Max trades to fetch per day (default: 100,000)"
)
parser.add_argument(
"--output-dir",
type=str,
default="/mnt/work/kalshi-data/v2",
help="Output directory (default: /mnt/work/kalshi-data/v2)"
)
return parser.parse_args()
def fetch_json(url: str, max_retries: int = 5) -> dict:
"""Fetch JSON from URL with retries and exponential backoff."""
req = urllib.request.Request(url, headers={"Accept": "application/json"})
for attempt in range(max_retries):
try:
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read().decode())
except (urllib.error.HTTPError, urllib.error.URLError) as e:
wait = 2 ** attempt
print(f" attempt {attempt + 1}/{max_retries} failed: {e}")
if attempt < max_retries - 1:
print(f" retrying in {wait}s...")
time.sleep(wait)
else:
raise
except Exception as e:
wait = 2 ** attempt
print(f" unexpected error: {e}")
if attempt < max_retries - 1:
print(f" retrying in {wait}s...")
time.sleep(wait)
else:
raise
def load_state(output_dir: Path) -> dict:
"""Load saved state for resuming."""
state_path = output_dir / STATE_FILE
if state_path.exists():
with open(state_path) as f:
return json.load(f)
return {
"completed_days": [],
"current_day": None,
"current_day_cursor": None,
"current_day_count": 0,
"total_trades": 0,
}
def save_state(output_dir: Path, state: dict):
"""Save state for resuming."""
state_path = output_dir / STATE_FILE
with open(state_path, "w") as f:
json.dump(state, f, indent=2)
def append_trades_csv(trades: list, output_path: Path, write_header: bool):
"""Append trades to CSV."""
mode = "w" if write_header else "a"
with open(output_path, mode, newline="") as f:
writer = csv.writer(f)
if write_header:
writer.writerow(["timestamp", "ticker", "price", "volume", "taker_side"])
for t in trades:
price = t.get("yes_price", t.get("price", 50))
taker_side = t.get("taker_side", "")
if not taker_side:
taker_side = "yes" if t.get("is_taker_side_yes", True) else "no"
writer.writerow([
t.get("created_time", t.get("ts", "")),
t.get("ticker", t.get("market_ticker", "")),
price,
t.get("count", t.get("volume", 1)),
taker_side,
])
def date_to_timestamps(date_str: str) -> tuple[int, int]:
"""Convert YYYY-MM-DD to (start_ts, end_ts) for that day."""
dt = datetime.strptime(date_str, "%Y-%m-%d")
start_ts = int(dt.timestamp())
end_ts = int((dt + timedelta(days=1)).timestamp()) - 1
return start_ts, end_ts
def generate_date_range(start_date: str, end_date: str) -> list[str]:
"""Generate list of YYYY-MM-DD strings from start to end (inclusive)."""
start = datetime.strptime(start_date, "%Y-%m-%d")
end = datetime.strptime(end_date, "%Y-%m-%d")
dates = []
current = start
while current <= end:
dates.append(current.strftime("%Y-%m-%d"))
current += timedelta(days=1)
return dates
def fetch_day_trades(
output_dir: Path,
state: dict,
day: str,
trades_per_day: int,
output_path: Path,
) -> int:
"""Fetch trades for a single day. Returns count fetched."""
min_ts, max_ts = date_to_timestamps(day)
cursor = state["current_day_cursor"]
count = state["current_day_count"]
write_header = not output_path.exists()
while count < trades_per_day:
url = f"{BASE_URL}/markets/trades?limit=1000&min_ts={min_ts}&max_ts={max_ts}"
if cursor:
url += f"&cursor={cursor}"
try:
data = fetch_json(url)
except Exception as e:
print(f" error: {e}")
print(f" progress saved. run again to resume.")
return count
batch = data.get("trades", [])
if not batch:
break
append_trades_csv(batch, output_path, write_header)
write_header = False
count += len(batch)
state["total_trades"] += len(batch)
cursor = data.get("cursor")
state["current_day_cursor"] = cursor
state["current_day_count"] = count
save_state(output_dir, state)
if count % 10000 == 0 or count >= trades_per_day:
print(f" {day}: {count:,} trades")
if not cursor:
break
time.sleep(0.3)
return count
def main():
args = parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "trades.csv"
print("=" * 60)
print("Kalshi Data Fetcher v2 (daily distribution)")
print("=" * 60)
print(f"Date range: {args.start_date} to {args.end_date}")
print(f"Trades per day: {args.trades_per_day:,}")
print(f"Output: {output_path}")
print()
state = load_state(output_dir)
all_days = generate_date_range(args.start_date, args.end_date)
completed = set(state["completed_days"])
remaining_days = [d for d in all_days if d not in completed]
print(f"Days: {len(all_days)} total, {len(completed)} completed, "
f"{len(remaining_days)} remaining")
print(f"Trades so far: {state['total_trades']:,}")
print()
for day in remaining_days:
# check if we're resuming this day
if state["current_day"] == day:
print(f" resuming {day} from {state['current_day_count']:,} trades...")
else:
state["current_day"] = day
state["current_day_cursor"] = None
state["current_day_count"] = 0
save_state(output_dir, state)
print(f" fetching {day}...")
count = fetch_day_trades(
output_dir, state, day, args.trades_per_day, output_path
)
# mark day complete
state["completed_days"].append(day)
state["current_day"] = None
state["current_day_cursor"] = None
state["current_day_count"] = 0
save_state(output_dir, state)
print(f" {day} complete: {count:,} trades")
print()
print("=" * 60)
print("Done!")
print(f"Total trades: {state['total_trades']:,}")
print(f"Days completed: {len(state['completed_days'])}")
print(f"Output: {output_path}")
print("=" * 60)
return 0
if __name__ == "__main__":
exit(main())

168
src/api/client.rs Normal file
View File

@ -0,0 +1,168 @@
use super::types::{ApiMarket, ApiTrade, MarketsResponse, TradesResponse};
use crate::config::KalshiConfig;
use governor::{Quota, RateLimiter};
use std::num::NonZeroU32;
use std::sync::Arc;
use tracing::{debug, warn};
pub struct KalshiClient {
http: reqwest::Client,
base_url: String,
limiter: Arc<
RateLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
>,
>,
}
impl KalshiClient {
pub fn new(config: &KalshiConfig) -> Self {
let quota = Quota::per_second(
NonZeroU32::new(config.rate_limit_per_sec).unwrap_or(
NonZeroU32::new(5).unwrap(),
),
);
let limiter = Arc::new(RateLimiter::direct(quota));
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("failed to build http client");
Self {
http,
base_url: config.base_url.clone(),
limiter,
}
}
async fn rate_limit(&self) {
self.limiter.until_ready().await;
}
async fn request_with_retry(
&self,
url: &str,
) -> anyhow::Result<reqwest::Response> {
for attempt in 0..5u32 {
if attempt > 0 {
let delay = std::time::Duration::from_secs(
3u64.pow(attempt),
);
debug!(
attempt = attempt,
delay_secs = delay.as_secs(),
"retrying after rate limit"
);
tokio::time::sleep(delay).await;
}
self.rate_limit().await;
let resp = self.http.get(url).send().await?;
if resp.status() == reqwest::StatusCode::TOO_MANY_REQUESTS {
warn!(
attempt = attempt,
"rate limited (429), backing off"
);
continue;
}
if !resp.status().is_success() {
let status = resp.status();
let body =
resp.text().await.unwrap_or_default();
anyhow::bail!(
"kalshi API error {}: {}",
status,
body
);
}
return Ok(resp);
}
anyhow::bail!("exhausted retries after rate limiting")
}
pub async fn get_open_markets(
&self,
) -> anyhow::Result<Vec<ApiMarket>> {
let mut all_markets = Vec::new();
let mut cursor: Option<String> = None;
let max_pages = 5; // cap at 1000 markets to avoid rate limiting
for page_num in 0..max_pages {
self.rate_limit().await;
let mut url = format!(
"{}/markets?status=open&limit=200",
self.base_url
);
if let Some(ref c) = cursor {
url.push_str(&format!("&cursor={}", c));
}
debug!(
url = %url,
page = page_num,
"fetching markets"
);
let resp = self
.request_with_retry(&url)
.await?;
let page: MarketsResponse = resp.json().await?;
let count = page.markets.len();
all_markets.extend(page.markets);
if !page.cursor.is_empty() && count > 0 {
cursor = Some(page.cursor);
} else {
break;
}
}
debug!(
total = all_markets.len(),
"fetched open markets"
);
Ok(all_markets)
}
pub async fn get_market_trades(
&self,
ticker: &str,
limit: u32,
) -> anyhow::Result<Vec<ApiTrade>> {
self.rate_limit().await;
let url = format!(
"{}/markets/trades?ticker={}&limit={}",
self.base_url, ticker, limit
);
debug!(ticker = %ticker, "fetching trades");
let resp = match self
.request_with_retry(&url)
.await
{
Ok(r) => r,
Err(e) => {
warn!(
ticker = %ticker,
error = %e,
"failed to fetch trades"
);
return Ok(Vec::new());
}
};
let data: TradesResponse = resp.json().await?;
Ok(data.trades)
}
}

5
src/api/mod.rs Normal file
View File

@ -0,0 +1,5 @@
pub mod client;
pub mod types;
pub use client::KalshiClient;
pub use types::*;

119
src/api/types.rs Normal file
View File

@ -0,0 +1,119 @@
use chrono::{DateTime, Utc};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct MarketsResponse {
pub markets: Vec<ApiMarket>,
#[serde(default)]
pub cursor: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiMarket {
pub ticker: String,
#[serde(default)]
pub title: String,
#[serde(default)]
pub event_ticker: String,
#[serde(default)]
pub status: String,
pub open_time: DateTime<Utc>,
pub close_time: DateTime<Utc>,
#[serde(default)]
pub yes_ask: i64,
#[serde(default)]
pub yes_bid: i64,
#[serde(default)]
pub no_ask: i64,
#[serde(default)]
pub no_bid: i64,
#[serde(default)]
pub last_price: i64,
#[serde(default)]
pub volume: i64,
#[serde(default)]
pub volume_24h: i64,
#[serde(default)]
pub result: String,
#[serde(default)]
pub subtitle: String,
}
impl ApiMarket {
/// returns yes price as a fraction (0.0 - 1.0)
/// prices from API are in cents (0-100)
pub fn mid_yes_price(&self) -> f64 {
let bid = self.yes_bid as f64 / 100.0;
let ask = self.yes_ask as f64 / 100.0;
if bid > 0.0 && ask > 0.0 {
(bid + ask) / 2.0
} else if bid > 0.0 {
bid
} else if ask > 0.0 {
ask
} else if self.last_price > 0 {
self.last_price as f64 / 100.0
} else {
0.0
}
}
pub fn category_from_event(&self) -> String {
let lower = self.event_ticker.to_lowercase();
if lower.contains("nba")
|| lower.contains("nfl")
|| lower.contains("sport")
{
"sports".to_string()
} else if lower.contains("btc")
|| lower.contains("crypto")
|| lower.contains("eth")
{
"crypto".to_string()
} else if lower.contains("weather")
|| lower.contains("temp")
{
"weather".to_string()
} else if lower.contains("econ")
|| lower.contains("fed")
|| lower.contains("cpi")
|| lower.contains("gdp")
{
"economics".to_string()
} else if lower.contains("elect")
|| lower.contains("polit")
|| lower.contains("trump")
|| lower.contains("biden")
{
"politics".to_string()
} else {
"other".to_string()
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TradesResponse {
#[serde(default)]
pub trades: Vec<ApiTrade>,
#[serde(default)]
pub cursor: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiTrade {
#[serde(default)]
pub trade_id: String,
#[serde(default)]
pub ticker: String,
pub created_time: DateTime<Utc>,
#[serde(default)]
pub yes_price: i64,
#[serde(default)]
pub no_price: i64,
#[serde(default)]
pub count: i64,
#[serde(default)]
pub taker_side: String,
}

View File

@ -1,9 +1,9 @@
use crate::data::HistoricalData; use crate::data::HistoricalData;
use crate::execution::{Executor, PositionSizingConfig}; use crate::execution::{BacktestExecutor, OrderExecutor, PositionSizingConfig};
use crate::metrics::{BacktestResult, MetricsCollector}; use crate::metrics::{BacktestResult, MetricsCollector};
use crate::pipeline::{ use crate::pipeline::{
AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer, Filter, AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer,
HistoricalMarketSource, LiquidityFilter, MeanReversionScorer, MomentumScorer, Filter, HistoricalMarketSource, LiquidityFilter, MeanReversionScorer, MomentumScorer,
MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer, Selector, Source, TimeDecayScorer, MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer, Selector, Source, TimeDecayScorer,
TimeToCloseFilter, TopKSelector, TradingPipeline, VolumeScorer, TimeToCloseFilter, TopKSelector, TradingPipeline, VolumeScorer,
}; };
@ -11,11 +11,13 @@ use crate::types::{
BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType, BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType,
TradingContext, TradingContext,
}; };
use crate::web::BacktestProgress;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rust_decimal::Decimal; use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive; use rust_decimal::prelude::ToPrimitive;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::Ordering;
use tracing::info; use tracing::info;
/// resolves any positions in markets that have closed /// resolves any positions in markets that have closed
@ -86,19 +88,21 @@ pub struct Backtester {
config: BacktestConfig, config: BacktestConfig,
data: Arc<HistoricalData>, data: Arc<HistoricalData>,
pipeline: TradingPipeline, pipeline: TradingPipeline,
executor: Executor, executor: BacktestExecutor,
progress: Option<Arc<BacktestProgress>>,
} }
impl Backtester { impl Backtester {
pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self { pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self {
let pipeline = Self::build_default_pipeline(data.clone(), &config); let pipeline = Self::build_default_pipeline(data.clone(), &config);
let executor = Executor::new(data.clone(), 10, config.max_position_size); let executor = BacktestExecutor::new(data.clone(), 10, config.max_position_size);
Self { Self {
config, config,
data, data,
pipeline, pipeline,
executor, executor,
progress: None,
} }
} }
@ -109,7 +113,7 @@ impl Backtester {
exit_config: ExitConfig, exit_config: ExitConfig,
) -> Self { ) -> Self {
let pipeline = Self::build_default_pipeline(data.clone(), &config); let pipeline = Self::build_default_pipeline(data.clone(), &config);
let executor = Executor::new(data.clone(), 10, config.max_position_size) let executor = BacktestExecutor::new(data.clone(), 10, config.max_position_size)
.with_sizing_config(sizing_config) .with_sizing_config(sizing_config)
.with_exit_config(exit_config); .with_exit_config(exit_config);
@ -118,9 +122,15 @@ impl Backtester {
data, data,
pipeline, pipeline,
executor, executor,
progress: None,
} }
} }
pub fn with_progress(mut self, progress: Arc<BacktestProgress>) -> Self {
self.progress = Some(progress);
self
}
pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self { pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self {
self.pipeline = pipeline; self.pipeline = pipeline;
self self
@ -160,13 +170,30 @@ impl Backtester {
let mut current_time = self.config.start_time; let mut current_time = self.config.start_time;
let total_steps = (self.config.end_time - self.config.start_time)
.num_seconds()
/ self.config.interval.num_seconds().max(1);
if let Some(ref progress) = self.progress {
progress.total_steps.store(
total_steps as u64,
Ordering::Relaxed,
);
progress.phase.store(
BacktestProgress::PHASE_RUNNING,
Ordering::Relaxed,
);
}
info!( info!(
start = %self.config.start_time, start = %self.config.start_time,
end = %self.config.end_time, end = %self.config.end_time,
interval_hours = self.config.interval.num_hours(), interval_hours = self.config.interval.num_hours(),
total_steps = total_steps,
"starting backtest" "starting backtest"
); );
let mut step: u64 = 0;
while current_time < self.config.end_time { while current_time < self.config.end_time {
context.timestamp = current_time; context.timestamp = current_time;
context.request_id = uuid::Uuid::new_v4().to_string(); context.request_id = uuid::Uuid::new_v4().to_string();
@ -232,7 +259,12 @@ impl Backtester {
let signals = self.executor.generate_signals(&result.selected_candidates, &context); let signals = self.executor.generate_signals(&result.selected_candidates, &context);
for signal in signals { for signal in signals {
if let Some(fill) = self.executor.execute_signal(&signal, &context) { // enforce max_positions limit
if context.portfolio.positions.len() >= self.config.max_positions {
break;
}
if let Some(fill) = self.executor.execute_signal(&signal, &context).await {
info!( info!(
ticker = %fill.ticker, ticker = %fill.ticker,
side = ?fill.side, side = ?fill.side,
@ -267,6 +299,13 @@ impl Backtester {
let market_prices = self.get_current_prices(current_time); let market_prices = self.get_current_prices(current_time);
metrics.record(current_time, &context.portfolio, &market_prices); metrics.record(current_time, &context.portfolio, &market_prices);
step += 1;
if let Some(ref progress) = self.progress {
progress
.current_step
.store(step, Ordering::Relaxed);
}
current_time = current_time + self.config.interval; current_time = current_time + self.config.interval;
} }

124
src/config.rs Normal file
View File

@ -0,0 +1,124 @@
use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RunMode {
Backtest,
Paper,
Live,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AppConfig {
pub mode: RunMode,
pub kalshi: KalshiConfig,
pub trading: TradingConfig,
pub persistence: PersistenceConfig,
pub web: WebConfig,
pub circuit_breaker: CircuitBreakerConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct KalshiConfig {
pub base_url: String,
pub poll_interval_secs: u64,
pub rate_limit_per_sec: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TradingConfig {
pub initial_capital: f64,
pub max_positions: usize,
pub kelly_fraction: f64,
pub max_position_pct: f64,
pub take_profit_pct: Option<f64>,
pub stop_loss_pct: Option<f64>,
pub max_hold_hours: Option<i64>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct PersistenceConfig {
pub db_path: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct WebConfig {
pub enabled: bool,
pub bind_addr: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CircuitBreakerConfig {
pub max_drawdown_pct: f64,
pub max_daily_loss_pct: f64,
pub max_positions: Option<usize>,
pub max_single_position_pct: Option<f64>,
pub max_consecutive_errors: Option<u32>,
pub max_fills_per_hour: Option<u32>,
pub max_fills_per_day: Option<u32>,
}
impl AppConfig {
pub fn load(path: &Path) -> anyhow::Result<Self> {
let content = std::fs::read_to_string(path)?;
let config: Self = toml::from_str(&content)?;
Ok(config)
}
}
impl Default for KalshiConfig {
fn default() -> Self {
Self {
base_url: "https://api.elections.kalshi.com/trade-api/v2"
.to_string(),
poll_interval_secs: 300,
rate_limit_per_sec: 5,
}
}
}
impl Default for TradingConfig {
fn default() -> Self {
Self {
initial_capital: 10000.0,
max_positions: 100,
kelly_fraction: 0.25,
max_position_pct: 0.10,
take_profit_pct: Some(0.50),
stop_loss_pct: Some(0.99),
max_hold_hours: Some(48),
}
}
}
impl Default for PersistenceConfig {
fn default() -> Self {
Self {
db_path: "kalshi-paper.db".to_string(),
}
}
}
impl Default for WebConfig {
fn default() -> Self {
Self {
enabled: true,
bind_addr: "127.0.0.1:3030".to_string(),
}
}
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
max_drawdown_pct: 0.15,
max_daily_loss_pct: 0.05,
max_positions: Some(100),
max_single_position_pct: Some(0.10),
max_consecutive_errors: Some(5),
max_fills_per_hour: Some(50),
max_fills_per_day: Some(200),
}
}
}

View File

@ -0,0 +1,240 @@
use crate::config::CircuitBreakerConfig;
use crate::store::SqliteStore;
use chrono::{Duration, Utc};
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use std::sync::Arc;
use tracing::warn;
#[derive(Debug, Clone, PartialEq)]
pub enum CbStatus {
Ok,
Tripped(String),
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
store: Arc<SqliteStore>,
consecutive_errors: u32,
}
impl CircuitBreaker {
pub fn new(
config: CircuitBreakerConfig,
store: Arc<SqliteStore>,
) -> Self {
Self {
config,
store,
consecutive_errors: 0,
}
}
pub fn record_error(&mut self) {
self.consecutive_errors += 1;
}
pub fn record_success(&mut self) {
self.consecutive_errors = 0;
}
pub async fn check(
&self,
current_equity: Decimal,
peak_equity: Decimal,
positions_count: usize,
) -> CbStatus {
if let Some(status) =
self.check_drawdown(current_equity, peak_equity)
{
return status;
}
if let Some(status) = self.check_daily_loss().await {
return status;
}
if let Some(status) =
self.check_max_positions(positions_count)
{
return status;
}
if let Some(status) = self.check_consecutive_errors() {
return status;
}
if let Some(status) = self.check_fill_rate().await {
return status;
}
CbStatus::Ok
}
fn check_drawdown(
&self,
current_equity: Decimal,
peak_equity: Decimal,
) -> Option<CbStatus> {
if peak_equity <= Decimal::ZERO {
return None;
}
let drawdown = (peak_equity - current_equity)
.to_f64()
.unwrap_or(0.0)
/ peak_equity.to_f64().unwrap_or(1.0);
if drawdown >= self.config.max_drawdown_pct {
let msg = format!(
"drawdown {:.1}% exceeds max {:.1}%",
drawdown * 100.0,
self.config.max_drawdown_pct * 100.0
);
warn!(rule = "max_drawdown", %msg);
self.log_event("max_drawdown", &msg, "pause");
return Some(CbStatus::Tripped(msg));
}
None
}
async fn check_daily_loss(&self) -> Option<CbStatus> {
let today_start =
Utc::now().date_naive().and_hms_opt(0, 0, 0)?;
let today_utc = chrono::TimeZone::from_utc_datetime(
&Utc,
&today_start,
);
let fills = self
.store
.get_recent_fills(1000)
.await
.unwrap_or_default();
let daily_pnl: f64 = fills
.iter()
.filter(|f| f.timestamp >= today_utc)
.filter_map(|f| {
f.pnl.as_ref()?.to_f64()
})
.sum();
let peak = self
.store
.get_peak_equity()
.await
.ok()
.flatten()
.unwrap_or(Decimal::new(10000, 0));
let peak_f64 = peak.to_f64().unwrap_or(10000.0);
if peak_f64 <= 0.0 {
return None;
}
let daily_loss_pct = (-daily_pnl) / peak_f64;
if daily_loss_pct >= self.config.max_daily_loss_pct {
let msg = format!(
"daily loss {:.1}% exceeds max {:.1}%",
daily_loss_pct * 100.0,
self.config.max_daily_loss_pct * 100.0
);
warn!(rule = "max_daily_loss", %msg);
self.log_event("max_daily_loss", &msg, "pause");
return Some(CbStatus::Tripped(msg));
}
None
}
fn check_max_positions(
&self,
count: usize,
) -> Option<CbStatus> {
let max = self.config.max_positions.unwrap_or(100);
if count >= max {
let msg = format!(
"positions {} at max {}",
count, max
);
warn!(rule = "max_positions", %msg);
return Some(CbStatus::Tripped(msg));
}
None
}
fn check_consecutive_errors(&self) -> Option<CbStatus> {
let max = self.config.max_consecutive_errors.unwrap_or(5);
if self.consecutive_errors >= max {
let msg = format!(
"{} consecutive errors (max {})",
self.consecutive_errors, max
);
warn!(rule = "consecutive_errors", %msg);
self.log_event(
"consecutive_errors",
&msg,
"pause",
);
return Some(CbStatus::Tripped(msg));
}
None
}
async fn check_fill_rate(&self) -> Option<CbStatus> {
let max_hourly =
self.config.max_fills_per_hour.unwrap_or(50);
let max_daily =
self.config.max_fills_per_day.unwrap_or(200);
let one_hour_ago = Utc::now() - Duration::hours(1);
let hourly_fills = self
.store
.get_fills_since(one_hour_ago)
.await
.unwrap_or(0);
if hourly_fills >= max_hourly {
let msg = format!(
"{} fills/hour exceeds max {}",
hourly_fills, max_hourly
);
warn!(rule = "fill_rate_hourly", %msg);
self.log_event("fill_rate_hourly", &msg, "pause");
return Some(CbStatus::Tripped(msg));
}
let today_start = Utc::now() - Duration::hours(24);
let daily_fills = self
.store
.get_fills_since(today_start)
.await
.unwrap_or(0);
if daily_fills >= max_daily {
let msg = format!(
"{} fills/day exceeds max {}",
daily_fills, max_daily
);
warn!(rule = "fill_rate_daily", %msg);
self.log_event("fill_rate_daily", &msg, "pause");
return Some(CbStatus::Tripped(msg));
}
None
}
fn log_event(&self, rule: &str, details: &str, action: &str) {
let store = self.store.clone();
let rule = rule.to_string();
let details = details.to_string();
let action = action.to_string();
tokio::spawn(async move {
let _ = store
.record_circuit_breaker_event(
&rule, &details, &action,
)
.await;
});
}
}

506
src/engine/mod.rs Normal file
View File

@ -0,0 +1,506 @@
pub mod circuit_breaker;
pub mod state;
pub use circuit_breaker::{CbStatus, CircuitBreaker};
pub use state::EngineState;
use crate::api::KalshiClient;
use crate::config::AppConfig;
use crate::execution::OrderExecutor;
use crate::paper_executor::PaperExecutor;
use crate::pipeline::{
AlreadyPositionedFilter, BollingerMeanReversionScorer,
CategoryWeightedScorer, Filter, LiveKalshiSource,
LiquidityFilter, MeanReversionScorer, MomentumScorer,
MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer,
Selector, Source, TimeDecayScorer, TimeToCloseFilter,
TopKSelector, TradingPipeline, VolumeScorer,
};
use crate::store::SqliteStore;
use crate::types::{Portfolio, Trade, TradeType, TradingContext};
use chrono::Utc;
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, Mutex, RwLock};
use tracing::{error, info, warn};
pub struct EngineStatus {
pub state: EngineState,
pub uptime_secs: u64,
pub last_tick: Option<chrono::DateTime<Utc>>,
pub ticks_completed: u64,
}
pub struct PaperTradingEngine {
config: AppConfig,
store: Arc<SqliteStore>,
executor: Arc<PaperExecutor>,
pipeline: Mutex<TradingPipeline>,
circuit_breaker: Mutex<CircuitBreaker>,
state: RwLock<EngineState>,
context: RwLock<TradingContext>,
shutdown_tx: broadcast::Sender<()>,
start_time: Instant,
ticks: RwLock<u64>,
last_tick: RwLock<Option<chrono::DateTime<Utc>>>,
}
impl PaperTradingEngine {
pub async fn new(
config: AppConfig,
store: Arc<SqliteStore>,
executor: Arc<PaperExecutor>,
client: Arc<KalshiClient>,
) -> anyhow::Result<Self> {
let (shutdown_tx, _) = broadcast::channel(1);
let pipeline =
Self::build_pipeline(client, &config);
let circuit_breaker = CircuitBreaker::new(
config.circuit_breaker.clone(),
store.clone(),
);
let initial_capital = Decimal::try_from(
config.trading.initial_capital,
)
.unwrap_or(Decimal::new(10000, 0));
let portfolio = store
.load_portfolio()
.await?
.unwrap_or_else(|| Portfolio::new(initial_capital));
let mut ctx = TradingContext::new(
portfolio.initial_capital,
Utc::now(),
);
ctx.portfolio = portfolio;
Ok(Self {
config,
store,
executor,
pipeline: Mutex::new(pipeline),
circuit_breaker: Mutex::new(circuit_breaker),
state: RwLock::new(EngineState::Starting),
context: RwLock::new(ctx),
shutdown_tx,
start_time: Instant::now(),
ticks: RwLock::new(0),
last_tick: RwLock::new(None),
})
}
fn build_pipeline(
client: Arc<KalshiClient>,
config: &AppConfig,
) -> TradingPipeline {
let sources: Vec<Box<dyn Source>> =
vec![Box::new(LiveKalshiSource::new(client))];
let max_pos_size =
(config.trading.initial_capital
* config.trading.max_position_pct)
as u64;
let filters: Vec<Box<dyn Filter>> = vec![
Box::new(TimeToCloseFilter::new(2, Some(720))),
Box::new(AlreadyPositionedFilter::new(
max_pos_size.max(100),
)),
];
let scorers: Vec<Box<dyn Scorer>> = vec![
Box::new(MomentumScorer::new(6)),
Box::new(
MultiTimeframeMomentumScorer::default_windows(),
),
Box::new(MeanReversionScorer::new(24)),
Box::new(
BollingerMeanReversionScorer::default_config(),
),
Box::new(VolumeScorer::new(6)),
Box::new(OrderFlowScorer::new()),
Box::new(TimeDecayScorer::new()),
Box::new(CategoryWeightedScorer::with_defaults()),
];
let max_positions = config.trading.max_positions;
let selector: Box<dyn Selector> =
Box::new(TopKSelector::new(max_positions));
TradingPipeline::new(
sources,
filters,
scorers,
selector,
max_positions,
)
}
pub fn shutdown_handle(&self) -> broadcast::Sender<()> {
self.shutdown_tx.clone()
}
pub async fn get_status(&self) -> EngineStatus {
EngineStatus {
state: self.state.read().await.clone(),
uptime_secs: self.start_time.elapsed().as_secs(),
last_tick: *self.last_tick.read().await,
ticks_completed: *self.ticks.read().await,
}
}
pub async fn get_context(&self) -> TradingContext {
self.context.read().await.clone()
}
pub async fn pause(&self, reason: String) {
let mut state = self.state.write().await;
*state = EngineState::Paused(reason);
}
pub async fn resume(&self) {
let mut state = self.state.write().await;
if matches!(*state, EngineState::Paused(_)) {
*state = EngineState::Running;
}
}
pub async fn run(&self) -> anyhow::Result<()> {
let mut shutdown_rx = self.shutdown_tx.subscribe();
let poll_interval = std::time::Duration::from_secs(
self.config.kalshi.poll_interval_secs,
);
{
let mut state = self.state.write().await;
*state = EngineState::Recovering;
}
info!("recovering state from SQLite");
if let Ok(Some(portfolio)) =
self.store.load_portfolio().await
{
let mut ctx = self.context.write().await;
ctx.portfolio = portfolio;
info!(
positions = ctx.portfolio.positions.len(),
cash = %ctx.portfolio.cash,
"state recovered"
);
}
{
let mut state = self.state.write().await;
*state = EngineState::Running;
}
info!(
interval_secs = self.config.kalshi.poll_interval_secs,
"engine running"
);
// run first tick immediately
self.tick().await;
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!("shutdown signal received");
let mut state = self.state.write().await;
*state = EngineState::ShuttingDown;
break;
}
_ = tokio::time::sleep(poll_interval) => {
let current_state =
self.state.read().await.clone();
match current_state {
EngineState::Running => {
self.tick().await;
}
EngineState::Paused(ref reason) => {
info!(
reason = %reason,
"engine paused, skipping tick"
);
}
_ => {}
}
}
}
}
info!("persisting final state");
let ctx = self.context.read().await;
if let Err(e) =
self.store.save_portfolio(&ctx.portfolio).await
{
error!(error = %e, "failed to persist final state");
}
info!("engine shutdown complete");
Ok(())
}
async fn tick(&self) {
let tick_start = Instant::now();
let now = Utc::now();
let context_snapshot = {
let mut ctx = self.context.write().await;
ctx.timestamp = now;
ctx.request_id =
uuid::Uuid::new_v4().to_string();
ctx.clone()
};
let result = {
let pipeline = self.pipeline.lock().await;
pipeline.execute(context_snapshot.clone()).await
};
let candidates_fetched =
result.retrieved_candidates.len();
let candidates_filtered =
result.filtered_candidates.len();
let candidates_selected =
result.selected_candidates.len();
let candidate_scores: HashMap<String, f64> = result
.selected_candidates
.iter()
.map(|c| (c.ticker.clone(), c.final_score))
.collect();
let current_prices: HashMap<String, Decimal> = result
.selected_candidates
.iter()
.map(|c| (c.ticker.clone(), c.current_yes_price))
.collect();
self.executor.update_prices(current_prices).await;
let exit_signals = self.executor.generate_exit_signals(
&context_snapshot,
&candidate_scores,
);
let mut ctx = self.context.write().await;
let mut fills_executed = 0u32;
for exit in &exit_signals {
if let Some(position) = ctx
.portfolio
.positions
.get(&exit.ticker)
.cloned()
{
let pnl = ctx.portfolio.close_position(
&exit.ticker,
exit.current_price,
);
info!(
ticker = %exit.ticker,
reason = ?exit.reason,
pnl = ?pnl,
"paper exit"
);
let exit_fill = crate::types::Fill {
ticker: exit.ticker.clone(),
side: position.side,
quantity: position.quantity,
price: exit.current_price,
timestamp: now,
};
let reason_str =
format!("{:?}", exit.reason);
let _ = self
.store
.record_fill(
&exit_fill,
pnl,
Some(&reason_str),
)
.await;
ctx.trading_history.push(Trade {
ticker: exit.ticker.clone(),
side: position.side,
quantity: position.quantity,
price: exit.current_price,
timestamp: now,
trade_type: TradeType::Close,
});
fills_executed += 1;
}
}
let signals = self.executor.generate_signals(
&result.selected_candidates,
&*ctx,
);
let signals_generated = signals.len();
let peak_equity = self
.store
.get_peak_equity()
.await
.ok()
.flatten()
.unwrap_or(ctx.portfolio.initial_capital);
let positions_value: Decimal = ctx
.portfolio
.positions
.values()
.map(|p| {
p.avg_entry_price * Decimal::from(p.quantity)
})
.sum();
let current_equity = ctx.portfolio.cash + positions_value;
let cb_status = {
let cb = self.circuit_breaker.lock().await;
cb.check(
current_equity,
peak_equity,
ctx.portfolio.positions.len(),
)
.await
};
if let CbStatus::Tripped(reason) = cb_status {
warn!(
reason = %reason,
"circuit breaker tripped, pausing"
);
drop(ctx);
let mut state = self.state.write().await;
*state = EngineState::Paused(reason);
return;
}
{
let mut cb = self.circuit_breaker.lock().await;
cb.record_success();
}
for signal in signals {
if ctx.portfolio.positions.len()
>= self.config.trading.max_positions
{
break;
}
let context_for_exec = (*ctx).clone();
if let Some(fill) = self
.executor
.execute_signal(&signal, &context_for_exec)
.await
{
info!(
ticker = %fill.ticker,
side = ?fill.side,
qty = fill.quantity,
price = %fill.price,
"paper fill"
);
ctx.portfolio.apply_fill(&fill);
ctx.trading_history.push(Trade {
ticker: fill.ticker.clone(),
side: fill.side,
quantity: fill.quantity,
price: fill.price,
timestamp: fill.timestamp,
trade_type: TradeType::Open,
});
fills_executed += 1;
}
}
if let Err(e) =
self.store.save_portfolio(&ctx.portfolio).await
{
error!(error = %e, "failed to persist portfolio");
let mut cb = self.circuit_breaker.lock().await;
cb.record_error();
}
let positions_value: Decimal = ctx
.portfolio
.positions
.values()
.map(|p| {
p.avg_entry_price * Decimal::from(p.quantity)
})
.sum();
let equity = ctx.portfolio.cash + positions_value;
let drawdown = if peak_equity > Decimal::ZERO {
((peak_equity - equity)
.to_f64()
.unwrap_or(0.0))
/ peak_equity.to_f64().unwrap_or(1.0)
} else {
0.0
};
let _ = self
.store
.snapshot_equity(
now,
equity,
ctx.portfolio.cash,
positions_value,
drawdown.max(0.0),
)
.await;
let duration_ms =
tick_start.elapsed().as_millis() as u64;
let _ = self
.store
.record_pipeline_run(
now,
duration_ms,
candidates_fetched,
candidates_filtered,
candidates_selected,
signals_generated,
fills_executed as usize,
None,
)
.await;
{
let mut ticks = self.ticks.write().await;
*ticks += 1;
}
{
let mut last = self.last_tick.write().await;
*last = Some(now);
}
info!(
fetched = candidates_fetched,
filtered = candidates_filtered,
selected = candidates_selected,
signals = signals_generated,
fills = fills_executed,
equity = %equity,
duration_ms = duration_ms,
"tick complete"
);
}
}

25
src/engine/state.rs Normal file
View File

@ -0,0 +1,25 @@
use serde::Serialize;
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum EngineState {
Starting,
Recovering,
Running,
Paused(String),
ShuttingDown,
}
impl std::fmt::Display for EngineState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Starting => write!(f, "starting"),
Self::Recovering => write!(f, "recovering"),
Self::Running => write!(f, "running"),
Self::Paused(reason) => {
write!(f, "paused: {}", reason)
}
Self::ShuttingDown => write!(f, "shutting_down"),
}
}
}

View File

@ -1,7 +1,12 @@
use crate::data::HistoricalData; use crate::data::HistoricalData;
use crate::types::{ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side, Signal, TradingContext}; use crate::types::{
ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side,
Signal, TradingContext,
};
use async_trait::async_trait;
use rust_decimal::Decimal; use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive; use rust_decimal::prelude::ToPrimitive;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -15,8 +20,8 @@ pub struct PositionSizingConfig {
impl Default for PositionSizingConfig { impl Default for PositionSizingConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
kelly_fraction: 0.25, kelly_fraction: 0.40,
max_position_pct: 0.25, max_position_pct: 0.30,
min_position_size: 10, min_position_size: 10,
max_position_size: 1000, max_position_size: 1000,
} }
@ -43,13 +48,33 @@ impl PositionSizingConfig {
} }
} }
#[async_trait]
pub trait OrderExecutor: Send + Sync {
async fn execute_signal(
&self,
signal: &Signal,
context: &TradingContext,
) -> Option<Fill>;
fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal>;
fn generate_exit_signals(
&self,
context: &TradingContext,
candidate_scores: &HashMap<String, f64>,
) -> Vec<ExitSignal>;
}
/// maps scoring edge [-inf, +inf] to win probability [0, 1] /// maps scoring edge [-inf, +inf] to win probability [0, 1]
/// tanh squashes extreme values smoothly; +1)/2 shifts from [-1,1] to [0,1] pub fn edge_to_win_probability(edge: f64) -> f64 {
fn edge_to_win_probability(edge: f64) -> f64 {
(1.0 + edge.tanh()) / 2.0 (1.0 + edge.tanh()) / 2.0
} }
fn kelly_size( pub fn kelly_size(
edge: f64, edge: f64,
price: f64, price: f64,
bankroll: f64, bankroll: f64,
@ -68,13 +93,165 @@ fn kelly_size(
let kelly = (odds * win_prob - (1.0 - win_prob)) / odds; let kelly = (odds * win_prob - (1.0 - win_prob)) / odds;
let safe_kelly = (kelly * config.kelly_fraction).max(0.0); let safe_kelly = (kelly * config.kelly_fraction).max(0.0);
let position_value = bankroll * safe_kelly.min(config.max_position_pct); let position_value =
bankroll * safe_kelly.min(config.max_position_pct);
let shares = (position_value / price).floor() as u64; let shares = (position_value / price).floor() as u64;
shares.max(config.min_position_size).min(config.max_position_size) shares
.max(config.min_position_size)
.min(config.max_position_size)
} }
pub struct Executor { pub fn candidate_to_signal(
candidate: &MarketCandidate,
context: &TradingContext,
sizing_config: &PositionSizingConfig,
max_position_size: u64,
) -> Option<Signal> {
let current_position =
context.portfolio.get_position(&candidate.ticker);
let current_qty =
current_position.map(|p| p.quantity).unwrap_or(0);
if current_qty >= max_position_size {
return None;
}
let yes_price =
candidate.current_yes_price.to_f64().unwrap_or(0.5);
let side = if candidate.final_score > 0.0 {
if yes_price < 0.5 {
Side::Yes
} else {
Side::No
}
} else if candidate.final_score < 0.0 {
if yes_price > 0.5 {
Side::No
} else {
Side::Yes
}
} else {
return None;
};
let price = match side {
Side::Yes => candidate.current_yes_price,
Side::No => candidate.current_no_price,
};
let available_cash =
context.portfolio.cash.to_f64().unwrap_or(0.0);
let price_f64 = price.to_f64().unwrap_or(0.5);
if price_f64 <= 0.0 {
return None;
}
let kelly_qty = kelly_size(
candidate.final_score,
price_f64,
available_cash,
sizing_config,
);
let max_affordable = (available_cash / price_f64) as u64;
let quantity = kelly_qty
.min(max_affordable)
.min(max_position_size - current_qty);
if quantity < sizing_config.min_position_size {
return None;
}
Some(Signal {
ticker: candidate.ticker.clone(),
side,
quantity,
limit_price: Some(price),
reason: format!(
"score={:.3}, side={:?}, price={:.2}",
candidate.final_score, side, price_f64
),
})
}
pub fn compute_exit_signals(
context: &TradingContext,
candidate_scores: &HashMap<String, f64>,
exit_config: &ExitConfig,
price_lookup: &dyn Fn(&str) -> Option<Decimal>,
) -> Vec<ExitSignal> {
let mut exits = Vec::new();
for (ticker, position) in &context.portfolio.positions {
let current_price = match price_lookup(ticker) {
Some(p) => p,
None => continue,
};
let effective_price = match position.side {
Side::Yes => current_price,
Side::No => Decimal::ONE - current_price,
};
let entry_price_f64 =
position.avg_entry_price.to_f64().unwrap_or(0.5);
let current_price_f64 =
effective_price.to_f64().unwrap_or(0.5);
if entry_price_f64 <= 0.0 {
continue;
}
let pnl_pct = (current_price_f64 - entry_price_f64)
/ entry_price_f64;
if pnl_pct >= exit_config.take_profit_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TakeProfit { pnl_pct },
current_price,
});
continue;
}
if pnl_pct <= -exit_config.stop_loss_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::StopLoss { pnl_pct },
current_price,
});
continue;
}
let hours_held =
(context.timestamp - position.entry_time).num_hours();
if hours_held >= exit_config.max_hold_hours {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TimeStop { hours_held },
current_price,
});
continue;
}
if let Some(&new_score) = candidate_scores.get(ticker) {
if new_score < exit_config.score_reversal_threshold {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::ScoreReversal { new_score },
current_price,
});
}
}
}
exits
}
pub struct BacktestExecutor {
data: Arc<HistoricalData>, data: Arc<HistoricalData>,
slippage_bps: u32, slippage_bps: u32,
max_position_size: u64, max_position_size: u64,
@ -82,8 +259,12 @@ pub struct Executor {
exit_config: ExitConfig, exit_config: ExitConfig,
} }
impl Executor { impl BacktestExecutor {
pub fn new(data: Arc<HistoricalData>, slippage_bps: u32, max_position_size: u64) -> Self { pub fn new(
data: Arc<HistoricalData>,
slippage_bps: u32,
max_position_size: u64,
) -> Self {
Self { Self {
data, data,
slippage_bps, slippage_bps,
@ -93,7 +274,10 @@ impl Executor {
} }
} }
pub fn with_sizing_config(mut self, config: PositionSizingConfig) -> Self { pub fn with_sizing_config(
mut self,
config: PositionSizingConfig,
) -> Self {
self.sizing_config = config; self.sizing_config = config;
self self
} }
@ -102,168 +286,33 @@ impl Executor {
self.exit_config = config; self.exit_config = config;
self self
} }
}
pub fn generate_exit_signals( #[async_trait]
&self, impl OrderExecutor for BacktestExecutor {
context: &TradingContext, async fn execute_signal(
candidate_scores: &std::collections::HashMap<String, f64>,
) -> Vec<ExitSignal> {
let mut exits = Vec::new();
for (ticker, position) in &context.portfolio.positions {
let current_price = match self.data.get_current_price(ticker, context.timestamp) {
Some(p) => p,
None => continue,
};
let effective_price = match position.side {
Side::Yes => current_price,
Side::No => Decimal::ONE - current_price,
};
let entry_price_f64 = position.avg_entry_price.to_f64().unwrap_or(0.5);
let current_price_f64 = effective_price.to_f64().unwrap_or(0.5);
if entry_price_f64 <= 0.0 {
continue;
}
let pnl_pct = (current_price_f64 - entry_price_f64) / entry_price_f64;
if pnl_pct >= self.exit_config.take_profit_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TakeProfit { pnl_pct },
current_price,
});
continue;
}
if pnl_pct <= -self.exit_config.stop_loss_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::StopLoss { pnl_pct },
current_price,
});
continue;
}
let hours_held = (context.timestamp - position.entry_time).num_hours();
if hours_held >= self.exit_config.max_hold_hours {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TimeStop { hours_held },
current_price,
});
continue;
}
if let Some(&new_score) = candidate_scores.get(ticker) {
if new_score < self.exit_config.score_reversal_threshold {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::ScoreReversal { new_score },
current_price,
});
}
}
}
exits
}
pub fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal> {
candidates
.iter()
.filter_map(|c| self.candidate_to_signal(c, context))
.collect()
}
fn candidate_to_signal(
&self,
candidate: &MarketCandidate,
context: &TradingContext,
) -> Option<Signal> {
let current_position = context.portfolio.get_position(&candidate.ticker);
let current_qty = current_position.map(|p| p.quantity).unwrap_or(0);
if current_qty >= self.max_position_size {
return None;
}
let yes_price = candidate.current_yes_price.to_f64().unwrap_or(0.5);
// positive score = bullish signal, so buy the cheaper side (better risk/reward)
// negative score = bearish signal, so buy against the expensive side
let side = if candidate.final_score > 0.0 {
if yes_price < 0.5 { Side::Yes } else { Side::No }
} else if candidate.final_score < 0.0 {
if yes_price > 0.5 { Side::No } else { Side::Yes }
} else {
return None;
};
let price = match side {
Side::Yes => candidate.current_yes_price,
Side::No => candidate.current_no_price,
};
let available_cash = context.portfolio.cash.to_f64().unwrap_or(0.0);
let price_f64 = price.to_f64().unwrap_or(0.5);
if price_f64 <= 0.0 {
return None;
}
let kelly_qty = kelly_size(
candidate.final_score,
price_f64,
available_cash,
&self.sizing_config,
);
let max_affordable = (available_cash / price_f64) as u64;
let quantity = kelly_qty
.min(max_affordable)
.min(self.max_position_size - current_qty);
if quantity < self.sizing_config.min_position_size {
return None;
}
Some(Signal {
ticker: candidate.ticker.clone(),
side,
quantity,
limit_price: Some(price),
reason: format!(
"score={:.3}, side={:?}, price={:.2}",
candidate.final_score, side, price_f64
),
})
}
pub fn execute_signal(
&self, &self,
signal: &Signal, signal: &Signal,
context: &TradingContext, context: &TradingContext,
) -> Option<Fill> { ) -> Option<Fill> {
let market_price = self.data.get_current_price(&signal.ticker, context.timestamp)?; let market_price = self
.data
.get_current_price(&signal.ticker, context.timestamp)?;
let effective_price = match signal.side { let effective_price = match signal.side {
Side::Yes => market_price, Side::Yes => market_price,
Side::No => Decimal::ONE - market_price, Side::No => Decimal::ONE - market_price,
}; };
let slippage = Decimal::from(self.slippage_bps) / Decimal::from(10000); let slippage =
let fill_price = effective_price * (Decimal::ONE + slippage); Decimal::from(self.slippage_bps) / Decimal::from(10000);
let fill_price =
effective_price * (Decimal::ONE + slippage);
if let Some(limit) = signal.limit_price { if let Some(limit) = signal.limit_price {
if fill_price > limit * (Decimal::ONE + slippage * Decimal::from(2)) { if fill_price
> limit * (Decimal::ONE + slippage * Decimal::from(2))
{
return None; return None;
} }
} }
@ -294,6 +343,39 @@ impl Executor {
timestamp: context.timestamp, timestamp: context.timestamp,
}) })
} }
fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal> {
candidates
.iter()
.filter_map(|c| {
candidate_to_signal(
c,
context,
&self.sizing_config,
self.max_position_size,
)
})
.collect()
}
fn generate_exit_signals(
&self,
context: &TradingContext,
candidate_scores: &HashMap<String, f64>,
) -> Vec<ExitSignal> {
let data = self.data.clone();
let timestamp = context.timestamp;
compute_exit_signals(
context,
candidate_scores,
&self.exit_config,
&|ticker| data.get_current_price(ticker, timestamp),
)
}
} }
pub fn simple_signal_generator( pub fn simple_signal_generator(
@ -306,7 +388,8 @@ pub fn simple_signal_generator(
.filter(|c| c.final_score > 0.0) .filter(|c| c.final_score > 0.0)
.filter(|c| !context.portfolio.has_position(&c.ticker)) .filter(|c| !context.portfolio.has_position(&c.ticker))
.map(|c| { .map(|c| {
let yes_price = c.current_yes_price.to_f64().unwrap_or(0.5); let yes_price =
c.current_yes_price.to_f64().unwrap_or(0.5);
let (side, price) = if yes_price < 0.5 { let (side, price) = if yes_price < 0.5 {
(Side::Yes, c.current_yes_price) (Side::Yes, c.current_yes_price)
} else { } else {
@ -318,7 +401,10 @@ pub fn simple_signal_generator(
side, side,
quantity: position_size, quantity: position_size,
limit_price: Some(price), limit_price: Some(price),
reason: format!("simple: score={:.3}", c.final_score), reason: format!(
"simple: score={:.3}",
c.final_score
),
} }
}) })
.collect() .collect()

View File

@ -1,16 +1,22 @@
mod api;
mod backtest; mod backtest;
mod config;
mod data; mod data;
mod engine;
mod execution; mod execution;
mod metrics; mod metrics;
mod paper_executor;
mod pipeline; mod pipeline;
mod store;
mod types; mod types;
mod web;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use backtest::{Backtester, RandomBaseline}; use backtest::{Backtester, RandomBaseline};
use chrono::{DateTime, NaiveDate, TimeZone, Utc}; use chrono::{DateTime, NaiveDate, TimeZone, Utc};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use data::HistoricalData; use data::HistoricalData;
use execution::{Executor, PositionSizingConfig}; use execution::PositionSizingConfig;
use rust_decimal::Decimal; use rust_decimal::Decimal;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
@ -44,7 +50,8 @@ enum Commands {
#[arg(long, default_value = "100")] #[arg(long, default_value = "100")]
max_position: u64, max_position: u64,
#[arg(long, default_value = "5")] /// max concurrent positions (higher = more diversified)
#[arg(long, default_value = "100")]
max_positions: usize, max_positions: usize,
#[arg(long, default_value = "1")] #[arg(long, default_value = "1")]
@ -56,35 +63,48 @@ enum Commands {
#[arg(long)] #[arg(long)]
compare_random: bool, compare_random: bool,
#[arg(long, default_value = "0.25")] /// kelly fraction for position sizing
#[arg(long, default_value = "0.40")]
kelly_fraction: f64, kelly_fraction: f64,
#[arg(long, default_value = "0.25")] /// max portfolio % per position
#[arg(long, default_value = "0.30")]
max_position_pct: f64, max_position_pct: f64,
#[arg(long, default_value = "0.20")] /// take profit threshold
#[arg(long, default_value = "0.50")]
take_profit: f64, take_profit: f64,
#[arg(long, default_value = "0.15")] /// stop loss threshold
#[arg(long, default_value = "0.99")]
stop_loss: f64, stop_loss: f64,
#[arg(long, default_value = "72")] /// max hours to hold a position
#[arg(long, default_value = "48")]
max_hold_hours: i64, max_hold_hours: i64,
}, },
Paper {
/// path to config.toml
#[arg(short, long, default_value = "config.toml")]
config: PathBuf,
},
Summary { Summary {
#[arg(short, long)] #[arg(short, long)]
results_file: PathBuf, results_file: PathBuf,
}, },
} }
fn parse_date(s: &str) -> Result<DateTime<Utc>> { pub fn parse_date(s: &str) -> Result<DateTime<Utc>> {
if let Ok(dt) = DateTime::parse_from_rfc3339(s) { if let Ok(dt) = DateTime::parse_from_rfc3339(s) {
return Ok(dt.with_timezone(&Utc)); return Ok(dt.with_timezone(&Utc));
} }
if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") { if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
return Ok(Utc.from_utc_datetime(&date.and_hms_opt(0, 0, 0).unwrap())); return Ok(Utc.from_utc_datetime(
&date.and_hms_opt(0, 0, 0).unwrap(),
));
} }
Err(anyhow::anyhow!("could not parse date: {}", s)) Err(anyhow::anyhow!("could not parse date: {}", s))
@ -95,7 +115,9 @@ async fn main() -> Result<()> {
tracing_subscriber::registry() tracing_subscriber::registry()
.with( .with(
tracing_subscriber::EnvFilter::try_from_default_env() tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "kalshi_backtest=info".into()), .unwrap_or_else(|_| {
"kalshi_backtest=info".into()
}),
) )
.with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::fmt::layer())
.init(); .init();
@ -119,8 +141,10 @@ async fn main() -> Result<()> {
stop_loss, stop_loss,
max_hold_hours, max_hold_hours,
} => { } => {
let start_time = parse_date(&start).context("parsing start date")?; let start_time =
let end_time = parse_date(&end).context("parsing end date")?; parse_date(&start).context("parsing start date")?;
let end_time =
parse_date(&end).context("parsing end date")?;
info!( info!(
data_dir = %data_dir.display(), data_dir = %data_dir.display(),
@ -131,7 +155,8 @@ async fn main() -> Result<()> {
); );
let data = Arc::new( let data = Arc::new(
HistoricalData::load(&data_dir).context("loading historical data")?, HistoricalData::load(&data_dir)
.context("loading historical data")?,
); );
info!( info!(
@ -144,7 +169,8 @@ async fn main() -> Result<()> {
start_time, start_time,
end_time, end_time,
interval: chrono::Duration::hours(interval_hours), interval: chrono::Duration::hours(interval_hours),
initial_capital: Decimal::try_from(capital).unwrap(), initial_capital: Decimal::try_from(capital)
.unwrap(),
max_position_size: max_position, max_position_size: max_position,
max_positions, max_positions,
}; };
@ -163,16 +189,25 @@ async fn main() -> Result<()> {
score_reversal_threshold: -0.3, score_reversal_threshold: -0.3,
}; };
let backtester = Backtester::with_configs(config.clone(), data.clone(), sizing_config, exit_config); let backtester = Backtester::with_configs(
config.clone(),
data.clone(),
sizing_config,
exit_config,
);
let result = backtester.run().await; let result = backtester.run().await;
println!("{}", result.summary()); println!("{}", result.summary());
std::fs::create_dir_all(&output_dir)?; std::fs::create_dir_all(&output_dir)?;
let result_path = output_dir.join("backtest_result.json"); let result_path =
output_dir.join("backtest_result.json");
let json = serde_json::to_string_pretty(&result)?; let json = serde_json::to_string_pretty(&result)?;
std::fs::write(&result_path, json)?; std::fs::write(&result_path, json)?;
info!(path = %result_path.display(), "results saved"); info!(
path = %result_path.display(),
"results saved"
);
if compare_random { if compare_random {
println!("\n--- random baseline ---\n"); println!("\n--- random baseline ---\n");
@ -180,18 +215,23 @@ async fn main() -> Result<()> {
let baseline_result = baseline.run().await; let baseline_result = baseline.run().await;
println!("{}", baseline_result.summary()); println!("{}", baseline_result.summary());
let baseline_path = output_dir.join("baseline_result.json"); let baseline_path =
let json = serde_json::to_string_pretty(&baseline_result)?; output_dir.join("baseline_result.json");
let json = serde_json::to_string_pretty(
&baseline_result,
)?;
std::fs::write(&baseline_path, json)?; std::fs::write(&baseline_path, json)?;
println!("\n--- comparison ---\n"); println!("\n--- comparison ---\n");
println!( println!(
"strategy return: {:.2}% vs baseline: {:.2}%", "strategy return: {:.2}% vs baseline: {:.2}%",
result.total_return_pct, baseline_result.total_return_pct result.total_return_pct,
baseline_result.total_return_pct
); );
println!( println!(
"strategy sharpe: {:.3} vs baseline: {:.3}", "strategy sharpe: {:.3} vs baseline: {:.3}",
result.sharpe_ratio, baseline_result.sharpe_ratio result.sharpe_ratio,
baseline_result.sharpe_ratio
); );
println!( println!(
"strategy win rate: {:.1}% vs baseline: {:.1}%", "strategy win rate: {:.1}% vs baseline: {:.1}%",
@ -202,11 +242,16 @@ async fn main() -> Result<()> {
Ok(()) Ok(())
} }
Commands::Paper { config: config_path } => {
run_paper(config_path).await
}
Commands::Summary { results_file } => { Commands::Summary { results_file } => {
let content = std::fs::read_to_string(&results_file) let content = std::fs::read_to_string(&results_file)
.context("reading results file")?; .context("reading results file")?;
let result: metrics::BacktestResult = let result: metrics::BacktestResult =
serde_json::from_str(&content).context("parsing results")?; serde_json::from_str(&content)
.context("parsing results")?;
println!("{}", result.summary()); println!("{}", result.summary());
@ -214,3 +259,124 @@ async fn main() -> Result<()> {
} }
} }
} }
async fn run_paper(config_path: PathBuf) -> Result<()> {
let app_config = config::AppConfig::load(&config_path)
.context("loading config")?;
info!(
mode = ?app_config.mode,
poll_secs = app_config.kalshi.poll_interval_secs,
capital = app_config.trading.initial_capital,
"starting paper trading"
);
let store = Arc::new(
store::SqliteStore::new(&app_config.persistence.db_path)
.await
.context("initializing SQLite store")?,
);
let client = Arc::new(api::KalshiClient::new(
&app_config.kalshi,
));
let sizing_config = PositionSizingConfig {
kelly_fraction: app_config.trading.kelly_fraction,
max_position_pct: app_config.trading.max_position_pct,
min_position_size: 10,
max_position_size: 1000,
};
let exit_config = ExitConfig {
take_profit_pct: app_config
.trading
.take_profit_pct
.unwrap_or(0.50),
stop_loss_pct: app_config
.trading
.stop_loss_pct
.unwrap_or(0.99),
max_hold_hours: app_config
.trading
.max_hold_hours
.unwrap_or(48),
score_reversal_threshold: -0.3,
};
let executor = Arc::new(paper_executor::PaperExecutor::new(
1000,
sizing_config,
exit_config,
store.clone(),
));
let engine = engine::PaperTradingEngine::new(
app_config.clone(),
store.clone(),
executor,
client,
)
.await
.context("initializing engine")?;
let shutdown_tx = engine.shutdown_handle();
let engine = Arc::new(engine);
if app_config.web.enabled {
let web_state = Arc::new(web::AppState {
engine: engine.clone(),
store: store.clone(),
shutdown_tx: shutdown_tx.clone(),
backtest: Arc::new(tokio::sync::Mutex::new(
web::BacktestState {
status: web::BacktestRunStatus::Idle,
progress: None,
result: None,
error: None,
},
)),
data_dir: PathBuf::from("data"),
});
let router = web::build_router(web_state);
let bind_addr = app_config.web.bind_addr.clone();
info!(addr = %bind_addr, "starting web dashboard");
match tokio::net::TcpListener::bind(&bind_addr).await {
Ok(listener) => {
tokio::spawn(async move {
if let Err(e) =
axum::serve(listener, router).await
{
tracing::error!(
error = %e,
"web server error"
);
}
});
}
Err(e) => {
tracing::warn!(
addr = %bind_addr,
error = %e,
"web dashboard disabled (port in use)"
);
}
}
}
let shutdown_tx_clone = shutdown_tx.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
info!("ctrl+c received, shutting down");
let _ = shutdown_tx_clone.send(());
});
engine.run().await?;
info!("paper trading session ended");
Ok(())
}

143
src/paper_executor.rs Normal file
View File

@ -0,0 +1,143 @@
use crate::execution::{
candidate_to_signal, compute_exit_signals, OrderExecutor,
PositionSizingConfig,
};
use crate::store::SqliteStore;
use crate::types::{
ExitConfig, ExitSignal, Fill, MarketCandidate, Signal, Side,
TradingContext,
};
use async_trait::async_trait;
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct PaperExecutor {
max_position_size: u64,
sizing_config: PositionSizingConfig,
exit_config: ExitConfig,
store: Arc<SqliteStore>,
current_prices: Arc<RwLock<HashMap<String, Decimal>>>,
}
impl PaperExecutor {
pub fn new(
max_position_size: u64,
sizing_config: PositionSizingConfig,
exit_config: ExitConfig,
store: Arc<SqliteStore>,
) -> Self {
Self {
max_position_size,
sizing_config,
exit_config,
store,
current_prices: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn update_prices(
&self,
prices: HashMap<String, Decimal>,
) {
let mut current = self.current_prices.write().await;
*current = prices;
}
}
#[async_trait]
impl OrderExecutor for PaperExecutor {
async fn execute_signal(
&self,
signal: &Signal,
context: &TradingContext,
) -> Option<Fill> {
let prices = self.current_prices.read().await;
let market_price = prices.get(&signal.ticker).copied()?;
let effective_price = match signal.side {
Side::Yes => market_price,
Side::No => Decimal::ONE - market_price,
};
if let Some(limit) = signal.limit_price {
let tolerance = Decimal::new(5, 2);
if effective_price > limit * (Decimal::ONE + tolerance) {
return None;
}
}
let cost = effective_price * Decimal::from(signal.quantity);
let quantity = if cost > context.portfolio.cash {
let affordable = (context.portfolio.cash
/ effective_price)
.to_u64()
.unwrap_or(0);
if affordable == 0 {
return None;
}
affordable
} else {
signal.quantity
};
let fill = Fill {
ticker: signal.ticker.clone(),
side: signal.side,
quantity,
price: effective_price,
timestamp: context.timestamp,
};
if let Err(e) =
self.store.record_fill(&fill, None, None).await
{
tracing::error!(
error = %e,
"failed to persist fill"
);
}
Some(fill)
}
fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal> {
candidates
.iter()
.filter_map(|c| {
candidate_to_signal(
c,
context,
&self.sizing_config,
self.max_position_size,
)
})
.collect()
}
fn generate_exit_signals(
&self,
context: &TradingContext,
candidate_scores: &HashMap<String, f64>,
) -> Vec<ExitSignal> {
let prices = self.current_prices.try_read();
match prices {
Ok(prices) => {
let prices_ref = prices.clone();
compute_exit_signals(
context,
candidate_scores,
&self.exit_config,
&|ticker| prices_ref.get(ticker).copied(),
)
}
Err(_) => Vec::new(),
}
}
}

View File

@ -5,7 +5,6 @@ mod scorers;
mod selector; mod selector;
mod sources; mod sources;
pub use correlation_scorer::*;
pub use filters::*; pub use filters::*;
pub use ml_scorer::*; pub use ml_scorer::*;
pub use scorers::*; pub use scorers::*;

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,9 @@
use crate::api::KalshiClient;
use crate::data::HistoricalData; use crate::data::HistoricalData;
use crate::pipeline::Source; use crate::pipeline::Source;
use crate::types::{MarketCandidate, TradingContext}; use crate::types::{MarketCandidate, PricePoint, TradingContext};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Utc;
use rust_decimal::Decimal; use rust_decimal::Decimal;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -67,3 +69,78 @@ impl Source for HistoricalMarketSource {
Ok(candidates) Ok(candidates)
} }
} }
pub struct LiveKalshiSource {
client: Arc<KalshiClient>,
}
impl LiveKalshiSource {
pub fn new(client: Arc<KalshiClient>) -> Self {
Self { client }
}
}
#[async_trait]
impl Source for LiveKalshiSource {
fn name(&self) -> &'static str {
"LiveKalshiSource"
}
async fn get_candidates(
&self,
_context: &TradingContext,
) -> Result<Vec<MarketCandidate>, String> {
let markets = self
.client
.get_open_markets()
.await
.map_err(|e| format!("API error: {}", e))?;
let now = Utc::now();
let mut candidates = Vec::with_capacity(markets.len());
for market in markets {
let yes_price = market.mid_yes_price();
if yes_price <= 0.0 || yes_price >= 1.0 {
continue;
}
let yes_dec = Decimal::try_from(yes_price)
.unwrap_or(Decimal::new(50, 2));
let no_dec = Decimal::ONE - yes_dec;
let volume_24h = market.volume_24h.max(0) as u64;
let total_volume = market.volume.max(0) as u64;
// skip per-market trade fetching to avoid N+1
// the market list already provides volume data
// for filtering; scorers will work with what's available
let price_history = Vec::new();
let buy_vol = 0u64;
let sell_vol = 0u64;
let category = market.category_from_event();
candidates.push(MarketCandidate {
ticker: market.ticker,
title: market.title,
category,
current_yes_price: yes_dec,
current_no_price: no_dec,
volume_24h,
total_volume,
buy_volume_24h: buy_vol,
sell_volume_24h: sell_vol,
open_time: market.open_time,
close_time: market.close_time,
result: None,
price_history,
scores: HashMap::new(),
final_score: 0.0,
});
}
Ok(candidates)
}
}

4
src/store/mod.rs Normal file
View File

@ -0,0 +1,4 @@
mod queries;
mod schema;
pub use queries::*;

372
src/store/queries.rs Normal file
View File

@ -0,0 +1,372 @@
use crate::types::{Fill, Portfolio, Position, Side};
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use sqlx::SqlitePool;
use std::collections::HashMap;
use std::str::FromStr;
use super::schema::MIGRATIONS;
pub struct SqliteStore {
pool: SqlitePool,
}
impl SqliteStore {
pub async fn new(db_path: &str) -> anyhow::Result<Self> {
let url = format!("sqlite:{}?mode=rwc", db_path);
let pool = SqlitePool::connect(&url).await?;
let store = Self { pool };
store.run_migrations().await?;
Ok(store)
}
async fn run_migrations(&self) -> anyhow::Result<()> {
sqlx::raw_sql(MIGRATIONS).execute(&self.pool).await?;
Ok(())
}
pub async fn load_portfolio(
&self,
) -> anyhow::Result<Option<Portfolio>> {
let row = sqlx::query_as::<_, (String, String)>(
"SELECT cash, initial_capital FROM portfolio_state \
WHERE id = 1",
)
.fetch_optional(&self.pool)
.await?;
let Some((cash_str, capital_str)) = row else {
return Ok(None);
};
let cash = Decimal::from_str(&cash_str)?;
let initial_capital = Decimal::from_str(&capital_str)?;
let position_rows = sqlx::query_as::<
_,
(String, String, i64, String, String),
>(
"SELECT ticker, side, quantity, avg_entry_price, \
entry_time FROM positions",
)
.fetch_all(&self.pool)
.await?;
let mut positions = HashMap::new();
for (ticker, side_str, qty, price_str, time_str) in
position_rows
{
let side = match side_str.as_str() {
"Yes" => Side::Yes,
_ => Side::No,
};
let price = Decimal::from_str(&price_str)?;
let entry_time: DateTime<Utc> =
time_str.parse::<DateTime<Utc>>()?;
positions.insert(
ticker.clone(),
Position {
ticker,
side,
quantity: qty as u64,
avg_entry_price: price,
entry_time,
},
);
}
Ok(Some(Portfolio {
positions,
cash,
initial_capital,
}))
}
pub async fn save_portfolio(
&self,
portfolio: &Portfolio,
) -> anyhow::Result<()> {
let cash = portfolio.cash.to_string();
let capital = portfolio.initial_capital.to_string();
sqlx::query(
"INSERT INTO portfolio_state (id, cash, initial_capital, \
updated_at) VALUES (1, ?1, ?2, datetime('now')) \
ON CONFLICT(id) DO UPDATE SET \
cash = ?1, initial_capital = ?2, \
updated_at = datetime('now')",
)
.bind(&cash)
.bind(&capital)
.execute(&self.pool)
.await?;
sqlx::query("DELETE FROM positions")
.execute(&self.pool)
.await?;
for pos in portfolio.positions.values() {
let side = match pos.side {
Side::Yes => "Yes",
Side::No => "No",
};
sqlx::query(
"INSERT INTO positions \
(ticker, side, quantity, avg_entry_price, \
entry_time) VALUES (?1, ?2, ?3, ?4, ?5)",
)
.bind(&pos.ticker)
.bind(side)
.bind(pos.quantity as i64)
.bind(pos.avg_entry_price.to_string())
.bind(pos.entry_time.to_rfc3339())
.execute(&self.pool)
.await?;
}
Ok(())
}
pub async fn record_fill(
&self,
fill: &Fill,
pnl: Option<Decimal>,
exit_reason: Option<&str>,
) -> anyhow::Result<()> {
let side = match fill.side {
Side::Yes => "Yes",
Side::No => "No",
};
sqlx::query(
"INSERT INTO fills \
(ticker, side, quantity, price, timestamp, pnl, \
exit_reason) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
)
.bind(&fill.ticker)
.bind(side)
.bind(fill.quantity as i64)
.bind(fill.price.to_string())
.bind(fill.timestamp.to_rfc3339())
.bind(pnl.map(|p| p.to_string()))
.bind(exit_reason)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn snapshot_equity(
&self,
timestamp: DateTime<Utc>,
equity: Decimal,
cash: Decimal,
positions_value: Decimal,
drawdown_pct: f64,
) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO equity_snapshots \
(timestamp, equity, cash, positions_value, \
drawdown_pct) VALUES (?1, ?2, ?3, ?4, ?5)",
)
.bind(timestamp.to_rfc3339())
.bind(equity.to_string())
.bind(cash.to_string())
.bind(positions_value.to_string())
.bind(drawdown_pct)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn record_circuit_breaker_event(
&self,
rule: &str,
details: &str,
action: &str,
) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO circuit_breaker_events \
(timestamp, rule, details, action) \
VALUES (datetime('now'), ?1, ?2, ?3)",
)
.bind(rule)
.bind(details)
.bind(action)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn record_pipeline_run(
&self,
timestamp: DateTime<Utc>,
duration_ms: u64,
candidates_fetched: usize,
candidates_filtered: usize,
candidates_selected: usize,
signals_generated: usize,
fills_executed: usize,
errors: Option<&str>,
) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO pipeline_runs \
(timestamp, duration_ms, candidates_fetched, \
candidates_filtered, candidates_selected, \
signals_generated, fills_executed, errors) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(timestamp.to_rfc3339())
.bind(duration_ms as i64)
.bind(candidates_fetched as i64)
.bind(candidates_filtered as i64)
.bind(candidates_selected as i64)
.bind(signals_generated as i64)
.bind(fills_executed as i64)
.bind(errors)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn get_equity_curve(
&self,
) -> anyhow::Result<Vec<EquitySnapshot>> {
let rows = sqlx::query_as::<_, (String, String, String, String, f64)>(
"SELECT timestamp, equity, cash, positions_value, \
drawdown_pct FROM equity_snapshots \
ORDER BY timestamp ASC",
)
.fetch_all(&self.pool)
.await?;
let mut snapshots = Vec::with_capacity(rows.len());
for (ts, eq, cash, pv, dd) in rows {
snapshots.push(EquitySnapshot {
timestamp: ts.parse::<DateTime<Utc>>()?,
equity: Decimal::from_str(&eq)?,
cash: Decimal::from_str(&cash)?,
positions_value: Decimal::from_str(&pv)?,
drawdown_pct: dd,
});
}
Ok(snapshots)
}
pub async fn get_recent_fills(
&self,
limit: u32,
) -> anyhow::Result<Vec<FillRecord>> {
let rows = sqlx::query_as::<
_,
(String, String, i64, String, String, Option<String>, Option<String>),
>(
"SELECT ticker, side, quantity, price, timestamp, \
pnl, exit_reason FROM fills \
ORDER BY id DESC LIMIT ?1",
)
.bind(limit as i64)
.fetch_all(&self.pool)
.await?;
let mut fills = Vec::with_capacity(rows.len());
for (ticker, side, qty, price, ts, pnl, reason) in rows {
fills.push(FillRecord {
ticker,
side: match side.as_str() {
"Yes" => Side::Yes,
_ => Side::No,
},
quantity: qty as u64,
price: Decimal::from_str(&price)?,
timestamp: ts.parse::<DateTime<Utc>>()?,
pnl: pnl
.map(|p| Decimal::from_str(&p))
.transpose()?,
exit_reason: reason,
});
}
Ok(fills)
}
pub async fn get_fills_since(
&self,
since: DateTime<Utc>,
) -> anyhow::Result<u32> {
let row = sqlx::query_as::<_, (i64,)>(
"SELECT COUNT(*) FROM fills \
WHERE timestamp >= ?1",
)
.bind(since.to_rfc3339())
.fetch_one(&self.pool)
.await?;
Ok(row.0 as u32)
}
pub async fn get_circuit_breaker_events(
&self,
limit: u32,
) -> anyhow::Result<Vec<CbEvent>> {
let rows = sqlx::query_as::<_, (String, String, String, String)>(
"SELECT timestamp, rule, details, action \
FROM circuit_breaker_events \
ORDER BY id DESC LIMIT ?1",
)
.bind(limit as i64)
.fetch_all(&self.pool)
.await?;
let mut events = Vec::with_capacity(rows.len());
for (ts, rule, details, action) in rows {
events.push(CbEvent {
timestamp: ts.parse::<DateTime<Utc>>()?,
rule,
details,
action,
});
}
Ok(events)
}
pub async fn get_peak_equity(
&self,
) -> anyhow::Result<Option<Decimal>> {
let row = sqlx::query_as::<_, (Option<String>,)>(
"SELECT MAX(equity) FROM equity_snapshots",
)
.fetch_one(&self.pool)
.await?;
match row.0 {
Some(s) => Ok(Some(Decimal::from_str(&s)?)),
None => Ok(None),
}
}
}
#[derive(Debug, Clone)]
pub struct EquitySnapshot {
pub timestamp: DateTime<Utc>,
pub equity: Decimal,
pub cash: Decimal,
pub positions_value: Decimal,
pub drawdown_pct: f64,
}
#[derive(Debug, Clone)]
pub struct FillRecord {
pub ticker: String,
pub side: Side,
pub quantity: u64,
pub price: Decimal,
pub timestamp: DateTime<Utc>,
pub pnl: Option<Decimal>,
pub exit_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CbEvent {
pub timestamp: DateTime<Utc>,
pub rule: String,
pub details: String,
pub action: String,
}

56
src/store/schema.rs Normal file
View File

@ -0,0 +1,56 @@
pub const MIGRATIONS: &str = r#"
CREATE TABLE IF NOT EXISTS portfolio_state (
id INTEGER PRIMARY KEY CHECK (id = 1),
cash TEXT NOT NULL,
initial_capital TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS positions (
ticker TEXT PRIMARY KEY,
side TEXT NOT NULL,
quantity INTEGER NOT NULL,
avg_entry_price TEXT NOT NULL,
entry_time TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS fills (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ticker TEXT NOT NULL,
side TEXT NOT NULL,
quantity INTEGER NOT NULL,
price TEXT NOT NULL,
timestamp TEXT NOT NULL,
pnl TEXT,
exit_reason TEXT
);
CREATE TABLE IF NOT EXISTS equity_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
equity TEXT NOT NULL,
cash TEXT NOT NULL,
positions_value TEXT NOT NULL,
drawdown_pct REAL NOT NULL DEFAULT 0.0
);
CREATE TABLE IF NOT EXISTS circuit_breaker_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
rule TEXT NOT NULL,
details TEXT NOT NULL,
action TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS pipeline_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
duration_ms INTEGER NOT NULL,
candidates_fetched INTEGER NOT NULL DEFAULT 0,
candidates_filtered INTEGER NOT NULL DEFAULT 0,
candidates_selected INTEGER NOT NULL DEFAULT 0,
signals_generated INTEGER NOT NULL DEFAULT 0,
fills_executed INTEGER NOT NULL DEFAULT 0,
errors TEXT
);
"#;

View File

@ -266,11 +266,14 @@ pub struct ExitConfig {
impl Default for ExitConfig { impl Default for ExitConfig {
fn default() -> Self { fn default() -> Self {
// optimized for prediction markets based on iteration 3 testing
// - 50% take profit balances locking gains vs letting winners run
// - stop loss disabled (prices gap through, doesn't help)
Self { Self {
take_profit_pct: 0.20, take_profit_pct: 0.50,
stop_loss_pct: 0.15, stop_loss_pct: 0.99, // effectively disabled
max_hold_hours: 72, max_hold_hours: 48,
score_reversal_threshold: -0.3, score_reversal_threshold: -0.5,
} }
} }
} }
@ -293,6 +296,20 @@ impl ExitConfig {
score_reversal_threshold: -0.5, score_reversal_threshold: -0.5,
} }
} }
/// optimized for prediction markets with binary outcomes
/// - disables mechanical stop loss (prices gap through anyway)
/// - raises take profit to 100% (let winners run)
/// - relies on signal reversal for early exits
/// - position sizing limits max loss per trade
pub fn prediction_market() -> Self {
Self {
take_profit_pct: 1.00, // only exit at +100% (doubled)
stop_loss_pct: 0.99, // effectively disabled
max_hold_hours: 48, // shorter for 2-day backtest
score_reversal_threshold: -0.5, // exit on strong signal reversal
}
}
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

514
src/web/handlers.rs Normal file
View File

@ -0,0 +1,514 @@
use super::{AppState, BacktestRunStatus};
use crate::backtest::Backtester;
use crate::data::HistoricalData;
use crate::execution::PositionSizingConfig;
use crate::types::{BacktestConfig, ExitConfig};
use axum::extract::State;
use axum::http::StatusCode;
use axum::Json;
use chrono::Utc;
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tracing::{error, info};
#[derive(Serialize)]
pub struct StatusResponse {
pub state: String,
pub uptime_secs: u64,
pub last_tick: Option<String>,
pub ticks_completed: u64,
}
#[derive(Serialize)]
pub struct PortfolioResponse {
pub cash: f64,
pub equity: f64,
pub initial_capital: f64,
pub return_pct: f64,
pub drawdown_pct: f64,
pub positions_count: usize,
}
#[derive(Serialize)]
pub struct PositionResponse {
pub ticker: String,
pub side: String,
pub quantity: u64,
pub entry_price: f64,
pub entry_time: String,
pub unrealized_pnl: f64,
}
#[derive(Serialize)]
pub struct TradeResponse {
pub ticker: String,
pub side: String,
pub quantity: u64,
pub price: f64,
pub timestamp: String,
pub pnl: Option<f64>,
pub exit_reason: Option<String>,
}
#[derive(Serialize)]
pub struct EquityPoint {
pub timestamp: String,
pub equity: f64,
pub cash: f64,
pub positions_value: f64,
pub drawdown_pct: f64,
}
#[derive(Serialize)]
pub struct CbResponse {
pub status: String,
pub events: Vec<CbEventResponse>,
}
#[derive(Serialize)]
pub struct CbEventResponse {
pub timestamp: String,
pub rule: String,
pub details: String,
pub action: String,
}
pub async fn get_status(
State(state): State<Arc<AppState>>,
) -> Json<StatusResponse> {
let status = state.engine.get_status().await;
Json(StatusResponse {
state: format!("{}", status.state),
uptime_secs: status.uptime_secs,
last_tick: status
.last_tick
.map(|t| t.to_rfc3339()),
ticks_completed: status.ticks_completed,
})
}
pub async fn get_portfolio(
State(state): State<Arc<AppState>>,
) -> Json<PortfolioResponse> {
let ctx = state.engine.get_context().await;
let portfolio = &ctx.portfolio;
let positions_value: f64 = portfolio
.positions
.values()
.map(|p| {
p.avg_entry_price.to_f64().unwrap_or(0.0)
* p.quantity as f64
})
.sum();
let cash = portfolio.cash.to_f64().unwrap_or(0.0);
let equity = cash + positions_value;
let initial = portfolio
.initial_capital
.to_f64()
.unwrap_or(10000.0);
let return_pct = if initial > 0.0 {
(equity - initial) / initial * 100.0
} else {
0.0
};
let peak = state
.store
.get_peak_equity()
.await
.ok()
.flatten()
.and_then(|p| p.to_f64())
.unwrap_or(equity);
let drawdown_pct = if peak > 0.0 {
((peak - equity) / peak * 100.0).max(0.0)
} else {
0.0
};
Json(PortfolioResponse {
cash,
equity,
initial_capital: initial,
return_pct,
drawdown_pct,
positions_count: portfolio.positions.len(),
})
}
pub async fn get_positions(
State(state): State<Arc<AppState>>,
) -> Json<Vec<PositionResponse>> {
let ctx = state.engine.get_context().await;
let positions: Vec<PositionResponse> = ctx
.portfolio
.positions
.values()
.map(|p| {
let entry = p.avg_entry_price.to_f64().unwrap_or(0.0);
PositionResponse {
ticker: p.ticker.clone(),
side: format!("{:?}", p.side),
quantity: p.quantity,
entry_price: entry,
entry_time: p.entry_time.to_rfc3339(),
unrealized_pnl: 0.0,
}
})
.collect();
Json(positions)
}
pub async fn get_trades(
State(state): State<Arc<AppState>>,
) -> Json<Vec<TradeResponse>> {
let fills = state
.store
.get_recent_fills(100)
.await
.unwrap_or_default();
let trades: Vec<TradeResponse> = fills
.into_iter()
.map(|f| TradeResponse {
ticker: f.ticker,
side: format!("{:?}", f.side),
quantity: f.quantity,
price: f.price.to_f64().unwrap_or(0.0),
timestamp: f.timestamp.to_rfc3339(),
pnl: f.pnl.and_then(|p| p.to_f64()),
exit_reason: f.exit_reason,
})
.collect();
Json(trades)
}
pub async fn get_equity(
State(state): State<Arc<AppState>>,
) -> Json<Vec<EquityPoint>> {
let snapshots = state
.store
.get_equity_curve()
.await
.unwrap_or_default();
let points: Vec<EquityPoint> = snapshots
.into_iter()
.map(|s| EquityPoint {
timestamp: s.timestamp.to_rfc3339(),
equity: s.equity.to_f64().unwrap_or(0.0),
cash: s.cash.to_f64().unwrap_or(0.0),
positions_value: s.positions_value.to_f64().unwrap_or(0.0),
drawdown_pct: s.drawdown_pct,
})
.collect();
Json(points)
}
pub async fn get_circuit_breaker(
State(state): State<Arc<AppState>>,
) -> Json<CbResponse> {
let engine_status = state.engine.get_status().await;
let cb_status = match engine_status.state {
crate::engine::EngineState::Paused(ref reason) => {
format!("tripped: {}", reason)
}
_ => "ok".to_string(),
};
let events = state
.store
.get_circuit_breaker_events(20)
.await
.unwrap_or_default();
let event_responses: Vec<CbEventResponse> = events
.into_iter()
.map(|e| CbEventResponse {
timestamp: e.timestamp.to_rfc3339(),
rule: e.rule,
details: e.details,
action: e.action,
})
.collect();
Json(CbResponse {
status: cb_status,
events: event_responses,
})
}
pub async fn post_pause(
State(state): State<Arc<AppState>>,
) -> StatusCode {
state
.engine
.pause("manual pause via API".to_string())
.await;
StatusCode::OK
}
pub async fn post_resume(
State(state): State<Arc<AppState>>,
) -> StatusCode {
state.engine.resume().await;
StatusCode::OK
}
#[derive(Deserialize)]
pub struct BacktestRequest {
pub start: String,
pub end: String,
pub data_dir: Option<String>,
pub capital: Option<f64>,
pub max_positions: Option<usize>,
pub max_position: Option<u64>,
pub interval_hours: Option<i64>,
pub kelly_fraction: Option<f64>,
pub max_position_pct: Option<f64>,
pub take_profit: Option<f64>,
pub stop_loss: Option<f64>,
pub max_hold_hours: Option<i64>,
}
#[derive(Serialize)]
pub struct BacktestStatusResponse {
pub status: String,
pub elapsed_secs: Option<u64>,
pub error: Option<String>,
pub phase: Option<String>,
pub current_step: Option<u64>,
pub total_steps: Option<u64>,
pub progress_pct: Option<f64>,
}
#[derive(Serialize)]
pub struct BacktestErrorResponse {
pub error: String,
}
pub async fn post_backtest_run(
State(state): State<Arc<AppState>>,
Json(req): Json<BacktestRequest>,
) -> Result<StatusCode, (StatusCode, Json<BacktestErrorResponse>)> {
{
let guard = state.backtest.lock().await;
if matches!(guard.status, BacktestRunStatus::Running { .. }) {
return Err((
StatusCode::CONFLICT,
Json(BacktestErrorResponse {
error: "backtest already running".into(),
}),
));
}
}
let start_time = crate::parse_date(&req.start).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(BacktestErrorResponse {
error: format!("invalid start date: {}", e),
}),
)
})?;
let end_time = crate::parse_date(&req.end).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(BacktestErrorResponse {
error: format!("invalid end date: {}", e),
}),
)
})?;
let data_dir = if let Some(ref dir) = req.data_dir {
PathBuf::from(dir)
} else {
state.data_dir.clone()
};
if !data_dir.exists() {
return Err((
StatusCode::BAD_REQUEST,
Json(BacktestErrorResponse {
error: format!(
"data directory not found: {}",
data_dir.display()
),
}),
));
}
info!(
start = %start_time,
end = %end_time,
data_dir = %data_dir.display(),
"starting backtest from web UI"
);
let progress = Arc::new(super::BacktestProgress::new(0));
{
let mut guard = state.backtest.lock().await;
guard.status =
BacktestRunStatus::Running { started_at: Utc::now() };
guard.progress = Some(progress.clone());
guard.result = None;
guard.error = None;
}
let backtest_state = state.backtest.clone();
let progress = progress.clone();
let capital = req.capital.unwrap_or(10000.0);
let max_positions = req.max_positions.unwrap_or(100);
let max_position = req.max_position.unwrap_or(100);
let interval_hours = req.interval_hours.unwrap_or(1);
let kelly_fraction = req.kelly_fraction.unwrap_or(0.40);
let max_position_pct = req.max_position_pct.unwrap_or(0.30);
let take_profit = req.take_profit.unwrap_or(0.50);
let stop_loss = req.stop_loss.unwrap_or(0.99);
let max_hold_hours = req.max_hold_hours.unwrap_or(48);
tokio::spawn(async move {
let data = match tokio::task::spawn_blocking(move || {
HistoricalData::load(&data_dir)
})
.await
{
Ok(Ok(d)) => Arc::new(d),
Ok(Err(e)) => {
let mut guard = backtest_state.lock().await;
guard.status = BacktestRunStatus::Failed;
guard.error =
Some(format!("failed to load data: {}", e));
error!(error = %e, "backtest data load failed");
return;
}
Err(e) => {
let mut guard = backtest_state.lock().await;
guard.status = BacktestRunStatus::Failed;
guard.error =
Some(format!("task join error: {}", e));
return;
}
};
let config = BacktestConfig {
start_time,
end_time,
interval: chrono::Duration::hours(interval_hours),
initial_capital: Decimal::try_from(capital)
.unwrap_or(Decimal::new(10000, 0)),
max_position_size: max_position,
max_positions,
};
let sizing_config = PositionSizingConfig {
kelly_fraction,
max_position_pct,
min_position_size: 10,
max_position_size: max_position,
};
let exit_config = ExitConfig {
take_profit_pct: take_profit,
stop_loss_pct: stop_loss,
max_hold_hours,
score_reversal_threshold: -0.3,
};
let backtester = Backtester::with_configs(
config,
data,
sizing_config,
exit_config,
)
.with_progress(progress);
let result = backtester.run().await;
let mut guard = backtest_state.lock().await;
guard.status = BacktestRunStatus::Complete;
guard.result = Some(result);
});
Ok(StatusCode::OK)
}
pub async fn get_backtest_status(
State(state): State<Arc<AppState>>,
) -> Json<BacktestStatusResponse> {
let guard = state.backtest.lock().await;
let (status_str, elapsed, error) = match &guard.status {
BacktestRunStatus::Idle => {
("idle".to_string(), None, None)
}
BacktestRunStatus::Running { started_at } => {
let elapsed = Utc::now()
.signed_duration_since(*started_at)
.num_seconds()
.max(0) as u64;
("running".to_string(), Some(elapsed), None)
}
BacktestRunStatus::Complete => {
("complete".to_string(), None, None)
}
BacktestRunStatus::Failed => {
("failed".to_string(), None, guard.error.clone())
}
};
let (phase, current_step, total_steps, progress_pct) =
if let Some(ref p) = guard.progress {
let current = p.current_step.load(
std::sync::atomic::Ordering::Relaxed,
);
let total = p.total_steps.load(
std::sync::atomic::Ordering::Relaxed,
);
let pct = if total > 0 {
current as f64 / total as f64 * 100.0
} else {
0.0
};
(
Some(p.phase_name().to_string()),
Some(current),
Some(total),
Some(pct),
)
} else {
(None, None, None, None)
};
Json(BacktestStatusResponse {
status: status_str,
elapsed_secs: elapsed,
error,
phase,
current_step,
total_steps,
progress_pct,
})
}
pub async fn get_backtest_result(
State(state): State<Arc<AppState>>,
) -> Result<
Json<crate::metrics::BacktestResult>,
StatusCode,
> {
let guard = state.backtest.lock().await;
match &guard.result {
Some(result) => Ok(Json(result.clone())),
None => Err(StatusCode::NOT_FOUND),
}
}

103
src/web/mod.rs Normal file
View File

@ -0,0 +1,103 @@
mod handlers;
use crate::engine::PaperTradingEngine;
use crate::metrics::BacktestResult;
use crate::store::SqliteStore;
use axum::routing::{get, post};
use axum::Router;
use chrono::{DateTime, Utc};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::broadcast;
use tower_http::services::ServeDir;
pub enum BacktestRunStatus {
Idle,
Running { started_at: DateTime<Utc> },
Complete,
Failed,
}
pub struct BacktestProgress {
pub phase: std::sync::atomic::AtomicU8,
pub current_step: std::sync::atomic::AtomicU64,
pub total_steps: std::sync::atomic::AtomicU64,
}
impl BacktestProgress {
pub const PHASE_LOADING: u8 = 0;
pub const PHASE_RUNNING: u8 = 1;
pub fn new(total_steps: u64) -> Self {
Self {
phase: std::sync::atomic::AtomicU8::new(
Self::PHASE_LOADING,
),
current_step: std::sync::atomic::AtomicU64::new(0),
total_steps: std::sync::atomic::AtomicU64::new(
total_steps,
),
}
}
pub fn phase_name(&self) -> &'static str {
match self
.phase
.load(std::sync::atomic::Ordering::Relaxed)
{
Self::PHASE_LOADING => "loading data",
Self::PHASE_RUNNING => "simulating",
_ => "unknown",
}
}
}
pub struct BacktestState {
pub status: BacktestRunStatus,
pub progress: Option<Arc<BacktestProgress>>,
pub result: Option<BacktestResult>,
pub error: Option<String>,
}
pub struct AppState {
pub engine: Arc<PaperTradingEngine>,
pub store: Arc<SqliteStore>,
pub shutdown_tx: broadcast::Sender<()>,
pub backtest: Arc<tokio::sync::Mutex<BacktestState>>,
pub data_dir: PathBuf,
}
pub fn build_router(state: Arc<AppState>) -> Router {
Router::new()
.route("/api/status", get(handlers::get_status))
.route("/api/portfolio", get(handlers::get_portfolio))
.route("/api/positions", get(handlers::get_positions))
.route("/api/trades", get(handlers::get_trades))
.route("/api/equity", get(handlers::get_equity))
.route(
"/api/circuit-breaker",
get(handlers::get_circuit_breaker),
)
.route(
"/api/control/pause",
post(handlers::post_pause),
)
.route(
"/api/control/resume",
post(handlers::post_resume),
)
.route(
"/api/backtest/run",
post(handlers::post_backtest_run),
)
.route(
"/api/backtest/status",
get(handlers::get_backtest_status),
)
.route(
"/api/backtest/result",
get(handlers::get_backtest_result),
)
.fallback_service(ServeDir::new("static"))
.with_state(state)
}

1124
static/index.html Normal file

File diff suppressed because it is too large Load Diff