Source code for src.scripts.run_all_models

#!/usr/bin/env python3
import importlib
import inspect
import argparse
import traceback
import cProfile
import pandas as pd  # type: ignore[import-untyped]
from datetime import timedelta
from importlib import resources

import src.models as models_package
from src.models.base_model import ElectionForecastModel
from src.utils.logging_config import setup_logging, get_logger
from src.utils.data_utils import set_election_config
from src.utils.data_utils import get_current_election_date
from typing import List, Optional


logger = get_logger(__name__)


[docs] def discover_models(): """ Auto-discover all model classes using importlib.resources. Returns: List of tuples (model_class_name, model_class) sorted by name. """ models = [] try: for item in resources.files(models_package).iterdir(): if not item.is_file(): continue if not item.name.endswith(".py"): continue if item.name.startswith("_") or item.name == "base_model.py": continue module_name = f"src.models.{item.name[:-3]}" # strip .py try: module = importlib.import_module(module_name) for name, obj in inspect.getmembers(module, inspect.isclass): if ( issubclass(obj, ElectionForecastModel) and obj is not ElectionForecastModel and obj.__module__ == module_name ): models.append((name, obj)) except Exception as e: logger.warning(f"Could not import {module_name}: {e}") except Exception as e: logger.error(f"Error discovering models: {e}") return sorted(models, key=lambda x: x[0])
def _default_election_and_start_dates(year: int) -> tuple[str, str]: """ Helper to provide sensible default election / start dates per cycle. """ election_dates = { 2012: "2012-11-06", 2016: "2016-11-08", 2020: "2020-11-03", } election_date = election_dates.get(year, f"{year}-11-01") start_date = f"{year}-09-01" return election_date, start_date
[docs] def generate_forecast_dates( n_dates: int, election_date: Optional[str] = None, start_date: Optional[str] = None, ) -> List[pd.Timestamp]: """ Generate n evenly-spaced forecast dates between start_date and election_date. Args: n_dates: Number of forecast dates to generate. election_date: Election day as a string (YYYY-MM-DD). If None, use the currently configured election date. start_date: Earliest date to start forecasting from. If None, default to September 1 of the election year. Returns: List of pd.Timestamp forecast dates. """ # Default election_date to the configured election (2016 in tests) if election_date is None: election_date = get_current_election_date() # Default start_date to Sept 1 of the election year if start_date is None: year = int(pd.to_datetime(election_date).year) start_date = f"{year}-09-01" election = pd.to_datetime(election_date) start = pd.to_datetime(start_date) # Calculate total days available (end 1 day before election) last_date = election - timedelta(days=1) total_days = (last_date - start).days # Generate n evenly-spaced dates (work backwards from election) dates: List[pd.Timestamp] = [] for i in range(n_dates): if n_dates > 1: days_from_end = int(total_days * (n_dates - 1 - i) / (n_dates - 1)) else: days_from_end = 0 forecast_date = last_date - timedelta(days=days_from_end) dates.append(forecast_date) return dates
[docs] def main(): parser = argparse.ArgumentParser( description="Run all election forecasting models", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: election-forecast # Default: 4 forecast dates on 2016 election-forecast --dates 8 # Use 8 forecast dates election-forecast --year 2020 # Run on 2020 data (expects 2020_president_polls.csv) election-forecast --year 2012 --polls-file data/polls/2012_president_polls.csv election-forecast -n 16 # Use 16 forecast dates election-forecast -v # Verbose output election-forecast --parallel 4 # Use 4 parallel workers election-forecast -w 8 # Use 8 parallel workers """, ) parser.add_argument( "--dates", "-n", type=int, default=4, help="Number of forecast dates to use (default: 4)", ) parser.add_argument( "--year", "-y", type=int, default=2016, help="Election year to run (e.g. 2012, 2016, 2020). Default: 2016.", ) parser.add_argument( "--polls-file", type=str, default=None, help=( "Optional path to a FiveThirtyEight-style polls CSV. " "If omitted, the loader uses the built-in 2016 timeseries for year=2016 " "or data/polls/{year}_president_polls.csv for other years." ), ) parser.add_argument( "--verbose", "-v", action="store_true", help="Enable verbose output" ) parser.add_argument( "--profile", "-p", type=str, metavar="FILE", help="Enable profiling and save to FILE (e.g., forecast.prof)", ) parser.add_argument( "--seed", "-s", type=int, metavar="SEED", help="Random seed for reproducibility (default: None for non-deterministic)", ) parser.add_argument( "--parallel", "-w", type=int, metavar="WORKERS", help="Number of parallel workers for state-level parallelization (default: None for sequential)", ) args = parser.parse_args() if args.profile: profiler = cProfile.Profile() profiler.enable() # Configure global election year / polls file for all downstream loaders set_election_config(year=args.year, polls_file=args.polls_file) setup_logging(__name__, level="DEBUG" if args.verbose else "INFO") election_date, start_date = _default_election_and_start_dates(args.year) forecast_dates = generate_forecast_dates( n_dates=args.dates, election_date=election_date, start_date=start_date, ) logger.info(f"Using {len(forecast_dates)} forecast dates for year {args.year}") if args.verbose: for date in forecast_dates: days_to_election = (pd.to_datetime(election_date) - date).days logger.info(f" - {date.date()} ({days_to_election} days before election)") logger.info("Looking for models...") model_classes = discover_models() if not model_classes: logger.warning("No models found in src.models") return logger.info(f"Found {len(model_classes)} model(s)") if args.verbose: for name, _ in model_classes: logger.info(f" - {name}") for model_name, ModelClass in model_classes: logger.info(f"\nRunning: {model_name}") try: model = ModelClass(seed=args.seed) pred_df = model.run_forecast( forecast_dates=forecast_dates, verbose=args.verbose, n_workers=args.parallel, ) metrics_df = model.save_results() if args.verbose: logger.info(f"Total predictions: {len(pred_df)}") logger.info(f"Metrics:\n{metrics_df.to_string(index=False)}") except Exception as e: logger.error(f"ERROR running {model_name}: {e}") traceback.print_exc() if args.profile: profiler.disable() profiler.dump_stats(args.profile) logger.info(f"\nProfiling data saved to {args.profile}") logger.info("View with: snakeviz {args.profile}")
if __name__ == "__main__": main()