Multi Model Evaluation
Comparing multiple AI weather prediction (AIWP) models against a shared
target is one of the most common EWB workflows. Create one
EvaluationObject per model, all sharing the same target and metric
list, then pass them together to a single ewb.evaluation call. The
results DataFrame carries a forecast_source column that labels each
row by model, making comparison straightforward. See
Usage for the single-model baseline workflow.
Example — Comparing four CIRA MLWP models on heat waves
The CIRA MLWP icechunk store contains eight models, each in its own
zarr group. ewb.inputs.get_cira_icechunk is the convenience wrapper
that returns an XarrayForecast for any model name.
import extremeweatherbench as ewb
from extremeweatherbench import inputs
model_names = [
"FOUR_v200_IFS",
"FOUR_v200_GFS",
"GRAP_v100_IFS",
"AURO_v100_IFS",
]
# One EvaluationObject per model; target and metrics are shared
target = ewb.ERA5(variables=["surface_air_temperature"])
metrics_list = [
ewb.metrics.MeanAbsoluteError(
forecast_variable="surface_air_temperature",
target_variable="surface_air_temperature",
),
ewb.metrics.MaximumMeanAbsoluteError(
forecast_variable="surface_air_temperature",
target_variable="surface_air_temperature",
),
ewb.metrics.RootMeanSquaredError(
forecast_variable="surface_air_temperature",
target_variable="surface_air_temperature",
),
]
eval_objects = [
ewb.EvaluationObject(
event_type="heat_wave",
metric_list=metrics_list,
target=target,
forecast=inputs.get_cira_icechunk(model_name=name),
)
for name in model_names
]
cases = ewb.load_cases()
runner = ewb.evaluation(case_metadata=cases, evaluation_objects=eval_objects)
outputs = runner.run_evaluation()
outputs.to_csv("multi_model_heatwave.csv", index=False)
Comparing models from different sources
Mix ZarrForecast, XarrayForecast, and CIRA models in the same run.
Each must have its own name to appear distinctly in the output:
import extremeweatherbench as ewb
from extremeweatherbench import inputs
hres = ewb.ZarrForecast(
source=(
"gs://weatherbench2/datasets/hres/"
"2016-2022-0012-1440x721.zarr"
),
name="HRES",
variable_mapping=ewb.HRES_metadata_variable_mapping,
storage_options={"remote_options": {"anon": True}},
)
fcnv2_ifs = inputs.get_cira_icechunk("FOUR_v200_IFS")
pangu_ifs = inputs.get_cira_icechunk("PANG_v100_IFS")
target = ewb.ERA5(variables=["surface_air_temperature"])
metrics_list = [
ewb.metrics.MeanAbsoluteError(
forecast_variable="surface_air_temperature",
target_variable="surface_air_temperature",
),
]
eval_objects = [
ewb.EvaluationObject(
event_type="heat_wave",
metric_list=metrics_list,
target=target,
forecast=model,
)
for model in [hres, fcnv2_ifs, pangu_ifs]
]
cases = ewb.load_cases()
runner = ewb.evaluation(case_metadata=cases, evaluation_objects=eval_objects)
outputs = runner.run_evaluation()
Filtering and plotting results
The output DataFrame has a forecast_source column that matches each
model's name. Use it to pivot or group results for comparison:
import pandas as pd
outputs = pd.read_csv("multi_model_heatwave.csv")
# Mean MAE per model and lead time
mae = outputs[outputs["metric"] == "MeanAbsoluteError"]
pivot = (
mae.groupby(["forecast_source", "lead_time"])["value"]
.mean()
.unstack("forecast_source")
)
print(pivot)
Parallel execution
For large model comparisons, enable parallel execution by passing a
parallel_config dictionary to runner.run_evaluation(). The configuration is
forwarded to joblib.Parallel:
parallel_config = {
"backend": "loky",
"n_jobs": 4,
}
outputs = runner.run_evaluation(parallel_config=parallel_config)
Detailed Explanation: Each
(case, EvaluationObject)pair becomes aCaseOperatorthat EWB can execute in parallel. With four models and 337 cases you have up to 1 348 operators. Memory scales withn_jobssince each worker holds its own copy of the forecast slice for the active case. Setn_jobsto the number of CPU cores available and adjust downward if you run out of memory. On a cloud VM with a fast network link,n_jobs=8is a good starting point for comparing four to eight models.
Subsetting to a geographic region
Use RegionSubsetter to restrict which cases from the full list are
included in the run, for example to focus on North American events only:
import extremeweatherbench as ewb
subsetter = ewb.RegionSubsetter(
latitude_min=15.0,
latitude_max=75.0,
longitude_min=230.0,
longitude_max=310.0,
)
runner = ewb.evaluation(
case_metadata=ewb.load_cases(),
evaluation_objects=eval_objects,
region_subsetter=subsetter,
)
outputs = runner.run_evaluation()
Complete Example
Four CIRA MLWP models compared on heat wave cases with parallel execution and a lead-time pivot table printed to stdout.
import datetime
import extremeweatherbench as ewb
from extremeweatherbench import inputs
from extremeweatherbench.cases import IndividualCase
from extremeweatherbench.regions import BoundingBoxRegion
demo_case = IndividualCase(
case_id_number=9010,
title="2020 SW US Heat Wave (demo)",
start_date=datetime.datetime(2020, 8, 15),
end_date=datetime.datetime(2020, 8, 18),
location=BoundingBoxRegion.create_region(
latitude_min=34.0,
latitude_max=40.0,
longitude_min=242.0,
longitude_max=248.0,
),
event_type="heat_wave",
)
cases = [demo_case]
model_names = [
"FOUR_v200_IFS",
"FOUR_v200_GFS",
"GRAP_v100_IFS",
"AURO_v100_IFS",
]
target = ewb.ERA5(
variables=["surface_air_temperature"]
)
metrics_list = [
ewb.metrics.MeanAbsoluteError(
forecast_variable="surface_air_temperature",
target_variable="surface_air_temperature",
),
ewb.metrics.MaximumMeanAbsoluteError(
forecast_variable="surface_air_temperature",
target_variable="surface_air_temperature",
),
ewb.metrics.RootMeanSquaredError(
forecast_variable="surface_air_temperature",
target_variable="surface_air_temperature",
),
]
eval_objects = [
ewb.EvaluationObject(
event_type="heat_wave",
metric_list=metrics_list,
target=target,
forecast=inputs.get_cira_icechunk(
model_name=name
),
)
for name in model_names
]
runner = ewb.evaluation(
case_metadata=cases,
evaluation_objects=eval_objects,
)
outputs = runner.run_evaluation()
mae = outputs[outputs["metric"] == "MeanAbsoluteError"]
pivot = (
mae.groupby(["forecast_source", "lead_time"])["value"]
.mean()
.unstack("forecast_source")
)
print(pivot)