backup for jake

This commit is contained in:
Nicholai Vogel 2026-01-25 01:20:44 -07:00
parent 3621d93643
commit 5dc05ba185
21 changed files with 4084 additions and 189 deletions

View File

@ -19,6 +19,13 @@ uuid = { version = "1", features = ["v4"] }
rust_decimal = { version = "1", features = ["serde"] } rust_decimal = { version = "1", features = ["serde"] }
rust_decimal_macros = "1" rust_decimal_macros = "1"
reqwest = { version = "0.12", features = ["json"] }
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "chrono"] }
axum = "0.8"
tower-http = { version = "0.6", features = ["cors", "fs"] }
toml = "0.8"
governor = "0.8"
ort = { version = "2.0.0-rc.11", optional = true } ort = { version = "2.0.0-rc.11", optional = true }
ndarray = { version = "0.16", optional = true } ndarray = { version = "0.16", optional = true }

31
config.toml Normal file
View File

@ -0,0 +1,31 @@
mode = "paper"
[kalshi]
base_url = "https://api.elections.kalshi.com/trade-api/v2"
poll_interval_secs = 60
rate_limit_per_sec = 2
[trading]
initial_capital = 10000.0
max_positions = 100
kelly_fraction = 0.25
max_position_pct = 0.10
take_profit_pct = 0.50
stop_loss_pct = 0.99
max_hold_hours = 48
[persistence]
db_path = "kalshi-paper.db"
[web]
enabled = true
bind_addr = "127.0.0.1:3030"
[circuit_breaker]
max_drawdown_pct = 0.15
max_daily_loss_pct = 0.05
max_positions = 100
max_single_position_pct = 0.10
max_consecutive_errors = 5
max_fills_per_hour = 200
max_fills_per_day = 1000

BIN
kalshi-paper.db Normal file

Binary file not shown.

168
src/api/client.rs Normal file
View File

@ -0,0 +1,168 @@
use super::types::{ApiMarket, ApiTrade, MarketsResponse, TradesResponse};
use crate::config::KalshiConfig;
use governor::{Quota, RateLimiter};
use std::num::NonZeroU32;
use std::sync::Arc;
use tracing::{debug, warn};
pub struct KalshiClient {
http: reqwest::Client,
base_url: String,
limiter: Arc<
RateLimiter<
governor::state::NotKeyed,
governor::state::InMemoryState,
governor::clock::DefaultClock,
>,
>,
}
impl KalshiClient {
pub fn new(config: &KalshiConfig) -> Self {
let quota = Quota::per_second(
NonZeroU32::new(config.rate_limit_per_sec).unwrap_or(
NonZeroU32::new(5).unwrap(),
),
);
let limiter = Arc::new(RateLimiter::direct(quota));
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("failed to build http client");
Self {
http,
base_url: config.base_url.clone(),
limiter,
}
}
async fn rate_limit(&self) {
self.limiter.until_ready().await;
}
async fn request_with_retry(
&self,
url: &str,
) -> anyhow::Result<reqwest::Response> {
for attempt in 0..5u32 {
if attempt > 0 {
let delay = std::time::Duration::from_secs(
3u64.pow(attempt),
);
debug!(
attempt = attempt,
delay_secs = delay.as_secs(),
"retrying after rate limit"
);
tokio::time::sleep(delay).await;
}
self.rate_limit().await;
let resp = self.http.get(url).send().await?;
if resp.status() == reqwest::StatusCode::TOO_MANY_REQUESTS {
warn!(
attempt = attempt,
"rate limited (429), backing off"
);
continue;
}
if !resp.status().is_success() {
let status = resp.status();
let body =
resp.text().await.unwrap_or_default();
anyhow::bail!(
"kalshi API error {}: {}",
status,
body
);
}
return Ok(resp);
}
anyhow::bail!("exhausted retries after rate limiting")
}
pub async fn get_open_markets(
&self,
) -> anyhow::Result<Vec<ApiMarket>> {
let mut all_markets = Vec::new();
let mut cursor: Option<String> = None;
let max_pages = 5; // cap at 1000 markets to avoid rate limiting
for page_num in 0..max_pages {
self.rate_limit().await;
let mut url = format!(
"{}/markets?status=open&limit=200",
self.base_url
);
if let Some(ref c) = cursor {
url.push_str(&format!("&cursor={}", c));
}
debug!(
url = %url,
page = page_num,
"fetching markets"
);
let resp = self
.request_with_retry(&url)
.await?;
let page: MarketsResponse = resp.json().await?;
let count = page.markets.len();
all_markets.extend(page.markets);
if !page.cursor.is_empty() && count > 0 {
cursor = Some(page.cursor);
} else {
break;
}
}
debug!(
total = all_markets.len(),
"fetched open markets"
);
Ok(all_markets)
}
pub async fn get_market_trades(
&self,
ticker: &str,
limit: u32,
) -> anyhow::Result<Vec<ApiTrade>> {
self.rate_limit().await;
let url = format!(
"{}/markets/trades?ticker={}&limit={}",
self.base_url, ticker, limit
);
debug!(ticker = %ticker, "fetching trades");
let resp = match self
.request_with_retry(&url)
.await
{
Ok(r) => r,
Err(e) => {
warn!(
ticker = %ticker,
error = %e,
"failed to fetch trades"
);
return Ok(Vec::new());
}
};
let data: TradesResponse = resp.json().await?;
Ok(data.trades)
}
}

5
src/api/mod.rs Normal file
View File

@ -0,0 +1,5 @@
pub mod client;
pub mod types;
pub use client::KalshiClient;
pub use types::*;

119
src/api/types.rs Normal file
View File

