Source code for src.scripts.generate_plots

#!/usr/bin/env python3
"""
Generate state-level plots for all models

Usage:
    election-plot                                  # Default: plot key swing states (for 2016)
    election-plot --all                            # Plot all states with sufficient data
    election-plot --states FL PA MI WI             # Plot specific states
    election-plot --year 2020 --all                # Plot all states for 2020
    election-plot --year 2020 --polls-file data/polls/2020_president_polls.csv --all
"""

import importlib
import inspect
import argparse
import traceback
from pathlib import Path
from importlib import resources

import pandas as pd  # type: ignore[import-untyped]

import src.models as models_package
from src.models.base_model import ElectionForecastModel
from src.utils.logging_config import setup_logging, get_logger
from src.utils.data_utils import load_polling_data, set_election_config

logger = get_logger(__name__)


[docs] def discover_models(): """Auto-discover all model classes using importlib.resources.""" models = [] try: for item in resources.files(models_package).iterdir(): if not item.is_file(): continue if not item.name.endswith(".py"): continue if item.name.startswith("_") or item.name == "base_model.py": continue module_name = f"src.models.{item.name[:-3]}" try: module = importlib.import_module(module_name) for name, obj in inspect.getmembers(module, inspect.isclass): if ( issubclass(obj, ElectionForecastModel) and obj is not ElectionForecastModel and obj.__module__ == module_name ): models.append((name, obj)) except Exception as e: logger.info(f"Warning: Could not import {module_name}: {e}") except Exception as e: logger.info(f"Error discovering models: {e}") return sorted(models, key=lambda x: x[0])
[docs] def main(): parser = argparse.ArgumentParser( description="Generate state-level forecast plots for all models", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: election-plot # Plot key swing states (default year 2016) election-plot --all # Plot all states election-plot --states FL PA MI # Plot specific states election-plot --year 2020 --all # Plot all states (2020) election-plot --year 2020 --polls-file data/polls/2020_president_polls.csv --all """, ) parser.add_argument( "--all", action="store_true", help="Generate plots for all states with sufficient polling data", ) parser.add_argument( "--states", nargs="+", help="Specific state codes to plot (e.g., FL PA MI WI)", ) parser.add_argument( "--year", "-y", type=int, default=2016, help="Election year to plot (e.g. 2012, 2016, 2020). Default: 2016.", ) parser.add_argument( "--polls-file", type=str, default=None, help=( "Optional path to a FiveThirtyEight-style polls CSV. " "If omitted, the loader uses the built-in 2016 timeseries for year=2016 " "or data/polls/{year}_president_polls.csv for other years." ), ) parser.add_argument( "--verbose", "-v", action="store_true", help="Enable verbose logging.", ) args = parser.parse_args() # Setup logging setup_logging(__name__, level="DEBUG" if args.verbose else "INFO") # IMPORTANT: make plotting use the same election year / polls file # as the forecasting step. set_election_config(year=args.year, polls_file=args.polls_file) # Load polls for the configured election; used only to decide which states exist polls = load_polling_data() # Determine which states to plot if args.states: states_to_plot = [s.upper() for s in args.states] logger.info(f"Plotting {len(states_to_plot)} specified states") elif args.all: # All states with polling data (drop NaNs to avoid str/float sorting error) states_to_plot = sorted( s for s in polls["state_code"].dropna().unique().tolist() ) logger.info(f"Plotting all {len(states_to_plot)} states with polling data") else: # Default: key swing states default_states = ["FL", "PA", "MI", "WI", "NC", "AZ", "NV", "GA", "OH", "VA"] # Only keep those that actually appear in the current polls available = set(polls["state_code"].dropna().unique().tolist()) states_to_plot = [s for s in default_states if s in available] logger.info(f"Plotting {len(states_to_plot)} key swing states") if not states_to_plot: logger.info("No states to plot (none found in polling data).") return logger.info(f"States: {', '.join(states_to_plot)}\n") # Discover models logger.info("Discovering models...") model_classes = discover_models() if not model_classes: logger.info("No models found in src.models") return logger.info(f"Found {len(model_classes)} model(s):") for name, _ in model_classes: logger.info(f" - {name}") # Ensure base plots directory exists Path("plots").mkdir(parents=True, exist_ok=True) # Generate plots for each model total_plots = 0 for model_name, ModelClass in model_classes: logger.info(f"\nGenerating plots for {model_name}...") try: model = ModelClass() # Load predictions from CSV if they exist pred_file = Path(f"predictions/{model.name}.csv") if pred_file.exists(): pred_df = pd.read_csv(pred_file) if pred_df.empty: logger.info(f" Warning: Predictions file {pred_file} is empty") continue # Convert forecast_date to datetime for correct plotting if "forecast_date" in pred_df.columns: pred_df["forecast_date"] = pd.to_datetime(pred_df["forecast_date"]) # Attach predictions in the same structure run_forecast() uses model.predictions = pred_df.to_dict("records") else: logger.info(f" Warning: No predictions found at {pred_file}") logger.info(" Run 'election-forecast' first to generate predictions") continue for state in states_to_plot: try: model.plot_state(state) total_plots += 1 except Exception as e: logger.info(f" Warning: Could not plot {state}: {e}") logger.info(f" ✓ Saved to plots/{model.name}/") except Exception as e: logger.info(f" ERROR: {e}") traceback.print_exc() logger.info(f"\n✓ Generated {total_plots} plots total") logger.info(" Plots saved in plots/ directory (organized by model)")
if __name__ == "__main__": main()