Source code for src.scripts.compare_models

#!/usr/bin/env python3
"""
Compare all forecasting models

Generates comparison tables, rankings, and plots for all models
"""

import pandas as pd  # type: ignore[import-untyped]
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
from src.utils.logging_config import setup_logging, get_logger

logger = get_logger(__name__)


[docs] def parse_metrics(filename): """ Parse metrics from text file Args: filename: Path to metrics text file Returns: DataFrame with columns: date, brier, log_loss, mae """ metrics = [] with open(filename) as f: lines = f.readlines() current = {} for line in lines: if "Forecast Date:" in line: if current: metrics.append(current) current = {"date": line.split(":")[1].strip()} elif "Brier Score:" in line: current["brier"] = float(line.split(":")[1].strip()) elif "Log Loss:" in line: current["log_loss"] = float(line.split(":")[1].strip()) elif "MAE (Margin):" in line: current["mae"] = float(line.split(":")[1].strip()) if current: metrics.append(current) return pd.DataFrame(metrics)
[docs] def main(): """Load all model metrics, compare performance, and generate visualizations""" setup_logging(__name__) metrics_files = glob.glob("metrics/*.txt") if len(metrics_files) == 0: logger.warning("No metrics files found. Run models first.") return all_metrics = [] for metrics_file in metrics_files: model_name = Path(metrics_file).stem df = parse_metrics(metrics_file) df["model"] = model_name all_metrics.append(df) all_metrics = pd.concat(all_metrics, ignore_index=True) # Create comparison tables logger.info("Brier Score (lower is better):") pivot_brier = all_metrics.pivot(index="date", columns="model", values="brier") logger.info(f"\n{pivot_brier.to_string()}") logger.info("\nLog Loss (lower is better):") pivot_ll = all_metrics.pivot(index="date", columns="model", values="log_loss") logger.info(f"\n{pivot_ll.to_string()}") logger.info("\nMAE Margin (lower is better):") pivot_mae = all_metrics.pivot(index="date", columns="model", values="mae") logger.info(f"\n{pivot_mae.to_string()}") logger.info("\nAverage performance across all forecast dates:") summary = all_metrics.groupby("model")[["brier", "log_loss", "mae"]].mean() summary = summary.round(4) logger.info(f"\n{summary.to_string()}") logger.info("\nModel rankings (1 = best)") rankings = pd.DataFrame( { "Brier Score": summary["brier"].rank(), "Log Loss": summary["log_loss"].rank(), "MAE": summary["mae"].rank(), } ) rankings["Average Rank"] = rankings.mean(axis=1) rankings = rankings.sort_values("Average Rank") logger.info(f"\n{rankings.to_string()}") comparison_table = all_metrics.pivot_table( index="date", columns="model", values=["brier", "log_loss", "mae"] ) comparison_table.to_csv("model_comparison.csv") models = all_metrics["model"].unique() fig, axes = plt.subplots(1, 3, figsize=(15, 4)) markers = ["o", "s", "^", "d", "v", "*", "p"] colors = plt.cm.tab10(np.linspace(0, 1, len(models))) for i, model in enumerate(models): model_data = all_metrics[all_metrics["model"] == model].sort_values("date") model_dates = pd.to_datetime(model_data["date"]) axes[0].plot( model_dates, model_data["brier"].values, marker=markers[i % len(markers)], label=model, linewidth=2, color=colors[i], markersize=8, ) axes[1].plot( model_dates, model_data["log_loss"].values, marker=markers[i % len(markers)], label=model, linewidth=2, color=colors[i], markersize=8, ) axes[2].plot( model_dates, model_data["mae"].values, marker=markers[i % len(markers)], label=model, linewidth=2, color=colors[i], markersize=8, ) axes[0].set_xlabel("Forecast Date") axes[0].set_ylabel("Brier Score") axes[0].set_title("Brier Score Over Time") axes[0].legend() axes[0].grid(alpha=0.3) axes[0].tick_params(axis="x", rotation=45) axes[1].set_xlabel("Forecast Date") axes[1].set_ylabel("Log Loss") axes[1].set_title("Log Loss Over Time") axes[1].legend() axes[1].grid(alpha=0.3) axes[1].tick_params(axis="x", rotation=45) axes[2].set_xlabel("Forecast Date") axes[2].set_ylabel("MAE (Margin)") axes[2].set_title("Margin Error Over Time") axes[2].legend() axes[2].grid(alpha=0.3) axes[2].tick_params(axis="x", rotation=45) plt.tight_layout() plt.savefig("model_comparison.png")
if __name__ == "__main__": main()