Skip to content

Instantly share code, notes, and snippets.

@rishabh135
Created September 27, 2024 16:33
Show Gist options
  • Save rishabh135/63cfd72a921002a81abb559474ed4bce to your computer and use it in GitHub Desktop.
Save rishabh135/63cfd72a921002a81abb559474ed4bce to your computer and use it in GitHub Desktop.
graphcast_working_inference for my project
import argparse
import dataclasses
import functools
import json
import logging
import math
import os
import re
import time
from datetime import datetime, timedelta
from typing import Optional
import cartopy.crs as ccrs
import faiss
import jax
import matplotlib.pyplot as plt
import matplotlib
import numpy as np, sys
import cartopy.crs as ccrs
import faiss
import haiku as hk
import matplotlib.animation as animation
import jax
import matplotlib.pyplot as plt
import numpy as np, sys
import tqdm
import xarray as xr
from graphcast import (
autoregressive,
casting,
checkpoint,
data_utils,
graphcast,
normalization,
rollout,
xarray_jax,
xarray_tree,
)
import matplotlib.pyplot as plt
import xarray as xr
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from graphcast.graphcast import ModelConfig, TaskConfig
import requests
import tqdm
import xarray
# Define constants
GLOBAL_PATH = "/scratch/gilbreth/gupt1075/graphcast/"
# Create logs directory if it doesn't exist
os.makedirs(f"{GLOBAL_PATH}/logs/", exist_ok=True)
# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
filename=f"{GLOBAL_PATH}/logs/plot_g_at_surface_{datetime.now().strftime('%B_%d_')}_modular_geopotential_level_1000__13levels__2.log",
)
def parse_file_parts(file_name):
"""
Parse the file name to extract relevant parts.
:param file_name: The name of the file to parse.
:return: A dictionary containing the parsed parts.
"""
return dict(part.split("-", 1) for part in file_name.split("_"))
def select_data(
data: xarray.Dataset,
variable: str,
level: Optional[int] = None,
max_steps: Optional[int] = None,
) -> xarray.Dataset:
"""
Select specific data from the dataset based on the given criteria.
:param data: The dataset to select from.
:param variable: The variable to select.
:param level: The level to select (optional).
:param max_steps: The maximum steps to select (optional).
:return: The selected dataset.
"""
data = data[variable]
if "batch" in data.dims:
data = data.isel(batch=0)
if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
data = data.isel(time=range(0, max_steps))
if level is not None and "level" in data.coords:
logging.warning(f" *** level: {level}, data_coords {data.coords} ")
data = data.sel(level=level, method="nearest",)
return data
def scale_data(
data: xarray.Dataset,
center: Optional[float] = None,
robust: bool = False,
) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
"""
Scale the data for plotting.
:param data: The dataset to scale.
:param center: The center value for scaling (optional).
:param robust: Whether to use robust scaling (optional).
:return: A tuple containing the scaled data, normalization, and colormap.
"""
vmin = np.nanpercentile(data, (2 if robust else 0))
vmax = np.nanpercentile(data, (98 if robust else 100))
if center is not None:
diff = max(vmax - center, center - vmin)
vmin = center - diff
vmax = center + diff
return (
data,
matplotlib.colors.Normalize(vmin, vmax),
("RdBu_r" if center is not None else "viridis"),
)
def save_animation_as_gif(ani, filename, fps=1):
"""
Save a Matplotlib animation as a GIF to a local directory.
Parameters:
ani (matplotlib.animation.FuncAnimation): The animation to save.
filename (str): The filename to save the GIF as.
fps (int, optional): The frames per second for the GIF. Defaults to 10.
"""
ani.save(filename, writer='pillow', fps=fps)
def plot_data(
data: dict[str, xarray.Dataset],
variable_name: str,
fig_title: str,
plot_size: float = 5,
robust: bool = False,
cols: int = 4,
) -> None:
"""
Plot the data.
:param data: A dictionary of datasets to plot.
:param fig_title: The title of the figure.
:param plot_size: The size of the plot (optional).
:param robust: Whether to use robust scaling (optional).
:param cols: The number of columns (optional).
"""
logging.info(f"Plotting data with title: {fig_title} and type data {type(data)}")
for d in data.values():
logging.warning(f" type_d: {type(d)} {len(d)} { type(d[0])} size data[0]: {d[0].dims} {type(d[1])} {type(d[2])} \n ")
logging.warning(f"\n *******************************\n")
first_data = next(iter(data.values()))[0]
logging.info(
f"*********** \nPlotting {fig_title} with {type(first_data)} first_data_length: {len(first_data)} \n "
)
# first_data = next(iter(data.values()))
logging.info(f" *** First_data {type(first_data)} ")
max_steps = first_data.sizes.get("time", 1)
assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())
cols = min(cols, len(data))
rows = math.ceil(len(data) / cols)
figure = plt.figure(figsize=(plot_size * 2 * cols, plot_size * rows))
figure.suptitle(fig_title, fontsize=16)
figure.subplots_adjust(wspace=0, hspace=0)
figure.tight_layout()
images = []
for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
ax = figure.add_subplot(rows, cols, i + 1, projection=ccrs.Mercator(central_longitude=0.0, min_latitude=-80.0, max_latitude=84.0, globe=None, latitude_true_scale=0.0),)
# Added features to the map
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.OCEAN)
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title)
im = ax.imshow(
plot_data.isel(time=0, missing_dims="ignore"),
norm=norm,
transform=ccrs.Mercator(central_longitude=0.0, min_latitude=-80.0, max_latitude=84.0, globe=None, latitude_true_scale=0.0),
origin="lower",
cmap=cmap,
)
plt.colorbar(
mappable=im,
ax=ax,
orientation="vertical",
pad=0.02,
aspect=16,
shrink=0.75,
cmap=cmap,
extend=("both" if robust else "neither"),
)
images.append(im)
def update(frame):
if "time" in first_data.dims:
td = timedelta(
microseconds=first_data["time"][frame].item() / 1000
)
figure.suptitle(f"{fig_title}, {td}", fontsize=16)
else:
figure.suptitle(fig_title, fontsize=16)
for im, (plot_data, norm, cmap) in zip(images, data.values()):
im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))
ani = matplotlib.animation.FuncAnimation(
fig=figure, func=update, frames=max_steps, interval=250
)
gif_filename = f"./output_animation/{variable_name}_{datetime.now().strftime('%B_%d_')}_{fig_title}_6hr_1fps.gif"
logging.warning(f" saved_animation: {gif_filename} ")
save_animation_as_gif(ani, gif_filename)
plt.show()
def data_valid_for_model(
file_name: str,
model_config: graphcast.ModelConfig,
task_config: graphcast.TaskConfig,
) -> bool:
"""
Check if the data is valid for the given model configuration.
:param file_name: The name of the file to check.
:param model_config: The model configuration.
:param task_config: The task configuration.
:return: Whether the data is valid.
"""
logging.info(f"Checking data validity for file: {file_name}")
file_parts = parse_file_parts(file_name.removesuffix(".nc"))
return (
model_config.resolution in (0, float(file_parts["res"]))
and len(task_config.pressure_levels) == int(file_parts["levels"])
and (
(
"total_precipitation_6hr" in task_config.input_variables
and file_parts["source"] in ("era5", "fake")
)
or (
"total_precipitation_6hr" not in task_config.input_variables
and file_parts["source"] in ("hres", "fake")
)
)
)
def construct_wrapped_graphcast(
model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig
) -> graphcast.GraphCast:
"""
Construct and wrap the GraphCast predictor.
:param model_config: The model configuration.
:param task_config: The task configuration.
:return: The wrapped GraphCast predictor.
"""
logging.info("Constructing wrapped GraphCast predictor")
# Deeper one-step predictor.
predictor = graphcast.GraphCast(model_config, task_config)
# Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to/from float32 to/from BFloat16.
predictor = casting.Bfloat16Cast(predictor)
# Load normalization data
with open("./stats/diffs_stddev_by_level.nc", "rb") as f:
diffs_stddev_by_level = xr.load_dataset(f).compute()
with open("./stats/mean_by_level.nc", "rb") as f:
mean_by_level = xr.load_dataset(f).compute()
with open("./stats/stddev_by_level.nc", "rb") as f:
stddev_by_level = xr.load_dataset(f).compute()
# Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from BFloat16 happens after applying normalization to the inputs/targets.
predictor = normalization.InputsAndResiduals(
predictor,
diffs_stddev_by_level=diffs_stddev_by_level,
mean_by_level=mean_by_level,
stddev_by_level=stddev_by_level,
)
# Wraps everything so the one-step model can produce trajectories.
predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
return predictor
@hk.transform_with_state
def run_forward(
model_config, task_config, inputs, targets_template, forcings
):
"""
Run the forward pass of the model.
:param model_config: The model configuration.
:param task_config: The task configuration.
:param inputs: The input data.
:param targets_template: The target template.
:param forcings: The forcing data.
:return: The output of the forward pass.
"""
logging.info("Running forward pass")
predictor = construct_wrapped_graphcast(model_config, task_config)
return predictor(inputs, targets_template=targets_template, forcings=forcings)
@hk.transform_with_state
def loss_fn(
model_config, task_config, inputs, targets, forcings
):
"""
Compute the loss of the model.
:param model_config: The model configuration.
:param task_config: The task configuration.
:param inputs: The input data.
:param targets: The target data.
:param forcings: The forcing data.
:return: The loss and diagnostics.
"""
logging.info("Computing loss")
predictor = construct_wrapped_graphcast(model_config, task_config)
loss, diagnostics = predictor.loss(inputs, targets, forcings)
return xarray_tree.map_structure(
lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
(loss, diagnostics),
)
def grads_fn(
params, state, model_config, task_config, inputs, targets, forcings
):
"""
Compute the gradients of the model.
:param params: The model parameters.
:param state: The model state.
:param model_config: The model configuration.
:param task_config: The task configuration.
:param inputs: The input data.
:param targets: The target data.
:param forcings: The forcing data.
:return: The loss, diagnostics, next state, and gradients.
"""
logging.info("Computing gradients")
def _aux(params, state, i, t, f):
(loss, diagnostics), next_state = loss_fn.apply(
params, state, jax.random.PRNGKey(0), model_config, task_config, i, t, f
)
return loss, (diagnostics, next_state)
(loss, (diagnostics, next_state)), grads = jax.value_and_grad(_aux, has_aux=True)(
params, state, inputs, targets, forcings
)
return loss, diagnostics, next_state, grads
# Jax doesn't seem to like passing configs as args through the jit. Passing it in via partial (instead of capture by closure) forces jax to invalidate the jit cache if you change configs.
def with_configs(fn):
return functools.partial(fn, model_config=model_config, task_config=task_config)
# Always pass params and state, so the usage below are simpler
def with_params(fn):
return functools.partial(fn, params=params, state=state)
# Our models aren't stateful, so the state is always empty, so just return the predictions. This is required by our rollout code, and generally simpler.
def drop_state(fn):
return lambda **kw: fn(**kw)
"""
geopotential_at_surface
land_sea_mask
2m_temperature
mean_sea_level_pressure
10m_v_component_of_wind
10m_u_component_of_wind
total_precipitation_6hr
toa_incident_solar_radiation
temperature
geopotential
u_component_of_wind
v_component_of_wind
vertical_velocity
specific_humidity
"""
def plot_geopotential_at_surface(data: xr.Dataset):
"""
Plot the geopotential_at_surface variable with land and ocean features.
:param data: The xarray Dataset containing the geopotential_at_surface variable.
"""
# Check if the dataset contains 'geopotential_at_surface'
if 'geopotential_at_surface' not in data:
raise ValueError("Dataset does not contain 'geopotential_at_surface' variable")
# Select the 'geopotential_at_surface' variable
geopotential = data['geopotential_at_surface']
# Create a figure with a specific projection
fig, ax = plt.subplots(1, 1, figsize=(10, 5), subplot_kw={'projection': ccrs.PlateCarree()})
# Plot the geopotential data
geopotential.plot(ax=ax, transform=ccrs.PlateCarree(), cmap='viridis', cbar_kwargs={'shrink': 0.5})
# Add geographical features
ax.add_feature(cfeature.LAND, zorder=1, edgecolor='black')
ax.add_feature(cfeature.OCEAN, zorder=0)
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAKES, alpha=0.5)
# Set title and gridlines
ax.set_title('Geopotential at Surface with Land and Ocean Features')
ax.gridlines(draw_labels=True)
plt.show()
def plot_all_variables(data: xr.Dataset, plot_size: float = 5):
"""
Plot all input variables from an xarray Dataset.
:param data: The xarray Dataset containing the variables.
:param plot_size: The size of each subplot (optional).
"""
variable_names = list(data.data_vars)
num_vars = len(variable_names)
cols = min(4, num_vars)
rows = (num_vars + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(plot_size * cols, plot_size * rows),
subplot_kw={'projection': ccrs.PlateCarree()})
axes = axes.flatten()
for i, var_name in enumerate(variable_names):
if var_name in data:
data[var_name].isel(time=0).plot(ax=axes[i], transform=ccrs.PlateCarree())
axes[i].set_title(var_name)
# Hide any remaining empty subplots
for j in range(i + 1, len(axes)):
fig.delaxes(axes[j])
plt.tight_layout()
plt.show()
def safe_select_data(data: xr.Dataset, time_index: int):
# Check if 'time' is a dimension
if 'time' in data.dims:
max_time_index = data.sizes['time'] - 1
if time_index > max_time_index:
print(f"Warning: time_index {time_index} is out of bounds. Using {max_time_index} instead.")
time_index = max_time_index
# Safely select data
selected_data = data.isel(time=time_index)
return selected_data
else:
raise ValueError("The dataset does not have a 'time' dimension.")
def plot_geopotential_at_surface(data: xr.Dataset):
"""Plot the geopotential_at_surface variable with land and ocean features."""
# Check if the dataset contains 'geopotential_at_surface'
if 'geopotential_at_surface' not in data:
raise ValueError("Dataset does not contain 'geopotential_at_surface' variable")
# Select the 'geopotential_at_surface' variable
geopotential = data['geopotential_at_surface']
logging.warning(f" geopotential_dimensions{ type(geopotential)} shape: {geopotential.shape} ")
# Create a figure with a specific projection
fig, ax = plt.subplots(1, 1, figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()})
# Plot the geopotential data
geopotential.plot(ax=ax, transform=ccrs.PlateCarree(), cmap='viridis', cbar_kwargs={'shrink': 0.5})
# Add geographical features
ax.add_feature(cfeature.LAND, zorder=1, edgecolor='black')
ax.add_feature(cfeature.OCEAN, zorder=0)
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=':')
ax.add_feature(cfeature.LAKES, alpha=0.5)
# Set title and gridlines
ax.set_title('Geopotential at Surface with Land and Ocean Features')
ax.gridlines(draw_labels=True)
plt.show()
# # Example usage
# dataset = xr.open_dataset('path_to_your_dataset.nc')
# plot_geopotential_at_surface(dataset)
def select_data(data: xr.Dataset, variable: str, level: Optional[int] = None,
time_step: Optional[int] = None) -> xr.DataArray:
"""
Select specific data from the dataset based on the given criteria.
:param data: The dataset to select from.
:param variable: The variable to select.
:param level: The level to select (optional).
:param time_step: The time step to select (optional).
:return: The selected data array.
"""
selected_data = data[variable]
if "batch" in selected_data.dims:
selected_data = selected_data.isel(batch=0)
if time_step is not None and "time" in selected_data.dims:
selected_data = selected_data.isel(time=time_step)
if level is not None and "level" in selected_data.coords:
selected_data = selected_data.sel(level=level, method="nearest")
return selected_data
def plot_variable_on_world_map(data, variable_name, lat_range=None, lon_range=None, level=None, time_index=0, output_dir="/scratch/gilbreth/gupt1075/graphcast/input_data_plots/"):
"""
Plots a specified variable from an xarray dataset on a world map within given latitude and longitude ranges,
and at a specified level and time index if applicable. Saves the plot to the specified output directory.
Parameters:
- data: xarray.Dataset containing the data variables.
- variable_name: str, name of the variable to plot.
- lat_range: tuple of (min_lat, max_lat), latitude range to plot.
- lon_range: tuple of (min_lon, max_lon), longitude range to plot.
- level: int or None, specific level to plot if the variable has a level dimension.
- time_index: int, index of the time dimension to plot.
- output_dir: str, directory where the plot will be saved.
Returns:
- None
"""
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Select the variable
var = data[variable_name]
# logging.warning(f" min_value {var.min(dim=('lat','lon'))} max values: {var.max(dim=('lat','lon'))} \n\n")
# Slice latitude and longitude range if specified
if lat_range is not None:
var = var.sel(lat=slice(lat_range[0], lat_range[1]))
if lon_range is not None:
var = var.sel(lon=slice(lon_range[0], lon_range[1]))
# Handle levels if applicable
if 'level' in var.dims and level is not None:
var = var.sel(level=level)
# Select the specific time index
if 'time' in var.dims:
var = var.isel(time=time_index)
# Plotting
logging.warning(f" min_value {var.min()} max values: {var.max() } \n\n")
fig = plt.figure(figsize=(12, 8))
ax = plt.axes(projection=ccrs.PlateCarree())
# Add features to the map
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.OCEAN)
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=':')
# Plot the data
if 'lat' in var.dims and 'lon' in var.dims:
var.plot(ax=ax, transform=ccrs.PlateCarree(), cmap = 'RdBu', vmin = -1600, vmax = 56000, cbar_kwargs={'shrink': 0.5} )
plt.title(f"{variable_name} at time index {time_index}")
# Save the plot to the specified directory
file_path = os.path.join(output_dir, f"{variable_name}_time{time_index}.png")
plt.savefig(file_path)
# Log saving information
print(f"Plot saved to {file_path}")
def map_geopotential_indices(D1, D2):
geopotential_at_surface = D1['geopotential_at_surface'].values
geopotential = D2['geopotential'].values
# Initialize an array to store mapped values with the dimensions of D2
mapped_values = np.empty((1, 2, 181, 360), dtype=geopotential.dtype)
# Iterate over each point in (181, 360)
for i in range(181):
for j in range(360):
surface_value = geopotential_at_surface[i, j]
# Find indices across the 13 levels where the value matches
idx = np.argwhere(geopotential[0, :, :, i, j] == surface_value)
if idx.size > 0:
# Set values across all found indices in the mapped_values
for k in range(13):
if k < len(idx):
# If a match is found at level k
mapped_values[0, 0, i, j] = geopotential[0, idx[k][0], k, i, j]
else:
# If no match is found at level k
mapped_values[0, 0, i, j] = np.nan
else:
# Handle no match case with NaN or some default value
mapped_values[0, 0, i, j] = np.nan
return mapped_values
def log_xarray_details(ds):
# Log general dataset information
logging.warning(f"Dataset: {ds}")
logging.warning(f"Dimensions: {ds.dims}")
logging.warning(f"Coordinates: {ds.coords}")
logging.warning(f"Data Variables: {list(ds.data_vars.keys())} Number of Data variables : {len(list(ds.data_vars.keys())) } ")
# Iterate through all data variables in the dataset
for var_name, var in ds.data_vars.items():
logging.warning(f"Variable: {var_name}")
logging.warning(f"Shape: {var.shape}")
logging.warning(f"Dimensions: {var.dims}")
logging.warning(f"Data Type: {var.dtype}")
logging.warning("Attributes:")
for attr_name, attr_value in var.attrs.items():
logging.warning(f" {attr_name}: {attr_value}")
logging.warning("Coordinates:")
for coord_name, coord in var.coords.items():
logging.warning(f" {coord_name}: {coord}")
logging.warning("") # Empty line for better readability
def scale(data: xarray.Dataset, center: Optional[float] = None, robust: bool = False,) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
vmin = np.nanpercentile(data, (2 if robust else 0))
vmax = np.nanpercentile(data, (98 if robust else 100))
if center is not None:
diff = max(vmax - center, center - vmin)
vmin = center - diff
vmax = center + diff
return (data, matplotlib.colors.Normalize(vmin, vmax), ("RdBu_r" if center is not None else "viridis"))
def truncate_path(path, max_length=255):
if len(path.encode('utf-8')) > max_length:
# Truncate the path to fit within the maximum length
truncated_path = path[:max_length - len('.tmp')].encode('utf-8').decode('utf-8') + '.tmp'
return truncated_path
else:
return path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="GraphCast Model Runner")
parser.add_argument(
"--dataset",
# default="/scratch/gilbreth/gupt1075/graphcast/dataset/source-era5_date-2022-01-01_res-0.25_levels-37_steps-12.nc",
default="/scratch/gilbreth/gupt1075/graphcast/dataset/source-era5_date-2022-01-01_res-1.0_levels-13_steps-04.nc",
help="Path to the dataset file",
)
parser.add_argument(
"--params",
#default="/scratch/gilbreth/gupt1075/graphcast/params/GraphCast_ERA5_1979-2017_pressure_levels_37.npz",
default = "/scratch/gilbreth/gupt1075/graphcast/params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz",
help="Path to the parameters file",
)
parser.add_argument(
"--source",
choices=["Random", "Checkpoint"],
default="Checkpoint",
help="Source of the model parameters",
)
parser.add_argument(
"--mesh-size", type=int, default=4, help="Mesh size for the random model"
)
parser.add_argument(
"--gnn-msg-steps",
type=int,
default=4,
help="GNN message steps for the random model",
)
parser.add_argument(
"--latent-size", type=int, default=32, help="Latent size for the random model"
)
parser.add_argument(
"--pressure-levels",
type=int,
default=13,
help="Pressure levels for the random model",
)
parser.add_argument("--variable", default="geopotential", help="Variable to plot")
parser.add_argument("--level", type=int, default=1000, help="Level to plot")
parser.add_argument(
"--robust", action="store_true", help="Use robust scaling for plotting"
)
parser.add_argument("--max-steps", type=int, help="Maximum steps to plot")
args = parser.parse_args()
dataset_file_value = args.dataset
params_file_value = args.params
source = args.source
if source == "Random":
params = None # Filled in below
state = {}
model_config = graphcast.ModelConfig(
resolution=0,
mesh_size=args.mesh_size,
latent_size=args.latent_size,
gnn_msg_steps=args.gnn_msg_steps,
hidden_layers=1,
radius_query_fraction_edge_length=0.6,
)
task_config = graphcast.TaskConfig(
input_variables=graphcast.TASK.input_variables,
target_variables=graphcast.TASK.target_variables,
forcing_variables=graphcast.TASK.forcing_variables,
pressure_levels=graphcast.PRESSURE_LEVELS[args.pressure_levels],
input_duration=graphcast.TASK.input_duration,
)
else:
assert source == "Checkpoint"
with open(params_file_value, "rb") as f:
ckpt = checkpoint.load(f, graphcast.CheckPoint)
params = ckpt.params
state = {}
model_config = ckpt.model_config
task_config = ckpt.task_config
logging.info(f"Model description:\n {ckpt.description} \n")
logging.info(f"Model license:\n {ckpt.license} \n")
# truncated_path = truncate_path(dataset_file_value)
with xarray.open_dataset(dataset_file_value) as example_batch:
example_batch = example_batch.compute()
log_xarray_details(example_batch)
# Example usage
# dataset = xr.open_dataset('path_to_your_dataset.nc')
# plot_all_variables(example_batch)
# safe_data = safe_select_data(example_batch, 4)
# logging.warning(f" SAFE_DATA: {safe_data} ")
# logging.warning(f" safe_data: {safe_data.shape} keys: {list(safe_data.keys())} ")
# plot_geopotential_at_surface(example_batch)
plot_variable_on_world_map(example_batch, "geopotential_at_surface" , lat_range=(-60, 60), lon_range=(0, 360), time_index=0)
assert example_batch.dims["time"] >= 3 # 2 for input, >=1 for targets
def parse_file_parts(file_name):
return dict(part.split("-", 1) for part in file_name.split("_"))
logging.warning(", ".join([f"{k}: {v}" for k, v in parse_file_parts(dataset_file_value.removesuffix(".nc")).items()]))
# init_jitted = jax.jit(with_configs(run_forward.init))
logging.warning(f" ******* example_batch_time: {example_batch.dims['time'] } \n train_steps.value, eval_steps.value (1,4) eval_steps.value ")
# added explicit train_steps_value as 1 and eval_steps_value as 4
(
train_inputs,
train_targets,
train_forcings,
) = data_utils.extract_inputs_targets_forcings(
example_batch,
target_lead_times=slice("6h", f"{ (min(1,example_batch.dims['time'])) *6}h"),
**dataclasses.asdict(task_config),
)
(
eval_inputs,
eval_targets,
eval_forcings,
) = data_utils.extract_inputs_targets_forcings(
example_batch,
target_lead_times=slice("6h", f"{ (example_batch.dims['time']-2)*6}h"),
**dataclasses.asdict(task_config),
)
# logging.warning(" train Inputs: train_inputs ", train_inputs.dims.mapping)
# logging.warning("train Targets: train_targets ", train_targets.dims.mapping)
# logging.warning("train Forcings: train_forcings ", train_forcings.dims.mapping)
logging.warning("All Examples: %s", example_batch.dims.mapping)
logging.warning("Train Inputs: %s", train_inputs.dims.mapping)
logging.warning("Train Targets: %s", train_targets.dims.mapping)
logging.warning(f" {train_inputs.dims}")
logging.warning("Train Forcings: %s", train_forcings.dims.mapping)
logging.warning("Eval Inputs: %s", eval_inputs.dims.mapping)
logging.warning("Eval Targets: %s", eval_targets.dims.mapping)
logging.warning("Eval Forcings: %s", eval_forcings.dims.mapping)
if params is None:
init_jitted = jax.jit(with_configs(run_forward.init))
params, state = init_jitted(
rng=jax.random.PRNGKey(0),
inputs=train_inputs,
targets_template=train_targets,
forcings=train_forcings,
)
loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(
with_params(jax.jit(with_configs(run_forward.apply)))
)
logging.warning(f" EVAL_INPUTS: {type(eval_inputs)} eval_inputs: {eval_inputs}")
predictions = rollout.chunked_prediction(
run_forward_jitted,
rng=jax.random.PRNGKey(0),
inputs=eval_inputs,
targets_template=eval_targets * np.nan,
forcings=eval_forcings,
)
logging.warning(f" {'*'*20} PREIDCTIONS: {'*'*20} ")
log_xarray_details(predictions)
logging.warning(f" {'*'*20} \n\n\n {'*'*20} ")
mapped_values = map_geopotential_indices(example_batch, predictions)
plot_max_steps = args.max_steps if args.max_steps else predictions.dims["time"]
logging.warning(f" plot_max_steps : {plot_max_steps} ")
data = {
"Targets": scale(
select_data(eval_targets, args.variable, args.level, plot_max_steps),
robust=args.robust,
),
"Predictions": scale(
select_data(predictions, args.variable, args.level, plot_max_steps),
robust=args.robust,
),
"Diff": scale(
(
select_data(eval_targets, args.variable, args.level, plot_max_steps)
- select_data(predictions, args.variable, args.level, plot_max_steps)
),
robust=args.robust,
center=0,
),
}
fig_title = args.variable
if "level" in predictions[args.variable].coords:
fig_title += f"_at_{args.level}_hPa_"
plot_data(data, variable_name=args.variable,fig_title=fig_title, plot_size=5, robust=args.robust)
# (/home/gupt1075/.conda/envs/cent7/5.3.1-py37/pytorch) gupt1075@gilbreth-i003:/scratch/gilbreth/gupt1075/graphcast$ CUDA_VISIBLE_DEVICES=0 python ./new_inference_modular.py --max-steps=2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment