Created
September 27, 2024 16:33
-
-
Save rishabh135/63cfd72a921002a81abb559474ed4bce to your computer and use it in GitHub Desktop.
graphcast_working_inference for my project
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import dataclasses | |
| import functools | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import re | |
| import time | |
| from datetime import datetime, timedelta | |
| from typing import Optional | |
| import cartopy.crs as ccrs | |
| import faiss | |
| import jax | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import numpy as np, sys | |
| import cartopy.crs as ccrs | |
| import faiss | |
| import haiku as hk | |
| import matplotlib.animation as animation | |
| import jax | |
| import matplotlib.pyplot as plt | |
| import numpy as np, sys | |
| import tqdm | |
| import xarray as xr | |
| from graphcast import ( | |
| autoregressive, | |
| casting, | |
| checkpoint, | |
| data_utils, | |
| graphcast, | |
| normalization, | |
| rollout, | |
| xarray_jax, | |
| xarray_tree, | |
| ) | |
| import matplotlib.pyplot as plt | |
| import xarray as xr | |
| import cartopy.crs as ccrs | |
| import cartopy.feature as cfeature | |
| from graphcast.graphcast import ModelConfig, TaskConfig | |
| import requests | |
| import tqdm | |
| import xarray | |
| # Define constants | |
| GLOBAL_PATH = "/scratch/gilbreth/gupt1075/graphcast/" | |
| # Create logs directory if it doesn't exist | |
| os.makedirs(f"{GLOBAL_PATH}/logs/", exist_ok=True) | |
| # Set up logging | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| filename=f"{GLOBAL_PATH}/logs/plot_g_at_surface_{datetime.now().strftime('%B_%d_')}_modular_geopotential_level_1000__13levels__2.log", | |
| ) | |
| def parse_file_parts(file_name): | |
| """ | |
| Parse the file name to extract relevant parts. | |
| :param file_name: The name of the file to parse. | |
| :return: A dictionary containing the parsed parts. | |
| """ | |
| return dict(part.split("-", 1) for part in file_name.split("_")) | |
| def select_data( | |
| data: xarray.Dataset, | |
| variable: str, | |
| level: Optional[int] = None, | |
| max_steps: Optional[int] = None, | |
| ) -> xarray.Dataset: | |
| """ | |
| Select specific data from the dataset based on the given criteria. | |
| :param data: The dataset to select from. | |
| :param variable: The variable to select. | |
| :param level: The level to select (optional). | |
| :param max_steps: The maximum steps to select (optional). | |
| :return: The selected dataset. | |
| """ | |
| data = data[variable] | |
| if "batch" in data.dims: | |
| data = data.isel(batch=0) | |
| if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]: | |
| data = data.isel(time=range(0, max_steps)) | |
| if level is not None and "level" in data.coords: | |
| logging.warning(f" *** level: {level}, data_coords {data.coords} ") | |
| data = data.sel(level=level, method="nearest",) | |
| return data | |
| def scale_data( | |
| data: xarray.Dataset, | |
| center: Optional[float] = None, | |
| robust: bool = False, | |
| ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]: | |
| """ | |
| Scale the data for plotting. | |
| :param data: The dataset to scale. | |
| :param center: The center value for scaling (optional). | |
| :param robust: Whether to use robust scaling (optional). | |
| :return: A tuple containing the scaled data, normalization, and colormap. | |
| """ | |
| vmin = np.nanpercentile(data, (2 if robust else 0)) | |
| vmax = np.nanpercentile(data, (98 if robust else 100)) | |
| if center is not None: | |
| diff = max(vmax - center, center - vmin) | |
| vmin = center - diff | |
| vmax = center + diff | |
| return ( | |
| data, | |
| matplotlib.colors.Normalize(vmin, vmax), | |
| ("RdBu_r" if center is not None else "viridis"), | |
| ) | |
| def save_animation_as_gif(ani, filename, fps=1): | |
| """ | |
| Save a Matplotlib animation as a GIF to a local directory. | |
| Parameters: | |
| ani (matplotlib.animation.FuncAnimation): The animation to save. | |
| filename (str): The filename to save the GIF as. | |
| fps (int, optional): The frames per second for the GIF. Defaults to 10. | |
| """ | |
| ani.save(filename, writer='pillow', fps=fps) | |
| def plot_data( | |
| data: dict[str, xarray.Dataset], | |
| variable_name: str, | |
| fig_title: str, | |
| plot_size: float = 5, | |
| robust: bool = False, | |
| cols: int = 4, | |
| ) -> None: | |
| """ | |
| Plot the data. | |
| :param data: A dictionary of datasets to plot. | |
| :param fig_title: The title of the figure. | |
| :param plot_size: The size of the plot (optional). | |
| :param robust: Whether to use robust scaling (optional). | |
| :param cols: The number of columns (optional). | |
| """ | |
| logging.info(f"Plotting data with title: {fig_title} and type data {type(data)}") | |
| for d in data.values(): | |
| logging.warning(f" type_d: {type(d)} {len(d)} { type(d[0])} size data[0]: {d[0].dims} {type(d[1])} {type(d[2])} \n ") | |
| logging.warning(f"\n *******************************\n") | |
| first_data = next(iter(data.values()))[0] | |
| logging.info( | |
| f"*********** \nPlotting {fig_title} with {type(first_data)} first_data_length: {len(first_data)} \n " | |
| ) | |
| # first_data = next(iter(data.values())) | |
| logging.info(f" *** First_data {type(first_data)} ") | |
| max_steps = first_data.sizes.get("time", 1) | |
| assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values()) | |
| cols = min(cols, len(data)) | |
| rows = math.ceil(len(data) / cols) | |
| figure = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows)) | |
| figure.suptitle(fig_title, fontsize=16) | |
| figure.subplots_adjust(wspace=0, hspace=0) | |
| figure.tight_layout() | |
| images = [] | |
| for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()): | |
| ax = figure.add_subplot(rows, cols, i + 1, projection=ccrs.Mercator(central_longitude=0.0, min_latitude=-80.0, max_latitude=84.0, globe=None, latitude_true_scale=0.0),) | |
| # Added features to the map | |
| ax.add_feature(cfeature.LAND) | |
| ax.add_feature(cfeature.OCEAN) | |
| ax.add_feature(cfeature.COASTLINE) | |
| ax.add_feature(cfeature.BORDERS, linestyle=':') | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_title(title) | |
| im = ax.imshow( | |
| plot_data.isel(time=0, missing_dims="ignore"), | |
| norm=norm, | |
| transform=ccrs.Mercator(central_longitude=0.0, min_latitude=-80.0, max_latitude=84.0, globe=None, latitude_true_scale=0.0), | |
| origin="lower", | |
| cmap=cmap, | |
| ) | |
| plt.colorbar( | |
| mappable=im, | |
| ax=ax, | |
| orientation="vertical", | |
| pad=0.02, | |
| aspect=16, | |
| shrink=0.75, | |
| cmap=cmap, | |
| extend=("both" if robust else "neither"), | |
| ) | |
| images.append(im) | |
| def update(frame): | |
| if "time" in first_data.dims: | |
| td = timedelta( | |
| microseconds=first_data["time"][frame].item() / 1000 | |
| ) | |
| figure.suptitle(f"{fig_title}, {td}", fontsize=16) | |
| else: | |
| figure.suptitle(fig_title, fontsize=16) | |
| for im, (plot_data, norm, cmap) in zip(images, data.values()): | |
| im.set_data(plot_data.isel(time=frame, missing_dims="ignore")) | |
| ani = matplotlib.animation.FuncAnimation( | |
| fig=figure, func=update, frames=max_steps, interval=250 | |
| ) | |
| gif_filename = f"./output_animation/{variable_name}_{datetime.now().strftime('%B_%d_')}_{fig_title}_6hr_1fps.gif" | |
| logging.warning(f" saved_animation: {gif_filename} ") | |
| save_animation_as_gif(ani, gif_filename) | |
| plt.show() | |
| def data_valid_for_model( | |
| file_name: str, | |
| model_config: graphcast.ModelConfig, | |
| task_config: graphcast.TaskConfig, | |
| ) -> bool: | |
| """ | |
| Check if the data is valid for the given model configuration. | |
| :param file_name: The name of the file to check. | |
| :param model_config: The model configuration. | |
| :param task_config: The task configuration. | |
| :return: Whether the data is valid. | |
| """ | |
| logging.info(f"Checking data validity for file: {file_name}") | |
| file_parts = parse_file_parts(file_name.removesuffix(".nc")) | |
| return ( | |
| model_config.resolution in (0, float(file_parts["res"])) | |
| and len(task_config.pressure_levels) == int(file_parts["levels"]) | |
| and ( | |
| ( | |
| "total_precipitation_6hr" in task_config.input_variables | |
| and file_parts["source"] in ("era5", "fake") | |
| ) | |
| or ( | |
| "total_precipitation_6hr" not in task_config.input_variables | |
| and file_parts["source"] in ("hres", "fake") | |
| ) | |
| ) | |
| ) | |
| def construct_wrapped_graphcast( | |
| model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig | |
| ) -> graphcast.GraphCast: | |
| """ | |
| Construct and wrap the GraphCast predictor. | |
| :param model_config: The model configuration. | |
| :param task_config: The task configuration. | |
| :return: The wrapped GraphCast predictor. | |
| """ | |
| logging.info("Constructing wrapped GraphCast predictor") | |
| # Deeper one-step predictor. | |
| predictor = graphcast.GraphCast(model_config, task_config) | |
| # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to/from float32 to/from BFloat16. | |
| predictor = casting.Bfloat16Cast(predictor) | |
| # Load normalization data | |
| with open("./stats/diffs_stddev_by_level.nc", "rb") as f: | |
| diffs_stddev_by_level = xr.load_dataset(f).compute() | |
| with open("./stats/mean_by_level.nc", "rb") as f: | |
| mean_by_level = xr.load_dataset(f).compute() | |
| with open("./stats/stddev_by_level.nc", "rb") as f: | |
| stddev_by_level = xr.load_dataset(f).compute() | |
| # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from BFloat16 happens after applying normalization to the inputs/targets. | |
| predictor = normalization.InputsAndResiduals( | |
| predictor, | |
| diffs_stddev_by_level=diffs_stddev_by_level, | |
| mean_by_level=mean_by_level, | |
| stddev_by_level=stddev_by_level, | |
| ) | |
| # Wraps everything so the one-step model can produce trajectories. | |
| predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True) | |
| return predictor | |
| @hk.transform_with_state | |
| def run_forward( | |
| model_config, task_config, inputs, targets_template, forcings | |
| ): | |
| """ | |
| Run the forward pass of the model. | |
| :param model_config: The model configuration. | |
| :param task_config: The task configuration. | |
| :param inputs: The input data. | |
| :param targets_template: The target template. | |
| :param forcings: The forcing data. | |
| :return: The output of the forward pass. | |
| """ | |
| logging.info("Running forward pass") | |
| predictor = construct_wrapped_graphcast(model_config, task_config) | |
| return predictor(inputs, targets_template=targets_template, forcings=forcings) | |
| @hk.transform_with_state | |
| def loss_fn( | |
| model_config, task_config, inputs, targets, forcings | |
| ): | |
| """ | |
| Compute the loss of the model. | |
| :param model_config: The model configuration. | |
| :param task_config: The task configuration. | |
| :param inputs: The input data. | |
| :param targets: The target data. | |
| :param forcings: The forcing data. | |
| :return: The loss and diagnostics. | |
| """ | |
| logging.info("Computing loss") | |
| predictor = construct_wrapped_graphcast(model_config, task_config) | |
| loss, diagnostics = predictor.loss(inputs, targets, forcings) | |
| return xarray_tree.map_structure( | |
| lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True), | |
| (loss, diagnostics), | |
| ) | |
| def grads_fn( | |
| params, state, model_config, task_config, inputs, targets, forcings | |
| ): | |
| """ | |
| Compute the gradients of the model. | |
| :param params: The model parameters. | |
| :param state: The model state. | |
| :param model_config: The model configuration. | |
| :param task_config: The task configuration. | |
| :param inputs: The input data. | |
| :param targets: The target data. | |
| :param forcings: The forcing data. | |
| :return: The loss, diagnostics, next state, and gradients. | |
| """ | |
| logging.info("Computing gradients") | |
| def _aux(params, state, i, t, f): | |
| (loss, diagnostics), next_state = loss_fn.apply( | |
| params, state, jax.random.PRNGKey(0), model_config, task_config, i, t, f | |
| ) | |
| return loss, (diagnostics, next_state) | |
| (loss, (diagnostics, next_state)), grads = jax.value_and_grad(_aux, has_aux=True)( | |
| params, state, inputs, targets, forcings | |
| ) | |
| return loss, diagnostics, next_state, grads | |
| # Jax doesn't seem to like passing configs as args through the jit. Passing it in via partial (instead of capture by closure) forces jax to invalidate the jit cache if you change configs. | |
| def with_configs(fn): | |
| return functools.partial(fn, model_config=model_config, task_config=task_config) | |
| # Always pass params and state, so the usage below are simpler | |
| def with_params(fn): | |
| return functools.partial(fn, params=params, state=state) | |
| # Our models aren't stateful, so the state is always empty, so just return the predictions. This is required by our rollout code, and generally simpler. | |
| def drop_state(fn): | |
| return lambda **kw: fn(**kw) | |
| """ | |
| geopotential_at_surface | |
| land_sea_mask | |
| 2m_temperature | |
| mean_sea_level_pressure | |
| 10m_v_component_of_wind | |
| 10m_u_component_of_wind | |
| total_precipitation_6hr | |
| toa_incident_solar_radiation | |
| temperature | |
| geopotential | |
| u_component_of_wind | |
| v_component_of_wind | |
| vertical_velocity | |
| specific_humidity | |
| """ | |
| def plot_geopotential_at_surface(data: xr.Dataset): | |
| """ | |
| Plot the geopotential_at_surface variable with land and ocean features. | |
| :param data: The xarray Dataset containing the geopotential_at_surface variable. | |
| """ | |
| # Check if the dataset contains 'geopotential_at_surface' | |
| if 'geopotential_at_surface' not in data: | |
| raise ValueError("Dataset does not contain 'geopotential_at_surface' variable") | |
| # Select the 'geopotential_at_surface' variable | |
| geopotential = data['geopotential_at_surface'] | |
| # Create a figure with a specific projection | |
| fig, ax = plt.subplots(1, 1, figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()}) | |
| # Plot the geopotential data | |
| geopotential.plot(ax=ax, transform=ccrs.PlateCarree(), cmap='viridis', cbar_kwargs={'shrink': 0.5}) | |
| # Add geographical features | |
| ax.add_feature(cfeature.LAND, zorder=1, edgecolor='black') | |
| ax.add_feature(cfeature.OCEAN, zorder=0) | |
| ax.add_feature(cfeature.COASTLINE) | |
| ax.add_feature(cfeature.BORDERS, linestyle=':') | |
| ax.add_feature(cfeature.LAKES, alpha=0.5) | |
| # Set title and gridlines | |
| ax.set_title('Geopotential at Surface with Land and Ocean Features') | |
| ax.gridlines(draw_labels=True) | |
| plt.show() | |
| def plot_all_variables(data: xr.Dataset, plot_size: float = 5): | |
| """ | |
| Plot all input variables from an xarray Dataset. | |
| :param data: The xarray Dataset containing the variables. | |
| :param plot_size: The size of each subplot (optional). | |
| """ | |
| variable_names = list(data.data_vars) | |
| num_vars = len(variable_names) | |
| cols = min(4, num_vars) | |
| rows = (num_vars + cols - 1) // cols | |
| fig, axes = plt.subplots(rows, cols, figsize=(plot_size * cols, plot_size * rows), | |
| subplot_kw={'projection': ccrs.PlateCarree()}) | |
| axes = axes.flatten() | |
| for i, var_name in enumerate(variable_names): | |
| if var_name in data: | |
| data[var_name].isel(time=0).plot(ax=axes[i], transform=ccrs.PlateCarree()) | |
| axes[i].set_title(var_name) | |
| # Hide any remaining empty subplots | |
| for j in range(i + 1, len(axes)): | |
| fig.delaxes(axes[j]) | |
| plt.tight_layout() | |
| plt.show() | |
| def safe_select_data(data: xr.Dataset, time_index: int): | |
| # Check if 'time' is a dimension | |
| if 'time' in data.dims: | |
| max_time_index = data.sizes['time'] - 1 | |
| if time_index > max_time_index: | |
| print(f"Warning: time_index {time_index} is out of bounds. Using {max_time_index} instead.") | |
| time_index = max_time_index | |
| # Safely select data | |
| selected_data = data.isel(time=time_index) | |
| return selected_data | |
| else: | |
| raise ValueError("The dataset does not have a 'time' dimension.") | |
| def plot_geopotential_at_surface(data: xr.Dataset): | |
| """Plot the geopotential_at_surface variable with land and ocean features.""" | |
| # Check if the dataset contains 'geopotential_at_surface' | |
| if 'geopotential_at_surface' not in data: | |
| raise ValueError("Dataset does not contain 'geopotential_at_surface' variable") | |
| # Select the 'geopotential_at_surface' variable | |
| geopotential = data['geopotential_at_surface'] | |
| logging.warning(f" geopotential_dimensions{ type(geopotential)} shape: {geopotential.shape} ") | |
| # Create a figure with a specific projection | |
| fig, ax = plt.subplots(1, 1, figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()}) | |
| # Plot the geopotential data | |
| geopotential.plot(ax=ax, transform=ccrs.PlateCarree(), cmap='viridis', cbar_kwargs={'shrink': 0.5}) | |
| # Add geographical features | |
| ax.add_feature(cfeature.LAND, zorder=1, edgecolor='black') | |
| ax.add_feature(cfeature.OCEAN, zorder=0) | |
| ax.add_feature(cfeature.COASTLINE) | |
| ax.add_feature(cfeature.BORDERS, linestyle=':') | |
| ax.add_feature(cfeature.LAKES, alpha=0.5) | |
| # Set title and gridlines | |
| ax.set_title('Geopotential at Surface with Land and Ocean Features') | |
| ax.gridlines(draw_labels=True) | |
| plt.show() | |
| # # Example usage | |
| # dataset = xr.open_dataset('path_to_your_dataset.nc') | |
| # plot_geopotential_at_surface(dataset) | |
| def select_data(data: xr.Dataset, variable: str, level: Optional[int] = None, | |
| time_step: Optional[int] = None) -> xr.DataArray: | |
| """ | |
| Select specific data from the dataset based on the given criteria. | |
| :param data: The dataset to select from. | |
| :param variable: The variable to select. | |
| :param level: The level to select (optional). | |
| :param time_step: The time step to select (optional). | |
| :return: The selected data array. | |
| """ | |
| selected_data = data[variable] | |
| if "batch" in selected_data.dims: | |
| selected_data = selected_data.isel(batch=0) | |
| if time_step is not None and "time" in selected_data.dims: | |
| selected_data = selected_data.isel(time=time_step) | |
| if level is not None and "level" in selected_data.coords: | |
| selected_data = selected_data.sel(level=level, method="nearest") | |
| return selected_data | |
| def plot_variable_on_world_map(data, variable_name, lat_range=None, lon_range=None, level=None, time_index=0, output_dir="/scratch/gilbreth/gupt1075/graphcast/input_data_plots/"): | |
| """ | |
| Plots a specified variable from an xarray dataset on a world map within given latitude and longitude ranges, | |
| and at a specified level and time index if applicable. Saves the plot to the specified output directory. | |
| Parameters: | |
| - data: xarray.Dataset containing the data variables. | |
| - variable_name: str, name of the variable to plot. | |
| - lat_range: tuple of (min_lat, max_lat), latitude range to plot. | |
| - lon_range: tuple of (min_lon, max_lon), longitude range to plot. | |
| - level: int or None, specific level to plot if the variable has a level dimension. | |
| - time_index: int, index of the time dimension to plot. | |
| - output_dir: str, directory where the plot will be saved. | |
| Returns: | |
| - None | |
| """ | |
| # Ensure output directory exists | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Select the variable | |
| var = data[variable_name] | |
| # logging.warning(f" min_value {var.min(dim=('lat','lon'))} max values: {var.max(dim=('lat','lon'))} \n\n") | |
| # Slice latitude and longitude range if specified | |
| if lat_range is not None: | |
| var = var.sel(lat=slice(lat_range[0], lat_range[1])) | |
| if lon_range is not None: | |
| var = var.sel(lon=slice(lon_range[0], lon_range[1])) | |
| # Handle levels if applicable | |
| if 'level' in var.dims and level is not None: | |
| var = var.sel(level=level) | |
| # Select the specific time index | |
| if 'time' in var.dims: | |
| var = var.isel(time=time_index) | |
| # Plotting | |
| logging.warning(f" min_value {var.min()} max values: {var.max() } \n\n") | |
| fig = plt.figure(figsize=(12, 8)) | |
| ax = plt.axes(projection=ccrs.PlateCarree()) | |
| # Add features to the map | |
| ax.add_feature(cfeature.LAND) | |
| ax.add_feature(cfeature.OCEAN) | |
| ax.add_feature(cfeature.COASTLINE) | |
| ax.add_feature(cfeature.BORDERS, linestyle=':') | |
| # Plot the data | |
| if 'lat' in var.dims and 'lon' in var.dims: | |
| var.plot(ax=ax, transform=ccrs.PlateCarree(), cmap = 'RdBu', vmin = -1600, vmax = 56000, cbar_kwargs={'shrink': 0.5} ) | |
| plt.title(f"{variable_name} at time index {time_index}") | |
| # Save the plot to the specified directory | |
| file_path = os.path.join(output_dir, f"{variable_name}_time{time_index}.png") | |
| plt.savefig(file_path) | |
| # Log saving information | |
| print(f"Plot saved to {file_path}") | |
| def map_geopotential_indices(D1, D2): | |
| geopotential_at_surface = D1['geopotential_at_surface'].values | |
| geopotential = D2['geopotential'].values | |
| # Initialize an array to store mapped values with the dimensions of D2 | |
| mapped_values = np.empty((1, 2, 181, 360), dtype=geopotential.dtype) | |
| # Iterate over each point in (181, 360) | |
| for i in range(181): | |
| for j in range(360): | |
| surface_value = geopotential_at_surface[i, j] | |
| # Find indices across the 13 levels where the value matches | |
| idx = np.argwhere(geopotential[0, :, :, i, j] == surface_value) | |
| if idx.size > 0: | |
| # Set values across all found indices in the mapped_values | |
| for k in range(13): | |
| if k < len(idx): | |
| # If a match is found at level k | |
| mapped_values[0, 0, i, j] = geopotential[0, idx[k][0], k, i, j] | |
| else: | |
| # If no match is found at level k | |
| mapped_values[0, 0, i, j] = np.nan | |
| else: | |
| # Handle no match case with NaN or some default value | |
| mapped_values[0, 0, i, j] = np.nan | |
| return mapped_values | |
| def log_xarray_details(ds): | |
| # Log general dataset information | |
| logging.warning(f"Dataset: {ds}") | |
| logging.warning(f"Dimensions: {ds.dims}") | |
| logging.warning(f"Coordinates: {ds.coords}") | |
| logging.warning(f"Data Variables: {list(ds.data_vars.keys())} Number of Data variables : {len(list(ds.data_vars.keys())) } ") | |
| # Iterate through all data variables in the dataset | |
| for var_name, var in ds.data_vars.items(): | |
| logging.warning(f"Variable: {var_name}") | |
| logging.warning(f"Shape: {var.shape}") | |
| logging.warning(f"Dimensions: {var.dims}") | |
| logging.warning(f"Data Type: {var.dtype}") | |
| logging.warning("Attributes:") | |
| for attr_name, attr_value in var.attrs.items(): | |
| logging.warning(f" {attr_name}: {attr_value}") | |
| logging.warning("Coordinates:") | |
| for coord_name, coord in var.coords.items(): | |
| logging.warning(f" {coord_name}: {coord}") | |
| logging.warning("") # Empty line for better readability | |
| def scale(data: xarray.Dataset, center: Optional[float] = None, robust: bool = False,) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]: | |
| vmin = np.nanpercentile(data, (2 if robust else 0)) | |
| vmax = np.nanpercentile(data, (98 if robust else 100)) | |
| if center is not None: | |
| diff = max(vmax - center, center - vmin) | |
| vmin = center - diff | |
| vmax = center + diff | |
| return (data, matplotlib.colors.Normalize(vmin, vmax), ("RdBu_r" if center is not None else "viridis")) | |
| def truncate_path(path, max_length=255): | |
| if len(path.encode('utf-8')) > max_length: | |
| # Truncate the path to fit within the maximum length | |
| truncated_path = path[:max_length - len('.tmp')].encode('utf-8').decode('utf-8') + '.tmp' | |
| return truncated_path | |
| else: | |
| return path | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="GraphCast Model Runner") | |
| parser.add_argument( | |
| "--dataset", | |
| # default="/scratch/gilbreth/gupt1075/graphcast/dataset/source-era5_date-2022-01-01_res-0.25_levels-37_steps-12.nc", | |
| default="/scratch/gilbreth/gupt1075/graphcast/dataset/source-era5_date-2022-01-01_res-1.0_levels-13_steps-04.nc", | |
| help="Path to the dataset file", | |
| ) | |
| parser.add_argument( | |
| "--params", | |
| #default="/scratch/gilbreth/gupt1075/graphcast/params/GraphCast_ERA5_1979-2017_pressure_levels_37.npz", | |
| default = "/scratch/gilbreth/gupt1075/graphcast/params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz", | |
| help="Path to the parameters file", | |
| ) | |
| parser.add_argument( | |
| "--source", | |
| choices=["Random", "Checkpoint"], | |
| default="Checkpoint", | |
| help="Source of the model parameters", | |
| ) | |
| parser.add_argument( | |
| "--mesh-size", type=int, default=4, help="Mesh size for the random model" | |
| ) | |
| parser.add_argument( | |
| "--gnn-msg-steps", | |
| type=int, | |
| default=4, | |
| help="GNN message steps for the random model", | |
| ) | |
| parser.add_argument( | |
| "--latent-size", type=int, default=32, help="Latent size for the random model" | |
| ) | |
| parser.add_argument( | |
| "--pressure-levels", | |
| type=int, | |
| default=13, | |
| help="Pressure levels for the random model", | |
| ) | |
| parser.add_argument("--variable", default="geopotential", help="Variable to plot") | |
| parser.add_argument("--level", type=int, default=1000, help="Level to plot") | |
| parser.add_argument( | |
| "--robust", action="store_true", help="Use robust scaling for plotting" | |
| ) | |
| parser.add_argument("--max-steps", type=int, help="Maximum steps to plot") | |
| args = parser.parse_args() | |
| dataset_file_value = args.dataset | |
| params_file_value = args.params | |
| source = args.source | |
| if source == "Random": | |
| params = None # Filled in below | |
| state = {} | |
| model_config = graphcast.ModelConfig( | |
| resolution=0, | |
| mesh_size=args.mesh_size, | |
| latent_size=args.latent_size, | |
| gnn_msg_steps=args.gnn_msg_steps, | |
| hidden_layers=1, | |
| radius_query_fraction_edge_length=0.6, | |
| ) | |
| task_config = graphcast.TaskConfig( | |
| input_variables=graphcast.TASK.input_variables, | |
| target_variables=graphcast.TASK.target_variables, | |
| forcing_variables=graphcast.TASK.forcing_variables, | |
| pressure_levels=graphcast.PRESSURE_LEVELS[args.pressure_levels], | |
| input_duration=graphcast.TASK.input_duration, | |
| ) | |
| else: | |
| assert source == "Checkpoint" | |
| with open(params_file_value, "rb") as f: | |
| ckpt = checkpoint.load(f, graphcast.CheckPoint) | |
| params = ckpt.params | |
| state = {} | |
| model_config = ckpt.model_config | |
| task_config = ckpt.task_config | |
| logging.info(f"Model description:\n {ckpt.description} \n") | |
| logging.info(f"Model license:\n {ckpt.license} \n") | |
| # truncated_path = truncate_path(dataset_file_value) | |
| with xarray.open_dataset(dataset_file_value) as example_batch: | |
| example_batch = example_batch.compute() | |
| log_xarray_details(example_batch) | |
| # Example usage | |
| # dataset = xr.open_dataset('path_to_your_dataset.nc') | |
| # plot_all_variables(example_batch) | |
| # safe_data = safe_select_data(example_batch, 4) | |
| # logging.warning(f" SAFE_DATA: {safe_data} ") | |
| # logging.warning(f" safe_data: {safe_data.shape} keys: {list(safe_data.keys())} ") | |
| # plot_geopotential_at_surface(example_batch) | |
| plot_variable_on_world_map(example_batch, "geopotential_at_surface" , lat_range=(-60, 60), lon_range=(0, 360), time_index=0) | |
| assert example_batch.dims["time"] >= 3 # 2 for input, >=1 for targets | |
| def parse_file_parts(file_name): | |
| return dict(part.split("-", 1) for part in file_name.split("_")) | |
| logging.warning(", ".join([f"{k}: {v}" for k, v in parse_file_parts(dataset_file_value.removesuffix(".nc")).items()])) | |
| # init_jitted = jax.jit(with_configs(run_forward.init)) | |
| logging.warning(f" ******* example_batch_time: {example_batch.dims['time'] } \n train_steps.value, eval_steps.value (1,4) eval_steps.value ") | |
| # added explicit train_steps_value as 1 and eval_steps_value as 4 | |
| ( | |
| train_inputs, | |
| train_targets, | |
| train_forcings, | |
| ) = data_utils.extract_inputs_targets_forcings( | |
| example_batch, | |
| target_lead_times=slice("6h", f"{ (min(1,example_batch.dims['time'])) *6}h"), | |
| **dataclasses.asdict(task_config), | |
| ) | |
| ( | |
| eval_inputs, | |
| eval_targets, | |
| eval_forcings, | |
| ) = data_utils.extract_inputs_targets_forcings( | |
| example_batch, | |
| target_lead_times=slice("6h", f"{ (example_batch.dims['time']-2)*6}h"), | |
| **dataclasses.asdict(task_config), | |
| ) | |
| # logging.warning(" train Inputs: train_inputs ", train_inputs.dims.mapping) | |
| # logging.warning("train Targets: train_targets ", train_targets.dims.mapping) | |
| # logging.warning("train Forcings: train_forcings ", train_forcings.dims.mapping) | |
| logging.warning("All Examples: %s", example_batch.dims.mapping) | |
| logging.warning("Train Inputs: %s", train_inputs.dims.mapping) | |
| logging.warning("Train Targets: %s", train_targets.dims.mapping) | |
| logging.warning(f" {train_inputs.dims}") | |
| logging.warning("Train Forcings: %s", train_forcings.dims.mapping) | |
| logging.warning("Eval Inputs: %s", eval_inputs.dims.mapping) | |
| logging.warning("Eval Targets: %s", eval_targets.dims.mapping) | |
| logging.warning("Eval Forcings: %s", eval_forcings.dims.mapping) | |
| if params is None: | |
| init_jitted = jax.jit(with_configs(run_forward.init)) | |
| params, state = init_jitted( | |
| rng=jax.random.PRNGKey(0), | |
| inputs=train_inputs, | |
| targets_template=train_targets, | |
| forcings=train_forcings, | |
| ) | |
| loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply)))) | |
| grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn))) | |
| run_forward_jitted = drop_state( | |
| with_params(jax.jit(with_configs(run_forward.apply))) | |
| ) | |
| logging.warning(f" EVAL_INPUTS: {type(eval_inputs)} eval_inputs: {eval_inputs}") | |
| predictions = rollout.chunked_prediction( | |
| run_forward_jitted, | |
| rng=jax.random.PRNGKey(0), | |
| inputs=eval_inputs, | |
| targets_template=eval_targets * np.nan, | |
| forcings=eval_forcings, | |
| ) | |
| logging.warning(f" {'*'*20} PREIDCTIONS: {'*'*20} ") | |
| log_xarray_details(predictions) | |
| logging.warning(f" {'*'*20} \n\n\n {'*'*20} ") | |
| mapped_values = map_geopotential_indices(example_batch, predictions) | |
| plot_max_steps = args.max_steps if args.max_steps else predictions.dims["time"] | |
| logging.warning(f" plot_max_steps : {plot_max_steps} ") | |
| data = { | |
| "Targets": scale( | |
| select_data(eval_targets, args.variable, args.level, plot_max_steps), | |
| robust=args.robust, | |
| ), | |
| "Predictions": scale( | |
| select_data(predictions, args.variable, args.level, plot_max_steps), | |
| robust=args.robust, | |
| ), | |
| "Diff": scale( | |
| ( | |
| select_data(eval_targets, args.variable, args.level, plot_max_steps) | |
| - select_data(predictions, args.variable, args.level, plot_max_steps) | |
| ), | |
| robust=args.robust, | |
| center=0, | |
| ), | |
| } | |
| fig_title = args.variable | |
| if "level" in predictions[args.variable].coords: | |
| fig_title += f"_at_{args.level}_hPa_" | |
| plot_data(data, variable_name=args.variable,fig_title=fig_title, plot_size=5, robust=args.robust) | |
| # (/home/gupt1075/.conda/envs/cent7/5.3.1-py37/pytorch) gupt1075@gilbreth-i003:/scratch/gilbreth/gupt1075/graphcast$ CUDA_VISIBLE_DEVICES=0 python ./new_inference_modular.py --max-steps=2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment