import argparse import dataclasses import functools import json import logging import math import os import re import time from datetime import datetime, timedelta from typing import Optional import cartopy.crs as ccrs import faiss import jax import matplotlib.pyplot as plt import matplotlib import numpy as np, sys import cartopy.crs as ccrs import faiss import haiku as hk import matplotlib.animation as animation import jax import matplotlib.pyplot as plt import numpy as np, sys import tqdm import xarray as xr from graphcast import ( autoregressive, casting, checkpoint, data_utils, graphcast, normalization, rollout, xarray_jax, xarray_tree, ) import matplotlib.pyplot as plt import xarray as xr import cartopy.crs as ccrs import cartopy.feature as cfeature from graphcast.graphcast import ModelConfig, TaskConfig import requests import tqdm import xarray # Define constants GLOBAL_PATH = "/scratch/gilbreth/gupt1075/graphcast/" # Create logs directory if it doesn't exist os.makedirs(f"{GLOBAL_PATH}/logs/", exist_ok=True) # Set up logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, filename=f"{GLOBAL_PATH}/logs/plot_g_at_surface_{datetime.now().strftime('%B_%d_')}_modular_geopotential_level_1000__13levels__2.log", ) def parse_file_parts(file_name): """ Parse the file name to extract relevant parts. :param file_name: The name of the file to parse. :return: A dictionary containing the parsed parts. """ return dict(part.split("-", 1) for part in file_name.split("_")) def select_data( data: xarray.Dataset, variable: str, level: Optional[int] = None, max_steps: Optional[int] = None, ) -> xarray.Dataset: """ Select specific data from the dataset based on the given criteria. :param data: The dataset to select from. :param variable: The variable to select. :param level: The level to select (optional). :param max_steps: The maximum steps to select (optional). :return: The selected dataset. """ data = data[variable] if "batch" in data.dims: data = data.isel(batch=0) if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]: data = data.isel(time=range(0, max_steps)) if level is not None and "level" in data.coords: logging.warning(f" *** level: {level}, data_coords {data.coords} ") data = data.sel(level=level, method="nearest",) return data def scale_data( data: xarray.Dataset, center: Optional[float] = None, robust: bool = False, ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]: """ Scale the data for plotting. :param data: The dataset to scale. :param center: The center value for scaling (optional). :param robust: Whether to use robust scaling (optional). :return: A tuple containing the scaled data, normalization, and colormap. """ vmin = np.nanpercentile(data, (2 if robust else 0)) vmax = np.nanpercentile(data, (98 if robust else 100)) if center is not None: diff = max(vmax - center, center - vmin) vmin = center - diff vmax = center + diff return ( data, matplotlib.colors.Normalize(vmin, vmax), ("RdBu_r" if center is not None else "viridis"), ) def save_animation_as_gif(ani, filename, fps=1): """ Save a Matplotlib animation as a GIF to a local directory. Parameters: ani (matplotlib.animation.FuncAnimation): The animation to save. filename (str): The filename to save the GIF as. fps (int, optional): The frames per second for the GIF. Defaults to 10. """ ani.save(filename, writer='pillow', fps=fps) def plot_data( data: dict[str, xarray.Dataset], variable_name: str, fig_title: str, plot_size: float = 5, robust: bool = False, cols: int = 4, ) -> None: """ Plot the data. :param data: A dictionary of datasets to plot. :param fig_title: The title of the figure. :param plot_size: The size of the plot (optional). :param robust: Whether to use robust scaling (optional). :param cols: The number of columns (optional). """ logging.info(f"Plotting data with title: {fig_title} and type data {type(data)}") for d in data.values(): logging.warning(f" type_d: {type(d)} {len(d)} { type(d[0])} size data[0]: {d[0].dims} {type(d[1])} {type(d[2])} \n ") logging.warning(f"\n *******************************\n") first_data = next(iter(data.values()))[0] logging.info( f"*********** \nPlotting {fig_title} with {type(first_data)} first_data_length: {len(first_data)} \n " ) # first_data = next(iter(data.values())) logging.info(f" *** First_data {type(first_data)} ") max_steps = first_data.sizes.get("time", 1) assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values()) cols = min(cols, len(data)) rows = math.ceil(len(data) / cols) figure = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows)) figure.suptitle(fig_title, fontsize=16) figure.subplots_adjust(wspace=0, hspace=0) figure.tight_layout() images = [] for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()): ax = figure.add_subplot(rows, cols, i + 1, projection=ccrs.Mercator(central_longitude=0.0, min_latitude=-80.0, max_latitude=84.0, globe=None, latitude_true_scale=0.0),) # Added features to the map ax.add_feature(cfeature.LAND) ax.add_feature(cfeature.OCEAN) ax.add_feature(cfeature.COASTLINE) ax.add_feature(cfeature.BORDERS, linestyle=':') ax.set_xticks([]) ax.set_yticks([]) ax.set_title(title) im = ax.imshow( plot_data.isel(time=0, missing_dims="ignore"), norm=norm, transform=ccrs.Mercator(central_longitude=0.0, min_latitude=-80.0, max_latitude=84.0, globe=None, latitude_true_scale=0.0), origin="lower", cmap=cmap, ) plt.colorbar( mappable=im, ax=ax, orientation="vertical", pad=0.02, aspect=16, shrink=0.75, cmap=cmap, extend=("both" if robust else "neither"), ) images.append(im) def update(frame): if "time" in first_data.dims: td = timedelta( microseconds=first_data["time"][frame].item() / 1000 ) figure.suptitle(f"{fig_title}, {td}", fontsize=16) else: figure.suptitle(fig_title, fontsize=16) for im, (plot_data, norm, cmap) in zip(images, data.values()): im.set_data(plot_data.isel(time=frame, missing_dims="ignore")) ani = matplotlib.animation.FuncAnimation( fig=figure, func=update, frames=max_steps, interval=250 ) gif_filename = f"./output_animation/{variable_name}_{datetime.now().strftime('%B_%d_')}_{fig_title}_6hr_1fps.gif" logging.warning(f" saved_animation: {gif_filename} ") save_animation_as_gif(ani, gif_filename) plt.show() def data_valid_for_model( file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig, ) -> bool: """ Check if the data is valid for the given model configuration. :param file_name: The name of the file to check. :param model_config: The model configuration. :param task_config: The task configuration. :return: Whether the data is valid. """ logging.info(f"Checking data validity for file: {file_name}") file_parts = parse_file_parts(file_name.removesuffix(".nc")) return ( model_config.resolution in (0, float(file_parts["res"])) and len(task_config.pressure_levels) == int(file_parts["levels"]) and ( ( "total_precipitation_6hr" in task_config.input_variables and file_parts["source"] in ("era5", "fake") ) or ( "total_precipitation_6hr" not in task_config.input_variables and file_parts["source"] in ("hres", "fake") ) ) ) def construct_wrapped_graphcast( model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig ) -> graphcast.GraphCast: """ Construct and wrap the GraphCast predictor. :param model_config: The model configuration. :param task_config: The task configuration. :return: The wrapped GraphCast predictor. """ logging.info("Constructing wrapped GraphCast predictor") # Deeper one-step predictor. predictor = graphcast.GraphCast(model_config, task_config) # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to/from float32 to/from BFloat16. predictor = casting.Bfloat16Cast(predictor) # Load normalization data with open("./stats/diffs_stddev_by_level.nc", "rb") as f: diffs_stddev_by_level = xr.load_dataset(f).compute() with open("./stats/mean_by_level.nc", "rb") as f: mean_by_level = xr.load_dataset(f).compute() with open("./stats/stddev_by_level.nc", "rb") as f: stddev_by_level = xr.load_dataset(f).compute() # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from BFloat16 happens after applying normalization to the inputs/targets. predictor = normalization.InputsAndResiduals( predictor, diffs_stddev_by_level=diffs_stddev_by_level, mean_by_level=mean_by_level, stddev_by_level=stddev_by_level, ) # Wraps everything so the one-step model can produce trajectories. predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True) return predictor @hk.transform_with_state def run_forward( model_config, task_config, inputs, targets_template, forcings ): """ Run the forward pass of the model. :param model_config: The model configuration. :param task_config: The task configuration. :param inputs: The input data. :param targets_template: The target template. :param forcings: The forcing data. :return: The output of the forward pass. """ logging.info("Running forward pass") predictor = construct_wrapped_graphcast(model_config, task_config) return predictor(inputs, targets_template=targets_template, forcings=forcings) @hk.transform_with_state def loss_fn( model_config, task_config, inputs, targets, forcings ): """ Compute the loss of the model. :param model_config: The model configuration. :param task_config: The task configuration. :param inputs: The input data. :param targets: The target data. :param forcings: The forcing data. :return: The loss and diagnostics. """ logging.info("Computing loss") predictor = construct_wrapped_graphcast(model_config, task_config) loss, diagnostics = predictor.loss(inputs, targets, forcings) return xarray_tree.map_structure( lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True), (loss, diagnostics), ) def grads_fn( params, state, model_config, task_config, inputs, targets, forcings ): """ Compute the gradients of the model. :param params: The model parameters. :param state: The model state. :param model_config: The model configuration. :param task_config: The task configuration. :param inputs: The input data. :param targets: The target data. :param forcings: The forcing data. :return: The loss, diagnostics, next state, and gradients. """ logging.info("Computing gradients") def _aux(params, state, i, t, f): (loss, diagnostics), next_state = loss_fn.apply( params, state, jax.random.PRNGKey(0), model_config, task_config, i, t, f ) return loss, (diagnostics, next_state) (loss, (diagnostics, next_state)), grads = jax.value_and_grad(_aux, has_aux=True)( params, state, inputs, targets, forcings ) return loss, diagnostics, next_state, grads # Jax doesn't seem to like passing configs as args through the jit. Passing it in via partial (instead of capture by closure) forces jax to invalidate the jit cache if you change configs. def with_configs(fn): return functools.partial(fn, model_config=model_config, task_config=task_config) # Always pass params and state, so the usage below are simpler def with_params(fn): return functools.partial(fn, params=params, state=state) # Our models aren't stateful, so the state is always empty, so just return the predictions. This is required by our rollout code, and generally simpler. def drop_state(fn): return lambda **kw: fn(**kw) """ geopotential_at_surface land_sea_mask 2m_temperature mean_sea_level_pressure 10m_v_component_of_wind 10m_u_component_of_wind total_precipitation_6hr toa_incident_solar_radiation temperature geopotential u_component_of_wind v_component_of_wind vertical_velocity specific_humidity """ def plot_geopotential_at_surface(data: xr.Dataset): """ Plot the geopotential_at_surface variable with land and ocean features. :param data: The xarray Dataset containing the geopotential_at_surface variable. """ # Check if the dataset contains 'geopotential_at_surface' if 'geopotential_at_surface' not in data: raise ValueError("Dataset does not contain 'geopotential_at_surface' variable") # Select the 'geopotential_at_surface' variable geopotential = data['geopotential_at_surface'] # Create a figure with a specific projection fig, ax = plt.subplots(1, 1, figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()}) # Plot the geopotential data geopotential.plot(ax=ax, transform=ccrs.PlateCarree(), cmap='viridis', cbar_kwargs={'shrink': 0.5}) # Add geographical features ax.add_feature(cfeature.LAND, zorder=1, edgecolor='black') ax.add_feature(cfeature.OCEAN, zorder=0) ax.add_feature(cfeature.COASTLINE) ax.add_feature(cfeature.BORDERS, linestyle=':') ax.add_feature(cfeature.LAKES, alpha=0.5) # Set title and gridlines ax.set_title('Geopotential at Surface with Land and Ocean Features') ax.gridlines(draw_labels=True) plt.show() def plot_all_variables(data: xr.Dataset, plot_size: float = 5): """ Plot all input variables from an xarray Dataset. :param data: The xarray Dataset containing the variables. :param plot_size: The size of each subplot (optional). """ variable_names = list(data.data_vars) num_vars = len(variable_names) cols = min(4, num_vars) rows = (num_vars + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(plot_size * cols, plot_size * rows), subplot_kw={'projection': ccrs.PlateCarree()}) axes = axes.flatten() for i, var_name in enumerate(variable_names): if var_name in data: data[var_name].isel(time=0).plot(ax=axes[i], transform=ccrs.PlateCarree()) axes[i].set_title(var_name) # Hide any remaining empty subplots for j in range(i + 1, len(axes)): fig.delaxes(axes[j]) plt.tight_layout() plt.show() def safe_select_data(data: xr.Dataset, time_index: int): # Check if 'time' is a dimension if 'time' in data.dims: max_time_index = data.sizes['time'] - 1 if time_index > max_time_index: print(f"Warning: time_index {time_index} is out of bounds. Using {max_time_index} instead.") time_index = max_time_index # Safely select data selected_data = data.isel(time=time_index) return selected_data else: raise ValueError("The dataset does not have a 'time' dimension.") def plot_geopotential_at_surface(data: xr.Dataset): """Plot the geopotential_at_surface variable with land and ocean features.""" # Check if the dataset contains 'geopotential_at_surface' if 'geopotential_at_surface' not in data: raise ValueError("Dataset does not contain 'geopotential_at_surface' variable") # Select the 'geopotential_at_surface' variable geopotential = data['geopotential_at_surface'] logging.warning(f" geopotential_dimensions{ type(geopotential)} shape: {geopotential.shape} ") # Create a figure with a specific projection fig, ax = plt.subplots(1, 1, figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()}) # Plot the geopotential data geopotential.plot(ax=ax, transform=ccrs.PlateCarree(), cmap='viridis', cbar_kwargs={'shrink': 0.5}) # Add geographical features ax.add_feature(cfeature.LAND, zorder=1, edgecolor='black') ax.add_feature(cfeature.OCEAN, zorder=0) ax.add_feature(cfeature.COASTLINE) ax.add_feature(cfeature.BORDERS, linestyle=':') ax.add_feature(cfeature.LAKES, alpha=0.5) # Set title and gridlines ax.set_title('Geopotential at Surface with Land and Ocean Features') ax.gridlines(draw_labels=True) plt.show() # # Example usage # dataset = xr.open_dataset('path_to_your_dataset.nc') # plot_geopotential_at_surface(dataset) def select_data(data: xr.Dataset, variable: str, level: Optional[int] = None, time_step: Optional[int] = None) -> xr.DataArray: """ Select specific data from the dataset based on the given criteria. :param data: The dataset to select from. :param variable: The variable to select. :param level: The level to select (optional). :param time_step: The time step to select (optional). :return: The selected data array. """ selected_data = data[variable] if "batch" in selected_data.dims: selected_data = selected_data.isel(batch=0) if time_step is not None and "time" in selected_data.dims: selected_data = selected_data.isel(time=time_step) if level is not None and "level" in selected_data.coords: selected_data = selected_data.sel(level=level, method="nearest") return selected_data def plot_variable_on_world_map(data, variable_name, lat_range=None, lon_range=None, level=None, time_index=0, output_dir="/scratch/gilbreth/gupt1075/graphcast/input_data_plots/"): """ Plots a specified variable from an xarray dataset on a world map within given latitude and longitude ranges, and at a specified level and time index if applicable. Saves the plot to the specified output directory. Parameters: - data: xarray.Dataset containing the data variables. - variable_name: str, name of the variable to plot. - lat_range: tuple of (min_lat, max_lat), latitude range to plot. - lon_range: tuple of (min_lon, max_lon), longitude range to plot. - level: int or None, specific level to plot if the variable has a level dimension. - time_index: int, index of the time dimension to plot. - output_dir: str, directory where the plot will be saved. Returns: - None """ # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) # Select the variable var = data[variable_name] # logging.warning(f" min_value {var.min(dim=('lat','lon'))} max values: {var.max(dim=('lat','lon'))} \n\n") # Slice latitude and longitude range if specified if lat_range is not None: var = var.sel(lat=slice(lat_range[0], lat_range[1])) if lon_range is not None: var = var.sel(lon=slice(lon_range[0], lon_range[1])) # Handle levels if applicable if 'level' in var.dims and level is not None: var = var.sel(level=level) # Select the specific time index if 'time' in var.dims: var = var.isel(time=time_index) # Plotting logging.warning(f" min_value {var.min()} max values: {var.max() } \n\n") fig = plt.figure(figsize=(12, 8)) ax = plt.axes(projection=ccrs.PlateCarree()) # Add features to the map ax.add_feature(cfeature.LAND) ax.add_feature(cfeature.OCEAN) ax.add_feature(cfeature.COASTLINE) ax.add_feature(cfeature.BORDERS, linestyle=':') # Plot the data if 'lat' in var.dims and 'lon' in var.dims: var.plot(ax=ax, transform=ccrs.PlateCarree(), cmap = 'RdBu', vmin = -1600, vmax = 56000, cbar_kwargs={'shrink': 0.5} ) plt.title(f"{variable_name} at time index {time_index}") # Save the plot to the specified directory file_path = os.path.join(output_dir, f"{variable_name}_time{time_index}.png") plt.savefig(file_path) # Log saving information print(f"Plot saved to {file_path}") def map_geopotential_indices(D1, D2): geopotential_at_surface = D1['geopotential_at_surface'].values geopotential = D2['geopotential'].values # Initialize an array to store mapped values with the dimensions of D2 mapped_values = np.empty((1, 2, 181, 360), dtype=geopotential.dtype) # Iterate over each point in (181, 360) for i in range(181): for j in range(360): surface_value = geopotential_at_surface[i, j] # Find indices across the 13 levels where the value matches idx = np.argwhere(geopotential[0, :, :, i, j] == surface_value) if idx.size > 0: # Set values across all found indices in the mapped_values for k in range(13): if k < len(idx): # If a match is found at level k mapped_values[0, 0, i, j] = geopotential[0, idx[k][0], k, i, j] else: # If no match is found at level k mapped_values[0, 0, i, j] = np.nan else: # Handle no match case with NaN or some default value mapped_values[0, 0, i, j] = np.nan return mapped_values def log_xarray_details(ds): # Log general dataset information logging.warning(f"Dataset: {ds}") logging.warning(f"Dimensions: {ds.dims}") logging.warning(f"Coordinates: {ds.coords}") logging.warning(f"Data Variables: {list(ds.data_vars.keys())} Number of Data variables : {len(list(ds.data_vars.keys())) } ") # Iterate through all data variables in the dataset for var_name, var in ds.data_vars.items(): logging.warning(f"Variable: {var_name}") logging.warning(f"Shape: {var.shape}") logging.warning(f"Dimensions: {var.dims}") logging.warning(f"Data Type: {var.dtype}") logging.warning("Attributes:") for attr_name, attr_value in var.attrs.items(): logging.warning(f" {attr_name}: {attr_value}") logging.warning("Coordinates:") for coord_name, coord in var.coords.items(): logging.warning(f" {coord_name}: {coord}") logging.warning("") # Empty line for better readability def scale(data: xarray.Dataset, center: Optional[float] = None, robust: bool = False,) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]: vmin = np.nanpercentile(data, (2 if robust else 0)) vmax = np.nanpercentile(data, (98 if robust else 100)) if center is not None: diff = max(vmax - center, center - vmin) vmin = center - diff vmax = center + diff return (data, matplotlib.colors.Normalize(vmin, vmax), ("RdBu_r" if center is not None else "viridis")) def truncate_path(path, max_length=255): if len(path.encode('utf-8')) > max_length: # Truncate the path to fit within the maximum length truncated_path = path[:max_length - len('.tmp')].encode('utf-8').decode('utf-8') + '.tmp' return truncated_path else: return path if __name__ == "__main__": parser = argparse.ArgumentParser(description="GraphCast Model Runner") parser.add_argument( "--dataset", # default="/scratch/gilbreth/gupt1075/graphcast/dataset/source-era5_date-2022-01-01_res-0.25_levels-37_steps-12.nc", default="/scratch/gilbreth/gupt1075/graphcast/dataset/source-era5_date-2022-01-01_res-1.0_levels-13_steps-04.nc", help="Path to the dataset file", ) parser.add_argument( "--params", #default="/scratch/gilbreth/gupt1075/graphcast/params/GraphCast_ERA5_1979-2017_pressure_levels_37.npz", default = "/scratch/gilbreth/gupt1075/graphcast/params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz", help="Path to the parameters file", ) parser.add_argument( "--source", choices=["Random", "Checkpoint"], default="Checkpoint", help="Source of the model parameters", ) parser.add_argument( "--mesh-size", type=int, default=4, help="Mesh size for the random model" ) parser.add_argument( "--gnn-msg-steps", type=int, default=4, help="GNN message steps for the random model", ) parser.add_argument( "--latent-size", type=int, default=32, help="Latent size for the random model" ) parser.add_argument( "--pressure-levels", type=int, default=13, help="Pressure levels for the random model", ) parser.add_argument("--variable", default="geopotential", help="Variable to plot") parser.add_argument("--level", type=int, default=1000, help="Level to plot") parser.add_argument( "--robust", action="store_true", help="Use robust scaling for plotting" ) parser.add_argument("--max-steps", type=int, help="Maximum steps to plot") args = parser.parse_args() dataset_file_value = args.dataset params_file_value = args.params source = args.source if source == "Random": params = None # Filled in below state = {} model_config = graphcast.ModelConfig( resolution=0, mesh_size=args.mesh_size, latent_size=args.latent_size, gnn_msg_steps=args.gnn_msg_steps, hidden_layers=1, radius_query_fraction_edge_length=0.6, ) task_config = graphcast.TaskConfig( input_variables=graphcast.TASK.input_variables, target_variables=graphcast.TASK.target_variables, forcing_variables=graphcast.TASK.forcing_variables, pressure_levels=graphcast.PRESSURE_LEVELS[args.pressure_levels], input_duration=graphcast.TASK.input_duration, ) else: assert source == "Checkpoint" with open(params_file_value, "rb") as f: ckpt = checkpoint.load(f, graphcast.CheckPoint) params = ckpt.params state = {} model_config = ckpt.model_config task_config = ckpt.task_config logging.info(f"Model description:\n {ckpt.description} \n") logging.info(f"Model license:\n {ckpt.license} \n") # truncated_path = truncate_path(dataset_file_value) with xarray.open_dataset(dataset_file_value) as example_batch: example_batch = example_batch.compute() log_xarray_details(example_batch) # Example usage # dataset = xr.open_dataset('path_to_your_dataset.nc') # plot_all_variables(example_batch) # safe_data = safe_select_data(example_batch, 4) # logging.warning(f" SAFE_DATA: {safe_data} ") # logging.warning(f" safe_data: {safe_data.shape} keys: {list(safe_data.keys())} ") # plot_geopotential_at_surface(example_batch) plot_variable_on_world_map(example_batch, "geopotential_at_surface" , lat_range=(-60, 60), lon_range=(0, 360), time_index=0) assert example_batch.dims["time"] >= 3 # 2 for input, >=1 for targets def parse_file_parts(file_name): return dict(part.split("-", 1) for part in file_name.split("_")) logging.warning(", ".join([f"{k}: {v}" for k, v in parse_file_parts(dataset_file_value.removesuffix(".nc")).items()])) # init_jitted = jax.jit(with_configs(run_forward.init)) logging.warning(f" ******* example_batch_time: {example_batch.dims['time'] } \n train_steps.value, eval_steps.value (1,4) eval_steps.value ") # added explicit train_steps_value as 1 and eval_steps_value as 4 ( train_inputs, train_targets, train_forcings, ) = data_utils.extract_inputs_targets_forcings( example_batch, target_lead_times=slice("6h", f"{ (min(1,example_batch.dims['time'])) *6}h"), **dataclasses.asdict(task_config), ) ( eval_inputs, eval_targets, eval_forcings, ) = data_utils.extract_inputs_targets_forcings( example_batch, target_lead_times=slice("6h", f"{ (example_batch.dims['time']-2)*6}h"), **dataclasses.asdict(task_config), ) # logging.warning(" train Inputs: train_inputs ", train_inputs.dims.mapping) # logging.warning("train Targets: train_targets ", train_targets.dims.mapping) # logging.warning("train Forcings: train_forcings ", train_forcings.dims.mapping) logging.warning("All Examples: %s", example_batch.dims.mapping) logging.warning("Train Inputs: %s", train_inputs.dims.mapping) logging.warning("Train Targets: %s", train_targets.dims.mapping) logging.warning(f" {train_inputs.dims}") logging.warning("Train Forcings: %s", train_forcings.dims.mapping) logging.warning("Eval Inputs: %s", eval_inputs.dims.mapping) logging.warning("Eval Targets: %s", eval_targets.dims.mapping) logging.warning("Eval Forcings: %s", eval_forcings.dims.mapping) if params is None: init_jitted = jax.jit(with_configs(run_forward.init)) params, state = init_jitted( rng=jax.random.PRNGKey(0), inputs=train_inputs, targets_template=train_targets, forcings=train_forcings, ) loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply)))) grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn))) run_forward_jitted = drop_state( with_params(jax.jit(with_configs(run_forward.apply))) ) logging.warning(f" EVAL_INPUTS: {type(eval_inputs)} eval_inputs: {eval_inputs}") predictions = rollout.chunked_prediction( run_forward_jitted, rng=jax.random.PRNGKey(0), inputs=eval_inputs, targets_template=eval_targets * np.nan, forcings=eval_forcings, ) logging.warning(f" {'*'*20} PREIDCTIONS: {'*'*20} ") log_xarray_details(predictions) logging.warning(f" {'*'*20} \n\n\n {'*'*20} ") mapped_values = map_geopotential_indices(example_batch, predictions) plot_max_steps = args.max_steps if args.max_steps else predictions.dims["time"] logging.warning(f" plot_max_steps : {plot_max_steps} ") data = { "Targets": scale( select_data(eval_targets, args.variable, args.level, plot_max_steps), robust=args.robust, ), "Predictions": scale( select_data(predictions, args.variable, args.level, plot_max_steps), robust=args.robust, ), "Diff": scale( ( select_data(eval_targets, args.variable, args.level, plot_max_steps) - select_data(predictions, args.variable, args.level, plot_max_steps) ), robust=args.robust, center=0, ), } fig_title = args.variable if "level" in predictions[args.variable].coords: fig_title += f"_at_{args.level}_hPa_" plot_data(data, variable_name=args.variable,fig_title=fig_title, plot_size=5, robust=args.robust) # (/home/gupt1075/.conda/envs/cent7/5.3.1-py37/pytorch) gupt1075@gilbreth-i003:/scratch/gilbreth/gupt1075/graphcast$ CUDA_VISIBLE_DEVICES=0 python ./new_inference_modular.py --max-steps=2