Last active
August 26, 2025 15:07
-
-
Save KMarkert/e83c7ff9a96db000524d8a8c78cd544e to your computer and use it in GitHub Desktop.
g4g25_deeplearning_crnn_vfinal.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "private_outputs": true, | |
| "provenance": [], | |
| "gpuType": "L4", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# G4G25 Machine Learning with Earth Engine and Vertex AI\n", | |
| "\n", | |
| "<a href=https://colab.research.google.com/gist/KMarkert/e83c7ff9a96db000524d8a8c78cd544e/g4g25_deeplearning_crnn_vfinal.ipynb>\n", | |
| " <img src=https://colab.research.google.com/assets/colab-badge.svg alt=\"Open in Colab\">\n", | |
| "</a>\n", | |
| "\n", | |
| "\n", | |
| "Welcome to Geo for Good 2025\n", | |
| "Machine Learning with Earth Engine and Vertex AI Notebook! This demonstration highlights how to run time series predictions with geospatial data using a Convolutional Recurrent Neural Network (CRNN). The use case presented is to forecast drought in the continental United States." | |
| ], | |
| "metadata": { | |
| "id": "ZJlnnSEjKpxD" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## The Problem: Predicting Drought Conditions\n", | |
| "\n", | |
| "Drought is a slow-moving but devastating natural hazard with significant impacts on agriculture, water resources, ecosystems, and economies. Timely and accurate monitoring is crucial for mitigation and response planning. Traditionally, drought assessment involves a complex synthesis of various climate and environmental data by human experts.\n", | |
| "\n", | |
| "Our goal is to leverage deep learning to automate and potentially forecast drought conditions. We aim to build a model that can learn the temporal relationships between key environmental indicators and the resulting drought severity, effectively mimicking the expert assessment process but in a scalable, data-driven manner.\n", | |
| "\n", | |
| "### Existing Data Products\n", | |
| "\n", | |
| "To tackle this, we need two types of data: a reliable \"ground truth\" label of drought conditions and a set of physical predictor variables that influence those conditions.\n", | |
| "\n", | |
| "#### Target Variable: The \"Ground Truth\"\n", | |
| "\n", | |
| "* **[U.S. Drought Monitor (USDM)](https://developers.google.com/earth-engine/datasets/catalog/projects_sat-io_open-datasets_us-drought-monitor):** This dataset will serve as our target variable. It's a weekly, expert-assessed map of drought intensity across the United States. It's not a direct physical measurement but rather a consolidated judgment based on numerous data sources. It classifies drought into five categories, from D0 (Abnormally Dry) to D4 (Exceptional Drought). While it is the authoritative source for drought classification in the U.S., it is inherently retrospective.\n", | |
| "\n", | |
| "#### Predictor Variables: The Physical Drivers\n", | |
| "\n", | |
| "We hypothesize that drought conditions are a function of water availability and its impact on the landscape. We can measure these factors using various satellite-derived datasets.\n", | |
| "\n", | |
| "1. **Precipitation - [CHIRPS Daily](https://developers.google.com/earth-engine/datasets/catalog/UCSB-CHG_CHIRPS_DAILY):** Climate Hazards Group InfraRed Precipitation with Station data provides a daily, quasi-global rainfall dataset. It is a fundamental indicator of water input into the system. Persistent lack of precipitation is a primary driver of drought.\n", | |
| "2. **Soil Moisture - [SMAP](https://developers.google.com/earth-engine/datasets/catalog/NASA_SMAP_SPL4SMGP_008):** The Soil Moisture Active Passive (SMAP) mission provides global measurements of soil moisture. We will use the root-zone soil moisture product, which is critical for understanding water availability for plants.\n", | |
| "\n", | |
| "\n", | |
| "### Our Approach: A Time-Series Deep Learning Model\n", | |
| "\n", | |
| "**Hypothesis:** By analyzing a sequence of recent observations of precipitation, soil moisture, and vegetation health, a deep learning model can learn the complex, time-dependent patterns that lead to the drought classifications seen in the USDM.\n", | |
| "\n", | |
| "Drought is not an instantaneous event; it develops over time due to accumulated deficits. A simple model looking at a single point in time would miss this crucial context. This makes the problem perfectly suited for a **Convolutional Recurrent Neural Network (CRNN)**:\n", | |
| "\n", | |
| "* The **Convolutional (C)** part of the model can process the different predictor bands (`Precipitation`, `Soil Moisture`, `NDVI`) at each weekly time step, learning how these variables interact with each other.\n", | |
| "* The **Recurrent (R)** part of the model (using LSTM or GRU cells) can then process the sequence of these weekly feature sets, capturing the temporal dynamics—for example, learning that a gradual decline in soil moisture over 8 weeks is a strong indicator of developing drought.\n", | |
| "\n", | |
| "Our workflow will build this CRNN to predict the USDM class for a given week using the data from the preceding eight weeks, creating a robust, automated drought classification pipeline." | |
| ], | |
| "metadata": { | |
| "id": "vMqTvlOpKO16" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Part 1: Data Preparation & Export\n", | |
| "\n" | |
| ], | |
| "metadata": { | |
| "id": "hhcFRDGcF452" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 1.1 - Initialization and Setup" | |
| ], | |
| "metadata": { | |
| "id": "JUF8SjIXIT4J" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "#@title Install packages\n", | |
| "%%capture\n", | |
| "!pip install apache_beam[gcp]" | |
| ], | |
| "metadata": { | |
| "id": "yE-aBNZ1ujop", | |
| "cellView": "form" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "#@title Import packages\n", | |
| "import apache_beam as beam\n", | |
| "from apache_beam.options.pipeline_options import PipelineOptions\n", | |
| "import ee\n", | |
| "from google.api_core import exceptions, retry\n", | |
| "from google.colab import auth\n", | |
| "import io\n", | |
| "import geemap.core as geemap\n", | |
| "import geemap.colormaps as cmaps\n", | |
| "import math\n", | |
| "import numpy as np\n", | |
| "from pprint import pprint\n", | |
| "import tensorflow as tf\n", | |
| "import matplotlib.pyplot as plt\n" | |
| ], | |
| "metadata": { | |
| "cellView": "form", | |
| "id": "tqVzRlc9JqDo" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "#@title Authenticate and initialize cloud resources\n", | |
| "\n", | |
| "auth.authenticate_user()\n", | |
| "\n", | |
| "# --- User-defined parameters ---\n", | |
| "\n", | |
| "# Cloud project to use for example\n", | |
| "PROJECT = \"\" # @param {type:\"string\"}\n", | |
| "\n", | |
| "# Bucket to store data in\n", | |
| "BUCKET = \"g4g25-mlai-session\" # @param {type:\"string\"}\n", | |
| "\n", | |
| "# GCP region, us-central1 is co-located with GEE\n", | |
| "REGION = \"us-central1\" # @param {type:\"string\"}\n", | |
| "\n", | |
| "\n", | |
| "# Initiaize Earth Engine with the provided project id\n", | |
| "ee.Initialize(project=PROJECT)" | |
| ], | |
| "metadata": { | |
| "cellView": "form", | |
| "id": "6fI9nrIdF_8a" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "#@title Workflow parameters\n", | |
| "\n", | |
| "#@markdown Number of weeks to use as a time series for prediction\n", | |
| "SEQUENCE_LENGTH = 4 #@param {type:\"number\"}\n", | |
| "#@markdown ___\n", | |
| "\n", | |
| "\n", | |
| "#@markdown Number of sample points per drought class per year\n", | |
| "N_SAMPLES_PER_TIME = 12 #@param {type:\"number\"}\n", | |
| "#@markdown ___\n", | |
| "\n", | |
| "\n", | |
| "#@markdown Years to generate training data for\n", | |
| "EXPORT_YEARS = [2018, 2019, 2020, 2021, 2022] #@param {type:\"raw\"}\n", | |
| "#@markdown ___\n", | |
| "\n", | |
| "\n", | |
| "#@markdown Prefix for exported training data on GCS\n", | |
| "EXPORT_FILE_PREFIX = 'drought_crnn_training' #@param {type:\"string\"}\n", | |
| "#@markdown ___\n", | |
| "\n", | |
| "#@markdown Size of the image patches to extract from Earth Engine\n", | |
| "PATCH_SIZE = 64 #@param {type:\"number\"}\n", | |
| "#@markdown ___\n", | |
| "\n", | |
| "#@markdown Scale in meters for extracting data out of Earth Engine\n", | |
| "SCALE = 5000 #@param {type:\"number\"}\n", | |
| "\n", | |
| "\n", | |
| "BEAM_OUTPUT_PREFIX = f'gs://{BUCKET}/{EXPORT_FILE_PREFIX}'" | |
| ], | |
| "metadata": { | |
| "cellView": "form", | |
| "id": "iTw7L_PDG0h-" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "#@title Create map panel\n", | |
| "\n", | |
| "# create a Map object with geemap to visualize EE results\n", | |
| "m = geemap.Map()\n", | |
| "\n", | |
| "# programically create a scratch cell for displaying the Map\n", | |
| "from google.colab import _frontend\n", | |
| "_frontend.create_scratch_cell(\"#@title Map\\nm\", False)" | |
| ], | |
| "metadata": { | |
| "id": "6sLmLq3jSK5L", | |
| "cellView": "form" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 1.2 - Load and Preprocess Data Collections" | |
| ], | |
| "metadata": { | |
| "id": "kT6DpU5jICq1" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Target Variable: US Drought Monitor (weekly)\n", | |
| "usdm = ee.ImageCollection(\"projects/sat-io/open-datasets/us-drought-monitor\")" | |
| ], | |
| "metadata": { | |
| "id": "vtOxvYuRIBM0" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Predictor Variables\n", | |
| "precip = ee.ImageCollection(\"UCSB-CHG/CHIRPS/DAILY\")\n", | |
| "soil_moisture = ee.ImageCollection(\"NASA/SMAP/SPL4SMGP/008\").select('sm_rootzone')" | |
| ], | |
| "metadata": { | |
| "id": "Yrws-FUwIJma" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 1.3 - Create Time Series Composites" | |
| ], | |
| "metadata": { | |
| "id": "pyCas5zxHq1h" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def create_weekly_composite(date):\n", | |
| " \"\"\"Creates a 7-day composite of all predictor variables for a given end date.\"\"\"\n", | |
| " end_date = ee.Date(date)\n", | |
| " start_date = end_date.advance(-7, \"days\")\n", | |
| "\n", | |
| " precip_img = (\n", | |
| " precip.filterDate(start_date, end_date)\n", | |
| " .sum().rename(\"PRECIP\")\n", | |
| " )\n", | |
| " sm_img = soil_moisture.filterDate(start_date, end_date).mean().rename(\"SM\")\n", | |
| "\n", | |
| " usdm_filtered = usdm.filterDate(start_date, end_date)\n", | |
| " n = usdm_filtered.size()\n", | |
| " usdm_img_ = usdm_filtered.first().add(1).unmask(0)\n", | |
| " usdm_img = ee.Algorithms.If(n.gt(0), usdm_img_, ee.Image(0))\n", | |
| "\n", | |
| " return ee.Image.cat([precip_img, sm_img,usdm_img]).set('system:time_start', start_date.millis())" | |
| ], | |
| "metadata": { | |
| "id": "a-VqfYgsIYPX" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Define a dictionary which will be used to make legend and visualize image on map\n", | |
| "vis_dict = {\n", | |
| " 'names': [\n", | |
| " \"DO\tAbnormally Dry\", # 1\n", | |
| " \"D1 Moderate Drought\", # 2\n", | |
| " \"D2 Severe Drought\", # 3\n", | |
| " \"D3 Extreme Drought\", # 4\n", | |
| " \"D4 Exceptional Drought\", # 5\n", | |
| " ],\n", | |
| " 'colors': [\"FFFF00\", \"FCD37F\", \"FFAA00\", \"E60000\", \"730000\"],\n", | |
| "};" | |
| ], | |
| "metadata": { | |
| "id": "GEwxHnBHSsug" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "date = '2022-01-01'\n", | |
| "composite = create_weekly_composite(date)\n", | |
| "\n", | |
| "m.set_center(-100, 40, 4)\n", | |
| "\n", | |
| "m.add_layer(composite, {'bands':'SM', 'min':0,'max':1, 'palette':cmaps.get_palette('viridis')}, f\"SM {date}\")\n", | |
| "m.add_layer(composite, {'bands':'PRECIP', 'min':0,'max':50, 'palette':cmaps.get_palette('Blues')}, f\"PRECIP {date}\")\n", | |
| "m.add_layer(composite.selfMask(), {'bands':'DM', 'min':1,'max':6, 'palette':vis_dict['colors']}, f\"USDM {date}\")" | |
| ], | |
| "metadata": { | |
| "id": "q0w4WxGxSTly" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 1.4 - Sample Generation for CRNN" | |
| ], | |
| "metadata": { | |
| "id": "PGW-XKevIq4z" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Define the region of interest\n", | |
| "conus_region = ee.FeatureCollection('USDOS/LSIB_SIMPLE/2017') \\\n", | |
| " .filter('country_na == \"United States\"') \\\n", | |
| " .geometry().dissolve().simplify(10000)" | |
| ], | |
| "metadata": { | |
| "id": "9Td-7JNFIlKJ" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Function for hexagonal grid to create spatial folds (train/val/test)\n", | |
| "def conus_hex(hex_size=50_000) -> ee.Image:\n", | |
| " # NAD83 / Conus Albers projection, equal area.\n", | |
| " proj = 'EPSG:5070'\n", | |
| "\n", | |
| " square_proj = ee.Projection(\n", | |
| " proj, [1, 0, 0, 0, math.sqrt(3.0)/2.0, 0])\n", | |
| " shear_l = ee.Projection(\n", | |
| " proj, [1, .5, 0, 0, math.sqrt(3.0)/2.0, 0])\n", | |
| " shear_r = ee.Projection(proj, [1, -.5, 0, 0, math.sqrt(3.0)/ 2.0, 0])\n", | |
| " base = 6\n", | |
| "\n", | |
| " def create_stripes(proj: ee.Projection, lat_lon_axis: int):\n", | |
| " stripes = ee.Image.pixelLonLat().changeProj(\n", | |
| " 'EPSG:4326', proj).select(lat_lon_axis)\n", | |
| " rows = stripes.divide(hex_size).floor().mod(6)\n", | |
| " return rows.add(base).mod(base)\n", | |
| "\n", | |
| " rows = create_stripes(square_proj, 1)\n", | |
| " left = create_stripes(shear_l, 0)\n", | |
| " right = create_stripes(shear_r, 0)\n", | |
| "\n", | |
| " left_sum_right = left.add(right)\n", | |
| " wheels = left_sum_right.add(rows.multiply(3)).mod(6)\n", | |
| " rows2 = wheels.gte(3).add(rows.mod(6))\n", | |
| " hex_mask = rows2.mod(3)\n", | |
| " return hex_mask.rename('hex_mask')" | |
| ], | |
| "metadata": { | |
| "id": "U7A7D81eI3yO" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "m.add_layer(conus_hex(75_000), {}, 'Hex grid', True, 0.35)" | |
| ], | |
| "metadata": { | |
| "id": "LRaDRm-TXHQW" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Generate stratified random sample points for a given year\n", | |
| "def get_sample_points(date):\n", | |
| " if not isinstance(date, ee.Number):\n", | |
| " date = ee.Date(date).millis()\n", | |
| "\n", | |
| " n_hex_classes = 3\n", | |
| " samples_per_class = N_SAMPLES_PER_TIME//n_hex_classes\n", | |
| "\n", | |
| " points = conus_hex(75_000).int().stratifiedSample(\n", | |
| " numPoints=N_SAMPLES_PER_TIME,\n", | |
| " classBand='hex_mask',\n", | |
| " region=conus_region,\n", | |
| " scale=5000,\n", | |
| " classValues=ee.List.sequence(0, n_hex_classes-1),\n", | |
| " classPoints=ee.List([samples_per_class,]*n_hex_classes),\n", | |
| " geometries=True,\n", | |
| " seed=date.divide(1e3).toInt64()\n", | |
| " )\n", | |
| " return points.map(lambda f: f.set('date',date))" | |
| ], | |
| "metadata": { | |
| "id": "stqMaDoBI6Qf" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Prepare the label data by getting all unique USDM dates for the target years\n", | |
| "usdm_dates = usdm.filter(ee.Filter.calendarRange(min(EXPORT_YEARS), max(EXPORT_YEARS), 'year')) \\\n", | |
| " .aggregate_array(\"system:time_start\") \\\n", | |
| " .map(lambda t: ee.Date(t).format('YYYY-MM-dd')) \\\n", | |
| " .distinct()" | |
| ], | |
| "metadata": { | |
| "id": "mZok9nSbI411" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "usdm_dates" | |
| ], | |
| "metadata": { | |
| "id": "Yn-_L_mVcAkI" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "predictor_bands = ['PRECIP', 'SM','DM']\n", | |
| "time_series_bands = [f'{b}_{i}' for i in range(SEQUENCE_LENGTH) for b in predictor_bands]" | |
| ], | |
| "metadata": { | |
| "id": "lZRmMcuNhKk7" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "training_pts = ee.FeatureCollection(usdm_dates.map(get_sample_points)).flatten()" | |
| ], | |
| "metadata": { | |
| "id": "40tAGdc3uYiZ" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "m.add_layer(training_pts,{'color':'red'},'training points')" | |
| ], | |
| "metadata": { | |
| "id": "kpPhVWX95Ow0" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 1.5 - Export Data to Google Cloud Storage" | |
| ], | |
| "metadata": { | |
| "id": "RZ9hVVugJDeA" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# request the points to client-side object to pass in Beam pipeline\n", | |
| "client_side_training_pts = training_pts.getInfo()" | |
| ], | |
| "metadata": { | |
| "id": "4oET37N05Wti" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Make a projection to discover the scale in degrees.\n", | |
| "proj = ee.Projection('EPSG:4326').atScale(SCALE).getInfo()\n", | |
| "\n", | |
| "# Get scales out of the transform.\n", | |
| "scale_x = proj['transform'][0]\n", | |
| "scale_y = -proj['transform'][4]" | |
| ], | |
| "metadata": { | |
| "id": "RkwKWxrPxZql" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# This is the core function to generate a training sequence for each point\n", | |
| "\n", | |
| "@retry.Retry(predicate=retry.if_exception_type(exceptions.GoogleAPICallError), deadline=60)\n", | |
| "def get_training_patch(feature):\n", | |
| "\n", | |
| " coords = feature['geometry']['coordinates']\n", | |
| " date = ee.Date(feature['properties']['date'])\n", | |
| "\n", | |
| " label = usdm.filterDate(date, date.advance(7, 'days')).first() \\\n", | |
| " .add(1).unmask(0).select(['DM'],['label'])\n", | |
| " # Generate the sequence of predictor composites for the N weeks *before* the label date\n", | |
| " # This part of the logic remains the same.\n", | |
| " def get_weekly_image(week_offset):\n", | |
| " target_date = date.advance(ee.Number(week_offset).multiply(-1), 'week')\n", | |
| " return create_weekly_composite(target_date)\n", | |
| "\n", | |
| " week_offsets = ee.List.sequence(0, SEQUENCE_LENGTH - 1)\n", | |
| " predictor_collection = ee.ImageCollection.fromImages(week_offsets.map(get_weekly_image))\n", | |
| "\n", | |
| " point = ee.Geometry.Point(coords)\n", | |
| "\n", | |
| " # Convert the collection to a 3D array (time x bands x tile) and sample it\n", | |
| " example_image = predictor_collection.toBands().rename(time_series_bands).addBands(label)\n", | |
| "\n", | |
| " path_aoi = point.buffer(SCALE*PATCH_SIZE).bounds()\n", | |
| "\n", | |
| " pixel_data = ee.data.computePixels({\n", | |
| " 'expression': example_image,\n", | |
| " 'grid': {\n", | |
| " 'dimensions': {\n", | |
| " 'width': PATCH_SIZE,\n", | |
| " 'height': PATCH_SIZE\n", | |
| " },\n", | |
| " 'affineTransform': {\n", | |
| " 'scaleX': scale_x,\n", | |
| " 'shearX': 0,\n", | |
| " 'translateX': coords[0]+(PATCH_SIZE//2*scale_x),\n", | |
| " 'shearY': 0,\n", | |
| " 'scaleY': scale_y,\n", | |
| " 'translateY': coords[1]-(PATCH_SIZE//2*scale_x)\n", | |
| " },\n", | |
| " 'crsCode': proj['crs'],\n", | |
| " },\n", | |
| " 'fileFormat': 'NPY',\n", | |
| " })\n", | |
| "\n", | |
| " # The result is a byte string, load it as a NumPy array\n", | |
| " example_array = np.load(io.BytesIO(pixel_data))\n", | |
| "\n", | |
| " id = feature['id']\n", | |
| "\n", | |
| " # Return a feature with all necessary properties\n", | |
| " return example_array" | |
| ], | |
| "metadata": { | |
| "id": "zYlcpt3NI-FO" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def serialize_example(patch):\n", | |
| " \"\"\"Converts a structured array from the previous step into a serialized TFRecord example.\"\"\"\n", | |
| " # For each band in the image, create a feature and add it to the dictionary.\n", | |
| " features = {\n", | |
| " name: tf.train.Feature(\n", | |
| " float_list=tf.train.FloatList(value=patch[name].astype(np.float32).flatten())\n", | |
| " )\n", | |
| " for name in patch.dtype.names\n", | |
| " }\n", | |
| " example = tf.train.Example(features=tf.train.Features(feature=features))\n", | |
| "\n", | |
| " return example.SerializeToString()" | |
| ], | |
| "metadata": { | |
| "id": "RuSX1fyj9gei" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "beam_options = PipelineOptions(\n", | |
| " direct_running_mode='multi_threading',\n", | |
| " direct_num_workers=0,\n", | |
| " # runner='DataflowRunner',\n", | |
| " project=PROJECT,\n", | |
| " region=REGION,\n", | |
| " # temp_location=f'gs://{BUCKET}/beam-temp',\n", | |
| " # job_name='drought-crnn-data-generation'\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "h8nSW39fudTA" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "print(\"Starting Apache Beam pipeline...\")\n", | |
| "with beam.Pipeline(options=beam_options) as pipeline:\n", | |
| " # 1. Create initial PCollection of samples created for each time\n", | |
| " # 2. Map to fetch predictor data for each point (the heavy lifting)\n", | |
| " # 3. Map to serialize the data into TFRecord format\n", | |
| " # 4. Write each TFRecord partitions to GCS\n", | |
| "\n", | |
| " arrays = (\n", | |
| " pipeline\n", | |
| " | 'CreateInputSamples' >> beam.Create(client_side_training_pts['features'])\n", | |
| " | 'GetTrainingData' >> beam.Map(get_training_patch)\n", | |
| " | 'Serialize' >> beam.Map(serialize_example)\n", | |
| " | 'WriteTFRecord' >> beam.io.WriteToTFRecord(\n", | |
| " f'{BEAM_OUTPUT_PREFIX}/data', file_name_suffix='.tfrecord.gz'\n", | |
| " )\n", | |
| " )\n", | |
| "\n", | |
| "print(\"Beam pipeline finished. Check your GCS bucket for the output files.\")" | |
| ], | |
| "metadata": { | |
| "id": "G8RageSkuahW" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Part 2: TensorFlow Model Building & Training" | |
| ], | |
| "metadata": { | |
| "id": "GF_fG4gdJOjM" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 2.1 - Load and Parse TFRecord Data" | |
| ], | |
| "metadata": { | |
| "id": "HAwiD-r8C3Gs" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "print(\"Listing exported TFRecord files...\")\n", | |
| "!gsutil ls gs://{BUCKET}/{EXPORT_FILE_PREFIX}*" | |
| ], | |
| "metadata": { | |
| "id": "Wy81sR-GC8d_" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "tfrecord_files = tf.io.gfile.glob(f'gs://{BUCKET}/{EXPORT_FILE_PREFIX}/*.tfrecord.gz')\n", | |
| "dataset = tf.data.TFRecordDataset(tfrecord_files, compression_type='GZIP')" | |
| ], | |
| "metadata": { | |
| "id": "FTJBsYSnC9p5" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# --- Define the parsing schema ---\n", | |
| "# The features are exported as flat lists, we need to reshape them.\n", | |
| "feature_description = {}\n", | |
| "flattened_size = PATCH_SIZE * PATCH_SIZE\n", | |
| "for i in range(SEQUENCE_LENGTH):\n", | |
| " for band in predictor_bands:\n", | |
| " # dtype = tf.float32 if 'DM' not in band else tf.int32\n", | |
| " feature_description[f'{band}_{i}'] = tf.io.FixedLenFeature([PATCH_SIZE,PATCH_SIZE], tf.float32)" | |
| ], | |
| "metadata": { | |
| "id": "dc1d_bWCDA6c" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "feature_description" | |
| ], | |
| "metadata": { | |
| "id": "72ZDQD0tOgRi" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "feature_description['label'] = tf.io.FixedLenFeature([PATCH_SIZE, PATCH_SIZE], tf.float32)\n", | |
| "NUM_CLASSES = 6 # 0 (No drought) to 5 (D4)\n", | |
| "NUM_BANDS = len(feature_description.keys()) -1" | |
| ], | |
| "metadata": { | |
| "id": "BI_27U2_DB_n" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def parse_tfrecord_fn(example_proto):\n", | |
| " \"\"\"Parses a TFRecord example and reshapes the features.\"\"\"\n", | |
| " parsed_features = tf.io.parse_single_example(example_proto, feature_description)\n", | |
| "\n", | |
| " # Stack the flat features into a (y,x,c) tensor\n", | |
| " features_list = []\n", | |
| " for i in range(SEQUENCE_LENGTH):\n", | |
| " for band in predictor_bands:\n", | |
| " features_list.append(tf.concat(parsed_features[f'{band}_{i}'], axis=-1))\n", | |
| "\n", | |
| " features = tf.stack(features_list,axis=-1)\n", | |
| " label = tf.cast(parsed_features['label'], tf.int32)\n", | |
| "\n", | |
| " return features, tf.one_hot(label, NUM_CLASSES)\n", | |
| "\n", | |
| "# Create the parsed dataset\n", | |
| "parsed_dataset = dataset.map(parse_tfrecord_fn)" | |
| ], | |
| "metadata": { | |
| "id": "lzPNwSn0C677" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 2.2 - Split Data and Create Batches" | |
| ], | |
| "metadata": { | |
| "id": "DfLvbRVhDcC1" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "n_examples = len(client_side_training_pts['features'])" | |
| ], | |
| "metadata": { | |
| "id": "L__xVxGcQqFq" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "train_prop = 0.7\n", | |
| "test_prop = 0.2\n", | |
| "val_prop = 0.1\n", | |
| "\n", | |
| "ds = parsed_dataset.shuffle(n_examples)\n", | |
| "train_size = int(n_examples*train_prop)\n", | |
| "train = ds.take(train_size)\n", | |
| "\n", | |
| "test_size = int(n_examples*test_prop)\n", | |
| "test = ds.skip(train_size).take(test_size)\n", | |
| "\n", | |
| "val_size = int(n_examples*test_prop)\n", | |
| "val = ds.skip(train_size+test_size).take(val_size)" | |
| ], | |
| "metadata": { | |
| "id": "W8QxW2TcDgQM" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Prepare datasets for training\n", | |
| "BATCH_SIZE = 64\n", | |
| "train_dataset = train.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)\n", | |
| "test_dataset = test.batch(1).prefetch(tf.data.AUTOTUNE)\n", | |
| "val_dataset = val.batch(1).prefetch(tf.data.AUTOTUNE)\n" | |
| ], | |
| "metadata": { | |
| "id": "pkVUx81TDfZg" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 2.3 - Build the CRNN Model" | |
| ], | |
| "metadata": { | |
| "id": "F2-q-FCdDj6B" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "class ReshapeLayer(tf.keras.Layer):\n", | |
| " def call(self, x):\n", | |
| " x_slices = []\n", | |
| " for i in range(SEQUENCE_LENGTH):\n", | |
| " x_slices.append(x[:,:,:,i*len(predictor_bands):(i*len(predictor_bands))+len(predictor_bands)])\n", | |
| " return tf.keras.ops.moveaxis(\n", | |
| " tf.stack(x_slices, axis=-1), -1, 1\n", | |
| " )\n", | |
| "\n", | |
| "def build_crnn_model():\n", | |
| " input = tf.keras.layers.Input(shape=(PATCH_SIZE, PATCH_SIZE, NUM_BANDS),name='input')\n", | |
| " #reashape from 2D+band to 3D+bands\n", | |
| " x = ReshapeLayer()(input)\n", | |
| "\n", | |
| " x = tf.keras.layers.ConvLSTM2D(\n", | |
| " filters=32,\n", | |
| " kernel_size=(5, 5),\n", | |
| " padding=\"same\",\n", | |
| " return_sequences=True,\n", | |
| " activation=\"relu\",\n", | |
| " )(x)\n", | |
| " x = tf.keras.layers.BatchNormalization()(x)\n", | |
| " x = tf.keras.layers.ConvLSTM2D(\n", | |
| " filters=64,\n", | |
| " kernel_size=(3,3),\n", | |
| " padding=\"same\",\n", | |
| " return_sequences=True,\n", | |
| " activation=\"relu\",\n", | |
| " )(x)\n", | |
| " x = tf.keras.layers.BatchNormalization()(x)\n", | |
| " x=tf.keras.layers.TimeDistributed(\n", | |
| " tf.keras.layers.Conv2DTranspose(\n", | |
| " PATCH_SIZE,\n", | |
| " kernel_size=3,\n", | |
| " strides=1,\n", | |
| " activation=\"relu\",\n", | |
| " padding=\"same\"\n", | |
| " )\n", | |
| " )(x)\n", | |
| " x=tf.keras.layers.TimeDistributed(\n", | |
| " tf.keras.layers.BatchNormalization()\n", | |
| " )(x)\n", | |
| " x = tf.keras.layers.BatchNormalization()(x)\n", | |
| " x = tf.keras.layers.ConvLSTM2D(\n", | |
| " filters=32,\n", | |
| " kernel_size=(1,1),\n", | |
| " padding=\"same\",\n", | |
| " return_sequences=False,\n", | |
| " activation=\"relu\",\n", | |
| " )(x)\n", | |
| " y = tf.keras.layers.Conv2D(\n", | |
| " filters=NUM_CLASSES,\n", | |
| " kernel_size=(1,1,),\n", | |
| " activation=\"softmax\",\n", | |
| " padding=\"same\",\n", | |
| " name='output'\n", | |
| " )(x)\n", | |
| "\n", | |
| " return tf.keras.models.Model(input, y)\n", | |
| "\n", | |
| "model = build_crnn_model()\n", | |
| "model.compile(\n", | |
| " optimizer='adam',\n", | |
| " loss='categorical_crossentropy',\n", | |
| " metrics=['accuracy']\n", | |
| ")\n", | |
| "model.summary()" | |
| ], | |
| "metadata": { | |
| "id": "Pbs5PuzsDnWS" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "###2.4 - Train the Model" | |
| ], | |
| "metadata": { | |
| "id": "7spEbzkqDjxS" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Define some callbacks to improve training.\n", | |
| "early_stopping = tf.keras.callbacks.EarlyStopping(monitor=\"val_loss\", patience=5, restore_best_weights=True, verbose=1)" | |
| ], | |
| "metadata": { | |
| "id": "RqRenKl6sHup" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "print(\"\\nStarting model training...\")\n", | |
| "history = model.fit(\n", | |
| " train_dataset,\n", | |
| " epochs=20,\n", | |
| " validation_data=test_dataset,\n", | |
| " callbacks=[early_stopping,],\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "_MYx1OUzDv4E" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "plt.figure(figsize=(8,5))\n", | |
| "plt.plot(history.history['loss'])\n", | |
| "plt.plot(history.history['val_loss'])\n", | |
| "plt.title('model loss')\n", | |
| "plt.ylabel('loss')\n", | |
| "plt.xlabel('epoch')\n", | |
| "plt.legend(['train', 'test'], loc='upper right');" | |
| ], | |
| "metadata": { | |
| "id": "aM5lovWBwSGw" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 2.5 - Evaluate the Model" | |
| ], | |
| "metadata": { | |
| "id": "ciYWJzO5D0Ha" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "model.evaluate(val_dataset)" | |
| ], | |
| "metadata": { | |
| "id": "B0a2DIcWXmKS" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "x,y = next(iter(train_dataset))\n", | |
| "\n", | |
| "pred = model(x)" | |
| ], | |
| "metadata": { | |
| "id": "Vz_cRGrGDzsX" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "i = np.random.randint(0,x.shape[0])\n", | |
| "\n", | |
| "fig, ax = plt.subplots(1,2,figsize=(15,5))\n", | |
| "\n", | |
| "ax[0].imshow(np.argmax(y[i,:,:,:],axis=-1),vmin=0,vmax=6)\n", | |
| "ax[0].set_title('Observed')\n", | |
| "\n", | |
| "cb = ax[1].imshow(np.argmax(pred[i,:,:,:],axis=-1),vmin=0,vmax=6)\n", | |
| "ax[1].set_title('Predicted')\n", | |
| "\n", | |
| "plt.colorbar(cb)\n", | |
| "\n", | |
| "plt.show()" | |
| ], | |
| "metadata": { | |
| "id": "KiNvrg0XJySC" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Part 3: Deploy Model to Vertex AI & Perform Inference\n" | |
| ], | |
| "metadata": { | |
| "id": "vp45m86tD7v9" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 3.1 - Save and Deploy Model to Vertex AI\n" | |
| ], | |
| "metadata": { | |
| "id": "MfzD5t88EEya" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "MODEL_DISPLAY_NAME = 'drought_crnn_v1'\n", | |
| "MODEL_DIR = f'gs://{BUCKET}/{MODEL_DISPLAY_NAME}'\n", | |
| "\n", | |
| "print(f\"\\nSaving model to {MODEL_DIR}...\")\n", | |
| "model.export(MODEL_DIR)\n", | |
| "\n", | |
| "print(\"\\nUploading model to Vertex AI...\")\n", | |
| "!gcloud ai models upload \\\n", | |
| " --project={PROJECT} \\\n", | |
| " --region={REGION} \\\n", | |
| " --container-grpc-ports=8500 --container-ports=8080 \\\n", | |
| " --display-name={MODEL_DISPLAY_NAME} \\\n", | |
| " --artifact-uri={MODEL_DIR} \\\n", | |
| " --container-image-uri=\"us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-15:latest\"\n", | |
| "\n", | |
| "# Get the Model ID\n", | |
| "MODEL_LIST_OUTPUT = !gcloud ai models list --project={PROJECT} --region={REGION} --filter=\"displayName={MODEL_DISPLAY_NAME}\" --format=\"value(name)\"\n", | |
| "MODEL_ID = MODEL_LIST_OUTPUT[1]\n", | |
| "print(f\"Model ID: {MODEL_ID}\")\n", | |
| "\n", | |
| "print(\"\\nCreating Vertex AI Endpoint...\")\n", | |
| "ENDPOINT_DISPLAY_NAME = f'{MODEL_DISPLAY_NAME}-endpoint'\n", | |
| "!gcloud ai endpoints create \\\n", | |
| " --project={PROJECT} \\\n", | |
| " --region={REGION} \\\n", | |
| " --display-name={ENDPOINT_DISPLAY_NAME}\n", | |
| "\n", | |
| "# Get the Endpoint ID\n", | |
| "ENDPOINT_LIST_OUTPUT = !gcloud ai endpoints list --project={PROJECT} --region={REGION} --filter=\"displayName={ENDPOINT_DISPLAY_NAME}\" --format=\"value(name)\"\n", | |
| "ENDPOINT_ID = ENDPOINT_LIST_OUTPUT[-1]\n", | |
| "print(f\"Endpoint ID: {ENDPOINT_ID}\")\n", | |
| "\n", | |
| "print(\"\\nDeploying model to endpoint...\")\n", | |
| "!gcloud ai endpoints deploy-model {ENDPOINT_ID} \\\n", | |
| " --project={PROJECT} \\\n", | |
| " --region={REGION} \\\n", | |
| " --model={MODEL_ID} \\\n", | |
| " --max-replica-count=5 \\\n", | |
| " --machine-type=c2-standard-4 \\\n", | |
| " --display-name={MODEL_DISPLAY_NAME}" | |
| ], | |
| "metadata": { | |
| "id": "G39lL7A2EGEK" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 3.2 - Inference with Earth Engine" | |
| ], | |
| "metadata": { | |
| "id": "iFvUHXr6ELo9" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Connect GEE to the deployed Vertex AI model\n", | |
| "vertex_model = ee.Model.fromVertexAi(\n", | |
| " endpoint=f'projects/{PROJECT}/locations/{REGION}/endpoints/{ENDPOINT_ID}',\n", | |
| " # Process data in 64x64 pixel tiles\n", | |
| " # Requests will include 56px patch + 4px overlap on each side\n", | |
| " inputTileSize=[56,56],\n", | |
| " inputOverlapSize=[4,4],\n", | |
| " inputShapes={'array': [10]},\n", | |
| " outputBands={'output_0': {'type': ee.PixelType('float',dimensions=1)}},\n", | |
| " payloadFormat='GRPC_TF_TENSORS',\n", | |
| " fixInputProj=True,\n", | |
| " proj=ee.Projection('EPSG:4326').atScale(5000),\n", | |
| ")\n", | |
| "\n", | |
| "# Prepare an input image for a new date\n", | |
| "INFERENCE_DATE = '2023-09-01'\n", | |
| "inference_date = ee.Date(INFERENCE_DATE)\n", | |
| "\n", | |
| "# Generate the time series stack for the entire region\n", | |
| "week_offsets_inf = ee.List.sequence(0, SEQUENCE_LENGTH - 1)\n", | |
| "def get_inference_weekly_image(week_offset):\n", | |
| " target_date = inference_date.advance(ee.Number(week_offset).multiply(-7), 'day')\n", | |
| " return create_weekly_composite(target_date)\n", | |
| "\n", | |
| "inference_collection = ee.ImageCollection.fromImages(week_offsets_inf.map(get_inference_weekly_image))\n", | |
| "\n", | |
| "# The model expects a batch dimension, so we add it. GEE handles this automatically\n", | |
| "# for predictImage when the inputShapes dictionary has a batch dimension.\n", | |
| "# We also need to convert to an array image for the model.\n", | |
| "input_image_for_model = inference_collection.toBands().toArray().toFloat().rename('input')\n", | |
| "\n", | |
| "\n", | |
| "print(\"\\nRunning inference in Earth Engine...\")\n", | |
| "prediction_arr = vertex_model.predictImage(input_image_for_model).arrayProject([0]).rename('drought_pred')\n", | |
| "\n", | |
| "class_probas = prediction_arr.arrayFlatten([['no_drought',]+[f'D{i}' for i in range(5)]])\n", | |
| "\n", | |
| "# The output is a probability vector for each class. Use arrayArgmax to get the most likely class.\n", | |
| "drought_prediction_class = prediction_arr.arrayArgmax().arrayGet([0])" | |
| ], | |
| "metadata": { | |
| "id": "jnRVYJwaEQS0" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### 3.3 - Visualize Inference Results\n" | |
| ], | |
| "metadata": { | |
| "id": "unLHS5mCEfm9" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Define visualization parameters for the USDM classes\n", | |
| "vis_params = {'min': 1, 'max': 6, 'palette': vis_dict[\"colors\"]}\n", | |
| "\n", | |
| "# Get the actual USDM data for comparison\n", | |
| "actual_usdm = usdm.filterDate(inference_date, inference_date.advance(7,'day')).first().select('DM')\n", | |
| "\n", | |
| "# Add layers to the map\n", | |
| "m.centerObject(conus_region, 4)\n", | |
| "m.addLayer(class_probas, {}, f'Predicted proba ({INFERENCE_DATE})')\n", | |
| "m.addLayer(actual_usdm.add(1), vis_params, f'Actual USDM ({INFERENCE_DATE})')\n", | |
| "m.addLayer(drought_prediction_class.selfMask(), vis_params, f'Predicted Drought ({INFERENCE_DATE})')\n" | |
| ], | |
| "metadata": { | |
| "id": "5XLzB1HpEer5" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "diff = actual_usdm.add(1).unmask(0).subtract(drought_prediction_class)\n", | |
| "m.addLayer(diff, {'min':-5, 'max': 5, 'palette':['red','white','blue']}, f'Difference')\n" | |
| ], | |
| "metadata": { | |
| "id": "zjQydd9kg5l-" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Conclusion\n", | |
| "\n", | |
| "This notebook provides a complete, end-to-end workflow for building and deploying a deep learning model to classify drought conditions. It addresses a complex environmental challenge by learning from historical patterns in satellite data.\n", | |
| "\n", | |
| "The methodology leverages a powerful Convolutional Recurrent Neural Network (CRNN). The convolutional component learns the complex relationships between the different data bands at each weekly time step, while the recurrent component (an LSTM) learns the temporal sequence of these conditions, allowing the model to understand how drought develops over time.\n", | |
| "\n", | |
| "This notebook is a powerful example because it tackles the unique challenges of modeling dynamic, four-dimensional (latitude, longitude, time, and variables) Earth system processes. It serves as a blueprint for a wide range of similar problems by capturing temporal dynamics. Many critical environmental processes—like crop growth, deforestation, flood risk, or fire recovery—cannot be understood from a single snapshot in time. This workflow demonstrates how to structure data into time-series sequences and use Recurrent Neural Networks to explicitly model these time-dependent patterns. It moves beyond simple image classification to genuine process modeling.\n", | |
| "\n", | |
| "### Opportunities for Improvement and Future Work\n", | |
| "While this notebook offers a blueprint, this workflow can be extended and enhanced in several exciting ways:\n", | |
| " * Richer Feature Engineering:\n", | |
| "Add More Predictors: Incorporate other variables known to influence drought, such as Land Surface Temperature (LST), Evapotranspiration (PET/AET) from datasets like TerraClimate or MODIS, and Snow Water Equivalent (SWE) in relevant areas.\n", | |
| " * Include Static Data: Add static features like elevation, slope, aspect, and soil type. These variables don't change over time but heavily influence how a region responds to water deficits.\n", | |
| "\n", | |
| "* Experiment with model hyperparameter\n", | |
| " * Change the sequence length or patch sizes to better capture the dynamics\n", | |
| "\n", | |
| "* Model Architectures:\n", | |
| " * Transformer Models: Investigate the use of Transformer architectures (e.g., Vision Transformer variants), which use attention mechanisms to weigh the importance of different time steps in the sequence and are the current state-of-the-art in many sequence modeling tasks.\n" | |
| ], | |
| "metadata": { | |
| "id": "6Xyh6n9Z-3Tw" | |
| } | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment