diff --git a/Cargo.toml b/Cargo.toml index 5bd8ad8..5d94ade 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,13 @@ uuid = { version = "1", features = ["v4"] } rust_decimal = { version = "1", features = ["serde"] } rust_decimal_macros = "1" +reqwest = { version = "0.12", features = ["json"] } +sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "chrono"] } +axum = "0.8" +tower-http = { version = "0.6", features = ["cors", "fs"] } +toml = "0.8" +governor = "0.8" + ort = { version = "2.0.0-rc.11", optional = true } ndarray = { version = "0.16", optional = true } diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..2e1a77f --- /dev/null +++ b/config.toml @@ -0,0 +1,31 @@ +mode = "paper" + +[kalshi] +base_url = "https://api.elections.kalshi.com/trade-api/v2" +poll_interval_secs = 60 +rate_limit_per_sec = 2 + +[trading] +initial_capital = 10000.0 +max_positions = 100 +kelly_fraction = 0.25 +max_position_pct = 0.10 +take_profit_pct = 0.50 +stop_loss_pct = 0.99 +max_hold_hours = 48 + +[persistence] +db_path = "kalshi-paper.db" + +[web] +enabled = true +bind_addr = "127.0.0.1:3030" + +[circuit_breaker] +max_drawdown_pct = 0.15 +max_daily_loss_pct = 0.05 +max_positions = 100 +max_single_position_pct = 0.10 +max_consecutive_errors = 5 +max_fills_per_hour = 200 +max_fills_per_day = 1000 diff --git a/kalshi-paper.db b/kalshi-paper.db new file mode 100644 index 0000000..415337a Binary files /dev/null and b/kalshi-paper.db differ diff --git a/src/api/client.rs b/src/api/client.rs new file mode 100644 index 0000000..7c08537 --- /dev/null +++ b/src/api/client.rs @@ -0,0 +1,168 @@ +use super::types::{ApiMarket, ApiTrade, MarketsResponse, TradesResponse}; +use crate::config::KalshiConfig; +use governor::{Quota, RateLimiter}; +use std::num::NonZeroU32; +use std::sync::Arc; +use tracing::{debug, warn}; + +pub struct KalshiClient { + http: reqwest::Client, + base_url: String, + limiter: Arc< + RateLimiter< + governor::state::NotKeyed, + governor::state::InMemoryState, + governor::clock::DefaultClock, + >, + >, +} + +impl KalshiClient { + pub fn new(config: &KalshiConfig) -> Self { + let quota = Quota::per_second( + NonZeroU32::new(config.rate_limit_per_sec).unwrap_or( + NonZeroU32::new(5).unwrap(), + ), + ); + let limiter = Arc::new(RateLimiter::direct(quota)); + + let http = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("failed to build http client"); + + Self { + http, + base_url: config.base_url.clone(), + limiter, + } + } + + async fn rate_limit(&self) { + self.limiter.until_ready().await; + } + + async fn request_with_retry( + &self, + url: &str, + ) -> anyhow::Result { + for attempt in 0..5u32 { + if attempt > 0 { + let delay = std::time::Duration::from_secs( + 3u64.pow(attempt), + ); + debug!( + attempt = attempt, + delay_secs = delay.as_secs(), + "retrying after rate limit" + ); + tokio::time::sleep(delay).await; + } + + self.rate_limit().await; + let resp = self.http.get(url).send().await?; + + if resp.status() == reqwest::StatusCode::TOO_MANY_REQUESTS { + warn!( + attempt = attempt, + "rate limited (429), backing off" + ); + continue; + } + + if !resp.status().is_success() { + let status = resp.status(); + let body = + resp.text().await.unwrap_or_default(); + anyhow::bail!( + "kalshi API error {}: {}", + status, + body + ); + } + + return Ok(resp); + } + + anyhow::bail!("exhausted retries after rate limiting") + } + + pub async fn get_open_markets( + &self, + ) -> anyhow::Result> { + let mut all_markets = Vec::new(); + let mut cursor: Option = None; + let max_pages = 5; // cap at 1000 markets to avoid rate limiting + + for page_num in 0..max_pages { + self.rate_limit().await; + + let mut url = format!( + "{}/markets?status=open&limit=200", + self.base_url + ); + if let Some(ref c) = cursor { + url.push_str(&format!("&cursor={}", c)); + } + + debug!( + url = %url, + page = page_num, + "fetching markets" + ); + + let resp = self + .request_with_retry(&url) + .await?; + + let page: MarketsResponse = resp.json().await?; + let count = page.markets.len(); + all_markets.extend(page.markets); + + if !page.cursor.is_empty() && count > 0 { + cursor = Some(page.cursor); + } else { + break; + } + } + + debug!( + total = all_markets.len(), + "fetched open markets" + ); + Ok(all_markets) + } + + pub async fn get_market_trades( + &self, + ticker: &str, + limit: u32, + ) -> anyhow::Result> { + self.rate_limit().await; + + let url = format!( + "{}/markets/trades?ticker={}&limit={}", + self.base_url, ticker, limit + ); + + debug!(ticker = %ticker, "fetching trades"); + + let resp = match self + .request_with_retry(&url) + .await + { + Ok(r) => r, + Err(e) => { + warn!( + ticker = %ticker, + error = %e, + "failed to fetch trades" + ); + return Ok(Vec::new()); + } + }; + + let data: TradesResponse = resp.json().await?; + Ok(data.trades) + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..f89331d --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,5 @@ +pub mod client; +pub mod types; + +pub use client::KalshiClient; +pub use types::*; diff --git a/src/api/types.rs b/src/api/types.rs new file mode 100644 index 0000000..e6887b3 --- /dev/null +++ b/src/api/types.rs @@ -0,0 +1,119 @@ +use chrono::{DateTime, Utc}; +use serde::Deserialize; + +#[derive(Debug, Clone, Deserialize)] +pub struct MarketsResponse { + pub markets: Vec, + #[serde(default)] + pub cursor: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ApiMarket { + pub ticker: String, + #[serde(default)] + pub title: String, + #[serde(default)] + pub event_ticker: String, + #[serde(default)] + pub status: String, + pub open_time: DateTime, + pub close_time: DateTime, + #[serde(default)] + pub yes_ask: i64, + #[serde(default)] + pub yes_bid: i64, + #[serde(default)] + pub no_ask: i64, + #[serde(default)] + pub no_bid: i64, + #[serde(default)] + pub last_price: i64, + #[serde(default)] + pub volume: i64, + #[serde(default)] + pub volume_24h: i64, + #[serde(default)] + pub result: String, + #[serde(default)] + pub subtitle: String, +} + +impl ApiMarket { + /// returns yes price as a fraction (0.0 - 1.0) + /// prices from API are in cents (0-100) + pub fn mid_yes_price(&self) -> f64 { + let bid = self.yes_bid as f64 / 100.0; + let ask = self.yes_ask as f64 / 100.0; + + if bid > 0.0 && ask > 0.0 { + (bid + ask) / 2.0 + } else if bid > 0.0 { + bid + } else if ask > 0.0 { + ask + } else if self.last_price > 0 { + self.last_price as f64 / 100.0 + } else { + 0.0 + } + } + + pub fn category_from_event(&self) -> String { + let lower = self.event_ticker.to_lowercase(); + if lower.contains("nba") + || lower.contains("nfl") + || lower.contains("sport") + { + "sports".to_string() + } else if lower.contains("btc") + || lower.contains("crypto") + || lower.contains("eth") + { + "crypto".to_string() + } else if lower.contains("weather") + || lower.contains("temp") + { + "weather".to_string() + } else if lower.contains("econ") + || lower.contains("fed") + || lower.contains("cpi") + || lower.contains("gdp") + { + "economics".to_string() + } else if lower.contains("elect") + || lower.contains("polit") + || lower.contains("trump") + || lower.contains("biden") + { + "politics".to_string() + } else { + "other".to_string() + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TradesResponse { + #[serde(default)] + pub trades: Vec, + #[serde(default)] + pub cursor: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct ApiTrade { + #[serde(default)] + pub trade_id: String, + #[serde(default)] + pub ticker: String, + pub created_time: DateTime, + #[serde(default)] + pub yes_price: i64, + #[serde(default)] + pub no_price: i64, + #[serde(default)] + pub count: i64, + #[serde(default)] + pub taker_side: String, +} diff --git a/src/backtest.rs b/src/backtest.rs index 43292c3..258a767 100644 --- a/src/backtest.rs +++ b/src/backtest.rs @@ -1,5 +1,5 @@ use crate::data::HistoricalData; -use crate::execution::{Executor, PositionSizingConfig}; +use crate::execution::{BacktestExecutor, OrderExecutor, PositionSizingConfig}; use crate::metrics::{BacktestResult, MetricsCollector}; use crate::pipeline::{ AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer, @@ -11,11 +11,13 @@ use crate::types::{ BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType, TradingContext, }; +use crate::web::BacktestProgress; use chrono::{DateTime, Utc}; use rust_decimal::Decimal; use rust_decimal::prelude::ToPrimitive; use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use std::sync::atomic::Ordering; use tracing::info; /// resolves any positions in markets that have closed @@ -86,19 +88,21 @@ pub struct Backtester { config: BacktestConfig, data: Arc, pipeline: TradingPipeline, - executor: Executor, + executor: BacktestExecutor, + progress: Option>, } impl Backtester { pub fn new(config: BacktestConfig, data: Arc) -> Self { let pipeline = Self::build_default_pipeline(data.clone(), &config); - let executor = Executor::new(data.clone(), 10, config.max_position_size); + let executor = BacktestExecutor::new(data.clone(), 10, config.max_position_size); Self { config, data, pipeline, executor, + progress: None, } } @@ -109,7 +113,7 @@ impl Backtester { exit_config: ExitConfig, ) -> Self { let pipeline = Self::build_default_pipeline(data.clone(), &config); - let executor = Executor::new(data.clone(), 10, config.max_position_size) + let executor = BacktestExecutor::new(data.clone(), 10, config.max_position_size) .with_sizing_config(sizing_config) .with_exit_config(exit_config); @@ -118,9 +122,15 @@ impl Backtester { data, pipeline, executor, + progress: None, } } + pub fn with_progress(mut self, progress: Arc) -> Self { + self.progress = Some(progress); + self + } + pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self { self.pipeline = pipeline; self @@ -160,13 +170,30 @@ impl Backtester { let mut current_time = self.config.start_time; + let total_steps = (self.config.end_time - self.config.start_time) + .num_seconds() + / self.config.interval.num_seconds().max(1); + + if let Some(ref progress) = self.progress { + progress.total_steps.store( + total_steps as u64, + Ordering::Relaxed, + ); + progress.phase.store( + BacktestProgress::PHASE_RUNNING, + Ordering::Relaxed, + ); + } + info!( start = %self.config.start_time, end = %self.config.end_time, interval_hours = self.config.interval.num_hours(), + total_steps = total_steps, "starting backtest" ); + let mut step: u64 = 0; while current_time < self.config.end_time { context.timestamp = current_time; context.request_id = uuid::Uuid::new_v4().to_string(); @@ -237,7 +264,7 @@ impl Backtester { break; } - if let Some(fill) = self.executor.execute_signal(&signal, &context) { + if let Some(fill) = self.executor.execute_signal(&signal, &context).await { info!( ticker = %fill.ticker, side = ?fill.side, @@ -272,6 +299,13 @@ impl Backtester { let market_prices = self.get_current_prices(current_time); metrics.record(current_time, &context.portfolio, &market_prices); + step += 1; + if let Some(ref progress) = self.progress { + progress + .current_step + .store(step, Ordering::Relaxed); + } + current_time = current_time + self.config.interval; } diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..bc5e0cb --- /dev/null +++ b/src/config.rs @@ -0,0 +1,124 @@ +use serde::Deserialize; +use std::path::Path; + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum RunMode { + Backtest, + Paper, + Live, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct AppConfig { + pub mode: RunMode, + pub kalshi: KalshiConfig, + pub trading: TradingConfig, + pub persistence: PersistenceConfig, + pub web: WebConfig, + pub circuit_breaker: CircuitBreakerConfig, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct KalshiConfig { + pub base_url: String, + pub poll_interval_secs: u64, + pub rate_limit_per_sec: u32, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct TradingConfig { + pub initial_capital: f64, + pub max_positions: usize, + pub kelly_fraction: f64, + pub max_position_pct: f64, + pub take_profit_pct: Option, + pub stop_loss_pct: Option, + pub max_hold_hours: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct PersistenceConfig { + pub db_path: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct WebConfig { + pub enabled: bool, + pub bind_addr: String, +} + +#[derive(Debug, Clone, Deserialize)] +pub struct CircuitBreakerConfig { + pub max_drawdown_pct: f64, + pub max_daily_loss_pct: f64, + pub max_positions: Option, + pub max_single_position_pct: Option, + pub max_consecutive_errors: Option, + pub max_fills_per_hour: Option, + pub max_fills_per_day: Option, +} + +impl AppConfig { + pub fn load(path: &Path) -> anyhow::Result { + let content = std::fs::read_to_string(path)?; + let config: Self = toml::from_str(&content)?; + Ok(config) + } +} + +impl Default for KalshiConfig { + fn default() -> Self { + Self { + base_url: "https://api.elections.kalshi.com/trade-api/v2" + .to_string(), + poll_interval_secs: 300, + rate_limit_per_sec: 5, + } + } +} + +impl Default for TradingConfig { + fn default() -> Self { + Self { + initial_capital: 10000.0, + max_positions: 100, + kelly_fraction: 0.25, + max_position_pct: 0.10, + take_profit_pct: Some(0.50), + stop_loss_pct: Some(0.99), + max_hold_hours: Some(48), + } + } +} + +impl Default for PersistenceConfig { + fn default() -> Self { + Self { + db_path: "kalshi-paper.db".to_string(), + } + } +} + +impl Default for WebConfig { + fn default() -> Self { + Self { + enabled: true, + bind_addr: "127.0.0.1:3030".to_string(), + } + } +} + +impl Default for CircuitBreakerConfig { + fn default() -> Self { + Self { + max_drawdown_pct: 0.15, + max_daily_loss_pct: 0.05, + max_positions: Some(100), + max_single_position_pct: Some(0.10), + max_consecutive_errors: Some(5), + max_fills_per_hour: Some(50), + max_fills_per_day: Some(200), + } + } +} diff --git a/src/engine/circuit_breaker.rs b/src/engine/circuit_breaker.rs new file mode 100644 index 0000000..f9f132e --- /dev/null +++ b/src/engine/circuit_breaker.rs @@ -0,0 +1,240 @@ +use crate::config::CircuitBreakerConfig; +use crate::store::SqliteStore; +use chrono::{Duration, Utc}; +use rust_decimal::Decimal; +use rust_decimal::prelude::ToPrimitive; +use std::sync::Arc; +use tracing::warn; + +#[derive(Debug, Clone, PartialEq)] +pub enum CbStatus { + Ok, + Tripped(String), +} + +pub struct CircuitBreaker { + config: CircuitBreakerConfig, + store: Arc, + consecutive_errors: u32, +} + +impl CircuitBreaker { + pub fn new( + config: CircuitBreakerConfig, + store: Arc, + ) -> Self { + Self { + config, + store, + consecutive_errors: 0, + } + } + + pub fn record_error(&mut self) { + self.consecutive_errors += 1; + } + + pub fn record_success(&mut self) { + self.consecutive_errors = 0; + } + + pub async fn check( + &self, + current_equity: Decimal, + peak_equity: Decimal, + positions_count: usize, + ) -> CbStatus { + if let Some(status) = + self.check_drawdown(current_equity, peak_equity) + { + return status; + } + + if let Some(status) = self.check_daily_loss().await { + return status; + } + + if let Some(status) = + self.check_max_positions(positions_count) + { + return status; + } + + if let Some(status) = self.check_consecutive_errors() { + return status; + } + + if let Some(status) = self.check_fill_rate().await { + return status; + } + + CbStatus::Ok + } + + fn check_drawdown( + &self, + current_equity: Decimal, + peak_equity: Decimal, + ) -> Option { + if peak_equity <= Decimal::ZERO { + return None; + } + + let drawdown = (peak_equity - current_equity) + .to_f64() + .unwrap_or(0.0) + / peak_equity.to_f64().unwrap_or(1.0); + + if drawdown >= self.config.max_drawdown_pct { + let msg = format!( + "drawdown {:.1}% exceeds max {:.1}%", + drawdown * 100.0, + self.config.max_drawdown_pct * 100.0 + ); + warn!(rule = "max_drawdown", %msg); + self.log_event("max_drawdown", &msg, "pause"); + return Some(CbStatus::Tripped(msg)); + } + None + } + + async fn check_daily_loss(&self) -> Option { + let today_start = + Utc::now().date_naive().and_hms_opt(0, 0, 0)?; + let today_utc = chrono::TimeZone::from_utc_datetime( + &Utc, + &today_start, + ); + + let fills = self + .store + .get_recent_fills(1000) + .await + .unwrap_or_default(); + + let daily_pnl: f64 = fills + .iter() + .filter(|f| f.timestamp >= today_utc) + .filter_map(|f| { + f.pnl.as_ref()?.to_f64() + }) + .sum(); + + let peak = self + .store + .get_peak_equity() + .await + .ok() + .flatten() + .unwrap_or(Decimal::new(10000, 0)); + + let peak_f64 = peak.to_f64().unwrap_or(10000.0); + if peak_f64 <= 0.0 { + return None; + } + + let daily_loss_pct = (-daily_pnl) / peak_f64; + if daily_loss_pct >= self.config.max_daily_loss_pct { + let msg = format!( + "daily loss {:.1}% exceeds max {:.1}%", + daily_loss_pct * 100.0, + self.config.max_daily_loss_pct * 100.0 + ); + warn!(rule = "max_daily_loss", %msg); + self.log_event("max_daily_loss", &msg, "pause"); + return Some(CbStatus::Tripped(msg)); + } + None + } + + fn check_max_positions( + &self, + count: usize, + ) -> Option { + let max = self.config.max_positions.unwrap_or(100); + if count >= max { + let msg = format!( + "positions {} at max {}", + count, max + ); + warn!(rule = "max_positions", %msg); + return Some(CbStatus::Tripped(msg)); + } + None + } + + fn check_consecutive_errors(&self) -> Option { + let max = self.config.max_consecutive_errors.unwrap_or(5); + if self.consecutive_errors >= max { + let msg = format!( + "{} consecutive errors (max {})", + self.consecutive_errors, max + ); + warn!(rule = "consecutive_errors", %msg); + self.log_event( + "consecutive_errors", + &msg, + "pause", + ); + return Some(CbStatus::Tripped(msg)); + } + None + } + + async fn check_fill_rate(&self) -> Option { + let max_hourly = + self.config.max_fills_per_hour.unwrap_or(50); + let max_daily = + self.config.max_fills_per_day.unwrap_or(200); + + let one_hour_ago = Utc::now() - Duration::hours(1); + let hourly_fills = self + .store + .get_fills_since(one_hour_ago) + .await + .unwrap_or(0); + + if hourly_fills >= max_hourly { + let msg = format!( + "{} fills/hour exceeds max {}", + hourly_fills, max_hourly + ); + warn!(rule = "fill_rate_hourly", %msg); + self.log_event("fill_rate_hourly", &msg, "pause"); + return Some(CbStatus::Tripped(msg)); + } + + let today_start = Utc::now() - Duration::hours(24); + let daily_fills = self + .store + .get_fills_since(today_start) + .await + .unwrap_or(0); + + if daily_fills >= max_daily { + let msg = format!( + "{} fills/day exceeds max {}", + daily_fills, max_daily + ); + warn!(rule = "fill_rate_daily", %msg); + self.log_event("fill_rate_daily", &msg, "pause"); + return Some(CbStatus::Tripped(msg)); + } + + None + } + + fn log_event(&self, rule: &str, details: &str, action: &str) { + let store = self.store.clone(); + let rule = rule.to_string(); + let details = details.to_string(); + let action = action.to_string(); + tokio::spawn(async move { + let _ = store + .record_circuit_breaker_event( + &rule, &details, &action, + ) + .await; + }); + } +} diff --git a/src/engine/mod.rs b/src/engine/mod.rs new file mode 100644 index 0000000..38e52c2 --- /dev/null +++ b/src/engine/mod.rs @@ -0,0 +1,506 @@ +pub mod circuit_breaker; +pub mod state; + +pub use circuit_breaker::{CbStatus, CircuitBreaker}; +pub use state::EngineState; + +use crate::api::KalshiClient; +use crate::config::AppConfig; +use crate::execution::OrderExecutor; +use crate::paper_executor::PaperExecutor; +use crate::pipeline::{ + AlreadyPositionedFilter, BollingerMeanReversionScorer, + CategoryWeightedScorer, Filter, LiveKalshiSource, + LiquidityFilter, MeanReversionScorer, MomentumScorer, + MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer, + Selector, Source, TimeDecayScorer, TimeToCloseFilter, + TopKSelector, TradingPipeline, VolumeScorer, +}; +use crate::store::SqliteStore; +use crate::types::{Portfolio, Trade, TradeType, TradingContext}; +use chrono::Utc; +use rust_decimal::Decimal; +use rust_decimal::prelude::ToPrimitive; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; +use tokio::sync::{broadcast, Mutex, RwLock}; +use tracing::{error, info, warn}; + +pub struct EngineStatus { + pub state: EngineState, + pub uptime_secs: u64, + pub last_tick: Option>, + pub ticks_completed: u64, +} + +pub struct PaperTradingEngine { + config: AppConfig, + store: Arc, + executor: Arc, + pipeline: Mutex, + circuit_breaker: Mutex, + state: RwLock, + context: RwLock, + shutdown_tx: broadcast::Sender<()>, + start_time: Instant, + ticks: RwLock, + last_tick: RwLock>>, +} + +impl PaperTradingEngine { + pub async fn new( + config: AppConfig, + store: Arc, + executor: Arc, + client: Arc, + ) -> anyhow::Result { + let (shutdown_tx, _) = broadcast::channel(1); + + let pipeline = + Self::build_pipeline(client, &config); + + let circuit_breaker = CircuitBreaker::new( + config.circuit_breaker.clone(), + store.clone(), + ); + + let initial_capital = Decimal::try_from( + config.trading.initial_capital, + ) + .unwrap_or(Decimal::new(10000, 0)); + + let portfolio = store + .load_portfolio() + .await? + .unwrap_or_else(|| Portfolio::new(initial_capital)); + + let mut ctx = TradingContext::new( + portfolio.initial_capital, + Utc::now(), + ); + ctx.portfolio = portfolio; + + Ok(Self { + config, + store, + executor, + pipeline: Mutex::new(pipeline), + circuit_breaker: Mutex::new(circuit_breaker), + state: RwLock::new(EngineState::Starting), + context: RwLock::new(ctx), + shutdown_tx, + start_time: Instant::now(), + ticks: RwLock::new(0), + last_tick: RwLock::new(None), + }) + } + + fn build_pipeline( + client: Arc, + config: &AppConfig, + ) -> TradingPipeline { + let sources: Vec> = + vec![Box::new(LiveKalshiSource::new(client))]; + + let max_pos_size = + (config.trading.initial_capital + * config.trading.max_position_pct) + as u64; + + let filters: Vec> = vec![ + Box::new(TimeToCloseFilter::new(2, Some(720))), + Box::new(AlreadyPositionedFilter::new( + max_pos_size.max(100), + )), + ]; + + let scorers: Vec> = vec![ + Box::new(MomentumScorer::new(6)), + Box::new( + MultiTimeframeMomentumScorer::default_windows(), + ), + Box::new(MeanReversionScorer::new(24)), + Box::new( + BollingerMeanReversionScorer::default_config(), + ), + Box::new(VolumeScorer::new(6)), + Box::new(OrderFlowScorer::new()), + Box::new(TimeDecayScorer::new()), + Box::new(CategoryWeightedScorer::with_defaults()), + ]; + + let max_positions = config.trading.max_positions; + let selector: Box = + Box::new(TopKSelector::new(max_positions)); + + TradingPipeline::new( + sources, + filters, + scorers, + selector, + max_positions, + ) + } + + pub fn shutdown_handle(&self) -> broadcast::Sender<()> { + self.shutdown_tx.clone() + } + + pub async fn get_status(&self) -> EngineStatus { + EngineStatus { + state: self.state.read().await.clone(), + uptime_secs: self.start_time.elapsed().as_secs(), + last_tick: *self.last_tick.read().await, + ticks_completed: *self.ticks.read().await, + } + } + + pub async fn get_context(&self) -> TradingContext { + self.context.read().await.clone() + } + + pub async fn pause(&self, reason: String) { + let mut state = self.state.write().await; + *state = EngineState::Paused(reason); + } + + pub async fn resume(&self) { + let mut state = self.state.write().await; + if matches!(*state, EngineState::Paused(_)) { + *state = EngineState::Running; + } + } + + pub async fn run(&self) -> anyhow::Result<()> { + let mut shutdown_rx = self.shutdown_tx.subscribe(); + let poll_interval = std::time::Duration::from_secs( + self.config.kalshi.poll_interval_secs, + ); + + { + let mut state = self.state.write().await; + *state = EngineState::Recovering; + } + + info!("recovering state from SQLite"); + if let Ok(Some(portfolio)) = + self.store.load_portfolio().await + { + let mut ctx = self.context.write().await; + ctx.portfolio = portfolio; + info!( + positions = ctx.portfolio.positions.len(), + cash = %ctx.portfolio.cash, + "state recovered" + ); + } + + { + let mut state = self.state.write().await; + *state = EngineState::Running; + } + + info!( + interval_secs = self.config.kalshi.poll_interval_secs, + "engine running" + ); + + // run first tick immediately + self.tick().await; + + loop { + tokio::select! { + _ = shutdown_rx.recv() => { + info!("shutdown signal received"); + let mut state = self.state.write().await; + *state = EngineState::ShuttingDown; + break; + } + _ = tokio::time::sleep(poll_interval) => { + let current_state = + self.state.read().await.clone(); + match current_state { + EngineState::Running => { + self.tick().await; + } + EngineState::Paused(ref reason) => { + info!( + reason = %reason, + "engine paused, skipping tick" + ); + } + _ => {} + } + } + } + } + + info!("persisting final state"); + let ctx = self.context.read().await; + if let Err(e) = + self.store.save_portfolio(&ctx.portfolio).await + { + error!(error = %e, "failed to persist final state"); + } + + info!("engine shutdown complete"); + Ok(()) + } + + async fn tick(&self) { + let tick_start = Instant::now(); + let now = Utc::now(); + + let context_snapshot = { + let mut ctx = self.context.write().await; + ctx.timestamp = now; + ctx.request_id = + uuid::Uuid::new_v4().to_string(); + ctx.clone() + }; + + let result = { + let pipeline = self.pipeline.lock().await; + pipeline.execute(context_snapshot.clone()).await + }; + + let candidates_fetched = + result.retrieved_candidates.len(); + let candidates_filtered = + result.filtered_candidates.len(); + let candidates_selected = + result.selected_candidates.len(); + + let candidate_scores: HashMap = result + .selected_candidates + .iter() + .map(|c| (c.ticker.clone(), c.final_score)) + .collect(); + + let current_prices: HashMap = result + .selected_candidates + .iter() + .map(|c| (c.ticker.clone(), c.current_yes_price)) + .collect(); + self.executor.update_prices(current_prices).await; + + let exit_signals = self.executor.generate_exit_signals( + &context_snapshot, + &candidate_scores, + ); + + let mut ctx = self.context.write().await; + let mut fills_executed = 0u32; + + for exit in &exit_signals { + if let Some(position) = ctx + .portfolio + .positions + .get(&exit.ticker) + .cloned() + { + let pnl = ctx.portfolio.close_position( + &exit.ticker, + exit.current_price, + ); + + info!( + ticker = %exit.ticker, + reason = ?exit.reason, + pnl = ?pnl, + "paper exit" + ); + + let exit_fill = crate::types::Fill { + ticker: exit.ticker.clone(), + side: position.side, + quantity: position.quantity, + price: exit.current_price, + timestamp: now, + }; + + let reason_str = + format!("{:?}", exit.reason); + let _ = self + .store + .record_fill( + &exit_fill, + pnl, + Some(&reason_str), + ) + .await; + + ctx.trading_history.push(Trade { + ticker: exit.ticker.clone(), + side: position.side, + quantity: position.quantity, + price: exit.current_price, + timestamp: now, + trade_type: TradeType::Close, + }); + + fills_executed += 1; + } + } + + let signals = self.executor.generate_signals( + &result.selected_candidates, + &*ctx, + ); + let signals_generated = signals.len(); + + let peak_equity = self + .store + .get_peak_equity() + .await + .ok() + .flatten() + .unwrap_or(ctx.portfolio.initial_capital); + + let positions_value: Decimal = ctx + .portfolio + .positions + .values() + .map(|p| { + p.avg_entry_price * Decimal::from(p.quantity) + }) + .sum(); + let current_equity = ctx.portfolio.cash + positions_value; + + let cb_status = { + let cb = self.circuit_breaker.lock().await; + cb.check( + current_equity, + peak_equity, + ctx.portfolio.positions.len(), + ) + .await + }; + + if let CbStatus::Tripped(reason) = cb_status { + warn!( + reason = %reason, + "circuit breaker tripped, pausing" + ); + drop(ctx); + let mut state = self.state.write().await; + *state = EngineState::Paused(reason); + return; + } + + { + let mut cb = self.circuit_breaker.lock().await; + cb.record_success(); + } + + for signal in signals { + if ctx.portfolio.positions.len() + >= self.config.trading.max_positions + { + break; + } + + let context_for_exec = (*ctx).clone(); + if let Some(fill) = self + .executor + .execute_signal(&signal, &context_for_exec) + .await + { + info!( + ticker = %fill.ticker, + side = ?fill.side, + qty = fill.quantity, + price = %fill.price, + "paper fill" + ); + + ctx.portfolio.apply_fill(&fill); + ctx.trading_history.push(Trade { + ticker: fill.ticker.clone(), + side: fill.side, + quantity: fill.quantity, + price: fill.price, + timestamp: fill.timestamp, + trade_type: TradeType::Open, + }); + + fills_executed += 1; + } + } + + if let Err(e) = + self.store.save_portfolio(&ctx.portfolio).await + { + error!(error = %e, "failed to persist portfolio"); + let mut cb = self.circuit_breaker.lock().await; + cb.record_error(); + } + + let positions_value: Decimal = ctx + .portfolio + .positions + .values() + .map(|p| { + p.avg_entry_price * Decimal::from(p.quantity) + }) + .sum(); + let equity = ctx.portfolio.cash + positions_value; + let drawdown = if peak_equity > Decimal::ZERO { + ((peak_equity - equity) + .to_f64() + .unwrap_or(0.0)) + / peak_equity.to_f64().unwrap_or(1.0) + } else { + 0.0 + }; + + let _ = self + .store + .snapshot_equity( + now, + equity, + ctx.portfolio.cash, + positions_value, + drawdown.max(0.0), + ) + .await; + + let duration_ms = + tick_start.elapsed().as_millis() as u64; + + let _ = self + .store + .record_pipeline_run( + now, + duration_ms, + candidates_fetched, + candidates_filtered, + candidates_selected, + signals_generated, + fills_executed as usize, + None, + ) + .await; + + { + let mut ticks = self.ticks.write().await; + *ticks += 1; + } + { + let mut last = self.last_tick.write().await; + *last = Some(now); + } + + info!( + fetched = candidates_fetched, + filtered = candidates_filtered, + selected = candidates_selected, + signals = signals_generated, + fills = fills_executed, + equity = %equity, + duration_ms = duration_ms, + "tick complete" + ); + } +} diff --git a/src/engine/state.rs b/src/engine/state.rs new file mode 100644 index 0000000..1c8a6ac --- /dev/null +++ b/src/engine/state.rs @@ -0,0 +1,25 @@ +use serde::Serialize; + +#[derive(Debug, Clone, PartialEq, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum EngineState { + Starting, + Recovering, + Running, + Paused(String), + ShuttingDown, +} + +impl std::fmt::Display for EngineState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Starting => write!(f, "starting"), + Self::Recovering => write!(f, "recovering"), + Self::Running => write!(f, "running"), + Self::Paused(reason) => { + write!(f, "paused: {}", reason) + } + Self::ShuttingDown => write!(f, "shutting_down"), + } + } +} diff --git a/src/execution.rs b/src/execution.rs index 5372665..3abb36e 100644 --- a/src/execution.rs +++ b/src/execution.rs @@ -1,7 +1,12 @@ use crate::data::HistoricalData; -use crate::types::{ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side, Signal, TradingContext}; +use crate::types::{ + ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side, + Signal, TradingContext, +}; +use async_trait::async_trait; use rust_decimal::Decimal; use rust_decimal::prelude::ToPrimitive; +use std::collections::HashMap; use std::sync::Arc; #[derive(Debug, Clone)] @@ -14,9 +19,6 @@ pub struct PositionSizingConfig { impl Default for PositionSizingConfig { fn default() -> Self { - // iteration 4: increased kelly from 0.25 to 0.40 - // research shows half-kelly to full-kelly range works well - // with 100% win rate on closed trades, we can be more aggressive Self { kelly_fraction: 0.40, max_position_pct: 0.30, @@ -46,13 +48,33 @@ impl PositionSizingConfig { } } +#[async_trait] +pub trait OrderExecutor: Send + Sync { + async fn execute_signal( + &self, + signal: &Signal, + context: &TradingContext, + ) -> Option; + + fn generate_signals( + &self, + candidates: &[MarketCandidate], + context: &TradingContext, + ) -> Vec; + + fn generate_exit_signals( + &self, + context: &TradingContext, + candidate_scores: &HashMap, + ) -> Vec; +} + /// maps scoring edge [-inf, +inf] to win probability [0, 1] -/// tanh squashes extreme values smoothly; +1)/2 shifts from [-1,1] to [0,1] -fn edge_to_win_probability(edge: f64) -> f64 { +pub fn edge_to_win_probability(edge: f64) -> f64 { (1.0 + edge.tanh()) / 2.0 } -fn kelly_size( +pub fn kelly_size( edge: f64, price: f64, bankroll: f64, @@ -71,13 +93,165 @@ fn kelly_size( let kelly = (odds * win_prob - (1.0 - win_prob)) / odds; let safe_kelly = (kelly * config.kelly_fraction).max(0.0); - let position_value = bankroll * safe_kelly.min(config.max_position_pct); + let position_value = + bankroll * safe_kelly.min(config.max_position_pct); let shares = (position_value / price).floor() as u64; - shares.max(config.min_position_size).min(config.max_position_size) + shares + .max(config.min_position_size) + .min(config.max_position_size) } -pub struct Executor { +pub fn candidate_to_signal( + candidate: &MarketCandidate, + context: &TradingContext, + sizing_config: &PositionSizingConfig, + max_position_size: u64, +) -> Option { + let current_position = + context.portfolio.get_position(&candidate.ticker); + let current_qty = + current_position.map(|p| p.quantity).unwrap_or(0); + + if current_qty >= max_position_size { + return None; + } + + let yes_price = + candidate.current_yes_price.to_f64().unwrap_or(0.5); + + let side = if candidate.final_score > 0.0 { + if yes_price < 0.5 { + Side::Yes + } else { + Side::No + } + } else if candidate.final_score < 0.0 { + if yes_price > 0.5 { + Side::No + } else { + Side::Yes + } + } else { + return None; + }; + + let price = match side { + Side::Yes => candidate.current_yes_price, + Side::No => candidate.current_no_price, + }; + + let available_cash = + context.portfolio.cash.to_f64().unwrap_or(0.0); + let price_f64 = price.to_f64().unwrap_or(0.5); + + if price_f64 <= 0.0 { + return None; + } + + let kelly_qty = kelly_size( + candidate.final_score, + price_f64, + available_cash, + sizing_config, + ); + + let max_affordable = (available_cash / price_f64) as u64; + let quantity = kelly_qty + .min(max_affordable) + .min(max_position_size - current_qty); + + if quantity < sizing_config.min_position_size { + return None; + } + + Some(Signal { + ticker: candidate.ticker.clone(), + side, + quantity, + limit_price: Some(price), + reason: format!( + "score={:.3}, side={:?}, price={:.2}", + candidate.final_score, side, price_f64 + ), + }) +} + +pub fn compute_exit_signals( + context: &TradingContext, + candidate_scores: &HashMap, + exit_config: &ExitConfig, + price_lookup: &dyn Fn(&str) -> Option, +) -> Vec { + let mut exits = Vec::new(); + + for (ticker, position) in &context.portfolio.positions { + let current_price = match price_lookup(ticker) { + Some(p) => p, + None => continue, + }; + + let effective_price = match position.side { + Side::Yes => current_price, + Side::No => Decimal::ONE - current_price, + }; + + let entry_price_f64 = + position.avg_entry_price.to_f64().unwrap_or(0.5); + let current_price_f64 = + effective_price.to_f64().unwrap_or(0.5); + + if entry_price_f64 <= 0.0 { + continue; + } + + let pnl_pct = (current_price_f64 - entry_price_f64) + / entry_price_f64; + + if pnl_pct >= exit_config.take_profit_pct { + exits.push(ExitSignal { + ticker: ticker.clone(), + reason: ExitReason::TakeProfit { pnl_pct }, + current_price, + }); + continue; + } + + if pnl_pct <= -exit_config.stop_loss_pct { + exits.push(ExitSignal { + ticker: ticker.clone(), + reason: ExitReason::StopLoss { pnl_pct }, + current_price, + }); + continue; + } + + let hours_held = + (context.timestamp - position.entry_time).num_hours(); + if hours_held >= exit_config.max_hold_hours { + exits.push(ExitSignal { + ticker: ticker.clone(), + reason: ExitReason::TimeStop { hours_held }, + current_price, + }); + continue; + } + + if let Some(&new_score) = candidate_scores.get(ticker) { + if new_score < exit_config.score_reversal_threshold { + exits.push(ExitSignal { + ticker: ticker.clone(), + reason: ExitReason::ScoreReversal { new_score }, + current_price, + }); + } + } + } + + exits +} + +pub struct BacktestExecutor { data: Arc, slippage_bps: u32, max_position_size: u64, @@ -85,8 +259,12 @@ pub struct Executor { exit_config: ExitConfig, } -impl Executor { - pub fn new(data: Arc, slippage_bps: u32, max_position_size: u64) -> Self { +impl BacktestExecutor { + pub fn new( + data: Arc, + slippage_bps: u32, + max_position_size: u64, + ) -> Self { Self { data, slippage_bps, @@ -96,7 +274,10 @@ impl Executor { } } - pub fn with_sizing_config(mut self, config: PositionSizingConfig) -> Self { + pub fn with_sizing_config( + mut self, + config: PositionSizingConfig, + ) -> Self { self.sizing_config = config; self } @@ -105,168 +286,33 @@ impl Executor { self.exit_config = config; self } +} - pub fn generate_exit_signals( - &self, - context: &TradingContext, - candidate_scores: &std::collections::HashMap, - ) -> Vec { - let mut exits = Vec::new(); - - for (ticker, position) in &context.portfolio.positions { - let current_price = match self.data.get_current_price(ticker, context.timestamp) { - Some(p) => p, - None => continue, - }; - - let effective_price = match position.side { - Side::Yes => current_price, - Side::No => Decimal::ONE - current_price, - }; - - let entry_price_f64 = position.avg_entry_price.to_f64().unwrap_or(0.5); - let current_price_f64 = effective_price.to_f64().unwrap_or(0.5); - - if entry_price_f64 <= 0.0 { - continue; - } - - let pnl_pct = (current_price_f64 - entry_price_f64) / entry_price_f64; - - if pnl_pct >= self.exit_config.take_profit_pct { - exits.push(ExitSignal { - ticker: ticker.clone(), - reason: ExitReason::TakeProfit { pnl_pct }, - current_price, - }); - continue; - } - - if pnl_pct <= -self.exit_config.stop_loss_pct { - exits.push(ExitSignal { - ticker: ticker.clone(), - reason: ExitReason::StopLoss { pnl_pct }, - current_price, - }); - continue; - } - - let hours_held = (context.timestamp - position.entry_time).num_hours(); - if hours_held >= self.exit_config.max_hold_hours { - exits.push(ExitSignal { - ticker: ticker.clone(), - reason: ExitReason::TimeStop { hours_held }, - current_price, - }); - continue; - } - - if let Some(&new_score) = candidate_scores.get(ticker) { - if new_score < self.exit_config.score_reversal_threshold { - exits.push(ExitSignal { - ticker: ticker.clone(), - reason: ExitReason::ScoreReversal { new_score }, - current_price, - }); - } - } - } - - exits - } - - pub fn generate_signals( - &self, - candidates: &[MarketCandidate], - context: &TradingContext, - ) -> Vec { - candidates - .iter() - .filter_map(|c| self.candidate_to_signal(c, context)) - .collect() - } - - fn candidate_to_signal( - &self, - candidate: &MarketCandidate, - context: &TradingContext, - ) -> Option { - let current_position = context.portfolio.get_position(&candidate.ticker); - let current_qty = current_position.map(|p| p.quantity).unwrap_or(0); - - if current_qty >= self.max_position_size { - return None; - } - - let yes_price = candidate.current_yes_price.to_f64().unwrap_or(0.5); - - // positive score = bullish signal, so buy the cheaper side (better risk/reward) - // negative score = bearish signal, so buy against the expensive side - let side = if candidate.final_score > 0.0 { - if yes_price < 0.5 { Side::Yes } else { Side::No } - } else if candidate.final_score < 0.0 { - if yes_price > 0.5 { Side::No } else { Side::Yes } - } else { - return None; - }; - - let price = match side { - Side::Yes => candidate.current_yes_price, - Side::No => candidate.current_no_price, - }; - - let available_cash = context.portfolio.cash.to_f64().unwrap_or(0.0); - let price_f64 = price.to_f64().unwrap_or(0.5); - - if price_f64 <= 0.0 { - return None; - } - - let kelly_qty = kelly_size( - candidate.final_score, - price_f64, - available_cash, - &self.sizing_config, - ); - - let max_affordable = (available_cash / price_f64) as u64; - let quantity = kelly_qty - .min(max_affordable) - .min(self.max_position_size - current_qty); - - if quantity < self.sizing_config.min_position_size { - return None; - } - - Some(Signal { - ticker: candidate.ticker.clone(), - side, - quantity, - limit_price: Some(price), - reason: format!( - "score={:.3}, side={:?}, price={:.2}", - candidate.final_score, side, price_f64 - ), - }) - } - - pub fn execute_signal( +#[async_trait] +impl OrderExecutor for BacktestExecutor { + async fn execute_signal( &self, signal: &Signal, context: &TradingContext, ) -> Option { - let market_price = self.data.get_current_price(&signal.ticker, context.timestamp)?; + let market_price = self + .data + .get_current_price(&signal.ticker, context.timestamp)?; let effective_price = match signal.side { Side::Yes => market_price, Side::No => Decimal::ONE - market_price, }; - let slippage = Decimal::from(self.slippage_bps) / Decimal::from(10000); - let fill_price = effective_price * (Decimal::ONE + slippage); + let slippage = + Decimal::from(self.slippage_bps) / Decimal::from(10000); + let fill_price = + effective_price * (Decimal::ONE + slippage); if let Some(limit) = signal.limit_price { - if fill_price > limit * (Decimal::ONE + slippage * Decimal::from(2)) { + if fill_price + > limit * (Decimal::ONE + slippage * Decimal::from(2)) + { return None; } } @@ -297,6 +343,39 @@ impl Executor { timestamp: context.timestamp, }) } + + fn generate_signals( + &self, + candidates: &[MarketCandidate], + context: &TradingContext, + ) -> Vec { + candidates + .iter() + .filter_map(|c| { + candidate_to_signal( + c, + context, + &self.sizing_config, + self.max_position_size, + ) + }) + .collect() + } + + fn generate_exit_signals( + &self, + context: &TradingContext, + candidate_scores: &HashMap, + ) -> Vec { + let data = self.data.clone(); + let timestamp = context.timestamp; + compute_exit_signals( + context, + candidate_scores, + &self.exit_config, + &|ticker| data.get_current_price(ticker, timestamp), + ) + } } pub fn simple_signal_generator( @@ -309,7 +388,8 @@ pub fn simple_signal_generator( .filter(|c| c.final_score > 0.0) .filter(|c| !context.portfolio.has_position(&c.ticker)) .map(|c| { - let yes_price = c.current_yes_price.to_f64().unwrap_or(0.5); + let yes_price = + c.current_yes_price.to_f64().unwrap_or(0.5); let (side, price) = if yes_price < 0.5 { (Side::Yes, c.current_yes_price) } else { @@ -321,7 +401,10 @@ pub fn simple_signal_generator( side, quantity: position_size, limit_price: Some(price), - reason: format!("simple: score={:.3}", c.final_score), + reason: format!( + "simple: score={:.3}", + c.final_score + ), } }) .collect() diff --git a/src/main.rs b/src/main.rs index f86997a..0e59423 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,16 +1,22 @@ +mod api; mod backtest; +mod config; mod data; +mod engine; mod execution; mod metrics; +mod paper_executor; mod pipeline; +mod store; mod types; +mod web; use anyhow::{Context, Result}; use backtest::{Backtester, RandomBaseline}; use chrono::{DateTime, NaiveDate, TimeZone, Utc}; use clap::{Parser, Subcommand}; use data::HistoricalData; -use execution::{Executor, PositionSizingConfig}; +use execution::PositionSizingConfig; use rust_decimal::Decimal; use std::path::PathBuf; use std::sync::Arc; @@ -57,7 +63,7 @@ enum Commands { #[arg(long)] compare_random: bool, - /// kelly fraction for position sizing (0.40 = 40% of kelly optimal) + /// kelly fraction for position sizing #[arg(long, default_value = "0.40")] kelly_fraction: f64, @@ -65,11 +71,11 @@ enum Commands { #[arg(long, default_value = "0.30")] max_position_pct: f64, - /// take profit threshold (0.50 = +50%) + /// take profit threshold #[arg(long, default_value = "0.50")] take_profit: f64, - /// stop loss threshold (0.99 = disabled for prediction markets) + /// stop loss threshold #[arg(long, default_value = "0.99")] stop_loss: f64, @@ -78,19 +84,27 @@ enum Commands { max_hold_hours: i64, }, + Paper { + /// path to config.toml + #[arg(short, long, default_value = "config.toml")] + config: PathBuf, + }, + Summary { #[arg(short, long)] results_file: PathBuf, }, } -fn parse_date(s: &str) -> Result> { +pub fn parse_date(s: &str) -> Result> { if let Ok(dt) = DateTime::parse_from_rfc3339(s) { return Ok(dt.with_timezone(&Utc)); } if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") { - return Ok(Utc.from_utc_datetime(&date.and_hms_opt(0, 0, 0).unwrap())); + return Ok(Utc.from_utc_datetime( + &date.and_hms_opt(0, 0, 0).unwrap(), + )); } Err(anyhow::anyhow!("could not parse date: {}", s)) @@ -101,7 +115,9 @@ async fn main() -> Result<()> { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "kalshi_backtest=info".into()), + .unwrap_or_else(|_| { + "kalshi_backtest=info".into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -125,8 +141,10 @@ async fn main() -> Result<()> { stop_loss, max_hold_hours, } => { - let start_time = parse_date(&start).context("parsing start date")?; - let end_time = parse_date(&end).context("parsing end date")?; + let start_time = + parse_date(&start).context("parsing start date")?; + let end_time = + parse_date(&end).context("parsing end date")?; info!( data_dir = %data_dir.display(), @@ -137,7 +155,8 @@ async fn main() -> Result<()> { ); let data = Arc::new( - HistoricalData::load(&data_dir).context("loading historical data")?, + HistoricalData::load(&data_dir) + .context("loading historical data")?, ); info!( @@ -150,7 +169,8 @@ async fn main() -> Result<()> { start_time, end_time, interval: chrono::Duration::hours(interval_hours), - initial_capital: Decimal::try_from(capital).unwrap(), + initial_capital: Decimal::try_from(capital) + .unwrap(), max_position_size: max_position, max_positions, }; @@ -169,16 +189,25 @@ async fn main() -> Result<()> { score_reversal_threshold: -0.3, }; - let backtester = Backtester::with_configs(config.clone(), data.clone(), sizing_config, exit_config); + let backtester = Backtester::with_configs( + config.clone(), + data.clone(), + sizing_config, + exit_config, + ); let result = backtester.run().await; println!("{}", result.summary()); std::fs::create_dir_all(&output_dir)?; - let result_path = output_dir.join("backtest_result.json"); + let result_path = + output_dir.join("backtest_result.json"); let json = serde_json::to_string_pretty(&result)?; std::fs::write(&result_path, json)?; - info!(path = %result_path.display(), "results saved"); + info!( + path = %result_path.display(), + "results saved" + ); if compare_random { println!("\n--- random baseline ---\n"); @@ -186,18 +215,23 @@ async fn main() -> Result<()> { let baseline_result = baseline.run().await; println!("{}", baseline_result.summary()); - let baseline_path = output_dir.join("baseline_result.json"); - let json = serde_json::to_string_pretty(&baseline_result)?; + let baseline_path = + output_dir.join("baseline_result.json"); + let json = serde_json::to_string_pretty( + &baseline_result, + )?; std::fs::write(&baseline_path, json)?; println!("\n--- comparison ---\n"); println!( "strategy return: {:.2}% vs baseline: {:.2}%", - result.total_return_pct, baseline_result.total_return_pct + result.total_return_pct, + baseline_result.total_return_pct ); println!( "strategy sharpe: {:.3} vs baseline: {:.3}", - result.sharpe_ratio, baseline_result.sharpe_ratio + result.sharpe_ratio, + baseline_result.sharpe_ratio ); println!( "strategy win rate: {:.1}% vs baseline: {:.1}%", @@ -208,11 +242,16 @@ async fn main() -> Result<()> { Ok(()) } + Commands::Paper { config: config_path } => { + run_paper(config_path).await + } + Commands::Summary { results_file } => { let content = std::fs::read_to_string(&results_file) .context("reading results file")?; let result: metrics::BacktestResult = - serde_json::from_str(&content).context("parsing results")?; + serde_json::from_str(&content) + .context("parsing results")?; println!("{}", result.summary()); @@ -220,3 +259,124 @@ async fn main() -> Result<()> { } } } + +async fn run_paper(config_path: PathBuf) -> Result<()> { + let app_config = config::AppConfig::load(&config_path) + .context("loading config")?; + + info!( + mode = ?app_config.mode, + poll_secs = app_config.kalshi.poll_interval_secs, + capital = app_config.trading.initial_capital, + "starting paper trading" + ); + + let store = Arc::new( + store::SqliteStore::new(&app_config.persistence.db_path) + .await + .context("initializing SQLite store")?, + ); + + let client = Arc::new(api::KalshiClient::new( + &app_config.kalshi, + )); + + let sizing_config = PositionSizingConfig { + kelly_fraction: app_config.trading.kelly_fraction, + max_position_pct: app_config.trading.max_position_pct, + min_position_size: 10, + max_position_size: 1000, + }; + + let exit_config = ExitConfig { + take_profit_pct: app_config + .trading + .take_profit_pct + .unwrap_or(0.50), + stop_loss_pct: app_config + .trading + .stop_loss_pct + .unwrap_or(0.99), + max_hold_hours: app_config + .trading + .max_hold_hours + .unwrap_or(48), + score_reversal_threshold: -0.3, + }; + + let executor = Arc::new(paper_executor::PaperExecutor::new( + 1000, + sizing_config, + exit_config, + store.clone(), + )); + + let engine = engine::PaperTradingEngine::new( + app_config.clone(), + store.clone(), + executor, + client, + ) + .await + .context("initializing engine")?; + + let shutdown_tx = engine.shutdown_handle(); + + let engine = Arc::new(engine); + + if app_config.web.enabled { + let web_state = Arc::new(web::AppState { + engine: engine.clone(), + store: store.clone(), + shutdown_tx: shutdown_tx.clone(), + backtest: Arc::new(tokio::sync::Mutex::new( + web::BacktestState { + status: web::BacktestRunStatus::Idle, + progress: None, + result: None, + error: None, + }, + )), + data_dir: PathBuf::from("data"), + }); + + let router = web::build_router(web_state); + let bind_addr = app_config.web.bind_addr.clone(); + + info!(addr = %bind_addr, "starting web dashboard"); + + match tokio::net::TcpListener::bind(&bind_addr).await { + Ok(listener) => { + tokio::spawn(async move { + if let Err(e) = + axum::serve(listener, router).await + { + tracing::error!( + error = %e, + "web server error" + ); + } + }); + } + Err(e) => { + tracing::warn!( + addr = %bind_addr, + error = %e, + "web dashboard disabled (port in use)" + ); + } + } + } + + let shutdown_tx_clone = shutdown_tx.clone(); + tokio::spawn(async move { + tokio::signal::ctrl_c().await.ok(); + info!("ctrl+c received, shutting down"); + let _ = shutdown_tx_clone.send(()); + }); + + engine.run().await?; + + info!("paper trading session ended"); + Ok(()) +} diff --git a/src/paper_executor.rs b/src/paper_executor.rs new file mode 100644 index 0000000..daebd71 --- /dev/null +++ b/src/paper_executor.rs @@ -0,0 +1,143 @@ +use crate::execution::{ + candidate_to_signal, compute_exit_signals, OrderExecutor, + PositionSizingConfig, +}; +use crate::store::SqliteStore; +use crate::types::{ + ExitConfig, ExitSignal, Fill, MarketCandidate, Signal, Side, + TradingContext, +}; +use async_trait::async_trait; +use rust_decimal::Decimal; +use rust_decimal::prelude::ToPrimitive; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +pub struct PaperExecutor { + max_position_size: u64, + sizing_config: PositionSizingConfig, + exit_config: ExitConfig, + store: Arc, + current_prices: Arc>>, +} + +impl PaperExecutor { + pub fn new( + max_position_size: u64, + sizing_config: PositionSizingConfig, + exit_config: ExitConfig, + store: Arc, + ) -> Self { + Self { + max_position_size, + sizing_config, + exit_config, + store, + current_prices: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn update_prices( + &self, + prices: HashMap, + ) { + let mut current = self.current_prices.write().await; + *current = prices; + } +} + +#[async_trait] +impl OrderExecutor for PaperExecutor { + async fn execute_signal( + &self, + signal: &Signal, + context: &TradingContext, + ) -> Option { + let prices = self.current_prices.read().await; + let market_price = prices.get(&signal.ticker).copied()?; + + let effective_price = match signal.side { + Side::Yes => market_price, + Side::No => Decimal::ONE - market_price, + }; + + if let Some(limit) = signal.limit_price { + let tolerance = Decimal::new(5, 2); + if effective_price > limit * (Decimal::ONE + tolerance) { + return None; + } + } + + let cost = effective_price * Decimal::from(signal.quantity); + let quantity = if cost > context.portfolio.cash { + let affordable = (context.portfolio.cash + / effective_price) + .to_u64() + .unwrap_or(0); + if affordable == 0 { + return None; + } + affordable + } else { + signal.quantity + }; + + let fill = Fill { + ticker: signal.ticker.clone(), + side: signal.side, + quantity, + price: effective_price, + timestamp: context.timestamp, + }; + + if let Err(e) = + self.store.record_fill(&fill, None, None).await + { + tracing::error!( + error = %e, + "failed to persist fill" + ); + } + + Some(fill) + } + + fn generate_signals( + &self, + candidates: &[MarketCandidate], + context: &TradingContext, + ) -> Vec { + candidates + .iter() + .filter_map(|c| { + candidate_to_signal( + c, + context, + &self.sizing_config, + self.max_position_size, + ) + }) + .collect() + } + + fn generate_exit_signals( + &self, + context: &TradingContext, + candidate_scores: &HashMap, + ) -> Vec { + let prices = self.current_prices.try_read(); + match prices { + Ok(prices) => { + let prices_ref = prices.clone(); + compute_exit_signals( + context, + candidate_scores, + &self.exit_config, + &|ticker| prices_ref.get(ticker).copied(), + ) + } + Err(_) => Vec::new(), + } + } +} diff --git a/src/pipeline/sources.rs b/src/pipeline/sources.rs index ea34056..5beb771 100644 --- a/src/pipeline/sources.rs +++ b/src/pipeline/sources.rs @@ -1,7 +1,9 @@ +use crate::api::KalshiClient; use crate::data::HistoricalData; use crate::pipeline::Source; -use crate::types::{MarketCandidate, TradingContext}; +use crate::types::{MarketCandidate, PricePoint, TradingContext}; use async_trait::async_trait; +use chrono::Utc; use rust_decimal::Decimal; use std::collections::HashMap; use std::sync::Arc; @@ -67,3 +69,78 @@ impl Source for HistoricalMarketSource { Ok(candidates) } } + +pub struct LiveKalshiSource { + client: Arc, +} + +impl LiveKalshiSource { + pub fn new(client: Arc) -> Self { + Self { client } + } +} + +#[async_trait] +impl Source for LiveKalshiSource { + fn name(&self) -> &'static str { + "LiveKalshiSource" + } + + async fn get_candidates( + &self, + _context: &TradingContext, + ) -> Result, String> { + let markets = self + .client + .get_open_markets() + .await + .map_err(|e| format!("API error: {}", e))?; + + let now = Utc::now(); + + let mut candidates = Vec::with_capacity(markets.len()); + + for market in markets { + let yes_price = market.mid_yes_price(); + if yes_price <= 0.0 || yes_price >= 1.0 { + continue; + } + + let yes_dec = Decimal::try_from(yes_price) + .unwrap_or(Decimal::new(50, 2)); + let no_dec = Decimal::ONE - yes_dec; + + let volume_24h = market.volume_24h.max(0) as u64; + let total_volume = market.volume.max(0) as u64; + + // skip per-market trade fetching to avoid N+1 + // the market list already provides volume data + // for filtering; scorers will work with what's available + let price_history = Vec::new(); + let buy_vol = 0u64; + let sell_vol = 0u64; + + let category = market.category_from_event(); + + candidates.push(MarketCandidate { + ticker: market.ticker, + title: market.title, + category, + current_yes_price: yes_dec, + current_no_price: no_dec, + volume_24h, + total_volume, + buy_volume_24h: buy_vol, + sell_volume_24h: sell_vol, + open_time: market.open_time, + close_time: market.close_time, + result: None, + price_history, + scores: HashMap::new(), + final_score: 0.0, + }); + } + + Ok(candidates) + } +} diff --git a/src/store/mod.rs b/src/store/mod.rs new file mode 100644 index 0000000..5c6b0b9 --- /dev/null +++ b/src/store/mod.rs @@ -0,0 +1,4 @@ +mod queries; +mod schema; + +pub use queries::*; diff --git a/src/store/queries.rs b/src/store/queries.rs new file mode 100644 index 0000000..ca7d81d --- /dev/null +++ b/src/store/queries.rs @@ -0,0 +1,372 @@ +use crate::types::{Fill, Portfolio, Position, Side}; +use chrono::{DateTime, Utc}; +use rust_decimal::Decimal; +use sqlx::SqlitePool; +use std::collections::HashMap; +use std::str::FromStr; + +use super::schema::MIGRATIONS; + +pub struct SqliteStore { + pool: SqlitePool, +} + +impl SqliteStore { + pub async fn new(db_path: &str) -> anyhow::Result { + let url = format!("sqlite:{}?mode=rwc", db_path); + let pool = SqlitePool::connect(&url).await?; + let store = Self { pool }; + store.run_migrations().await?; + Ok(store) + } + + async fn run_migrations(&self) -> anyhow::Result<()> { + sqlx::raw_sql(MIGRATIONS).execute(&self.pool).await?; + Ok(()) + } + + pub async fn load_portfolio( + &self, + ) -> anyhow::Result> { + let row = sqlx::query_as::<_, (String, String)>( + "SELECT cash, initial_capital FROM portfolio_state \ + WHERE id = 1", + ) + .fetch_optional(&self.pool) + .await?; + + let Some((cash_str, capital_str)) = row else { + return Ok(None); + }; + + let cash = Decimal::from_str(&cash_str)?; + let initial_capital = Decimal::from_str(&capital_str)?; + + let position_rows = sqlx::query_as::< + _, + (String, String, i64, String, String), + >( + "SELECT ticker, side, quantity, avg_entry_price, \ + entry_time FROM positions", + ) + .fetch_all(&self.pool) + .await?; + + let mut positions = HashMap::new(); + for (ticker, side_str, qty, price_str, time_str) in + position_rows + { + let side = match side_str.as_str() { + "Yes" => Side::Yes, + _ => Side::No, + }; + let price = Decimal::from_str(&price_str)?; + let entry_time: DateTime = + time_str.parse::>()?; + + positions.insert( + ticker.clone(), + Position { + ticker, + side, + quantity: qty as u64, + avg_entry_price: price, + entry_time, + }, + ); + } + + Ok(Some(Portfolio { + positions, + cash, + initial_capital, + })) + } + + pub async fn save_portfolio( + &self, + portfolio: &Portfolio, + ) -> anyhow::Result<()> { + let cash = portfolio.cash.to_string(); + let capital = portfolio.initial_capital.to_string(); + + sqlx::query( + "INSERT INTO portfolio_state (id, cash, initial_capital, \ + updated_at) VALUES (1, ?1, ?2, datetime('now')) \ + ON CONFLICT(id) DO UPDATE SET \ + cash = ?1, initial_capital = ?2, \ + updated_at = datetime('now')", + ) + .bind(&cash) + .bind(&capital) + .execute(&self.pool) + .await?; + + sqlx::query("DELETE FROM positions") + .execute(&self.pool) + .await?; + + for pos in portfolio.positions.values() { + let side = match pos.side { + Side::Yes => "Yes", + Side::No => "No", + }; + sqlx::query( + "INSERT INTO positions \ + (ticker, side, quantity, avg_entry_price, \ + entry_time) VALUES (?1, ?2, ?3, ?4, ?5)", + ) + .bind(&pos.ticker) + .bind(side) + .bind(pos.quantity as i64) + .bind(pos.avg_entry_price.to_string()) + .bind(pos.entry_time.to_rfc3339()) + .execute(&self.pool) + .await?; + } + + Ok(()) + } + + pub async fn record_fill( + &self, + fill: &Fill, + pnl: Option, + exit_reason: Option<&str>, + ) -> anyhow::Result<()> { + let side = match fill.side { + Side::Yes => "Yes", + Side::No => "No", + }; + sqlx::query( + "INSERT INTO fills \ + (ticker, side, quantity, price, timestamp, pnl, \ + exit_reason) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", + ) + .bind(&fill.ticker) + .bind(side) + .bind(fill.quantity as i64) + .bind(fill.price.to_string()) + .bind(fill.timestamp.to_rfc3339()) + .bind(pnl.map(|p| p.to_string())) + .bind(exit_reason) + .execute(&self.pool) + .await?; + Ok(()) + } + + pub async fn snapshot_equity( + &self, + timestamp: DateTime, + equity: Decimal, + cash: Decimal, + positions_value: Decimal, + drawdown_pct: f64, + ) -> anyhow::Result<()> { + sqlx::query( + "INSERT INTO equity_snapshots \ + (timestamp, equity, cash, positions_value, \ + drawdown_pct) VALUES (?1, ?2, ?3, ?4, ?5)", + ) + .bind(timestamp.to_rfc3339()) + .bind(equity.to_string()) + .bind(cash.to_string()) + .bind(positions_value.to_string()) + .bind(drawdown_pct) + .execute(&self.pool) + .await?; + Ok(()) + } + + pub async fn record_circuit_breaker_event( + &self, + rule: &str, + details: &str, + action: &str, + ) -> anyhow::Result<()> { + sqlx::query( + "INSERT INTO circuit_breaker_events \ + (timestamp, rule, details, action) \ + VALUES (datetime('now'), ?1, ?2, ?3)", + ) + .bind(rule) + .bind(details) + .bind(action) + .execute(&self.pool) + .await?; + Ok(()) + } + + pub async fn record_pipeline_run( + &self, + timestamp: DateTime, + duration_ms: u64, + candidates_fetched: usize, + candidates_filtered: usize, + candidates_selected: usize, + signals_generated: usize, + fills_executed: usize, + errors: Option<&str>, + ) -> anyhow::Result<()> { + sqlx::query( + "INSERT INTO pipeline_runs \ + (timestamp, duration_ms, candidates_fetched, \ + candidates_filtered, candidates_selected, \ + signals_generated, fills_executed, errors) \ + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(timestamp.to_rfc3339()) + .bind(duration_ms as i64) + .bind(candidates_fetched as i64) + .bind(candidates_filtered as i64) + .bind(candidates_selected as i64) + .bind(signals_generated as i64) + .bind(fills_executed as i64) + .bind(errors) + .execute(&self.pool) + .await?; + Ok(()) + } + + pub async fn get_equity_curve( + &self, + ) -> anyhow::Result> { + let rows = sqlx::query_as::<_, (String, String, String, String, f64)>( + "SELECT timestamp, equity, cash, positions_value, \ + drawdown_pct FROM equity_snapshots \ + ORDER BY timestamp ASC", + ) + .fetch_all(&self.pool) + .await?; + + let mut snapshots = Vec::with_capacity(rows.len()); + for (ts, eq, cash, pv, dd) in rows { + snapshots.push(EquitySnapshot { + timestamp: ts.parse::>()?, + equity: Decimal::from_str(&eq)?, + cash: Decimal::from_str(&cash)?, + positions_value: Decimal::from_str(&pv)?, + drawdown_pct: dd, + }); + } + Ok(snapshots) + } + + pub async fn get_recent_fills( + &self, + limit: u32, + ) -> anyhow::Result> { + let rows = sqlx::query_as::< + _, + (String, String, i64, String, String, Option, Option), + >( + "SELECT ticker, side, quantity, price, timestamp, \ + pnl, exit_reason FROM fills \ + ORDER BY id DESC LIMIT ?1", + ) + .bind(limit as i64) + .fetch_all(&self.pool) + .await?; + + let mut fills = Vec::with_capacity(rows.len()); + for (ticker, side, qty, price, ts, pnl, reason) in rows { + fills.push(FillRecord { + ticker, + side: match side.as_str() { + "Yes" => Side::Yes, + _ => Side::No, + }, + quantity: qty as u64, + price: Decimal::from_str(&price)?, + timestamp: ts.parse::>()?, + pnl: pnl + .map(|p| Decimal::from_str(&p)) + .transpose()?, + exit_reason: reason, + }); + } + Ok(fills) + } + + pub async fn get_fills_since( + &self, + since: DateTime, + ) -> anyhow::Result { + let row = sqlx::query_as::<_, (i64,)>( + "SELECT COUNT(*) FROM fills \ + WHERE timestamp >= ?1", + ) + .bind(since.to_rfc3339()) + .fetch_one(&self.pool) + .await?; + Ok(row.0 as u32) + } + + pub async fn get_circuit_breaker_events( + &self, + limit: u32, + ) -> anyhow::Result> { + let rows = sqlx::query_as::<_, (String, String, String, String)>( + "SELECT timestamp, rule, details, action \ + FROM circuit_breaker_events \ + ORDER BY id DESC LIMIT ?1", + ) + .bind(limit as i64) + .fetch_all(&self.pool) + .await?; + + let mut events = Vec::with_capacity(rows.len()); + for (ts, rule, details, action) in rows { + events.push(CbEvent { + timestamp: ts.parse::>()?, + rule, + details, + action, + }); + } + Ok(events) + } + + pub async fn get_peak_equity( + &self, + ) -> anyhow::Result> { + let row = sqlx::query_as::<_, (Option,)>( + "SELECT MAX(equity) FROM equity_snapshots", + ) + .fetch_one(&self.pool) + .await?; + + match row.0 { + Some(s) => Ok(Some(Decimal::from_str(&s)?)), + None => Ok(None), + } + } +} + +#[derive(Debug, Clone)] +pub struct EquitySnapshot { + pub timestamp: DateTime, + pub equity: Decimal, + pub cash: Decimal, + pub positions_value: Decimal, + pub drawdown_pct: f64, +} + +#[derive(Debug, Clone)] +pub struct FillRecord { + pub ticker: String, + pub side: Side, + pub quantity: u64, + pub price: Decimal, + pub timestamp: DateTime, + pub pnl: Option, + pub exit_reason: Option, +} + +#[derive(Debug, Clone)] +pub struct CbEvent { + pub timestamp: DateTime, + pub rule: String, + pub details: String, + pub action: String, +} diff --git a/src/store/schema.rs b/src/store/schema.rs new file mode 100644 index 0000000..52935ba --- /dev/null +++ b/src/store/schema.rs @@ -0,0 +1,56 @@ +pub const MIGRATIONS: &str = r#" +CREATE TABLE IF NOT EXISTS portfolio_state ( + id INTEGER PRIMARY KEY CHECK (id = 1), + cash TEXT NOT NULL, + initial_capital TEXT NOT NULL, + updated_at TEXT NOT NULL DEFAULT (datetime('now')) +); + +CREATE TABLE IF NOT EXISTS positions ( + ticker TEXT PRIMARY KEY, + side TEXT NOT NULL, + quantity INTEGER NOT NULL, + avg_entry_price TEXT NOT NULL, + entry_time TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS fills ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ticker TEXT NOT NULL, + side TEXT NOT NULL, + quantity INTEGER NOT NULL, + price TEXT NOT NULL, + timestamp TEXT NOT NULL, + pnl TEXT, + exit_reason TEXT +); + +CREATE TABLE IF NOT EXISTS equity_snapshots ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + equity TEXT NOT NULL, + cash TEXT NOT NULL, + positions_value TEXT NOT NULL, + drawdown_pct REAL NOT NULL DEFAULT 0.0 +); + +CREATE TABLE IF NOT EXISTS circuit_breaker_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + rule TEXT NOT NULL, + details TEXT NOT NULL, + action TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS pipeline_runs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + duration_ms INTEGER NOT NULL, + candidates_fetched INTEGER NOT NULL DEFAULT 0, + candidates_filtered INTEGER NOT NULL DEFAULT 0, + candidates_selected INTEGER NOT NULL DEFAULT 0, + signals_generated INTEGER NOT NULL DEFAULT 0, + fills_executed INTEGER NOT NULL DEFAULT 0, + errors TEXT +); +"#; diff --git a/src/web/handlers.rs b/src/web/handlers.rs new file mode 100644 index 0000000..695c35a --- /dev/null +++ b/src/web/handlers.rs @@ -0,0 +1,514 @@ +use super::{AppState, BacktestRunStatus}; +use crate::backtest::Backtester; +use crate::data::HistoricalData; +use crate::execution::PositionSizingConfig; +use crate::types::{BacktestConfig, ExitConfig}; +use axum::extract::State; +use axum::http::StatusCode; +use axum::Json; +use chrono::Utc; +use rust_decimal::Decimal; +use rust_decimal::prelude::ToPrimitive; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::sync::Arc; +use tracing::{error, info}; + +#[derive(Serialize)] +pub struct StatusResponse { + pub state: String, + pub uptime_secs: u64, + pub last_tick: Option, + pub ticks_completed: u64, +} + +#[derive(Serialize)] +pub struct PortfolioResponse { + pub cash: f64, + pub equity: f64, + pub initial_capital: f64, + pub return_pct: f64, + pub drawdown_pct: f64, + pub positions_count: usize, +} + +#[derive(Serialize)] +pub struct PositionResponse { + pub ticker: String, + pub side: String, + pub quantity: u64, + pub entry_price: f64, + pub entry_time: String, + pub unrealized_pnl: f64, +} + +#[derive(Serialize)] +pub struct TradeResponse { + pub ticker: String, + pub side: String, + pub quantity: u64, + pub price: f64, + pub timestamp: String, + pub pnl: Option, + pub exit_reason: Option, +} + +#[derive(Serialize)] +pub struct EquityPoint { + pub timestamp: String, + pub equity: f64, + pub cash: f64, + pub positions_value: f64, + pub drawdown_pct: f64, +} + +#[derive(Serialize)] +pub struct CbResponse { + pub status: String, + pub events: Vec, +} + +#[derive(Serialize)] +pub struct CbEventResponse { + pub timestamp: String, + pub rule: String, + pub details: String, + pub action: String, +} + +pub async fn get_status( + State(state): State>, +) -> Json { + let status = state.engine.get_status().await; + Json(StatusResponse { + state: format!("{}", status.state), + uptime_secs: status.uptime_secs, + last_tick: status + .last_tick + .map(|t| t.to_rfc3339()), + ticks_completed: status.ticks_completed, + }) +} + +pub async fn get_portfolio( + State(state): State>, +) -> Json { + let ctx = state.engine.get_context().await; + let portfolio = &ctx.portfolio; + + let positions_value: f64 = portfolio + .positions + .values() + .map(|p| { + p.avg_entry_price.to_f64().unwrap_or(0.0) + * p.quantity as f64 + }) + .sum(); + + let cash = portfolio.cash.to_f64().unwrap_or(0.0); + let equity = cash + positions_value; + let initial = portfolio + .initial_capital + .to_f64() + .unwrap_or(10000.0); + let return_pct = if initial > 0.0 { + (equity - initial) / initial * 100.0 + } else { + 0.0 + }; + + let peak = state + .store + .get_peak_equity() + .await + .ok() + .flatten() + .and_then(|p| p.to_f64()) + .unwrap_or(equity); + + let drawdown_pct = if peak > 0.0 { + ((peak - equity) / peak * 100.0).max(0.0) + } else { + 0.0 + }; + + Json(PortfolioResponse { + cash, + equity, + initial_capital: initial, + return_pct, + drawdown_pct, + positions_count: portfolio.positions.len(), + }) +} + +pub async fn get_positions( + State(state): State>, +) -> Json> { + let ctx = state.engine.get_context().await; + let positions: Vec = ctx + .portfolio + .positions + .values() + .map(|p| { + let entry = p.avg_entry_price.to_f64().unwrap_or(0.0); + PositionResponse { + ticker: p.ticker.clone(), + side: format!("{:?}", p.side), + quantity: p.quantity, + entry_price: entry, + entry_time: p.entry_time.to_rfc3339(), + unrealized_pnl: 0.0, + } + }) + .collect(); + + Json(positions) +} + +pub async fn get_trades( + State(state): State>, +) -> Json> { + let fills = state + .store + .get_recent_fills(100) + .await + .unwrap_or_default(); + + let trades: Vec = fills + .into_iter() + .map(|f| TradeResponse { + ticker: f.ticker, + side: format!("{:?}", f.side), + quantity: f.quantity, + price: f.price.to_f64().unwrap_or(0.0), + timestamp: f.timestamp.to_rfc3339(), + pnl: f.pnl.and_then(|p| p.to_f64()), + exit_reason: f.exit_reason, + }) + .collect(); + + Json(trades) +} + +pub async fn get_equity( + State(state): State>, +) -> Json> { + let snapshots = state + .store + .get_equity_curve() + .await + .unwrap_or_default(); + + let points: Vec = snapshots + .into_iter() + .map(|s| EquityPoint { + timestamp: s.timestamp.to_rfc3339(), + equity: s.equity.to_f64().unwrap_or(0.0), + cash: s.cash.to_f64().unwrap_or(0.0), + positions_value: s.positions_value.to_f64().unwrap_or(0.0), + drawdown_pct: s.drawdown_pct, + }) + .collect(); + + Json(points) +} + +pub async fn get_circuit_breaker( + State(state): State>, +) -> Json { + let engine_status = state.engine.get_status().await; + let cb_status = match engine_status.state { + crate::engine::EngineState::Paused(ref reason) => { + format!("tripped: {}", reason) + } + _ => "ok".to_string(), + }; + + let events = state + .store + .get_circuit_breaker_events(20) + .await + .unwrap_or_default(); + + let event_responses: Vec = events + .into_iter() + .map(|e| CbEventResponse { + timestamp: e.timestamp.to_rfc3339(), + rule: e.rule, + details: e.details, + action: e.action, + }) + .collect(); + + Json(CbResponse { + status: cb_status, + events: event_responses, + }) +} + +pub async fn post_pause( + State(state): State>, +) -> StatusCode { + state + .engine + .pause("manual pause via API".to_string()) + .await; + StatusCode::OK +} + +pub async fn post_resume( + State(state): State>, +) -> StatusCode { + state.engine.resume().await; + StatusCode::OK +} + +#[derive(Deserialize)] +pub struct BacktestRequest { + pub start: String, + pub end: String, + pub data_dir: Option, + pub capital: Option, + pub max_positions: Option, + pub max_position: Option, + pub interval_hours: Option, + pub kelly_fraction: Option, + pub max_position_pct: Option, + pub take_profit: Option, + pub stop_loss: Option, + pub max_hold_hours: Option, +} + +#[derive(Serialize)] +pub struct BacktestStatusResponse { + pub status: String, + pub elapsed_secs: Option, + pub error: Option, + pub phase: Option, + pub current_step: Option, + pub total_steps: Option, + pub progress_pct: Option, +} + +#[derive(Serialize)] +pub struct BacktestErrorResponse { + pub error: String, +} + +pub async fn post_backtest_run( + State(state): State>, + Json(req): Json, +) -> Result)> { + { + let guard = state.backtest.lock().await; + if matches!(guard.status, BacktestRunStatus::Running { .. }) { + return Err(( + StatusCode::CONFLICT, + Json(BacktestErrorResponse { + error: "backtest already running".into(), + }), + )); + } + } + + let start_time = crate::parse_date(&req.start).map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(BacktestErrorResponse { + error: format!("invalid start date: {}", e), + }), + ) + })?; + let end_time = crate::parse_date(&req.end).map_err(|e| { + ( + StatusCode::BAD_REQUEST, + Json(BacktestErrorResponse { + error: format!("invalid end date: {}", e), + }), + ) + })?; + + let data_dir = if let Some(ref dir) = req.data_dir { + PathBuf::from(dir) + } else { + state.data_dir.clone() + }; + + if !data_dir.exists() { + return Err(( + StatusCode::BAD_REQUEST, + Json(BacktestErrorResponse { + error: format!( + "data directory not found: {}", + data_dir.display() + ), + }), + )); + } + + info!( + start = %start_time, + end = %end_time, + data_dir = %data_dir.display(), + "starting backtest from web UI" + ); + + let progress = Arc::new(super::BacktestProgress::new(0)); + + { + let mut guard = state.backtest.lock().await; + guard.status = + BacktestRunStatus::Running { started_at: Utc::now() }; + guard.progress = Some(progress.clone()); + guard.result = None; + guard.error = None; + } + + let backtest_state = state.backtest.clone(); + let progress = progress.clone(); + + let capital = req.capital.unwrap_or(10000.0); + let max_positions = req.max_positions.unwrap_or(100); + let max_position = req.max_position.unwrap_or(100); + let interval_hours = req.interval_hours.unwrap_or(1); + let kelly_fraction = req.kelly_fraction.unwrap_or(0.40); + let max_position_pct = req.max_position_pct.unwrap_or(0.30); + let take_profit = req.take_profit.unwrap_or(0.50); + let stop_loss = req.stop_loss.unwrap_or(0.99); + let max_hold_hours = req.max_hold_hours.unwrap_or(48); + + tokio::spawn(async move { + let data = match tokio::task::spawn_blocking(move || { + HistoricalData::load(&data_dir) + }) + .await + { + Ok(Ok(d)) => Arc::new(d), + Ok(Err(e)) => { + let mut guard = backtest_state.lock().await; + guard.status = BacktestRunStatus::Failed; + guard.error = + Some(format!("failed to load data: {}", e)); + error!(error = %e, "backtest data load failed"); + return; + } + Err(e) => { + let mut guard = backtest_state.lock().await; + guard.status = BacktestRunStatus::Failed; + guard.error = + Some(format!("task join error: {}", e)); + return; + } + }; + + let config = BacktestConfig { + start_time, + end_time, + interval: chrono::Duration::hours(interval_hours), + initial_capital: Decimal::try_from(capital) + .unwrap_or(Decimal::new(10000, 0)), + max_position_size: max_position, + max_positions, + }; + + let sizing_config = PositionSizingConfig { + kelly_fraction, + max_position_pct, + min_position_size: 10, + max_position_size: max_position, + }; + + let exit_config = ExitConfig { + take_profit_pct: take_profit, + stop_loss_pct: stop_loss, + max_hold_hours, + score_reversal_threshold: -0.3, + }; + + let backtester = Backtester::with_configs( + config, + data, + sizing_config, + exit_config, + ) + .with_progress(progress); + let result = backtester.run().await; + + let mut guard = backtest_state.lock().await; + guard.status = BacktestRunStatus::Complete; + guard.result = Some(result); + }); + + Ok(StatusCode::OK) +} + +pub async fn get_backtest_status( + State(state): State>, +) -> Json { + let guard = state.backtest.lock().await; + let (status_str, elapsed, error) = match &guard.status { + BacktestRunStatus::Idle => { + ("idle".to_string(), None, None) + } + BacktestRunStatus::Running { started_at } => { + let elapsed = Utc::now() + .signed_duration_since(*started_at) + .num_seconds() + .max(0) as u64; + ("running".to_string(), Some(elapsed), None) + } + BacktestRunStatus::Complete => { + ("complete".to_string(), None, None) + } + BacktestRunStatus::Failed => { + ("failed".to_string(), None, guard.error.clone()) + } + }; + + let (phase, current_step, total_steps, progress_pct) = + if let Some(ref p) = guard.progress { + let current = p.current_step.load( + std::sync::atomic::Ordering::Relaxed, + ); + let total = p.total_steps.load( + std::sync::atomic::Ordering::Relaxed, + ); + let pct = if total > 0 { + current as f64 / total as f64 * 100.0 + } else { + 0.0 + }; + ( + Some(p.phase_name().to_string()), + Some(current), + Some(total), + Some(pct), + ) + } else { + (None, None, None, None) + }; + + Json(BacktestStatusResponse { + status: status_str, + elapsed_secs: elapsed, + error, + phase, + current_step, + total_steps, + progress_pct, + }) +} + +pub async fn get_backtest_result( + State(state): State>, +) -> Result< + Json, + StatusCode, +> { + let guard = state.backtest.lock().await; + match &guard.result { + Some(result) => Ok(Json(result.clone())), + None => Err(StatusCode::NOT_FOUND), + } +} diff --git a/src/web/mod.rs b/src/web/mod.rs new file mode 100644 index 0000000..92e55d7 --- /dev/null +++ b/src/web/mod.rs @@ -0,0 +1,103 @@ +mod handlers; + +use crate::engine::PaperTradingEngine; +use crate::metrics::BacktestResult; +use crate::store::SqliteStore; +use axum::routing::{get, post}; +use axum::Router; +use chrono::{DateTime, Utc}; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::sync::broadcast; +use tower_http::services::ServeDir; + +pub enum BacktestRunStatus { + Idle, + Running { started_at: DateTime }, + Complete, + Failed, +} + +pub struct BacktestProgress { + pub phase: std::sync::atomic::AtomicU8, + pub current_step: std::sync::atomic::AtomicU64, + pub total_steps: std::sync::atomic::AtomicU64, +} + +impl BacktestProgress { + pub const PHASE_LOADING: u8 = 0; + pub const PHASE_RUNNING: u8 = 1; + + pub fn new(total_steps: u64) -> Self { + Self { + phase: std::sync::atomic::AtomicU8::new( + Self::PHASE_LOADING, + ), + current_step: std::sync::atomic::AtomicU64::new(0), + total_steps: std::sync::atomic::AtomicU64::new( + total_steps, + ), + } + } + + pub fn phase_name(&self) -> &'static str { + match self + .phase + .load(std::sync::atomic::Ordering::Relaxed) + { + Self::PHASE_LOADING => "loading data", + Self::PHASE_RUNNING => "simulating", + _ => "unknown", + } + } +} + +pub struct BacktestState { + pub status: BacktestRunStatus, + pub progress: Option>, + pub result: Option, + pub error: Option, +} + +pub struct AppState { + pub engine: Arc, + pub store: Arc, + pub shutdown_tx: broadcast::Sender<()>, + pub backtest: Arc>, + pub data_dir: PathBuf, +} + +pub fn build_router(state: Arc) -> Router { + Router::new() + .route("/api/status", get(handlers::get_status)) + .route("/api/portfolio", get(handlers::get_portfolio)) + .route("/api/positions", get(handlers::get_positions)) + .route("/api/trades", get(handlers::get_trades)) + .route("/api/equity", get(handlers::get_equity)) + .route( + "/api/circuit-breaker", + get(handlers::get_circuit_breaker), + ) + .route( + "/api/control/pause", + post(handlers::post_pause), + ) + .route( + "/api/control/resume", + post(handlers::post_resume), + ) + .route( + "/api/backtest/run", + post(handlers::post_backtest_run), + ) + .route( + "/api/backtest/status", + get(handlers::get_backtest_status), + ) + .route( + "/api/backtest/result", + get(handlers::get_backtest_result), + ) + .fallback_service(ServeDir::new("static")) + .with_state(state) +} diff --git a/static/index.html b/static/index.html new file mode 100644 index 0000000..ea41460 --- /dev/null +++ b/static/index.html @@ -0,0 +1,1124 @@ + + + + + + kalshi paper trading + + + + + +
+
kalshi paper
+
+ loading... +
+
+
+
+ +
+
+ + + +
+
+ +
+
+
+
+
+ +
+
+
positions
+

loading...

+
+ +
+
recent trades
+

loading...

+
+
+
+ +
+ +
+
backtest
+ +
+
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+ +
+
+ + + + + + + + +
+
+ + + +