#!/usr/bin/env python3
import importlib
import inspect
import argparse
import traceback
import cProfile
import pandas as pd # type: ignore[import-untyped]
from datetime import timedelta
from importlib import resources
import src.models as models_package
from src.models.base_model import ElectionForecastModel
from src.utils.logging_config import setup_logging, get_logger
logger = get_logger(__name__)
[docs]
def discover_models():
"""
Auto-discover all model classes using importlib.resources
Returns:
List of tuples (model_class_name, model_class) sorted by name
"""
models = []
try:
# Get all module names in the models package
for item in resources.files(models_package).iterdir():
if not item.is_file():
continue
if not item.name.endswith(".py"):
continue
if item.name.startswith("_") or item.name == "base_model.py":
continue
# Import the module
module_name = f"src.models.{item.name[:-3]}" # gets rid of .py
try:
module = importlib.import_module(module_name)
# Find all classes that inherit from ElectionForecastModel
for name, obj in inspect.getmembers(module, inspect.isclass):
if (
issubclass(obj, ElectionForecastModel)
and obj != ElectionForecastModel
and obj.__module__ == module_name
):
models.append((name, obj))
except Exception as e:
logger.warning(f"Could not import {module_name}: {e}")
except Exception as e:
logger.error(f"Error discovering models: {e}")
return sorted(models, key=lambda x: x[0]) # sort by name
[docs]
def generate_forecast_dates(
n_dates, election_date="2016-11-08", start_date="2016-09-01"
):
"""
Generate n evenly-spaced forecast dates between start_date and election_date
Args:
n_dates: Number of forecast dates to generate
election_date: Election day
start_date: Earliest date to start forecasting from
Returns:
List of pd.Timestamp forecast dates
"""
election = pd.to_datetime(election_date)
start = pd.to_datetime(start_date)
# Calculate total days available (end 1 day before election)
last_date = election - timedelta(days=1)
total_days = (last_date - start).days
# Generate n evenly-spaced dates (work backwards from election)
dates = []
for i in range(n_dates):
days_from_end = (
int(total_days * (n_dates - 1 - i) / (n_dates - 1)) if n_dates > 1 else 0
)
forecast_date = last_date - timedelta(days=days_from_end)
dates.append(forecast_date)
return dates
[docs]
def main():
parser = argparse.ArgumentParser(
description="Run all election forecasting models",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
election-forecast # Default: 4 forecast dates, sequential
election-forecast --dates 8 # Use 8 forecast dates
election-forecast -n 16 # Use 16 forecast dates
election-forecast -v # Verbose output
election-forecast --parallel 4 # Use 4 parallel workers
election-forecast -w 8 # Use 8 parallel workers
""",
)
parser.add_argument(
"--dates",
"-n",
type=int,
default=4,
help="Number of forecast dates to use (default: 4)",
)
parser.add_argument(
"--verbose", "-v", action="store_true", help="Enable verbose output"
)
parser.add_argument(
"--profile",
"-p",
type=str,
metavar="FILE",
help="Enable profiling and save to FILE (e.g., forecast.prof)",
)
parser.add_argument(
"--seed",
"-s",
type=int,
metavar="SEED",
help="Random seed for reproducibility (default: None for non-deterministic)",
)
parser.add_argument(
"--parallel",
"-w",
type=int,
metavar="WORKERS",
help="Number of parallel workers for state-level parallelization (default: None for sequential)",
)
args = parser.parse_args()
if args.profile:
profiler = cProfile.Profile()
profiler.enable()
setup_logging(__name__, level="DEBUG" if args.verbose else "INFO")
forecast_dates = generate_forecast_dates(args.dates)
logger.info(f"Using {len(forecast_dates)} forecast dates")
if args.verbose:
for date in forecast_dates:
days_to_election = (pd.to_datetime("2016-11-08") - date).days
logger.info(f" - {date.date()} ({days_to_election} days before election)")
logger.info("Looking for models...")
model_classes = discover_models()
if not model_classes:
logger.warning("No models found in src.models")
return
logger.info(f"Found {len(model_classes)} model(s)")
if args.verbose:
for name, _ in model_classes:
logger.info(f" - {name}")
for model_name, ModelClass in model_classes:
logger.info(f"\nRunning: {model_name}")
try:
model = ModelClass(seed=args.seed)
pred_df = model.run_forecast(
forecast_dates=forecast_dates,
verbose=args.verbose,
n_workers=args.parallel,
)
metrics_df = model.save_results()
if args.verbose:
logger.info(f"Total predictions: {len(pred_df)}")
logger.info(f"Metrics:\n{metrics_df.to_string(index=False)}")
except Exception as e:
logger.error(f"ERROR running {model_name}: {e}")
traceback.print_exc()
if args.profile:
profiler.disable()
profiler.dump_stats(args.profile)
logger.info(f"\nProfiling data saved to {args.profile}")
logger.info(f"View with: snakeviz {args.profile}")
if __name__ == "__main__":
main()