@ -0,0 +1,119 @@
use chrono::{DateTime, Utc};
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct MarketsResponse {
pub markets: Vec<ApiMarket>,
#[serde(default)]
pub cursor: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiMarket {
pub ticker: String,
#[serde(default)]
pub title: String,
#[serde(default)]
pub event_ticker: String,
#[serde(default)]
pub status: String,
pub open_time: DateTime<Utc>,
pub close_time: DateTime<Utc>,
#[serde(default)]
pub yes_ask: i64,
#[serde(default)]
pub yes_bid: i64,
#[serde(default)]
pub no_ask: i64,
#[serde(default)]
pub no_bid: i64,
#[serde(default)]
pub last_price: i64,
#[serde(default)]
pub volume: i64,
#[serde(default)]
pub volume_24h: i64,
#[serde(default)]
pub result: String,
#[serde(default)]
pub subtitle: String,
}
impl ApiMarket {
/// returns yes price as a fraction (0.0 - 1.0)
/// prices from API are in cents (0-100)
pub fn mid_yes_price(&self) -> f64 {
let bid = self.yes_bid as f64 / 100.0;
let ask = self.yes_ask as f64 / 100.0;
if bid > 0.0 && ask > 0.0 {
(bid + ask) / 2.0
} else if bid > 0.0 {
bid
} else if ask > 0.0 {
ask
} else if self.last_price > 0 {
self.last_price as f64 / 100.0
} else {
0.0
}
}
pub fn category_from_event(&self) -> String {
let lower = self.event_ticker.to_lowercase();
if lower.contains("nba")
|| lower.contains("nfl")
|| lower.contains("sport")
{
"sports".to_string()
} else if lower.contains("btc")
|| lower.contains("crypto")
|| lower.contains("eth")
{
"crypto".to_string()
} else if lower.contains("weather")
|| lower.contains("temp")
{
"weather".to_string()
} else if lower.contains("econ")
|| lower.contains("fed")
|| lower.contains("cpi")
|| lower.contains("gdp")
{
"economics".to_string()
} else if lower.contains("elect")
|| lower.contains("polit")
|| lower.contains("trump")
|| lower.contains("biden")
{
"politics".to_string()
} else {
"other".to_string()
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TradesResponse {
#[serde(default)]
pub trades: Vec<ApiTrade>,
#[serde(default)]
pub cursor: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiTrade {
#[serde(default)]
pub trade_id: String,
#[serde(default)]
pub ticker: String,
pub created_time: DateTime<Utc>,
#[serde(default)]
pub yes_price: i64,
#[serde(default)]
pub no_price: i64,
#[serde(default)]
pub count: i64,
#[serde(default)]
pub taker_side: String,
}

View File

@ -1,5 +1,5 @@
use crate::data::HistoricalData; use crate::data::HistoricalData;
use crate::execution::{Executor, PositionSizingConfig}; use crate::execution::{BacktestExecutor, OrderExecutor, PositionSizingConfig};
use crate::metrics::{BacktestResult, MetricsCollector}; use crate::metrics::{BacktestResult, MetricsCollector};
use crate::pipeline::{ use crate::pipeline::{
AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer, AlreadyPositionedFilter, BollingerMeanReversionScorer, CategoryWeightedScorer,
@ -11,11 +11,13 @@ use crate::types::{
BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType, BacktestConfig, ExitConfig, Fill, MarketResult, Portfolio, Side, Trade, TradeType,
TradingContext, TradingContext,
}; };
use crate::web::BacktestProgress;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use rust_decimal::Decimal; use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive; use rust_decimal::prelude::ToPrimitive;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::Ordering;
use tracing::info; use tracing::info;
/// resolves any positions in markets that have closed /// resolves any positions in markets that have closed
@ -86,19 +88,21 @@ pub struct Backtester {
config: BacktestConfig, config: BacktestConfig,
data: Arc<HistoricalData>, data: Arc<HistoricalData>,
pipeline: TradingPipeline, pipeline: TradingPipeline,
executor: Executor, executor: BacktestExecutor,
progress: Option<Arc<BacktestProgress>>,
} }
impl Backtester { impl Backtester {
pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self { pub fn new(config: BacktestConfig, data: Arc<HistoricalData>) -> Self {
let pipeline = Self::build_default_pipeline(data.clone(), &config); let pipeline = Self::build_default_pipeline(data.clone(), &config);
let executor = Executor::new(data.clone(), 10, config.max_position_size); let executor = BacktestExecutor::new(data.clone(), 10, config.max_position_size);
Self { Self {
config, config,
data, data,
pipeline, pipeline,
executor, executor,
progress: None,
} }
} }
@ -109,7 +113,7 @@ impl Backtester {
exit_config: ExitConfig, exit_config: ExitConfig,
) -> Self { ) -> Self {
let pipeline = Self::build_default_pipeline(data.clone(), &config); let pipeline = Self::build_default_pipeline(data.clone(), &config);
let executor = Executor::new(data.clone(), 10, config.max_position_size) let executor = BacktestExecutor::new(data.clone(), 10, config.max_position_size)
.with_sizing_config(sizing_config) .with_sizing_config(sizing_config)
.with_exit_config(exit_config); .with_exit_config(exit_config);
@ -118,9 +122,15 @@ impl Backtester {
data, data,
pipeline, pipeline,
executor, executor,
progress: None,
} }
} }
pub fn with_progress(mut self, progress: Arc<BacktestProgress>) -> Self {
self.progress = Some(progress);
self
}
pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self { pub fn with_pipeline(mut self, pipeline: TradingPipeline) -> Self {
self.pipeline = pipeline; self.pipeline = pipeline;
self self
@ -160,13 +170,30 @@ impl Backtester {
let mut current_time = self.config.start_time; let mut current_time = self.config.start_time;
let total_steps = (self.config.end_time - self.config.start_time)
.num_seconds()
/ self.config.interval.num_seconds().max(1);
if let Some(ref progress) = self.progress {
progress.total_steps.store(
total_steps as u64,
Ordering::Relaxed,
);
progress.phase.store(
BacktestProgress::PHASE_RUNNING,
Ordering::Relaxed,
);
}
info!( info!(
start = %self.config.start_time, start = %self.config.start_time,
end = %self.config.end_time, end = %self.config.end_time,
interval_hours = self.config.interval.num_hours(), interval_hours = self.config.interval.num_hours(),
total_steps = total_steps,
"starting backtest" "starting backtest"
); );
let mut step: u64 = 0;
while current_time < self.config.end_time { while current_time < self.config.end_time {
context.timestamp = current_time; context.timestamp = current_time;
context.request_id = uuid::Uuid::new_v4().to_string(); context.request_id = uuid::Uuid::new_v4().to_string();
@ -237,7 +264,7 @@ impl Backtester {
break; break;
} }
if let Some(fill) = self.executor.execute_signal(&signal, &context) { if let Some(fill) = self.executor.execute_signal(&signal, &context).await {
info!( info!(
ticker = %fill.ticker, ticker = %fill.ticker,
side = ?fill.side, side = ?fill.side,
@ -272,6 +299,13 @@ impl Backtester {
let market_prices = self.get_current_prices(current_time); let market_prices = self.get_current_prices(current_time);
metrics.record(current_time, &context.portfolio, &market_prices); metrics.record(current_time, &context.portfolio, &market_prices);
step += 1;
if let Some(ref progress) = self.progress {
progress
.current_step
.store(step, Ordering::Relaxed);
}
current_time = current_time + self.config.interval; current_time = current_time + self.config.interval;
} }

124
src/config.rs Normal file
View File

@ -0,0 +1,124 @@
use serde::Deserialize;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RunMode {
Backtest,
Paper,
Live,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AppConfig {
pub mode: RunMode,
pub kalshi: KalshiConfig,
pub trading: TradingConfig,
pub persistence: PersistenceConfig,
pub web: WebConfig,
pub circuit_breaker: CircuitBreakerConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct KalshiConfig {
pub base_url: String,
pub poll_interval_secs: u64,
pub rate_limit_per_sec: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TradingConfig {
pub initial_capital: f64,
pub max_positions: usize,
pub kelly_fraction: f64,
pub max_position_pct: f64,
pub take_profit_pct: Option<f64>,
pub stop_loss_pct: Option<f64>,
pub max_hold_hours: Option<i64>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct PersistenceConfig {
pub db_path: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct WebConfig {
pub enabled: bool,
pub bind_addr: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct CircuitBreakerConfig {
pub max_drawdown_pct: f64,
pub max_daily_loss_pct: f64,
pub max_positions: Option<usize>,
pub max_single_position_pct: Option<f64>,
pub max_consecutive_errors: Option<u32>,
pub max_fills_per_hour: Option<u32>,
pub max_fills_per_day: Option<u32>,
}
impl AppConfig {
pub fn load(path: &Path) -> anyhow::Result<Self> {
let content = std::fs::read_to_string(path)?;
let config: Self = toml::from_str(&content)?;
Ok(config)
}
}
impl Default for KalshiConfig {
fn default() -> Self {
Self {
base_url: "https://api.elections.kalshi.com/trade-api/v2"
.to_string(),
poll_interval_secs: 300,
rate_limit_per_sec: 5,
}
}
}
impl Default for TradingConfig {
fn default() -> Self {
Self {
initial_capital: 10000.0,
max_positions: 100,
kelly_fraction: 0.25,
max_position_pct: 0.10,
take_profit_pct: Some(0.50),
stop_loss_pct: Some(0.99),
max_hold_hours: Some(48),
}
}
}
impl Default for PersistenceConfig {
fn default() -> Self {
Self {
db_path: "kalshi-paper.db".to_string(),
}
}
}
impl Default for WebConfig {
fn default() -> Self {
Self {
enabled: true,
bind_addr: "127.0.0.1:3030".to_string(),
}
}
}
impl Default for CircuitBreakerConfig {
fn default() -> Self {
Self {
max_drawdown_pct: 0.15,
max_daily_loss_pct: 0.05,
max_positions: Some(100),
max_single_position_pct: Some(0.10),
max_consecutive_errors: Some(5),
max_fills_per_hour: Some(50),
max_fills_per_day: Some(200),
}
}
}

View File

@ -0,0 +1,240 @@
use crate::config::CircuitBreakerConfig;
use crate::store::SqliteStore;
use chrono::{Duration, Utc};
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use std::sync::Arc;
use tracing::warn;
#[derive(Debug, Clone, PartialEq)]
pub enum CbStatus {
Ok,
Tripped(String),
}
pub struct CircuitBreaker {
config: CircuitBreakerConfig,
store: Arc<SqliteStore>,
consecutive_errors: u32,
}
impl CircuitBreaker {
pub fn new(
config: CircuitBreakerConfig,
store: Arc<SqliteStore>,
) -> Self {
Self {
config,
store,
consecutive_errors: 0,
}
}
pub fn record_error(&mut self) {
self.consecutive_errors += 1;
}
pub fn record_success(&mut self) {
self.consecutive_errors = 0;
}
pub async fn check(
&self,
current_equity: Decimal,
peak_equity: Decimal,
positions_count: usize,
) -> CbStatus {
if let Some(status) =
self.check_drawdown(current_equity, peak_equity)
{
return status;
}
if let Some(status) = self.check_daily_loss().await {
return status;
}
if let Some(status) =
self.check_max_positions(positions_count)
{
return status;
}
if let Some(status) = self.check_consecutive_errors() {
return status;
}
if let Some(status) = self.check_fill_rate().await {
return status;
}
CbStatus::Ok
}
fn check_drawdown(
&self,
current_equity: Decimal,
peak_equity: Decimal,
) -> Option<CbStatus> {
if peak_equity <= Decimal::ZERO {
return None;
}
let drawdown = (peak_equity - current_equity)
.to_f64()
.unwrap_or(0.0)
/ peak_equity.to_f64().unwrap_or(1.0);
if drawdown >= self.config.max_drawdown_pct {
let msg = format!(
"drawdown {:.1}% exceeds max {:.1}%",
drawdown * 100.0,
self.config.max_drawdown_pct * 100.0
);
warn!(rule = "max_drawdown", %msg);
self.log_event("max_drawdown", &msg, "pause");
return Some(CbStatus::Tripped(msg));
}
None
}
async fn check_daily_loss(&self) -> Option<CbStatus> {
let today_start =
Utc::now().date_naive().and_hms_opt(0, 0, 0)?;
let today_utc = chrono::TimeZone::from_utc_datetime(
&Utc,
&today_start,
);
let fills = self
.store
.get_recent_fills(1000)
.await
.unwrap_or_default();
let daily_pnl: f64 = fills
.iter()
.filter(|f| f.timestamp >= today_utc)
.filter_map(|f| {
f.pnl.as_ref()?.to_f64()
})
.sum();
let peak = self
.store
.get_peak_equity()
.await
.ok()
.flatten()
.unwrap_or(Decimal::new(10000, 0));
let peak_f64 = peak.to_f64().unwrap_or(10000.0);
if peak_f64 <= 0.0 {
return None;
}
let daily_loss_pct = (-daily_pnl) / peak_f64;
if daily_loss_pct >= self.config.max_daily_loss_pct {
let msg = format!(
"daily loss {:.1}% exceeds max {:.1}%",
daily_loss_pct * 100.0,
self.config.max_daily_loss_pct * 100.0
);
warn!(rule = "max_daily_loss", %msg);
self.log_event("max_daily_loss", &msg, "pause");
return Some(CbStatus::Tripped(msg));
}
None
}
fn check_max_positions(
&self,
count: usize,
) -> Option<CbStatus> {
let max = self.config.max_positions.unwrap_or(100);
if count >= max {
let msg = format!(
"positions {} at max {}",
count, max
);
warn!(rule = "max_positions", %msg);
return Some(CbStatus::Tripped(msg));
}
None
}
fn check_consecutive_errors(&self) -> Option<CbStatus> {
let max = self.config.max_consecutive_errors.unwrap_or(5);
if self.consecutive_errors >= max {
let msg = format!(
"{} consecutive errors (max {})",
self.consecutive_errors, max
);
warn!(rule = "consecutive_errors", %msg);
self.log_event(
"consecutive_errors",
&msg,
"pause",
);
return Some(CbStatus::Tripped(msg));
}
None
}
async fn check_fill_rate(&self) -> Option<CbStatus> {
let max_hourly =
self.config.max_fills_per_hour.unwrap_or(50);
let max_daily =
self.config.max_fills_per_day.unwrap_or(200);
let one_hour_ago = Utc::now() - Duration::hours(1);
let hourly_fills = self
.store
.get_fills_since(one_hour_ago)
.await
.unwrap_or(0);
if hourly_fills >= max_hourly {
let msg = format!(
"{} fills/hour exceeds max {}",
hourly_fills, max_hourly
);
warn!(rule = "fill_rate_hourly", %msg);
self.log_event("fill_rate_hourly", &msg, "pause");
return Some(CbStatus::Tripped(msg));
}
let today_start = Utc::now() - Duration::hours(24);
let daily_fills = self
.store
.get_fills_since(today_start)
.await
.unwrap_or(0);
if daily_fills >= max_daily {
let msg = format!(
"{} fills/day exceeds max {}",
daily_fills, max_daily
);
warn!(rule = "fill_rate_daily", %msg);
self.log_event("fill_rate_daily", &msg, "pause");
return Some(CbStatus::Tripped(msg));
}
None
}
fn log_event(&self, rule: &str, details: &str, action: &str) {
let store = self.store.clone();
let rule = rule.to_string();
let details = details.to_string();
let action = action.to_string();
tokio::spawn(async move {
let _ = store
.record_circuit_breaker_event(
&rule, &details, &action,
)
.await;
});
}
}

506
src/engine/mod.rs Normal file
View File

@ -0,0 +1,506 @@
pub mod circuit_breaker;
pub mod state;
pub use circuit_breaker::{CbStatus, CircuitBreaker};
pub use state::EngineState;
use crate::api::KalshiClient;
use crate::config::AppConfig;
use crate::execution::OrderExecutor;
use crate::paper_executor::PaperExecutor;
use crate::pipeline::{
AlreadyPositionedFilter, BollingerMeanReversionScorer,
CategoryWeightedScorer, Filter, LiveKalshiSource,
LiquidityFilter, MeanReversionScorer, MomentumScorer,
MultiTimeframeMomentumScorer, OrderFlowScorer, Scorer,
Selector, Source, TimeDecayScorer, TimeToCloseFilter,
TopKSelector, TradingPipeline, VolumeScorer,
};
use crate::store::SqliteStore;
use crate::types::{Portfolio, Trade, TradeType, TradingContext};
use chrono::Utc;
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{broadcast, Mutex, RwLock};
use tracing::{error, info, warn};
pub struct EngineStatus {
pub state: EngineState,
pub uptime_secs: u64,
pub last_tick: Option<chrono::DateTime<Utc>>,
pub ticks_completed: u64,
}
pub struct PaperTradingEngine {
config: AppConfig,
store: Arc<SqliteStore>,
executor: Arc<PaperExecutor>,
pipeline: Mutex<TradingPipeline>,
circuit_breaker: Mutex<CircuitBreaker>,
state: RwLock<EngineState>,
context: RwLock<TradingContext>,
shutdown_tx: broadcast::Sender<()>,
start_time: Instant,
ticks: RwLock<u64>,
last_tick: RwLock<Option<chrono::DateTime<Utc>>>,
}
impl PaperTradingEngine {
pub async fn new(
config: AppConfig,
store: Arc<SqliteStore>,
executor: Arc<PaperExecutor>,
client: Arc<KalshiClient>,
) -> anyhow::Result<Self> {
let (shutdown_tx, _) = broadcast::channel(1);
let pipeline =
Self::build_pipeline(client, &config);
let circuit_breaker = CircuitBreaker::new(
config.circuit_breaker.clone(),
store.clone(),
);
let initial_capital = Decimal::try_from(
config.trading.initial_capital,
)
.unwrap_or(Decimal::new(10000, 0));
let portfolio = store
.load_portfolio()
.await?
.unwrap_or_else(|| Portfolio::new(initial_capital));
let mut ctx = TradingContext::new(
portfolio.initial_capital,
Utc::now(),
);
ctx.portfolio = portfolio;
Ok(Self {
config,
store,
executor,
pipeline: Mutex::new(pipeline),
circuit_breaker: Mutex::new(circuit_breaker),
state: RwLock::new(EngineState::Starting),
context: RwLock::new(ctx),
shutdown_tx,
start_time: Instant::now(),
ticks: RwLock::new(0),
last_tick: RwLock::new(None),
})
}
fn build_pipeline(
client: Arc<KalshiClient>,
config: &AppConfig,
) -> TradingPipeline {
let sources: Vec<Box<dyn Source>> =
vec![Box::new(LiveKalshiSource::new(client))];
let max_pos_size =
(config.trading.initial_capital
* config.trading.max_position_pct)
as u64;
let filters: Vec<Box<dyn Filter>> = vec![
Box::new(TimeToCloseFilter::new(2, Some(720))),
Box::new(AlreadyPositionedFilter::new(
max_pos_size.max(100),
)),
];
let scorers: Vec<Box<dyn Scorer>> = vec![
Box::new(MomentumScorer::new(6)),
Box::new(
MultiTimeframeMomentumScorer::default_windows(),
),
Box::new(MeanReversionScorer::new(24)),
Box::new(
BollingerMeanReversionScorer::default_config(),
),
Box::new(VolumeScorer::new(6)),
Box::new(OrderFlowScorer::new()),
Box::new(TimeDecayScorer::new()),
Box::new(CategoryWeightedScorer::with_defaults()),
];
let max_positions = config.trading.max_positions;
let selector: Box<dyn Selector> =
Box::new(TopKSelector::new(max_positions));
TradingPipeline::new(
sources,
filters,
scorers,
selector,
max_positions,
)
}
pub fn shutdown_handle(&self) -> broadcast::Sender<()> {
self.shutdown_tx.clone()
}
pub async fn get_status(&self) -> EngineStatus {
EngineStatus {
state: self.state.read().await.clone(),
uptime_secs: self.start_time.elapsed().as_secs(),
last_tick: *self.last_tick.read().await,
ticks_completed: *self.ticks.read().await,
}
}
pub async fn get_context(&self) -> TradingContext {
self.context.read().await.clone()
}
pub async fn pause(&self, reason: String) {
let mut state = self.state.write().await;
*state = EngineState::Paused(reason);
}
pub async fn resume(&self) {
let mut state = self.state.write().await;
if matches!(*state, EngineState::Paused(_)) {
*state = EngineState::Running;
}
}
pub async fn run(&self) -> anyhow::Result<()> {
let mut shutdown_rx = self.shutdown_tx.subscribe();
let poll_interval = std::time::Duration::from_secs(
self.config.kalshi.poll_interval_secs,
);
{
let mut state = self.state.write().await;
*state = EngineState::Recovering;
}
info!("recovering state from SQLite");
if let Ok(Some(portfolio)) =
self.store.load_portfolio().await
{
let mut ctx = self.context.write().await;
ctx.portfolio = portfolio;
info!(
positions = ctx.portfolio.positions.len(),
cash = %ctx.portfolio.cash,
"state recovered"
);
}
{
let mut state = self.state.write().await;
*state = EngineState::Running;
}
info!(
interval_secs = self.config.kalshi.poll_interval_secs,
"engine running"
);
// run first tick immediately
self.tick().await;
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
info!("shutdown signal received");
let mut state = self.state.write().await;
*state = EngineState::ShuttingDown;
break;
}
_ = tokio::time::sleep(poll_interval) => {
let current_state =
self.state.read().await.clone();
match current_state {
EngineState::Running => {
self.tick().await;
}
EngineState::Paused(ref reason) => {
info!(
reason = %reason,
"engine paused, skipping tick"
);
}
_ => {}
}
}
}
}
info!("persisting final state");
let ctx = self.context.read().await;
if let Err(e) =
self.store.save_portfolio(&ctx.portfolio).await
{
error!(error = %e, "failed to persist final state");
}
info!("engine shutdown complete");
Ok(())
}
async fn tick(&self) {
let tick_start = Instant::now();
let now = Utc::now();
let context_snapshot = {
let mut ctx = self.context.write().await;
ctx.timestamp = now;
ctx.request_id =
uuid::Uuid::new_v4().to_string();
ctx.clone()
};
let result = {
let pipeline = self.pipeline.lock().await;
pipeline.execute(context_snapshot.clone()).await
};
let candidates_fetched =
result.retrieved_candidates.len();
let candidates_filtered =
result.filtered_candidates.len();
let candidates_selected =
result.selected_candidates.len();
let candidate_scores: HashMap<String, f64> = result
.selected_candidates
.iter()
.map(|c| (c.ticker.clone(), c.final_score))
.collect();
let current_prices: HashMap<String, Decimal> = result
.selected_candidates
.iter()
.map(|c| (c.ticker.clone(), c.current_yes_price))
.collect();
self.executor.update_prices(current_prices).await;
let exit_signals = self.executor.generate_exit_signals(
&context_snapshot,
&candidate_scores,
);
let mut ctx = self.context.write().await;
let mut fills_executed = 0u32;
for exit in &exit_signals {
if let Some(position) = ctx
.portfolio
.positions
.get(&exit.ticker)
.cloned()
{
let pnl = ctx.portfolio.close_position(
&exit.ticker,
exit.current_price,
);
info!(
ticker = %exit.ticker,
reason = ?exit.reason,
pnl = ?pnl,
"paper exit"
);
let exit_fill = crate::types::Fill {
ticker: exit.ticker.clone(),
side: position.side,
quantity: position.quantity,
price: exit.current_price,
timestamp: now,
};
let reason_str =
format!("{:?}", exit.reason);
let _ = self
.store
.record_fill(
&exit_fill,
pnl,
Some(&reason_str),
)
.await;
ctx.trading_history.push(Trade {
ticker: exit.ticker.clone(),
side: position.side,
quantity: position.quantity,
price: exit.current_price,
timestamp: now,
trade_type: TradeType::Close,
});
fills_executed += 1;
}
}
let signals = self.executor.generate_signals(
&result.selected_candidates,
&*ctx,
);
let signals_generated = signals.len();
let peak_equity = self
.store
.get_peak_equity()
.await
.ok()
.flatten()
.unwrap_or(ctx.portfolio.initial_capital);
let positions_value: Decimal = ctx
.portfolio
.positions
.values()
.map(|p| {
p.avg_entry_price * Decimal::from(p.quantity)
})
.sum();
let current_equity = ctx.portfolio.cash + positions_value;
let cb_status = {
let cb = self.circuit_breaker.lock().await;
cb.check(
current_equity,
peak_equity,
ctx.portfolio.positions.len(),
)
.await
};
if let CbStatus::Tripped(reason) = cb_status {
warn!(
reason = %reason,
"circuit breaker tripped, pausing"
);
drop(ctx);
let mut state = self.state.write().await;
*state = EngineState::Paused(reason);
return;
}
{
let mut cb = self.circuit_breaker.lock().await;
cb.record_success();
}
for signal in signals {
if ctx.portfolio.positions.len()
>= self.config.trading.max_positions
{
break;
}
let context_for_exec = (*ctx).clone();
if let Some(fill) = self
.executor
.execute_signal(&signal, &context_for_exec)
.await
{
info!(
ticker = %fill.ticker,
side = ?fill.side,
qty = fill.quantity,
price = %fill.price,
"paper fill"
);
ctx.portfolio.apply_fill(&fill);
ctx.trading_history.push(Trade {
ticker: fill.ticker.clone(),
side: fill.side,
quantity: fill.quantity,
price: fill.price,
timestamp: fill.timestamp,
trade_type: TradeType::Open,
});
fills_executed += 1;
}
}
if let Err(e) =
self.store.save_portfolio(&ctx.portfolio).await
{
error!(error = %e, "failed to persist portfolio");
let mut cb = self.circuit_breaker.lock().await;
cb.record_error();
}
let positions_value: Decimal = ctx
.portfolio
.positions
.values()
.map(|p| {
p.avg_entry_price * Decimal::from(p.quantity)
})
.sum();
let equity = ctx.portfolio.cash + positions_value;
let drawdown = if peak_equity > Decimal::ZERO {
((peak_equity - equity)
.to_f64()
.unwrap_or(0.0))
/ peak_equity.to_f64().unwrap_or(1.0)
} else {
0.0
};
let _ = self
.store
.snapshot_equity(
now,
equity,
ctx.portfolio.cash,
positions_value,
drawdown.max(0.0),
)
.await;
let duration_ms =
tick_start.elapsed().as_millis() as u64;
let _ = self
.store
.record_pipeline_run(
now,
duration_ms,
candidates_fetched,
candidates_filtered,
candidates_selected,
signals_generated,
fills_executed as usize,
None,
)
.await;
{
let mut ticks = self.ticks.write().await;
*ticks += 1;
}
{
let mut last = self.last_tick.write().await;
*last = Some(now);
}
info!(
fetched = candidates_fetched,
filtered = candidates_filtered,
selected = candidates_selected,
signals = signals_generated,
fills = fills_executed,
equity = %equity,
duration_ms = duration_ms,
"tick complete"
);
}
}

25
src/engine/state.rs Normal file
View File

@ -0,0 +1,25 @@
use serde::Serialize;
#[derive(Debug, Clone, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum EngineState {
Starting,
Recovering,
Running,
Paused(String),
ShuttingDown,
}
impl std::fmt::Display for EngineState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Starting => write!(f, "starting"),
Self::Recovering => write!(f, "recovering"),
Self::Running => write!(f, "running"),
Self::Paused(reason) => {
write!(f, "paused: {}", reason)
}
Self::ShuttingDown => write!(f, "shutting_down"),
}
}
}

View File

@ -1,7 +1,12 @@
use crate::data::HistoricalData; use crate::data::HistoricalData;
use crate::types::{ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side, Signal, TradingContext}; use crate::types::{
ExitConfig, ExitReason, ExitSignal, Fill, MarketCandidate, Side,
Signal, TradingContext,
};
use async_trait::async_trait;
use rust_decimal::Decimal; use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive; use rust_decimal::prelude::ToPrimitive;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -14,9 +19,6 @@ pub struct PositionSizingConfig {
impl Default for PositionSizingConfig { impl Default for PositionSizingConfig {
fn default() -> Self { fn default() -> Self {
// iteration 4: increased kelly from 0.25 to 0.40
// research shows half-kelly to full-kelly range works well
// with 100% win rate on closed trades, we can be more aggressive
Self { Self {
kelly_fraction: 0.40, kelly_fraction: 0.40,
max_position_pct: 0.30, max_position_pct: 0.30,
@ -46,13 +48,33 @@ impl PositionSizingConfig {
} }
} }
#[async_trait]
pub trait OrderExecutor: Send + Sync {
async fn execute_signal(
&self,
signal: &Signal,
context: &TradingContext,
) -> Option<Fill>;
fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal>;
fn generate_exit_signals(
&self,
context: &TradingContext,
candidate_scores: &HashMap<String, f64>,
) -> Vec<ExitSignal>;
}
/// maps scoring edge [-inf, +inf] to win probability [0, 1] /// maps scoring edge [-inf, +inf] to win probability [0, 1]
/// tanh squashes extreme values smoothly; +1)/2 shifts from [-1,1] to [0,1] pub fn edge_to_win_probability(edge: f64) -> f64 {
fn edge_to_win_probability(edge: f64) -> f64 {
(1.0 + edge.tanh()) / 2.0 (1.0 + edge.tanh()) / 2.0
} }
fn kelly_size( pub fn kelly_size(
edge: f64, edge: f64,
price: f64, price: f64,
bankroll: f64, bankroll: f64,
@ -71,13 +93,165 @@ fn kelly_size(
let kelly = (odds * win_prob - (1.0 - win_prob)) / odds; let kelly = (odds * win_prob - (1.0 - win_prob)) / odds;
let safe_kelly = (kelly * config.kelly_fraction).max(0.0); let safe_kelly = (kelly * config.kelly_fraction).max(0.0);
let position_value = bankroll * safe_kelly.min(config.max_position_pct); let position_value =
bankroll * safe_kelly.min(config.max_position_pct);
let shares = (position_value / price).floor() as u64; let shares = (position_value / price).floor() as u64;
shares.max(config.min_position_size).min(config.max_position_size) shares
.max(config.min_position_size)
.min(config.max_position_size)
} }
pub struct Executor { pub fn candidate_to_signal(
candidate: &MarketCandidate,
context: &TradingContext,
sizing_config: &PositionSizingConfig,
max_position_size: u64,
) -> Option<Signal> {
let current_position =
context.portfolio.get_position(&candidate.ticker);
let current_qty =
current_position.map(|p| p.quantity).unwrap_or(0);
if current_qty >= max_position_size {
return None;
}
let yes_price =
candidate.current_yes_price.to_f64().unwrap_or(0.5);
let side = if candidate.final_score > 0.0 {
if yes_price < 0.5 {
Side::Yes
} else {
Side::No
}
} else if candidate.final_score < 0.0 {
if yes_price > 0.5 {
Side::No
} else {
Side::Yes
}
} else {
return None;
};
let price = match side {
Side::Yes => candidate.current_yes_price,
Side::No => candidate.current_no_price,
};
let available_cash =
context.portfolio.cash.to_f64().unwrap_or(0.0);
let price_f64 = price.to_f64().unwrap_or(0.5);
if price_f64 <= 0.0 {
return None;
}
let kelly_qty = kelly_size(
candidate.final_score,
price_f64,
available_cash,
sizing_config,
);
let max_affordable = (available_cash / price_f64) as u64;
let quantity = kelly_qty
.min(max_affordable)
.min(max_position_size - current_qty);
if quantity < sizing_config.min_position_size {
return None;
}
Some(Signal {
ticker: candidate.ticker.clone(),
side,
quantity,
limit_price: Some(price),
reason: format!(
"score={:.3}, side={:?}, price={:.2}",
candidate.final_score, side, price_f64
),
})
}
pub fn compute_exit_signals(
context: &TradingContext,
candidate_scores: &HashMap<String, f64>,
exit_config: &ExitConfig,
price_lookup: &dyn Fn(&str) -> Option<Decimal>,
) -> Vec<ExitSignal> {
let mut exits = Vec::new();
for (ticker, position) in &context.portfolio.positions {
let current_price = match price_lookup(ticker) {
Some(p) => p,
None => continue,
};
let effective_price = match position.side {
Side::Yes => current_price,
Side::No => Decimal::ONE - current_price,
};
let entry_price_f64 =
position.avg_entry_price.to_f64().unwrap_or(0.5);
let current_price_f64 =
effective_price.to_f64().unwrap_or(0.5);
if entry_price_f64 <= 0.0 {
continue;
}
let pnl_pct = (current_price_f64 - entry_price_f64)
/ entry_price_f64;
if pnl_pct >= exit_config.take_profit_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TakeProfit { pnl_pct },
current_price,
});
continue;
}
if pnl_pct <= -exit_config.stop_loss_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::StopLoss { pnl_pct },
current_price,
});
continue;
}
let hours_held =
(context.timestamp - position.entry_time).num_hours();
if hours_held >= exit_config.max_hold_hours {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TimeStop { hours_held },
current_price,
});
continue;
}
if let Some(&new_score) = candidate_scores.get(ticker) {
if new_score < exit_config.score_reversal_threshold {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::ScoreReversal { new_score },
current_price,
});
}
}
}
exits
}
pub struct BacktestExecutor {
data: Arc<HistoricalData>, data: Arc<HistoricalData>,
slippage_bps: u32, slippage_bps: u32,
max_position_size: u64, max_position_size: u64,
@ -85,8 +259,12 @@ pub struct Executor {
exit_config: ExitConfig, exit_config: ExitConfig,
} }
impl Executor { impl BacktestExecutor {
pub fn new(data: Arc<HistoricalData>, slippage_bps: u32, max_position_size: u64) -> Self { pub fn new(
data: Arc<HistoricalData>,
slippage_bps: u32,
max_position_size: u64,
) -> Self {
Self { Self {
data, data,
slippage_bps, slippage_bps,
@ -96,7 +274,10 @@ impl Executor {
} }
} }
pub fn with_sizing_config(mut self, config: PositionSizingConfig) -> Self { pub fn with_sizing_config(
mut self,
config: PositionSizingConfig,
) -> Self {
self.sizing_config = config; self.sizing_config = config;
self self
} }
@ -105,168 +286,33 @@ impl Executor {
self.exit_config = config; self.exit_config = config;
self self
} }
}
pub fn generate_exit_signals( #[async_trait]
&self, impl OrderExecutor for BacktestExecutor {
context: &TradingContext, async fn execute_signal(
candidate_scores: &std::collections::HashMap<String, f64>,
) -> Vec<ExitSignal> {
let mut exits = Vec::new();
for (ticker, position) in &context.portfolio.positions {
let current_price = match self.data.get_current_price(ticker, context.timestamp) {
Some(p) => p,
None => continue,
};
let effective_price = match position.side {
Side::Yes => current_price,
Side::No => Decimal::ONE - current_price,
};
let entry_price_f64 = position.avg_entry_price.to_f64().unwrap_or(0.5);
let current_price_f64 = effective_price.to_f64().unwrap_or(0.5);
if entry_price_f64 <= 0.0 {
continue;
}
let pnl_pct = (current_price_f64 - entry_price_f64) / entry_price_f64;
if pnl_pct >= self.exit_config.take_profit_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TakeProfit { pnl_pct },
current_price,
});
continue;
}
if pnl_pct <= -self.exit_config.stop_loss_pct {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::StopLoss { pnl_pct },
current_price,
});
continue;
}
let hours_held = (context.timestamp - position.entry_time).num_hours();
if hours_held >= self.exit_config.max_hold_hours {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::TimeStop { hours_held },
current_price,
});
continue;
}
if let Some(&new_score) = candidate_scores.get(ticker) {
if new_score < self.exit_config.score_reversal_threshold {
exits.push(ExitSignal {
ticker: ticker.clone(),
reason: ExitReason::ScoreReversal { new_score },
current_price,
});
}
}
}
exits
}
pub fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal> {
candidates
.iter()
.filter_map(|c| self.candidate_to_signal(c, context))
.collect()
}
fn candidate_to_signal(
&self,
candidate: &MarketCandidate,
context: &TradingContext,
) -> Option<Signal> {
let current_position = context.portfolio.get_position(&candidate.ticker);
let current_qty = current_position.map(|p| p.quantity).unwrap_or(0);
if current_qty >= self.max_position_size {
return None;
}
let yes_price = candidate.current_yes_price.to_f64().unwrap_or(0.5);
// positive score = bullish signal, so buy the cheaper side (better risk/reward)
// negative score = bearish signal, so buy against the expensive side
let side = if candidate.final_score > 0.0 {
if yes_price < 0.5 { Side::Yes } else { Side::No }
} else if candidate.final_score < 0.0 {
if yes_price > 0.5 { Side::No } else { Side::Yes }
} else {
return None;
};
let price = match side {
Side::Yes => candidate.current_yes_price,
Side::No => candidate.current_no_price,
};
let available_cash = context.portfolio.cash.to_f64().unwrap_or(0.0);
let price_f64 = price.to_f64().unwrap_or(0.5);
if price_f64 <= 0.0 {
return None;
}
let kelly_qty = kelly_size(
candidate.final_score,
price_f64,
available_cash,
&self.sizing_config,
);
let max_affordable = (available_cash / price_f64) as u64;
let quantity = kelly_qty
.min(max_affordable)
.min(self.max_position_size - current_qty);
if quantity < self.sizing_config.min_position_size {
return None;
}
Some(Signal {
ticker: candidate.ticker.clone(),
side,
quantity,
limit_price: Some(price),
reason: format!(
"score={:.3}, side={:?}, price={:.2}",
candidate.final_score, side, price_f64
),
})
}
pub fn execute_signal(
&self, &self,
signal: &Signal, signal: &Signal,
context: &TradingContext, context: &TradingContext,
) -> Option<Fill> { ) -> Option<Fill> {
let market_price = self.data.get_current_price(&signal.ticker, context.timestamp)?; let market_price = self
.data
.get_current_price(&signal.ticker, context.timestamp)?;
let effective_price = match signal.side { let effective_price = match signal.side {
Side::Yes => market_price, Side::Yes => market_price,
Side::No => Decimal::ONE - market_price, Side::No => Decimal::ONE - market_price,
}; };
let slippage = Decimal::from(self.slippage_bps) / Decimal::from(10000); let slippage =
let fill_price = effective_price * (Decimal::ONE + slippage); Decimal::from(self.slippage_bps) / Decimal::from(10000);
let fill_price =
effective_price * (Decimal::ONE + slippage);
if let Some(limit) = signal.limit_price { if let Some(limit) = signal.limit_price {
if fill_price > limit * (Decimal::ONE + slippage * Decimal::from(2)) { if fill_price
> limit * (Decimal::ONE + slippage * Decimal::from(2))
{
return None; return None;
} }
} }
@ -297,6 +343,39 @@ impl Executor {
timestamp: context.timestamp, timestamp: context.timestamp,
}) })
} }
fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal> {
candidates
.iter()
.filter_map(|c| {
candidate_to_signal(
c,
context,
&self.sizing_config,
self.max_position_size,
)
})
.collect()
}
fn generate_exit_signals(
&self,
context: &TradingContext,
candidate_scores: &HashMap<String, f64>,
) -> Vec<ExitSignal> {
let data = self.data.clone();
let timestamp = context.timestamp;
compute_exit_signals(
context,
candidate_scores,
&self.exit_config,
&|ticker| data.get_current_price(ticker, timestamp),
)
}
} }
pub fn simple_signal_generator( pub fn simple_signal_generator(
@ -309,7 +388,8 @@ pub fn simple_signal_generator(
.filter(|c| c.final_score > 0.0) .filter(|c| c.final_score > 0.0)
.filter(|c| !context.portfolio.has_position(&c.ticker)) .filter(|c| !context.portfolio.has_position(&c.ticker))
.map(|c| { .map(|c| {
let yes_price = c.current_yes_price.to_f64().unwrap_or(0.5); let yes_price =
c.current_yes_price.to_f64().unwrap_or(0.5);
let (side, price) = if yes_price < 0.5 { let (side, price) = if yes_price < 0.5 {
(Side::Yes, c.current_yes_price) (Side::Yes, c.current_yes_price)
} else { } else {
@ -321,7 +401,10 @@ pub fn simple_signal_generator(
side, side,
quantity: position_size, quantity: position_size,
limit_price: Some(price), limit_price: Some(price),
reason: format!("simple: score={:.3}", c.final_score), reason: format!(
"simple: score={:.3}",
c.final_score
),
} }
}) })
.collect() .collect()

View File

@ -1,16 +1,22 @@
mod api;
mod backtest; mod backtest;
mod config;
mod data; mod data;
mod engine;
mod execution; mod execution;
mod metrics; mod metrics;
mod paper_executor;
mod pipeline; mod pipeline;
mod store;
mod types; mod types;
mod web;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use backtest::{Backtester, RandomBaseline}; use backtest::{Backtester, RandomBaseline};
use chrono::{DateTime, NaiveDate, TimeZone, Utc}; use chrono::{DateTime, NaiveDate, TimeZone, Utc};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use data::HistoricalData; use data::HistoricalData;
use execution::{Executor, PositionSizingConfig}; use execution::PositionSizingConfig;
use rust_decimal::Decimal; use rust_decimal::Decimal;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
@ -57,7 +63,7 @@ enum Commands {
#[arg(long)] #[arg(long)]
compare_random: bool, compare_random: bool,
/// kelly fraction for position sizing (0.40 = 40% of kelly optimal) /// kelly fraction for position sizing
#[arg(long, default_value = "0.40")] #[arg(long, default_value = "0.40")]
kelly_fraction: f64, kelly_fraction: f64,
@ -65,11 +71,11 @@ enum Commands {
#[arg(long, default_value = "0.30")] #[arg(long, default_value = "0.30")]
max_position_pct: f64, max_position_pct: f64,
/// take profit threshold (0.50 = +50%) /// take profit threshold
#[arg(long, default_value = "0.50")] #[arg(long, default_value = "0.50")]
take_profit: f64, take_profit: f64,
/// stop loss threshold (0.99 = disabled for prediction markets) /// stop loss threshold
#[arg(long, default_value = "0.99")] #[arg(long, default_value = "0.99")]
stop_loss: f64, stop_loss: f64,
@ -78,19 +84,27 @@ enum Commands {
max_hold_hours: i64, max_hold_hours: i64,
}, },
Paper {
/// path to config.toml
#[arg(short, long, default_value = "config.toml")]
config: PathBuf,
},
Summary { Summary {
#[arg(short, long)] #[arg(short, long)]
results_file: PathBuf, results_file: PathBuf,
}, },
} }
fn parse_date(s: &str) -> Result<DateTime<Utc>> { pub fn parse_date(s: &str) -> Result<DateTime<Utc>> {
if let Ok(dt) = DateTime::parse_from_rfc3339(s) { if let Ok(dt) = DateTime::parse_from_rfc3339(s) {
return Ok(dt.with_timezone(&Utc)); return Ok(dt.with_timezone(&Utc));
} }
if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") { if let Ok(date) = NaiveDate::parse_from_str(s, "%Y-%m-%d") {
return Ok(Utc.from_utc_datetime(&date.and_hms_opt(0, 0, 0).unwrap())); return Ok(Utc.from_utc_datetime(
&date.and_hms_opt(0, 0, 0).unwrap(),
));
} }
Err(anyhow::anyhow!("could not parse date: {}", s)) Err(anyhow::anyhow!("could not parse date: {}", s))
@ -101,7 +115,9 @@ async fn main() -> Result<()> {
tracing_subscriber::registry() tracing_subscriber::registry()
.with( .with(
tracing_subscriber::EnvFilter::try_from_default_env() tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "kalshi_backtest=info".into()), .unwrap_or_else(|_| {
"kalshi_backtest=info".into()
}),
) )
.with(tracing_subscriber::fmt::layer()) .with(tracing_subscriber::fmt::layer())
.init(); .init();
@ -125,8 +141,10 @@ async fn main() -> Result<()> {
stop_loss, stop_loss,
max_hold_hours, max_hold_hours,
} => { } => {
let start_time = parse_date(&start).context("parsing start date")?; let start_time =
let end_time = parse_date(&end).context("parsing end date")?; parse_date(&start).context("parsing start date")?;
let end_time =
parse_date(&end).context("parsing end date")?;
info!( info!(
data_dir = %data_dir.display(), data_dir = %data_dir.display(),
@ -137,7 +155,8 @@ async fn main() -> Result<()> {
); );
let data = Arc::new( let data = Arc::new(
HistoricalData::load(&data_dir).context("loading historical data")?, HistoricalData::load(&data_dir)
.context("loading historical data")?,
); );
info!( info!(
@ -150,7 +169,8 @@ async fn main() -> Result<()> {
start_time, start_time,
end_time, end_time,
interval: chrono::Duration::hours(interval_hours), interval: chrono::Duration::hours(interval_hours),
initial_capital: Decimal::try_from(capital).unwrap(), initial_capital: Decimal::try_from(capital)
.unwrap(),
max_position_size: max_position, max_position_size: max_position,
max_positions, max_positions,
}; };
@ -169,16 +189,25 @@ async fn main() -> Result<()> {
score_reversal_threshold: -0.3, score_reversal_threshold: -0.3,
}; };
let backtester = Backtester::with_configs(config.clone(), data.clone(), sizing_config, exit_config); let backtester = Backtester::with_configs(
config.clone(),
data.clone(),
sizing_config,
exit_config,
);
let result = backtester.run().await; let result = backtester.run().await;
println!("{}", result.summary()); println!("{}", result.summary());
std::fs::create_dir_all(&output_dir)?; std::fs::create_dir_all(&output_dir)?;
let result_path = output_dir.join("backtest_result.json"); let result_path =
output_dir.join("backtest_result.json");
let json = serde_json::to_string_pretty(&result)?; let json = serde_json::to_string_pretty(&result)?;
std::fs::write(&result_path, json)?; std::fs::write(&result_path, json)?;
info!(path = %result_path.display(), "results saved"); info!(
path = %result_path.display(),
"results saved"
);
if compare_random { if compare_random {
println!("\n--- random baseline ---\n"); println!("\n--- random baseline ---\n");
@ -186,18 +215,23 @@ async fn main() -> Result<()> {
let baseline_result = baseline.run().await; let baseline_result = baseline.run().await;
println!("{}", baseline_result.summary()); println!("{}", baseline_result.summary());
let baseline_path = output_dir.join("baseline_result.json"); let baseline_path =
let json = serde_json::to_string_pretty(&baseline_result)?; output_dir.join("baseline_result.json");
let json = serde_json::to_string_pretty(
&baseline_result,
)?;
std::fs::write(&baseline_path, json)?; std::fs::write(&baseline_path, json)?;
println!("\n--- comparison ---\n"); println!("\n--- comparison ---\n");
println!( println!(
"strategy return: {:.2}% vs baseline: {:.2}%", "strategy return: {:.2}% vs baseline: {:.2}%",
result.total_return_pct, baseline_result.total_return_pct result.total_return_pct,
baseline_result.total_return_pct
); );
println!( println!(
"strategy sharpe: {:.3} vs baseline: {:.3}", "strategy sharpe: {:.3} vs baseline: {:.3}",
result.sharpe_ratio, baseline_result.sharpe_ratio result.sharpe_ratio,
baseline_result.sharpe_ratio
); );
println!( println!(
"strategy win rate: {:.1}% vs baseline: {:.1}%", "strategy win rate: {:.1}% vs baseline: {:.1}%",
@ -208,11 +242,16 @@ async fn main() -> Result<()> {
Ok(()) Ok(())
} }
Commands::Paper { config: config_path } => {
run_paper(config_path).await
}
Commands::Summary { results_file } => { Commands::Summary { results_file } => {
let content = std::fs::read_to_string(&results_file) let content = std::fs::read_to_string(&results_file)
.context("reading results file")?; .context("reading results file")?;
let result: metrics::BacktestResult = let result: metrics::BacktestResult =
serde_json::from_str(&content).context("parsing results")?; serde_json::from_str(&content)
.context("parsing results")?;
println!("{}", result.summary()); println!("{}", result.summary());
@ -220,3 +259,124 @@ async fn main() -> Result<()> {
} }
} }
} }
async fn run_paper(config_path: PathBuf) -> Result<()> {
let app_config = config::AppConfig::load(&config_path)
.context("loading config")?;
info!(
mode = ?app_config.mode,
poll_secs = app_config.kalshi.poll_interval_secs,
capital = app_config.trading.initial_capital,
"starting paper trading"
);
let store = Arc::new(
store::SqliteStore::new(&app_config.persistence.db_path)
.await
.context("initializing SQLite store")?,
);
let client = Arc::new(api::KalshiClient::new(
&app_config.kalshi,
));
let sizing_config = PositionSizingConfig {
kelly_fraction: app_config.trading.kelly_fraction,
max_position_pct: app_config.trading.max_position_pct,
min_position_size: 10,
max_position_size: 1000,
};
let exit_config = ExitConfig {
take_profit_pct: app_config
.trading
.take_profit_pct
.unwrap_or(0.50),
stop_loss_pct: app_config
.trading
.stop_loss_pct
.unwrap_or(0.99),
max_hold_hours: app_config
.trading
.max_hold_hours
.unwrap_or(48),
score_reversal_threshold: -0.3,
};
let executor = Arc::new(paper_executor::PaperExecutor::new(
1000,
sizing_config,
exit_config,
store.clone(),
));
let engine = engine::PaperTradingEngine::new(
app_config.clone(),
store.clone(),
executor,
client,
)
.await
.context("initializing engine")?;
let shutdown_tx = engine.shutdown_handle();
let engine = Arc::new(engine);
if app_config.web.enabled {
let web_state = Arc::new(web::AppState {
engine: engine.clone(),
store: store.clone(),
shutdown_tx: shutdown_tx.clone(),
backtest: Arc::new(tokio::sync::Mutex::new(
web::BacktestState {
status: web::BacktestRunStatus::Idle,
progress: None,
result: None,
error: None,
},
)),
data_dir: PathBuf::from("data"),
});
let router = web::build_router(web_state);
let bind_addr = app_config.web.bind_addr.clone();
info!(addr = %bind_addr, "starting web dashboard");
match tokio::net::TcpListener::bind(&bind_addr).await {
Ok(listener) => {
tokio::spawn(async move {
if let Err(e) =
axum::serve(listener, router).await
{
tracing::error!(
error = %e,
"web server error"
);
}
});
}
Err(e) => {
tracing::warn!(
addr = %bind_addr,
error = %e,
"web dashboard disabled (port in use)"
);
}
}
}
let shutdown_tx_clone = shutdown_tx.clone();
tokio::spawn(async move {
tokio::signal::ctrl_c().await.ok();
info!("ctrl+c received, shutting down");
let _ = shutdown_tx_clone.send(());
});
engine.run().await?;
info!("paper trading session ended");
Ok(())
}

143
src/paper_executor.rs Normal file
View File

@ -0,0 +1,143 @@
use crate::execution::{
candidate_to_signal, compute_exit_signals, OrderExecutor,
PositionSizingConfig,
};
use crate::store::SqliteStore;
use crate::types::{
ExitConfig, ExitSignal, Fill, MarketCandidate, Signal, Side,
TradingContext,
};
use async_trait::async_trait;
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct PaperExecutor {
max_position_size: u64,
sizing_config: PositionSizingConfig,
exit_config: ExitConfig,
store: Arc<SqliteStore>,
current_prices: Arc<RwLock<HashMap<String, Decimal>>>,
}
impl PaperExecutor {
pub fn new(
max_position_size: u64,
sizing_config: PositionSizingConfig,
exit_config: ExitConfig,
store: Arc<SqliteStore>,
) -> Self {
Self {
max_position_size,
sizing_config,
exit_config,
store,
current_prices: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn update_prices(
&self,
prices: HashMap<String, Decimal>,
) {
let mut current = self.current_prices.write().await;
*current = prices;
}
}
#[async_trait]
impl OrderExecutor for PaperExecutor {
async fn execute_signal(
&self,
signal: &Signal,
context: &TradingContext,
) -> Option<Fill> {
let prices = self.current_prices.read().await;
let market_price = prices.get(&signal.ticker).copied()?;
let effective_price = match signal.side {
Side::Yes => market_price,
Side::No => Decimal::ONE - market_price,
};
if let Some(limit) = signal.limit_price {
let tolerance = Decimal::new(5, 2);
if effective_price > limit * (Decimal::ONE + tolerance) {
return None;
}
}
let cost = effective_price * Decimal::from(signal.quantity);
let quantity = if cost > context.portfolio.cash {
let affordable = (context.portfolio.cash
/ effective_price)
.to_u64()
.unwrap_or(0);
if affordable == 0 {
return None;
}
affordable
} else {
signal.quantity
};
let fill = Fill {
ticker: signal.ticker.clone(),
side: signal.side,
quantity,
price: effective_price,
timestamp: context.timestamp,
};
if let Err(e) =
self.store.record_fill(&fill, None, None).await
{
tracing::error!(
error = %e,
"failed to persist fill"
);
}
Some(fill)
}
fn generate_signals(
&self,
candidates: &[MarketCandidate],
context: &TradingContext,
) -> Vec<Signal> {
candidates
.iter()
.filter_map(|c| {
candidate_to_signal(
c,
context,
&self.sizing_config,
self.max_position_size,
)
})
.collect()
}
fn generate_exit_signals(
&self,
context: &TradingContext,
candidate_scores: &HashMap<String, f64>,
) -> Vec<ExitSignal> {
let prices = self.current_prices.try_read();
match prices {
Ok(prices) => {
let prices_ref = prices.clone();
compute_exit_signals(
context,
candidate_scores,
&self.exit_config,
&|ticker| prices_ref.get(ticker).copied(),
)
}
Err(_) => Vec::new(),
}
}
}

View File

@ -1,7 +1,9 @@
use crate::api::KalshiClient;
use crate::data::HistoricalData; use crate::data::HistoricalData;
use crate::pipeline::Source; use crate::pipeline::Source;
use crate::types::{MarketCandidate, TradingContext}; use crate::types::{MarketCandidate, PricePoint, TradingContext};
use async_trait::async_trait; use async_trait::async_trait;
use chrono::Utc;
use rust_decimal::Decimal; use rust_decimal::Decimal;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
@ -67,3 +69,78 @@ impl Source for HistoricalMarketSource {
Ok(candidates) Ok(candidates)
} }
} }
pub struct LiveKalshiSource {
client: Arc<KalshiClient>,
}
impl LiveKalshiSource {
pub fn new(client: Arc<KalshiClient>) -> Self {
Self { client }
}
}
#[async_trait]
impl Source for LiveKalshiSource {
fn name(&self) -> &'static str {
"LiveKalshiSource"
}
async fn get_candidates(
&self,
_context: &TradingContext,
) -> Result<Vec<MarketCandidate>, String> {
let markets = self
.client
.get_open_markets()
.await
.map_err(|e| format!("API error: {}", e))?;
let now = Utc::now();
let mut candidates = Vec::with_capacity(markets.len());
for market in markets {
let yes_price = market.mid_yes_price();
if yes_price <= 0.0 || yes_price >= 1.0 {
continue;
}
let yes_dec = Decimal::try_from(yes_price)
.unwrap_or(Decimal::new(50, 2));
let no_dec = Decimal::ONE - yes_dec;
let volume_24h = market.volume_24h.max(0) as u64;
let total_volume = market.volume.max(0) as u64;
// skip per-market trade fetching to avoid N+1
// the market list already provides volume data
// for filtering; scorers will work with what's available
let price_history = Vec::new();
let buy_vol = 0u64;
let sell_vol = 0u64;
let category = market.category_from_event();
candidates.push(MarketCandidate {
ticker: market.ticker,
title: market.title,
category,
current_yes_price: yes_dec,
current_no_price: no_dec,
volume_24h,
total_volume,
buy_volume_24h: buy_vol,
sell_volume_24h: sell_vol,
open_time: market.open_time,
close_time: market.close_time,
result: None,
price_history,
scores: HashMap::new(),
final_score: 0.0,
});
}
Ok(candidates)
}
}

4
src/store/mod.rs Normal file
View File

@ -0,0 +1,4 @@
mod queries;
mod schema;
pub use queries::*;

372
src/store/queries.rs Normal file
View File

@ -0,0 +1,372 @@
use crate::types::{Fill, Portfolio, Position, Side};
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use sqlx::SqlitePool;
use std::collections::HashMap;
use std::str::FromStr;
use super::schema::MIGRATIONS;
pub struct SqliteStore {
pool: SqlitePool,
}
impl SqliteStore {
pub async fn new(db_path: &str) -> anyhow::Result<Self> {
let url = format!("sqlite:{}?mode=rwc", db_path);
let pool = SqlitePool::connect(&url).await?;
let store = Self { pool };
store.run_migrations().await?;
Ok(store)
}
async fn run_migrations(&self) -> anyhow::Result<()> {
sqlx::raw_sql(MIGRATIONS).execute(&self.pool).await?;
Ok(())
}
pub async fn load_portfolio(
&self,
) -> anyhow::Result<Option<Portfolio>> {
let row = sqlx::query_as::<_, (String, String)>(
"SELECT cash, initial_capital FROM portfolio_state \
WHERE id = 1",
)
.fetch_optional(&self.pool)
.await?;
let Some((cash_str, capital_str)) = row else {
return Ok(None);
};
let cash = Decimal::from_str(&cash_str)?;
let initial_capital = Decimal::from_str(&capital_str)?;
let position_rows = sqlx::query_as::<
_,
(String, String, i64, String, String),
>(
"SELECT ticker, side, quantity, avg_entry_price, \
entry_time FROM positions",
)
.fetch_all(&self.pool)
.await?;
let mut positions = HashMap::new();
for (ticker, side_str, qty, price_str, time_str) in
position_rows
{
let side = match side_str.as_str() {
"Yes" => Side::Yes,
_ => Side::No,
};
let price = Decimal::from_str(&price_str)?;
let entry_time: DateTime<Utc> =
time_str.parse::<DateTime<Utc>>()?;
positions.insert(
ticker.clone(),
Position {
ticker,
side,
quantity: qty as u64,
avg_entry_price: price,
entry_time,
},
);
}
Ok(Some(Portfolio {
positions,
cash,
initial_capital,
}))
}
pub async fn save_portfolio(
&self,
portfolio: &Portfolio,
) -> anyhow::Result<()> {
let cash = portfolio.cash.to_string();
let capital = portfolio.initial_capital.to_string();
sqlx::query(
"INSERT INTO portfolio_state (id, cash, initial_capital, \
updated_at) VALUES (1, ?1, ?2, datetime('now')) \
ON CONFLICT(id) DO UPDATE SET \
cash = ?1, initial_capital = ?2, \
updated_at = datetime('now')",
)
.bind(&cash)
.bind(&capital)
.execute(&self.pool)
.await?;
sqlx::query("DELETE FROM positions")
.execute(&self.pool)
.await?;
for pos in portfolio.positions.values() {
let side = match pos.side {
Side::Yes => "Yes",
Side::No => "No",
};
sqlx::query(
"INSERT INTO positions \
(ticker, side, quantity, avg_entry_price, \
entry_time) VALUES (?1, ?2, ?3, ?4, ?5)",
)
.bind(&pos.ticker)
.bind(side)
.bind(pos.quantity as i64)
.bind(pos.avg_entry_price.to_string())
.bind(pos.entry_time.to_rfc3339())
.execute(&self.pool)
.await?;
}
Ok(())
}
pub async fn record_fill(
&self,
fill: &Fill,
pnl: Option<Decimal>,
exit_reason: Option<&str>,
) -> anyhow::Result<()> {
let side = match fill.side {
Side::Yes => "Yes",
Side::No => "No",
};
sqlx::query(
"INSERT INTO fills \
(ticker, side, quantity, price, timestamp, pnl, \
exit_reason) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
)
.bind(&fill.ticker)
.bind(side)
.bind(fill.quantity as i64)
.bind(fill.price.to_string())
.bind(fill.timestamp.to_rfc3339())
.bind(pnl.map(|p| p.to_string()))
.bind(exit_reason)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn snapshot_equity(
&self,
timestamp: DateTime<Utc>,
equity: Decimal,
cash: Decimal,
positions_value: Decimal,
drawdown_pct: f64,
) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO equity_snapshots \
(timestamp, equity, cash, positions_value, \
drawdown_pct) VALUES (?1, ?2, ?3, ?4, ?5)",
)
.bind(timestamp.to_rfc3339())
.bind(equity.to_string())
.bind(cash.to_string())
.bind(positions_value.to_string())
.bind(drawdown_pct)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn record_circuit_breaker_event(
&self,
rule: &str,
details: &str,
action: &str,
) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO circuit_breaker_events \
(timestamp, rule, details, action) \
VALUES (datetime('now'), ?1, ?2, ?3)",
)
.bind(rule)
.bind(details)
.bind(action)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn record_pipeline_run(
&self,
timestamp: DateTime<Utc>,
duration_ms: u64,
candidates_fetched: usize,
candidates_filtered: usize,
candidates_selected: usize,
signals_generated: usize,
fills_executed: usize,
errors: Option<&str>,
) -> anyhow::Result<()> {
sqlx::query(
"INSERT INTO pipeline_runs \
(timestamp, duration_ms, candidates_fetched, \
candidates_filtered, candidates_selected, \
signals_generated, fills_executed, errors) \
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(timestamp.to_rfc3339())
.bind(duration_ms as i64)
.bind(candidates_fetched as i64)
.bind(candidates_filtered as i64)
.bind(candidates_selected as i64)
.bind(signals_generated as i64)
.bind(fills_executed as i64)
.bind(errors)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn get_equity_curve(
&self,
) -> anyhow::Result<Vec<EquitySnapshot>> {
let rows = sqlx::query_as::<_, (String, String, String, String, f64)>(
"SELECT timestamp, equity, cash, positions_value, \
drawdown_pct FROM equity_snapshots \
ORDER BY timestamp ASC",
)
.fetch_all(&self.pool)
.await?;
let mut snapshots = Vec::with_capacity(rows.len());
for (ts, eq, cash, pv, dd) in rows {
snapshots.push(EquitySnapshot {
timestamp: ts.parse::<DateTime<Utc>>()?,
equity: Decimal::from_str(&eq)?,
cash: Decimal::from_str(&cash)?,
positions_value: Decimal::from_str(&pv)?,
drawdown_pct: dd,
});
}
Ok(snapshots)
}
pub async fn get_recent_fills(
&self,
limit: u32,
) -> anyhow::Result<Vec<FillRecord>> {
let rows = sqlx::query_as::<
_,
(String, String, i64, String, String, Option<String>, Option<String>),
>(
"SELECT ticker, side, quantity, price, timestamp, \
pnl, exit_reason FROM fills \
ORDER BY id DESC LIMIT ?1",
)
.bind(limit as i64)
.fetch_all(&self.pool)
.await?;
let mut fills = Vec::with_capacity(rows.len());
for (ticker, side, qty, price, ts, pnl, reason) in rows {
fills.push(FillRecord {
ticker,
side: match side.as_str() {
"Yes" => Side::Yes,
_ => Side::No,
},
quantity: qty as u64,
price: Decimal::from_str(&price)?,
timestamp: ts.parse::<DateTime<Utc>>()?,
pnl: pnl
.map(|p| Decimal::from_str(&p))
.transpose()?,
exit_reason: reason,
});
}
Ok(fills)
}
pub async fn get_fills_since(
&self,
since: DateTime<Utc>,
) -> anyhow::Result<u32> {
let row = sqlx::query_as::<_, (i64,)>(
"SELECT COUNT(*) FROM fills \
WHERE timestamp >= ?1",
)
.bind(since.to_rfc3339())
.fetch_one(&self.pool)
.await?;
Ok(row.0 as u32)
}
pub async fn get_circuit_breaker_events(
&self,
limit: u32,
) -> anyhow::Result<Vec<CbEvent>> {
let rows = sqlx::query_as::<_, (String, String, String, String)>(
"SELECT timestamp, rule, details, action \
FROM circuit_breaker_events \
ORDER BY id DESC LIMIT ?1",
)
.bind(limit as i64)
.fetch_all(&self.pool)
.await?;
let mut events = Vec::with_capacity(rows.len());
for (ts, rule, details, action) in rows {
events.push(CbEvent {
timestamp: ts.parse::<DateTime<Utc>>()?,
rule,
details,
action,
});
}
Ok(events)
}
pub async fn get_peak_equity(
&self,
) -> anyhow::Result<Option<Decimal>> {
let row = sqlx::query_as::<_, (Option<String>,)>(
"SELECT MAX(equity) FROM equity_snapshots",
)
.fetch_one(&self.pool)
.await?;
match row.0 {
Some(s) => Ok(Some(Decimal::from_str(&s)?)),
None => Ok(None),
}
}
}
#[derive(Debug, Clone)]
pub struct EquitySnapshot {
pub timestamp: DateTime<Utc>,
pub equity: Decimal,
pub cash: Decimal,
pub positions_value: Decimal,
pub drawdown_pct: f64,
}
#[derive(Debug, Clone)]
pub struct FillRecord {
pub ticker: String,
pub side: Side,
pub quantity: u64,
pub price: Decimal,
pub timestamp: DateTime<Utc>,
pub pnl: Option<Decimal>,
pub exit_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CbEvent {
pub timestamp: DateTime<Utc>,
pub rule: String,
pub details: String,
pub action: String,
}

56
src/store/schema.rs Normal file
View File

@ -0,0 +1,56 @@
pub const MIGRATIONS: &str = r#"
CREATE TABLE IF NOT EXISTS portfolio_state (
id INTEGER PRIMARY KEY CHECK (id = 1),
cash TEXT NOT NULL,
initial_capital TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS positions (
ticker TEXT PRIMARY KEY,
side TEXT NOT NULL,
quantity INTEGER NOT NULL,
avg_entry_price TEXT NOT NULL,
entry_time TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS fills (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ticker TEXT NOT NULL,
side TEXT NOT NULL,
quantity INTEGER NOT NULL,
price TEXT NOT NULL,
timestamp TEXT NOT NULL,
pnl TEXT,
exit_reason TEXT
);
CREATE TABLE IF NOT EXISTS equity_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
equity TEXT NOT NULL,
cash TEXT NOT NULL,
positions_value TEXT NOT NULL,
drawdown_pct REAL NOT NULL DEFAULT 0.0
);
CREATE TABLE IF NOT EXISTS circuit_breaker_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
rule TEXT NOT NULL,
details TEXT NOT NULL,
action TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS pipeline_runs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
duration_ms INTEGER NOT NULL,
candidates_fetched INTEGER NOT NULL DEFAULT 0,
candidates_filtered INTEGER NOT NULL DEFAULT 0,
candidates_selected INTEGER NOT NULL DEFAULT 0,
signals_generated INTEGER NOT NULL DEFAULT 0,
fills_executed INTEGER NOT NULL DEFAULT 0,
errors TEXT
);
"#;

514
src/web/handlers.rs Normal file
View File

@ -0,0 +1,514 @@
use super::{AppState, BacktestRunStatus};
use crate::backtest::Backtester;
use crate::data::HistoricalData;
use crate::execution::PositionSizingConfig;
use crate::types::{BacktestConfig, ExitConfig};
use axum::extract::State;
use axum::http::StatusCode;
use axum::Json;
use chrono::Utc;
use rust_decimal::Decimal;
use rust_decimal::prelude::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tracing::{error, info};
#[derive(Serialize)]
pub struct StatusResponse {
pub state: String,
pub uptime_secs: u64,
pub last_tick: Option<String>,
pub ticks_completed: u64,
}
#[derive(Serialize)]
pub struct PortfolioResponse {
pub cash: f64,
pub equity: f64,
pub initial_capital: f64,
pub return_pct: f64,
pub drawdown_pct: f64,
pub positions_count: usize,
}
#[derive(Serialize)]
pub struct PositionResponse {
pub ticker: String,
pub side: String,
pub quantity: u64,
pub entry_price: f64,
pub entry_time: String,
pub unrealized_pnl: f64,
}
#[derive(Serialize)]
pub struct TradeResponse {
pub ticker: String,
pub side: String,
pub quantity: u64,
pub price: f64,
pub timestamp: String,
pub pnl: Option<f64>,
pub exit_reason: Option<String>,
}
#[derive(Serialize)]
pub struct EquityPoint {
pub timestamp: String,
pub equity: f64,
pub cash: f64,
pub positions_value: f64,
pub drawdown_pct: f64,
}
#[derive(Serialize)]
pub struct CbResponse {
pub status: String,
pub events: Vec<CbEventResponse>,
}
#[derive(Serialize)]
pub struct CbEventResponse {
pub timestamp: String,
pub rule: String,
pub details: String,
pub action: String,
}
pub async fn get_status(
State(state): State<Arc<AppState>>,
) -> Json<StatusResponse> {
let status = state.engine.get_status().await;
Json(StatusResponse {
state: format!("{}", status.state),
uptime_secs: status.uptime_secs,
last_tick: status
.last_tick
.map(|t| t.to_rfc3339()),
ticks_completed: status.ticks_completed,
})
}
pub async fn get_portfolio(
State(state): State<Arc<AppState>>,
) -> Json<PortfolioResponse> {
let ctx = state.engine.get_context().await;
let portfolio = &ctx.portfolio;
let positions_value: f64 = portfolio
.positions
.values()
.map(|p| {
p.avg_entry_price.to_f64().unwrap_or(0.0)
* p.quantity as f64
})
.sum();
let cash = portfolio.cash.to_f64().unwrap_or(0.0);
let equity = cash + positions_value;
let initial = portfolio
.initial_capital
.to_f64()
.unwrap_or(10000.0);
let return_pct = if initial > 0.0 {
(equity - initial) / initial * 100.0
} else {
0.0
};
let peak = state
.store
.get_peak_equity()
.await
.ok()
.flatten()
.and_then(|p| p.to_f64())
.unwrap_or(equity);
let drawdown_pct = if peak > 0.0 {
((peak - equity) / peak * 100.0).max(0.0)
} else {
0.0
};
Json(PortfolioResponse {
cash,
equity,
initial_capital: initial,
return_pct,
drawdown_pct,
positions_count: portfolio.positions.len(),
})
}
pub async fn get_positions(
State(state): State<Arc<AppState>>,
) -> Json<Vec<PositionResponse>> {
let ctx = state.engine.get_context().await;
let positions: Vec<PositionResponse> = ctx
.portfolio
.positions
.values()
.map(|p| {
let entry = p.avg_entry_price.to_f64().unwrap_or(0.0);
PositionResponse {
ticker: p.ticker.clone(),
side: format!("{:?}", p.side),
quantity: p.quantity,
entry_price: entry,
entry_time: p.entry_time.to_rfc3339(),
unrealized_pnl: 0.0,
}
})
.collect();
Json(positions)
}
pub async fn get_trades(
State(state): State<Arc<AppState>>,
) -> Json<Vec<TradeResponse>> {
let fills = state
.store
.get_recent_fills(100)
.await
.unwrap_or_default();
let trades: Vec<TradeResponse> = fills
.into_iter()
.map(|f| TradeResponse {
ticker: f.ticker,
side: format!("{:?}", f.side),
quantity: f.quantity,
price: f.price.to_f64().unwrap_or(0.0),
timestamp: f.timestamp.to_rfc3339(),
pnl: f.pnl.and_then(|p| p.to_f64()),
exit_reason: f.exit_reason,
})
.collect();
Json(trades)
}
pub async fn get_equity(
State(state): State<Arc<AppState>>,
) -> Json<Vec<EquityPoint>> {
let snapshots = state
.store
.get_equity_curve()
.await
.unwrap_or_default();
let points: Vec<EquityPoint> = snapshots
.into_iter()
.map(|s| EquityPoint {
timestamp: s.timestamp.to_rfc3339(),
equity: s.equity.to_f64().unwrap_or(0.0),
cash: s.cash.to_f64().unwrap_or(0.0),
positions_value: s.positions_value.to_f64().unwrap_or(0.0),
drawdown_pct: s.drawdown_pct,
})
.collect();
Json(points)
}
pub async fn get_circuit_breaker(
State(state): State<Arc<AppState>>,
) -> Json<CbResponse> {
let engine_status = state.engine.get_status().await;
let cb_status = match engine_status.state {
crate::engine::EngineState::Paused(ref reason) => {
format!("tripped: {}", reason)
}
_ => "ok".to_string(),
};
let events = state
.store
.get_circuit_breaker_events(20)
.await
.unwrap_or_default();
let event_responses: Vec<CbEventResponse> = events
.into_iter()
.map(|e| CbEventResponse {
timestamp: e.timestamp.to_rfc3339(),
rule: e.rule,
details: e.details,
action: e.action,
})
.collect();
Json(CbResponse {
status: cb_status,
events: event_responses,
})
}
pub async fn post_pause(
State(state): State<Arc<AppState>>,
) -> StatusCode {
state
.engine
.pause("manual pause via API".to_string())
.await;
StatusCode::OK
}
pub async fn post_resume(
State(state): State<Arc<AppState>>,
) -> StatusCode {
state.engine.resume().await;
StatusCode::OK
}
#[derive(Deserialize)]
pub struct BacktestRequest {
pub start: String,
pub end: String,
pub data_dir: Option<String>,
pub capital: Option<f64>,
pub max_positions: Option<usize>,
pub max_position: Option<u64>,
pub interval_hours: Option<i64>,
pub kelly_fraction: Option<f64>,
pub max_position_pct: Option<f64>,
pub take_profit: Option<f64>,
pub stop_loss: Option<f64>,
pub max_hold_hours: Option<i64>,
}
#[derive(Serialize)]
pub struct BacktestStatusResponse {
pub status: String,
pub elapsed_secs: Option<u64>,
pub error: Option<String>,
pub phase: Option<String>,
pub current_step: Option<u64>,
pub total_steps: Option<u64>,
pub progress_pct: Option<f64>,
}
#[derive(Serialize)]
pub struct BacktestErrorResponse {
pub error: String,
}
pub async fn post_backtest_run(
State(state): State<Arc<AppState>>,
Json(req): Json<BacktestRequest>,
) -> Result<StatusCode, (StatusCode, Json<BacktestErrorResponse>)> {
{
let guard = state.backtest.lock().await;
if matches!(guard.status, BacktestRunStatus::Running { .. }) {
return Err((
StatusCode::CONFLICT,
Json(BacktestErrorResponse {
error: "backtest already running".into(),
}),
));
}
}
let start_time = crate::parse_date(&req.start).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(BacktestErrorResponse {
error: format!("invalid start date: {}", e),
}),
)
})?;
let end_time = crate::parse_date(&req.end).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(BacktestErrorResponse {
error: format!("invalid end date: {}", e),
}),
)
})?;
let data_dir = if let Some(ref dir) = req.data_dir {
PathBuf::from(dir)
} else {
state.data_dir.clone()
};
if !data_dir.exists() {
return Err((
StatusCode::BAD_REQUEST,
Json(BacktestErrorResponse {
error: format!(
"data directory not found: {}",
data_dir.display()
),
}),
));
}
info!(
start = %start_time,
end = %end_time,
data_dir = %data_dir.display(),
"starting backtest from web UI"
);
let progress = Arc::new(super::BacktestProgress::new(0));
{
let mut guard = state.backtest.lock().await;
guard.status =
BacktestRunStatus::Running { started_at: Utc::now() };
guard.progress = Some(progress.clone());
guard.result = None;
guard.error = None;
}
let backtest_state = state.backtest.clone();
let progress = progress.clone();
let capital = req.capital.unwrap_or(10000.0);
let max_positions = req.max_positions.unwrap_or(100);
let max_position = req.max_position.unwrap_or(100);
let interval_hours = req.interval_hours.unwrap_or(1);
let kelly_fraction = req.kelly_fraction.unwrap_or(0.40);
let max_position_pct = req.max_position_pct.unwrap_or(0.30);
let take_profit = req.take_profit.unwrap_or(0.50);
let stop_loss = req.stop_loss.unwrap_or(0.99);
let max_hold_hours = req.max_hold_hours.unwrap_or(48);
tokio::spawn(async move {
let data = match tokio::task::spawn_blocking(move || {
HistoricalData::load(&data_dir)
})
.await
{
Ok(Ok(d)) => Arc::new(d),
Ok(Err(e)) => {
let mut guard = backtest_state.lock().await;
guard.status = BacktestRunStatus::Failed;
guard.error =
Some(format!("failed to load data: {}", e));
error!(error = %e, "backtest data load failed");
return;
}
Err(e) => {
let mut guard = backtest_state.lock().await;
guard.status = BacktestRunStatus::Failed;
guard.error =
Some(format!("task join error: {}", e));
return;
}
};
let config = BacktestConfig {
start_time,
end_time,
interval: chrono::Duration::hours(interval_hours),
initial_capital: Decimal::try_from(capital)
.unwrap_or(Decimal::new(10000, 0)),
max_position_size: max_position,
max_positions,
};
let sizing_config = PositionSizingConfig {
kelly_fraction,
max_position_pct,
min_position_size: 10,
max_position_size: max_position,
};
let exit_config = ExitConfig {
take_profit_pct: take_profit,
stop_loss_pct: stop_loss,
max_hold_hours,
score_reversal_threshold: -0.3,
};
let backtester = Backtester::with_configs(
config,
data,
sizing_config,
exit_config,
)
.with_progress(progress);
let result = backtester.run().await;
let mut guard = backtest_state.lock().await;
guard.status = BacktestRunStatus::Complete;
guard.result = Some(result);
});
Ok(StatusCode::OK)
}
pub async fn get_backtest_status(
State(state): State<Arc<AppState>>,
) -> Json<BacktestStatusResponse> {
let guard = state.backtest.lock().await;
let (status_str, elapsed, error) = match &guard.status {
BacktestRunStatus::Idle => {
("idle".to_string(), None, None)
}
BacktestRunStatus::Running { started_at } => {
let elapsed = Utc::now()
.signed_duration_since(*started_at)
.num_seconds()
.max(0) as u64;
("running".to_string(), Some(elapsed), None)
}
BacktestRunStatus::Complete => {
("complete".to_string(), None, None)
}
BacktestRunStatus::Failed => {
("failed".to_string(), None, guard.error.clone())
}
};
let (phase, current_step, total_steps, progress_pct) =
if let Some(ref p) = guard.progress {
let current = p.current_step.load(
std::sync::atomic::Ordering::Relaxed,
);
let total = p.total_steps.load(
std::sync::atomic::Ordering::Relaxed,
);
let pct = if total > 0 {
current as f64 / total as f64 * 100.0
} else {
0.0
};
(
Some(p.phase_name().to_string()),
Some(current),
Some(total),
Some(pct),
)
} else {
(None, None, None, None)
};
Json(BacktestStatusResponse {
status: status_str,
elapsed_secs: elapsed,
error,
phase,
current_step,
total_steps,
progress_pct,
})
}
pub async fn get_backtest_result(
State(state): State<Arc<AppState>>,
) -> Result<
Json<crate::metrics::BacktestResult>,
StatusCode,
> {
let guard = state.backtest.lock().await;
match &guard.result {
Some(result) => Ok(Json(result.clone())),
None => Err(StatusCode::NOT_FOUND),
}
}

