Source code for pyhanami.diags.Diagnostics

import copy
import cmocean
import warnings
warnings.simplefilter("always")

import numpy as np
import xarray as xr
import concurrent.futures
import multiprocessing as mp

from tqdm import tqdm
from scipy.stats import ttest_ind
from collections.abc import Iterable
from matplotlib.colors import LinearSegmentedColormap

from pyhanami.config import config_params
from pyhanami.diags.Simulations import SimulationData
from pyhanami.diags.Observations import ObservationData
from pyhanami.utils import data_general, plot, statistics


[docs] class DataDiagnostics: """ Perform diagnostic comparisons between climate simulation ensembles. This class provides functionality for computing and visualizing differences in climate variables between simulation ensembles. It includes methods for computing annual time series, absolute differences, effect sizes and significance differences at grid point level. Parameters ---------- datasets : SimulationData or Iterable[SimulationData], optional Ensemble or list of ensembles containing simulation data and metadata. Attributes ---------- datasets : list[SimulationData] List of ensembles containing simulation data and metadata. variables : dict Configuration dictionary mapping variable names to display metadata. max_workers_grid : int Number of parallel workers used for grid-level computations. """ def __init__(self, datasets=None): if datasets is None: self.datasets = [] elif isinstance(datasets, SimulationData): self.datasets = [datasets] elif isinstance(datasets, Iterable) and not isinstance(datasets, (str, bytes)) \ and all(isinstance(ds, SimulationData) for ds in datasets): self.datasets = list(datasets) else: raise TypeError("Input must be a SimulationData object or an iterable of SimulationData objects.") # Load config parameters once self.variables = data_general.load_yaml_file(config_params.VARIABLES_PATH) self.max_workers_grid = config_params.MAX_WORKERS_GRID def _compute_time_series(self, var_name, data_plot, time_freq='1YS'): """ Compute time series for the given simulation ensembles and variable with the selected time frequency. Parameters ---------- var_name : str Climate variable name. data_plot : list[SimulationData and/or ObservationData] List of ensembles to compute time series for. time_freq : str Resampling frequency for averaging (default: '1YS'). Returns ------- time_series : list[xr.DataArray]) List of mean time series. """ # Validate inputs if not isinstance(data_plot, list) or len(data_plot) == 0 \ or not all(isinstance(ds, SimulationData) or isinstance(ds, ObservationData) for ds in data_plot): raise TypeError("'data_plot' must be a non-empty list of SimulationData instances.") for dataset in data_plot: if var_name not in dataset.data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the simulated dataset '{dataset.name}'. " f"Available variables: {list(dataset.data.data_vars.keys())}") # Compute mean time series for each dataset time_series = [] for dataset in [ds.data for ds in data_plot]: data_var = dataset[var_name] weights = statistics.area_weights(data_var) data_weighted = data_var.weighted(weights) data_area_mean = data_weighted.mean(['lat', 'lon']) data_time_mean = data_area_mean.resample(time=time_freq).mean().compute() time_series.append(data_time_mean) return time_series def _compute_abs_diff(self, var_name, data_plot): """ Compute absolute average difference between two simulation ensembles for the given variable and time range at the grid point level. Parameters ---------- var_name : str Climate variable name. data_plot : list[SimulationData] List of two simulation ensembles to compute the absolute difference for. Returns ------- data_diff : xr.DataArray Absolute difference between the two ensembles. """ # Validate inputs if not isinstance(data_plot, list) or len(data_plot) != 2 \ or not all(isinstance(ds, SimulationData) for ds in data_plot): raise TypeError("'data_plot' must be a non-empty list with two SimulationData instances.") for dataset in data_plot: if var_name not in dataset.data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the simulated dataset '{dataset.name}'. " f"Available variables: {list(dataset.data.data_vars.keys())}") # Validate time coordinates match data_sim_1 = data_plot[0].data.persist() data_sim_2 = data_plot[1].data.persist() if not data_sim_1.time.equals(data_sim_2.time): raise ValueError( f"Time coordinates of the two datasets do not match:\n" f" {data_plot[0].name} has time from {str(data_sim_1.time.min().values)[:19]} to {str(data_sim_1.time.max().values)[:19]}\n" f" {data_plot[1].name} has time from {str(data_sim_2.time.min().values)[:19]} to {str(data_sim_2.time.max().values)[:19]}" ) # Compute mean absolute difference data_sim_mean_1 = data_sim_1[var_name].mean(['time']) data_sim_mean_2 = data_sim_2[var_name].mean(['time']) if 'realization' in data_sim_1.coords: data_sim_mean_1 = data_sim_mean_1.mean(['realization']) if 'realization' in data_sim_2.coords: data_sim_mean_2 = data_sim_mean_2.mean(['realization']) if data_sim_mean_1.shape != data_sim_mean_2.shape: raise ValueError(f"Averaged data shapes of the two datasets do not match ({data_sim_mean_1.shape} vs {data_sim_mean_2.shape}).") data_diff = (data_sim_mean_1 - data_sim_mean_2).compute() if var_name in ['siconc', 'sos', 'tos']: data_diff = xr.where((np.isnan(data_diff)) | (data_diff==0), 10**-6, data_diff) # Values which are exactly 0 are painted in white, not with the corresponding colorbar color for 0 print(f"Computed absolute difference for variable '{var_name}' between '{data_plot[0].name}' and '{data_plot[1].name}'.", flush=True) return data_diff def _compute_eff_size_ens(self, var_name, data_plot): """ Compute average effect size (Cohen's d) between two simulation ensembles in parallel for the given variable at the grid point level. Parameters ---------- var_name : str Climate variable name. data_plot : list[SimulationData] List of two simulation ensembles to compute the absolute difference for. Returns ------- data_effect_size : xr.DataArray) Effect size between the two ensembles. """ # Validate inputs if not isinstance(data_plot, list) or len(data_plot) != 2 \ or not all(isinstance(ds, SimulationData) for ds in data_plot): raise TypeError("'data_plot' must be a non-empty list with two SimulationData instances.") for dataset in data_plot: if var_name not in dataset.data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the simulated dataset '{dataset.name}'. " f"Available variables: {list(dataset.data.data_vars.keys())}") if 'realization' not in dataset.data.coords: raise ValueError(f"Dataset '{dataset.name}' must contain a 'realization' coordinate for ensemble computations.") # Validate time coordinates match data_sim_1 = data_plot[0].data.persist() data_sim_2 = data_plot[1].data.persist() if not data_sim_1.time.equals(data_sim_2.time): raise ValueError( f"Time coordinates of the two datasets do not match:\n" f" {data_plot[0].name} has time from {str(data_sim_1.time.min().values)[:19]} to {str(data_sim_1.time.max().values)[:19]}\n" f" {data_plot[1].name} has time from {str(data_sim_2.time.min().values)[:19]} to {str(data_sim_2.time.max().values)[:19]}" ) # Prepare data data_sim_flat_1 = data_sim_1[var_name].mean('time').stack(ngrid = ['lat','lon']).compute() data_sim_flat_2 = data_sim_2[var_name].mean('time').stack(ngrid = ['lat','lon']).compute() if data_sim_flat_1.shape != data_sim_flat_2.shape: raise ValueError(f"Averaged data shapes of the two datasets do not match ({data_sim_flat_1.shape} vs {data_sim_flat_2.shape}).") # Compute effect sizes in parallel tasks = [(data_sim_flat_1.sel(ngrid=i).values, data_sim_flat_2.sel(ngrid=i).values) for i in data_sim_flat_1.ngrid] effect_size = np.empty(len(tasks)) with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers_grid, mp_context=mp.get_context("spawn")) as executor: for idx, value in enumerate(tqdm(executor.map(statistics.cp_effect_size_bootstrap, tasks), total=len(tasks), desc=f"Computing effect sizes for variable '{var_name}'", unit=" grid points")): effect_size[idx] = value # Convert to xarray.DataArray effect_size_reshaped = effect_size.reshape(data_sim_1.sizes['lat'], data_sim_1.sizes['lon']) data_effect_size = xr.DataArray( effect_size_reshaped, dims=["lat", "lon"], coords={ "lat": data_sim_1["lat"], "lon": data_sim_1["lon"] }, name=var_name ) print(f"Computed effect size for variable '{var_name}' between '{data_plot[0].name}' and '{data_plot[1].name}'.", flush=True) return data_effect_size def _compute_significant_diff(self, var_name, data_plot, alpha=0.05, stat=ttest_ind): """ Compute significant difference between two simulation ensembles in parallel for the given variable at the grid point level. Parameters ---------- var_name : str Climate variable name. data_plot : list[SimulationData] List of two simulation ensembles to compute the significant differences for. alpha : float Significance level for the statistical test (default: 0.05). stat : Callable Statistical test function to use (default: ttest_ind). Returns ------- significant : np.ndarray Boolean array indicating significant differences between the two ensembles. """ # Validate input if not isinstance(data_plot, list) or len(data_plot) != 2 \ or not all(isinstance(ds, SimulationData) for ds in data_plot): raise TypeError("'data_plot' must be a non-empty list with two SimulationData instances.") for dataset in data_plot: if var_name not in dataset.data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the simulated dataset '{dataset.name}'. " f"Available variables: {list(dataset.data.data_vars.keys())}") if 'realization' not in dataset.data.coords: raise ValueError(f"Dataset '{dataset.name}' must contain a 'realization' coordinate for ensemble computations.") if not isinstance(alpha, (int, float)): raise TypeError(f"The significance level 'alpha' must be numeric.") if not (0 <= alpha <= 1): raise ValueError(f"'alpha' must be between 0 and 1.") if not callable(stat): raise TypeError(f"'stat' must be callable.") data_sim_1 = data_plot[0].data.persist() data_sim_2 = data_plot[1].data.persist() if not data_sim_1.time.equals(data_sim_2.time): raise ValueError( f"Time coordinates of the two datasets do not match:\n" f" {data_plot[0].name} has time from {str(data_sim_1.time.min().values)[:19]} to {str(data_sim_1.time.max().values)[:19]}\n" f" {data_plot[1].name} has time from {str(data_sim_2.time.min().values)[:19]} to {str(data_sim_2.time.max().values)[:19]}" ) # Prepare data n_lats = data_sim_1.sizes['lat'] n_lons = data_sim_1.sizes['lon'] n_points = n_lats*n_lons n_realizations = data_sim_1.sizes['realization'] data_sim_flat_1 = data_sim_1[var_name].mean('time').data.compute().reshape((n_realizations,n_points)) data_sim_flat_2 = data_sim_2[var_name].mean('time').data.compute().reshape((n_realizations,n_points)) if data_sim_flat_1.shape != data_sim_flat_2.shape: raise ValueError(f"Averaged data shapes of the two datasets do not match ({data_sim_flat_1.shape} vs {data_sim_flat_2.shape}).") d1, d2 = (data_sim_flat_1, data_sim_flat_2) # Compute significant differences in parallel tasks = [(d1[:,n], d2[:,n], alpha, stat) for n in range(n_points)] significant = np.zeros(n_points) with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers_grid, mp_context=mp.get_context("spawn")) as executor: for idx, value in enumerate(executor.map(statistics.significant_diff, tasks)): significant[idx] = value significant_reshaped = significant.reshape((n_lats,n_lons)) print(f"Computed significant difference for variable '{var_name}' between '{data_plot[0].name}' and '{data_plot[1].name}'.", flush=True) return significant_reshaped def _compute_bias(self, var_name, data_plot): """ Compute average bias between a simulation ensemble and observations for the given variable at the grid point level. Parameters ---------- var_name : str Climate variable name. data_plot : list[SimulationData and ObservationData] List containing the simulation ensemble and the observational dataset. Returns ------- data_bias : xr.DataArray Bias between the simulation ensemble and observations. """ # Validate inputs if not isinstance(data_plot, list) or len(data_plot) != 2 \ or not isinstance(data_plot[0], SimulationData) \ or not isinstance(data_plot[1], ObservationData): raise TypeError("'data_plot' must be a list with a SimulationData instance and an ObservationData instance.") for dataset in data_plot: if var_name not in dataset.data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the dataset '{dataset.name}'. " f"Available variables: {list(dataset.data.data_vars.keys())}") # Validate time coordinates match data_sim = data_plot[0].data.persist() data_obs = data_plot[1].data.persist() if not data_sim.time.equals(data_obs.time): raise ValueError( f"Time coordinates of the two datasets do not match:\n" f" {data_plot[0].name} has time from {str(data_sim.time.min().values)[:19]} to {str(data_sim.time.max().values)[:19]}\n" f" {data_plot[1].name} has time from {str(data_obs.time.min().values)[:19]} to {str(data_obs.time.max().values)[:19]}" ) # Compute mean bias data_sim_mean = data_sim[var_name].mean(['time']) data_obs_mean = data_obs[var_name].mean(['time']) if 'realization' in data_sim.coords: data_sim_mean = data_sim_mean.mean(['realization']) if data_sim_mean.shape != data_obs_mean.shape: raise ValueError(f"Averaged data shapes of the two datasets do not match ({data_sim_mean.shape} vs {data_obs_mean.shape}).") data_bias = (data_sim_mean - data_obs_mean).compute() return data_bias
[docs] def add_datasets(self, datasets): """ Add new datasets to the DataDiagnostics object. Parameters ---------- datasets : SimulationData or Iterable[SimulationData]) Ensemble or list of ensembles containing simulation data and metadata to add. """ # Validate input if isinstance(datasets, SimulationData): datasets = [datasets] elif not isinstance(datasets, Iterable) or isinstance(datasets, (str, bytes)) \ or not all(isinstance(ds, SimulationData) for ds in datasets): raise TypeError("Input must be a SimulationData object or an iterable of SimulationData objects.") # Check for duplicate datasets for dataset in datasets: if not any(ds.name == dataset.name for ds in self.datasets): self.datasets.append(dataset) else: warnings.warn(f"\nDataset with name '{dataset.name}' already exists in the DataDiagnostics object. Skipping addition.") return
[docs] def time_series_plot(self, var_name, data_names=None, output_path=None, obs=False, obs_paths=None, obs_names=None, time_freq='annual', start_year=None, end_year=None, plot_ens=False): """ Generate time series plot for the given datasets and variable for the selected period and time frequency. Parameters ---------- var_name : str Climate variable name. data_names : str or list[str], optional Name or list of names of simulation ensembles to plot. If None, all datasets in the DataDiagnostics object are used. output_path : str, optional Path to save the time series plot. obs : bool If True, also plot observational data if available (default: False). obs_paths : str or list[str], optional Path to the observations database/s. obs_names : str or list[str], optional Name of the observational dataset/s. time_freq : str Resampling frequency (default: 'annual'). start_year : int Start year to plot. end_year : int End year to plot. plot_ens : bool Whether to plot individual ensemble members trajectories (default: False). """ # Validate inputs if data_names is None: data_plot = self.datasets data_names = [ds.name for ds in data_plot] elif isinstance(data_names, str): data_plot = [ds for ds in self.datasets if ds.name == data_names] if not data_plot: raise ValueError(f"Dataset with name '{data_names}' not found in the DataDiagnostics object.") data_names = [data_names] elif isinstance(data_names, list) and all(isinstance(name, str) for name in data_names): existing_names = [ds.name for ds in self.datasets] missing_names = [name for name in data_names if name not in existing_names] if missing_names: raise ValueError(f"The following dataset names were not found in the DataDiagnostics object: {missing_names}.") data_plot = [next(ds for ds in self.datasets if ds.name == name) for name in data_names] else: raise TypeError("'data_names' must be a string or a list of strings representing dataset names.") for dataset in data_plot: if var_name not in dataset.data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the simulated dataset '{dataset.name}'. " f"Available variables: {list(dataset.data.data_vars.keys())}") if obs: if obs_paths is None or obs_names is None: raise NotImplementedError('Automatic selection of observations is not implemented yet. Please provide at least one path and one name for the observations database.') if isinstance(obs_paths, str): if not isinstance(obs_names, str): raise TypeError("'obs_names' must be a string if 'obs_paths' is a string.") obs_paths = [obs_paths] obs_names = [obs_names] elif isinstance(obs_names, str): raise TypeError("'obs_paths' must be a string if 'obs_names' is a string.") if len(obs_paths) != len(obs_names): raise ValueError("'obs_paths' and 'obs_names' must have the same length.") data_obs = [ ObservationData(path, data_plot[0].data[[var_name]], name) for path, name in zip(obs_paths, obs_names) ] data_plot.extend(data_obs) data_names.extend(obs_names) # Validate year range for dataset in data_plot: start_year, end_year = data_general.validate_year_range(dataset, start_year, end_year, process_name='time series') data_plot_filtered = copy.deepcopy(data_plot) for i, dataset in enumerate(data_plot): data_plot_filtered[i].data = dataset.data.sel(time=slice(str(start_year), str(end_year))) if time_freq == 'annual': time_freq_unit = '1YS' elif time_freq == 'monthly': time_freq_unit = '1MS' elif time_freq == 'daily': time_freq_unit = '1D' else: raise ValueError("Incorrect time frequency, supported values are 'annual', 'monthly' and 'daily'") # Compute and plot time series time_series = self._compute_time_series(var_name, data_plot_filtered, time_freq_unit) time_series_plot, _ = plot.plot_time_series(time_series, title=f"{time_freq.capitalize()} mean time series of {self.variables[var_name]['long_name']}", y_label=f"{var_name} ({self.variables[var_name]['units']})", labels=data_names, time_freq=time_freq, start_year=start_year, end_year=end_year, plot_ens=plot_ens) # Save plot to path if given data_names_str = "_".join([name.replace(' ', '-') for name in data_names]) ens_suffix = "_all_members" if plot_ens else "" plot.save_or_show_plot(time_series_plot, output_path, plot_filename=f"{time_freq}_time_series_{var_name}_{data_names_str}_{start_year}-{end_year}{ens_suffix}", plot_name=f"{time_freq.capitalize()} mean time series plot") return
[docs] def abs_diff_plot(self, var_name, data_names=None, output_path=None, start_year=None, end_year=None,clon=0): """ Generate absolute difference plot for the given datasets and variable. Parameters ---------- var_name : str Climate variable name. data_names : list[str], optional List of names of two simulation ensembles to compare. If None, the first two datasets in the diagnostics object are used. output_path : str, optional Path to save the spatial plots. start_year : int Start year to plot. end_year : int End year to plot. clon : int Central longitude for the spatial map (default: 0). """ # Validate inputs if data_names is None: if len(self.datasets) < 2: raise ValueError("At least two datasets are required for spatial plots. Please add more datasets.") data_plot = [self.datasets[0], self.datasets[1]] data_names = [ds.name for ds in data_plot] elif isinstance(data_names, list) and len(data_names) == 2 \ and all(isinstance(name, str) for name in data_names): existing_names = [ds.name for ds in self.datasets] missing_names = [name for name in data_names if name not in existing_names] if missing_names: raise ValueError(f"The following dataset names were not found in the DataDiagnostics object: {missing_names}.") data_plot = [next(ds for ds in self.datasets if ds.name == name) for name in data_names] else: raise TypeError("'data_names' must be a list of two strings representing dataset names.") for dataset in data_plot: if var_name not in dataset.data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the simulated dataset {dataset.name}. " f"Available variables: {list(dataset.data.data_vars.keys())}") # Validate year range for dataset in data_plot: start_year, end_year = data_general.validate_year_range(dataset, start_year, end_year, process_name='absolute difference') data_plot_filtered = copy.deepcopy(data_plot) for i, dataset in enumerate(data_plot): data_plot_filtered[i].data = dataset.data.sel(time=slice(str(start_year), str(end_year))) # Compute and plot absolute difference abs_diff = self._compute_abs_diff(var_name, data_plot_filtered) limit = np.max(np.abs(abs_diff.values)) levels = np.linspace(-limit, limit, 13) year_range = f"{start_year}-{end_year}" abs_diff_plot, _ = plot.plot_spatial(abs_diff, clon=clon, title=f"Difference in {self.variables[var_name]['long_name']} for {year_range} ({data_names[0]} - {data_names[1]})", cb_label=f"difference in {var_name} ({self.variables[var_name]['units']})", cmap=cmocean.cm.thermal, levels=levels) # Save plot to path if given data_names_str = "_".join([name.replace(' ', '-') for name in data_names]) plot.save_or_show_plot(abs_diff_plot, output_path, plot_filename=f"abs_diff_{var_name}_{data_names_str}_{year_range}_clon_{clon}", plot_name="Absolute difference plot") return
[docs] def eff_size_plot(self, var_name, data_names=None, output_path=None, start_year=None, end_year=None, clon=0, alpha=0.05, stat=ttest_ind): """ Generate effect size plot for the given datasets and variable marking grid points with statistically significant differences. Parameters ---------- var_name : str Climate variable name. data_names : list[str], optional List of names of two simulation ensembles to compare. If None, the first two datasets in the diagnostics object are used. output_path : str, optional Path to save the spatial plots. start_year : int Start year to plot. end_year : int End year to plot. clon : int Central longitude for the spatial map (default: 0). alpha : float Significance level for the statistical test (default: 0.05). stat : Callable Statistical test function to use for significance testing (default: ttest_ind). """ # Validate inputs if data_names is None: if len(self.datasets) < 2: raise ValueError("At least two datasets are required for spatial plots. Please add more datasets.") data_plot = [self.datasets[0], self.datasets[1]] data_names = [ds.name for ds in data_plot] elif isinstance(data_names, list) and len(data_names) == 2 \ and all(isinstance(name, str) for name in data_names): existing_names = [ds.name for ds in self.datasets] missing_names = [name for name in data_names if name not in existing_names] if missing_names: raise ValueError(f"The following dataset names were not found in the DataDiagnostics object: {missing_names}.") data_plot = [next(ds for ds in self.datasets if ds.name == name) for name in data_names] else: raise TypeError("'data_names' must be a list of two strings representing dataset names.") for dataset in data_plot: if var_name not in dataset.data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the simulated dataset {dataset.name}. " f"Available variables: {list(dataset.data.data_vars.keys())}") if 'realization' not in dataset.data.coords: raise ValueError(f"Dataset '{dataset.name}' must contain a 'realization' coordinate for ensemble computations.") if not isinstance(alpha, (int, float)) or not (0 <= alpha <= 1): raise TypeError(f"The significance level 'alpha' must be a numeric value between 0 and 1.") if not callable(stat): raise TypeError(f"'stat' must be callable.") # Validate year range for dataset in data_plot: start_year, end_year = data_general.validate_year_range(dataset, start_year, end_year, process_name='effect size') data_plot_filtered = copy.deepcopy(data_plot) for i, dataset in enumerate(data_plot): data_plot_filtered[i].data = dataset.data.sel(time=slice(str(start_year), str(end_year))) # Compute and plot effect size with significant differences eff_size = self._compute_eff_size_ens(var_name, data_plot_filtered) significant = self._compute_significant_diff(var_name, data_plot_filtered, alpha, stat) levels = [-2,-1.2,-0.8,-0.5,-0.2,-0.01,0.01,0.2,0.5,0.8,1.2,2.0] # Use Cohen's limits for effect size year_range = f"{start_year}-{end_year}" eff_size_plot, _ = plot.plot_spatial(eff_size, clon=clon, title=f"Effect size ($d$) for {self.variables[var_name]['long_name']} for {year_range} ({data_names[0]} - {data_names[1]})", cb_label=f"$d$ for {var_name} (-)", cmap=cmocean.cm.diff, levels=levels, significant=significant) # Save plot to path if given data_names_str = "_".join([name.replace(' ', '-') for name in data_names]) plot.save_or_show_plot(eff_size_plot, output_path, plot_filename=f"eff_size_{var_name}_{data_names_str}_{year_range}_clon_{clon}", plot_name="Effect size plot") return
[docs] def bias_plot(self, var_name, data_name=None, output_path=None, obs_path=None, obs_name=None, start_year=None, end_year=None, clon=0): """ Generate bias plot for the given dataset and variable comparing with observations. Parameters ---------- var_name : str Climate variable name. data_name : str, optional Name of the simulation ensemble to plot. If None, the first dataset in the DataDiagnostics object is used. output_path : str, optional Path to save the spatial plot. obs_path : str Path to the observations database. obs_name : str Name of the observational dataset. start_year : int Start year to plot. end_year : int End year to plot. clon : int Central longitude for the spatial map (default: 0). """ # Validate inputs if data_name is None: if len(self.datasets) < 1: raise ValueError("At least one dataset is required for bias plots. Please add a dataset.") data_plot = [self.datasets[0]] data_name = data_plot[0].name elif isinstance(data_name, str): data_plot = [ds for ds in self.datasets if ds.name == data_name] if not data_plot: raise ValueError(f"Dataset with name '{data_name}' not found in the DataDiagnostics object.") if len(data_plot) > 1: raise ValueError(f"Multiple datasets with name '{data_name}' found in the DataDiagnostics object.") else: raise TypeError("'data_name' must be a string representing a dataset name.") if var_name not in data_plot[0].data.data_vars: raise ValueError(f"Variable '{var_name}' not found in the simulated dataset '{data_plot[0].name}'. " f"Available variables: {list(data_plot[0].data.data_vars.keys())}") if obs_path is None or obs_name is None: raise ValueError('Automatic selection of observations is not implemented yet. ' 'Please provide a path and a name for the observations database.') elif not isinstance(obs_path, str) or not isinstance(obs_name, str): raise TypeError("'obs_path' and 'obs_name' must be strings representing the observations database path and name, respectively.") else: data_obs = ObservationData(obs_path, data_plot[0].data[[var_name]], obs_name) data_plot.append(data_obs) data_names = [data_name, obs_name] # Validate year range for dataset in data_plot: start_year, end_year = data_general.validate_year_range(dataset, start_year, end_year, process_name='bias') data_plot_filtered = copy.deepcopy(data_plot) for i, dataset in enumerate(data_plot): data_plot_filtered[i].data = dataset.data.sel(time=slice(str(start_year), str(end_year))) # Compute and plot bias bias = self._compute_bias(var_name, data_plot_filtered) limit = np.max(np.abs(bias.values)) levels = np.linspace(-limit, limit, 13) colors = ("RedGreen", ['tab:red', 'white', 'tab:green']) cmap = LinearSegmentedColormap.from_list(*colors) year_range = f"{start_year}-{end_year}" bias_plot, _ = plot.plot_spatial(bias, clon=clon, title=f"Bias in {self.variables[var_name]['long_name']} for {year_range} ({data_names[0]} - {data_names[1]})", cb_label=f"bias in {var_name} ({self.variables[var_name]['units']})", cmap=cmap, levels=levels) # Save plot to path if given data_names_str = "_".join([name.replace(' ', '-') for name in data_names]) plot.save_or_show_plot(bias_plot, output_path, plot_filename=f"bias_{var_name}_{data_names_str}_{year_range}_clon_{clon}", plot_name="Bias plot") return
# Unused, divided into two separate methods above (kept for reference) # def spatial_plots(self, var_name, data_names=None, output_path=None, clon=0, alpha=0.05, stat=ttest_ind): # """ # Generate absolute difference and effect size plots for the given datasets and variable. # Parameters # ---------- # var_name : str # Climate variable name. # data_names : list[str], optional # List of names of two simulation ensembles to compare. If None, the first two datasets # in the diagnostics object are used. # output_path : str, optional # Path to save the spatial plots. # clon : int # Central longitude for the spatial maps. # alpha : float # Significance level for the statistical test (default: 0.05). # stat : Callable # Statistical test function to use for significance testing (default: ttest_ind). # """ # # Validate inputs # if data_names is None: # if len(self.datasets) < 2: # raise ValueError("At least two datasets are required for spatial plots. Please add more datasets.") # data_plot = [self.datasets[0], self.datasets[1]] # data_names = [ds.name for ds in data_plot] # elif isinstance(data_names, list) and len(data_names) == 2 \ # and all(isinstance(name, str) for name in data_names): # existing_names = [ds.name for ds in self.datasets] # missing_names = [name for name in data_names if name not in existing_names] # if missing_names: # raise ValueError(f"The following dataset names were not found in the DataDiagnostics object: {missing_names}.") # data_plot = [next(ds for ds in self.datasets if ds.name == name) for name in data_names] # else: # raise TypeError("'data_names' must be a list of two strings representing dataset names.") # for dataset in data_plot: # if var_name not in dataset.data.data_vars: # raise ValueError(f"Variable '{var_name}' not found in the simulated dataset {dataset.name}. " # f"Available variables: {list(dataset.data.data_vars.keys())}") # if 'realization' not in dataset.data.coords: # raise ValueError(f"Dataset '{dataset.name}' must contain a 'realization' coordinate for ensemble computations.") # if not isinstance(alpha, (int, float)) or not (0 <= alpha <= 1): # raise TypeError(f"The significance level 'alpha' must be a numeric value between 0 and 1.") # if not callable(stat): # raise TypeError(f"'stat' must be callable.") # # Prepare output path if given # if output_path is not None: # output_path = Path(output_path) # if output_path.suffix != '': # raise ValueError("Output path must be a directory, not a file path, as two output files will be created.") # output_path.mkdir(parents=True, exist_ok=True) # data_names_str = "_".join([name.replace(' ', '-') for name in data_names]) # abs_diff_path = output_path / f"abs_diff_{var_name}_{data_names_str}.png" # eff_size_path = output_path / f"eff_size_{var_name}_{data_names_str}.png" # # Compute and plot absolute difference # abs_diff = self._compute_abs_diff(var_name, data_plot) # limit = np.max(np.abs(abs_diff.values)) # levels = np.linspace(-limit, limit, 13) # abs_diff_plot, _ = plot.plot_spatial(abs_diff, title=f"Difference in {self.variables[var_name]['long_name']} ({data_plot[0].name} - {data_plot[1].name})", # cb_label=f"difference in {var_name} ({self.variables[var_name]['units']})", cmap=cmocean.cm.thermal, levels=levels) # if output_path is None: # plt.show() # print("Absolute difference plot created and displayed.\n", flush=True) # else: # abs_diff_plot.savefig(abs_diff_path, bbox_inches='tight', dpi=150) # print(f"Absolute difference plot created and saved to '{abs_diff_path}'.\n", flush=True) # # Compute and plot effect size with significant differences # eff_size = self._compute_eff_size_ens(var_name, data_plot) # significant = self._compute_significant_diff(var_name, data_plot, alpha, stat) # levels = [-2,-1.2,-0.8,-0.5,-0.2,-0.01,0.01,0.2,0.5,0.8,1.2,2.0] # Use Cohen's limits for effect size # eff_size_plot, _ = plot.plot_spatial(eff_size, clon=clon, title=f"Cohen's effect size ($d$) for {self.variables[var_name]['long_name']} ({data_plot[0].name} - {data_plot[1].name})", # cb_label=f"$d$ for {var_name} (-)", cmap=cmocean.cm.diff, levels=levels, significant=significant) # if output_path is None: # plt.show() # print("Effect size plot created and displayed.", flush=True) # else: # eff_size_plot.savefig(eff_size_path, bbox_inches='tight', dpi=150) # print(f"Effect size plot created and saved to '{eff_size_path}'.\n", flush=True) # return