API Reference
Models
Base Model
Base class for election forecasting models.
This version is generalized so that the election year and election date are not hard-coded to 2016. The date is taken from src.utils.data_utils.get_current_election_date(), which in turn is controlled by set_election_config(…).
- class src.models.base_model.ElectionForecastModel(name: str, seed: int | None = None)[source]
Bases:
ABCAbstract base class for election forecasting models.
- __init__(name: str, seed: int | None = None) None[source]
Initialize the model.
- Parameters:
name – Model name.
seed – Random seed for reproducibility (default: None for non-deterministic).
- abstractmethod fit_and_forecast(state_polls: DataFrame, forecast_date: Timestamp, election_date: Timestamp, actual_margin: float, rng: Generator | None = None) Dict[str, float][source]
Fit model on polls up to forecast_date and predict election outcome.
- Must return a dict with keys:
“win_probability”
“predicted_margin”
optionally “margin_std”
- load_data() Tuple[DataFrame, Dict[str, float]][source]
Load polling and election results data.
- Returns:
tuple of (polls DataFrame, actual_margin dict)
- run_forecast(forecast_dates: List[Timestamp] | None = None, min_polls: int = 10, verbose: bool = False, n_workers: int | None = None) DataFrame[source]
Run forecast across multiple dates and states.
- Parameters:
forecast_dates – List of forecast dates. If None, use four default dates in October/November of the election year.
min_polls – Minimum number of polls required to forecast a state.
verbose – If True, log per-state progress.
n_workers – If None or <=1, run sequentially; otherwise use ProcessPoolExecutor with the given number of workers.
- Returns:
state, forecast_date, win_probability, predicted_margin, margin_std, actual_margin
- Return type:
DataFrame of predictions with columns
Poll Average Model
Simple Poll-of-Polls Average Model Weighted average of recent polls with empirical uncertainty
- class src.models.poll_average.PollAverageModel(seed=None)[source]
Bases:
ElectionForecastModelSimple weighted poll average baseline
Kalman Diffusion Model
Kalman Filter Diffusion Model with Improved Regularization Brownian motion with drift + pollster biases + fundamentals prior
- class src.models.kalman_diffusion.KalmanDiffusionModel(seed=None)[source]
Bases:
ElectionForecastModelImproved diffusion model with Kalman filter/RTS smoother
- __init__(seed=None)[source]
Initialize the model.
- Parameters:
name – Model name.
seed – Random seed for reproducibility (default: None for non-deterministic).
- kalman_filter_smoother(dates, observations, obs_variance, mu, sigma2)[source]
Kalman filter + RTS smoother for Brownian motion with drift
- Parameters:
dates – Array of time points (in days)
observations – Array of poll margins
obs_variance – Array of observation variances
mu – Drift parameter
sigma2 – Diffusion variance
- Returns:
smoothed state estimates and variances
- Return type:
tuple of (x_smooth, P_smooth)
- fit_state_diffusion(state_polls, prior_mean=0.0, max_iter=10)[source]
Fit diffusion model with EM algorithm
- Parameters:
state_polls – DataFrame of polls for a single state
prior_mean – Prior mean for fundamentals
max_iter – Maximum number of EM iterations
- Returns:
tuple of (mu, sigma2, pollster_bias, x_smooth, P_smooth, dates)
- simulate_forward(x_start, P_start, mu, sigma2, days, N=2000, rng=None)[source]
Simulate forward with Euler-Maruyama method
- Parameters:
x_start – Initial state estimate
P_start – Initial state variance
mu – Drift parameter
sigma2 – Diffusion variance
days – Number of days to simulate forward
N – Number of simulation samples
rng – NumPy random generator (default: None uses default_rng)
- Returns:
Array of final margin values (length N), clipped to [-1, 1]
Improved Kalman Model
Improved Kalman Diffusion Model
Key improvements over basic Kalman: - Increased minimum diffusion variance - Better regularized pollster biases - Smaller forecast horizon uncertainty - More conservative probability clipping
- class src.models.improved_kalman.ImprovedKalmanModel(seed=None)[source]
Bases:
ElectionForecastModelImproved Kalman filter diffusion model
- __init__(seed=None)[source]
Initialize the model.
- Parameters:
name – Model name.
seed – Random seed for reproducibility (default: None for non-deterministic).
- kalman_filter_rts(dates, observations, obs_variance, mu, sigma2)[source]
Kalman filter + RTS smoother
- fit_and_forecast(state_polls, forecast_date, election_date, actual_margin, rng=None)[source]
Fit improved Kalman diffusion and forecast
- simulate_forward(x_start, P_start, mu, sigma2, days, N=2000, rng=None)[source]
Simulate forward with Euler-Maruyama
- Parameters:
x_start – Initial state estimate
P_start – Initial state variance
mu – Drift parameter
sigma2 – Diffusion variance
days – Number of days to simulate forward
N – Number of simulation samples
rng – NumPy random generator (default: None uses default_rng)
- Returns:
Array of final margin values (length N)
Hierarchical Bayes Model
Hierarchical Bayesian Ensemble with Systematic Bias Adjustment (HBE-SBA)
Combines: 1. Fundamentals prior from historical results 2. Kalman-filtered polls with house effects 3. Adaptive systematic bias correction 4. Proper uncertainty quantification
- class src.models.hierarchical_bayes.HierarchicalBayesModel(seed=None)[source]
Bases:
ElectionForecastModelHierarchical Bayesian ensemble with bias correction
- __init__(seed=None)[source]
Initialize the model.
- Parameters:
name – Model name.
seed – Random seed for reproducibility (default: None for non-deterministic).
- estimate_house_effects(all_polls, lambda_shrink=10)[source]
Estimate pollster house effects with hierarchical shrinkage
- Parameters:
all_polls – DataFrame of all polling data
lambda_shrink – Shrinkage parameter (higher = more shrinkage to zero)
- Returns:
dict mapping pollster name to estimated house effect
- kalman_filter_rts(dates, observations, obs_variance, mu, sigma2)[source]
Kalman filter with Rauch-Tung-Striebel (RTS) backward smoother
- Parameters:
dates – Array of time points (in days)
observations – Array of poll margins
obs_variance – Array of observation variances
mu – Drift parameter
sigma2 – Diffusion variance
- Returns:
smoothed state estimates and variances
- Return type:
tuple of (x_smooth, P_smooth)
Utilities
Data Utilities
Shared data loading and preprocessing utilities.
This version supports multiple election cycles (e.g. 2012, 2016, 2020) and both the original 2016 timeseries file and FiveThirtyEight-style long polls files (like 2020_president_polls.csv).
- src.utils.data_utils.set_election_config(year: int = 2016, polls_file: str | None = None) None[source]
Configure which election cycle the rest of the module should use.
- Parameters:
year – Election year (e.g. 2012, 2016, 2020).
polls_file –
Optional path to a FiveThirtyEight-style polls CSV. If None, we:
use the original 2016 timeseries file for year=2016
otherwise fall back to f”data/polls/{year}_president_polls.csv”
- src.utils.data_utils.get_election_date(year: int) str[source]
Return the election day (YYYY-MM-DD) for a given year.
- src.utils.data_utils.get_current_election_date() str[source]
Convenience wrapper that uses the currently configured election year.
- src.utils.data_utils.load_polling_data() DataFrame[source]
Load polling data for the currently configured election.
- Behaviour:
If CURRENT_ELECTION_YEAR == 2016 and CURRENT_POLLS_FILE is None, this uses the original _load_polling_data_2016() to preserve backwards compatibility.
Otherwise, it expects a FiveThirtyEight-style CSV (either provided via CURRENT_POLLS_FILE or inferred as data/polls/{year}_president_polls.csv) and parses it with _load_polling_data_fte_long.
- src.utils.data_utils.load_election_results() Dict[str, float][source]
Public wrapper used by models.
Uses the currently configured election year.
- src.utils.data_utils.load_fundamentals() Dict[str, Dict[str, float]][source]
Load historical election results for fundamentals prior.
Computes weighted average of 2012 (70%) and 2008 (30%) results.
NOTE: This is still the same 2016-oriented prior as in the original project. If you want a 2012 or 2020-specific fundamentals prior, you can generalise this function further (e.g. use (2008, 2004) for 2012, or (2016, 2012) for 2020).
- Returns:
margin, margin_2012, margin_2008
- Return type:
dict mapping state code to fundamentals dict with keys
- src.utils.data_utils.get_state_list(polls: DataFrame, actual_results: Dict[str, float]) List[str][source]
Get list of states with sufficient polling data.
- Parameters:
polls – DataFrame of polling data
actual_results – dict of actual election results
- Returns:
list of state codes
- src.utils.data_utils.compute_metrics(predictions_df: DataFrame) DataFrame[source]
Compute evaluation metrics from predictions.
- Parameters:
predictions_df – DataFrame with columns: forecast_date, win_probability, predicted_margin, actual_margin
- Returns:
forecast_date, n_states, brier_score, log_loss, mae_margin
- Return type:
DataFrame with columns
Scripts
Run All Models
- src.scripts.run_all_models.discover_models()[source]
Auto-discover all model classes using importlib.resources.
- Returns:
List of tuples (model_class_name, model_class) sorted by name.
- src.scripts.run_all_models.generate_forecast_dates(n_dates: int, election_date: str | None = None, start_date: str | None = None) List[Timestamp][source]
Generate n evenly-spaced forecast dates between start_date and election_date.
- Parameters:
n_dates – Number of forecast dates to generate.
election_date – Election day as a string (YYYY-MM-DD). If None, use the currently configured election date.
start_date – Earliest date to start forecasting from. If None, default to September 1 of the election year.
- Returns:
List of pd.Timestamp forecast dates.
Compare Models
Compare all forecasting models
Generates comparison tables, rankings, and plots for all models
Generate Plots
Generate state-level plots for all models
- Usage:
election-plot # Default: plot key swing states (for 2016) election-plot –all # Plot all states with sufficient data election-plot –states FL PA MI WI # Plot specific states election-plot –year 2020 –all # Plot all states for 2020 election-plot –year 2020 –polls-file data/polls/2020_president_polls.csv –all
Run All Pipeline
- src.scripts.run_all.run_with_temp_argv(argv, func)[source]
Temporarily override sys.argv to call a subcommand.