103
src/web/mod.rs Normal file
View File

@ -0,0 +1,103 @@
mod handlers;
use crate::engine::PaperTradingEngine;
use crate::metrics::BacktestResult;
use crate::store::SqliteStore;
use axum::routing::{get, post};
use axum::Router;
use chrono::{DateTime, Utc};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::broadcast;
use tower_http::services::ServeDir;
pub enum BacktestRunStatus {
Idle,
Running { started_at: DateTime<Utc> },
Complete,
Failed,
}
pub struct BacktestProgress {
pub phase: std::sync::atomic::AtomicU8,
pub current_step: std::sync::atomic::AtomicU64,
pub total_steps: std::sync::atomic::AtomicU64,
}
impl BacktestProgress {
pub const PHASE_LOADING: u8 = 0;
pub const PHASE_RUNNING: u8 = 1;
pub fn new(total_steps: u64) -> Self {
Self {
phase: std::sync::atomic::AtomicU8::new(
Self::PHASE_LOADING,
),
current_step: std::sync::atomic::AtomicU64::new(0),
total_steps: std::sync::atomic::AtomicU64::new(
total_steps,
),
}
}
pub fn phase_name(&self) -> &'static str {
match self
.phase
.load(std::sync::atomic::Ordering::Relaxed)
{
Self::PHASE_LOADING => "loading data",
Self::PHASE_RUNNING => "simulating",
_ => "unknown",
}
}
}
pub struct BacktestState {
pub status: BacktestRunStatus,
pub progress: Option<Arc<BacktestProgress>>,
pub result: Option<BacktestResult>,
pub error: Option<String>,
}
pub struct AppState {
pub engine: Arc<PaperTradingEngine>,
pub store: Arc<SqliteStore>,
pub shutdown_tx: broadcast::Sender<()>,
pub backtest: Arc<tokio::sync::Mutex<BacktestState>>,
pub data_dir: PathBuf,
}
pub fn build_router(state: Arc<AppState>) -> Router {
Router::new()
.route("/api/status", get(handlers::get_status))
.route("/api/portfolio", get(handlers::get_portfolio))
.route("/api/positions", get(handlers::get_positions))
.route("/api/trades", get(handlers::get_trades))
.route("/api/equity", get(handlers::get_equity))
.route(
"/api/circuit-breaker",
get(handlers::get_circuit_breaker),
)
.route(
"/api/control/pause",
post(handlers::post_pause),
)
.route(
"/api/control/resume",
post(handlers::post_resume),
)
.route(
"/api/backtest/run",
post(handlers::post_backtest_run),
)
.route(
"/api/backtest/status",
get(handlers::get_backtest_status),
)
.route(
"/api/backtest/result",
get(handlers::get_backtest_result),
)
.fallback_service(ServeDir::new("static"))
.with_state(state)
}

1124
static/index.html Normal file

File diff suppressed because it is too large Load Diff