Last active
May 25, 2025 12:52
-
-
Save fepegar/8c5f5444dbd6b29a44c6fa14c070c4c9 to your computer and use it in GitHub Desktop.
Revisions
-
fepegar revised this gist
Feb 18, 2025 . 1 changed file with 120 additions and 50 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -26,11 +26,14 @@ import numpy as np import numpy.typing as npt import SimpleITK as sitk import torch import typer from einops import rearrange from loguru import logger from PIL import Image from procex import functional as F from procex.imgio import read_image from rich import print from rich.progress import ( BarColumn, @@ -44,7 +47,9 @@ from tqdm.auto import tqdm from transformers import AutoModel, BitImageProcessor app = typer.Typer( no_args_is_help=True, ) @enum.unique @@ -55,33 +60,35 @@ class Model(str, enum.Enum): @app.command() def main( input: Annotated[ Path, typer.Argument( help=( "Input image(s). If it is a DICOM file, it will be temporarily" " converted to PNG. If it is a .txt file, it must contain paths" "to images, one per line." ), show_default=False, exists=True, file_okay=True, dir_okay=False, writable=False, readable=True, resolve_path=True, ), ], features_path: Annotated[ Path | None, typer.Option( "--features", "-f", help="Output features file.", show_default=False, exists=True, file_okay=True, dir_okay=False, writable=True, readable=False, ), ] = None, out_dir: Annotated[ @@ -90,6 +97,28 @@ def main( "--out-dir", help="Output directory for features files.", show_default=False, exists=False, file_okay=False, dir_okay=True, writable=True, readable=False, ), ] = None, in_dir: Annotated[ Path | None, typer.Option( "--in-dir", help=( "If passed, the path of the output relative to the output" " directory will be the same as the input path relative to" " this directory." ), show_default=False, exists=True, file_okay=False, dir_okay=True, writable=False, readable=True, ), ] = None, model_name: Annotated[ @@ -125,26 +154,11 @@ def main( ), ] = True, ) -> None: input_paths = _get_input_paths(input) output_paths = _get_output_paths(input_paths, features_path, out_dir, in_dir) import sys sys.exit() device = _get_device() @@ -158,22 +172,65 @@ def main( input_batches = batched(input_paths, batch_size) output_batches = batched(output_paths, batch_size) iterable = list(zip(input_batches, output_batches)) _process_batches( iterable, model, processor, device, cls=cls, patch=patch, ) def _get_input_paths( input_path: Path, ) -> list[Path]: if input_path.suffix == ".txt": with input_path.open() as f: input_paths = [Path(line.strip()) for line in f] else: input_paths = [input_path] return input_paths def _get_output_paths( input_paths: list[Path], features_path: Path | None, out_dir: Path | None, in_dir: Path | None, ) -> list[Path]: if features_path is not None and out_dir is not None: message = "You can only provide one of --features or --out-dir" logger.error(message) raise typer.Abort elif features_path is not None: output_paths = [features_path] elif out_dir is not None: if in_dir is None: output_paths = [out_dir / p.with_suffix(".npz").name for p in input_paths] else: output_paths = [ out_dir / p.relative_to(in_dir).with_suffix(".npz") for p in input_paths ] print(output_paths) return output_paths def _process_batches( in_paths_out_paths: list[tuple[tuple[Path, ...], tuple[Path, ...]]], model: AutoModel, processor: BitImageProcessor, device: torch.device, *, cls: bool, patch: bool, ): message = "Processing batches..." with BarProgress(transient=True) as progress: task = progress.add_task(message, total=len(in_paths_out_paths)) for inputs_batch, outputs_batch in in_paths_out_paths: _process_batch( inputs_batch, outputs_batch, @@ -204,6 +261,7 @@ def _process_batch( cls_embeddings, patch_embeddings = _infer(images, model, processor, device) zipped = zip(output_paths, cls_embeddings, patch_embeddings) for features_path, cls_embedding, patch_embedding in zipped: features_path.parent.mkdir(parents=True, exist_ok=True) kwargs = {} if save_cls: kwargs["cls_embeddings"] = cls_embedding @@ -261,7 +319,19 @@ def _get_model_and_processor( def _load_image(image_path: Path) -> Image.Image: if image_path.suffix == ".dcm": image = _load_dicom(image_path) else: image = Image.open(image_path) return image def _load_dicom(image_path: Path) -> Image.Image: image = read_image(image_path) image = F.enhance_contrast(image, num_bits=8) array = sitk.GetArrayFromImage(image) array = np.squeeze(array) return Image.fromarray(array) def _get_device() -> torch.device: -
fepegar created this gist
Feb 18, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,299 @@ # /// script # requires-python = ">=3.12" # dependencies = [ # "einops", # "loguru", # "numpy", # "pillow", # "procex", # "torch", # "transformers", # "typer", # ] # # [tool.uv.sources] # procex = { git = "https://github.com/fepegar/procex" } # [tool.ruff.lint.isort] # force-single-line = true # /// from __future__ import annotations import enum from functools import cache from itertools import batched from pathlib import Path from typing import Annotated, Sequence import numpy as np import numpy.typing as npt import torch import typer from einops import rearrange from loguru import logger from PIL import Image from rich import print from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn, ) from torch import nn from tqdm.auto import tqdm from transformers import AutoModel, BitImageProcessor app = typer.Typer() @enum.unique class Model(str, enum.Enum): RAD_DINO = "rad-dino" RAD_DINO_MAIRA_2 = "rad-dino-maira-2" @app.command() def main( image_path: Annotated[ Path | None, typer.Option( "--image", "-i", help=( "Input image. If it is a DICOM file, it will be temporarily" " converted to PNG." ), show_default=False, ), ] = None, images_path: Annotated[ Path | None, typer.Option( "--images", help="Text file with paths to images, one per line.", show_default=False, ), ] = None, features_path: Annotated[ Path | None, typer.Option( "--features", "-f", help="Output features file.", show_default=False, ), ] = None, out_dir: Annotated[ Path | None, typer.Option( "--out-dir", help="Output directory for features files.", show_default=False, ), ] = None, model_name: Annotated[ Model, typer.Option( "--model", "-m", help="Model to use.", show_default=False, ), ] = Model.RAD_DINO, batch_size: Annotated[ int, typer.Option( "--batch-size", "-b", help="Batch size.", show_default=False, ), ] = 1, cls: Annotated[ bool, typer.Option( help="Whether to include the CLS token.", show_default=False, ), ] = True, patch: Annotated[ bool, typer.Option( help="Whether to include the patch tokens.", show_default=False, ), ] = True, ) -> None: if image_path is not None and images_path is not None: message = "You can only provide one of --image or --images" logger.error(message) raise typer.Abort if image_path is not None: input_paths = [image_path] elif images_path is not None: with images_path.open() as f: input_paths = [Path(line.strip()) for line in f] if features_path is not None and out_dir is not None: message = "You can only provide one of --features or --out-dir" logger.error(message) raise typer.Abort elif features_path is not None: output_paths = [features_path] elif out_dir is not None: out_dir.mkdir(parents=True, exist_ok=True) output_paths = [out_dir / p.with_suffix(".npz").name for p in input_paths] device = _get_device() with BarlessProgress() as progress: task = progress.add_task("Loading model...", total=1) model, processor = _get_model_and_processor(model_name.value, device) progress.update(task, advance=1) print(f'Running inference on device: "{device}"') input_batches = batched(input_paths, batch_size) output_batches = batched(output_paths, batch_size) iterable = list(zip(input_batches, output_batches)) num_batches = len(iterable) if num_batches == 1: _process_batch( *iterable[0], model, processor, device, save_cls=cls, save_patch=patch, ) raise typer.Exit() message = "Processing batches..." with BarProgress(transient=True) as progress: task = progress.add_task(message, total=num_batches) for inputs_batch, outputs_batch in iterable: _process_batch( inputs_batch, outputs_batch, model, processor, device, save_cls=cls, save_patch=patch, ) progress.update(task, advance=1) def _process_batch( input_paths: Sequence[Path], output_paths: Sequence[Path], model: AutoModel, processor: BitImageProcessor, device: torch.device, *, save_cls: bool, save_patch: bool, ): if not save_cls and not save_patch: message = "You must save at least one of the CLS token or the patch tokens." logger.error(message) raise typer.Abort images = [_load_image(p) for p in input_paths] cls_embeddings, patch_embeddings = _infer(images, model, processor, device) zipped = zip(output_paths, cls_embeddings, patch_embeddings) for features_path, cls_embedding, patch_embedding in zipped: kwargs = {} if save_cls: kwargs["cls_embeddings"] = cls_embedding if save_patch: kwargs["patch_embeddings"] = patch_embedding with features_path.open("wb") as f: np.savez(f, **kwargs) def _infer( images: list[Image.Image], model: nn.Module, processor: BitImageProcessor, device: torch.device, ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]: processed = processor(images, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**processed) cls_embeddings = outputs.pooler_output flat_patch_embeddings = outputs.last_hidden_state[:, 1:] # first token is CLS reshaped_patch_embeddings = _reshape_patch_embeddings( flat_patch_embeddings, image_size=processor.crop_size["height"], patch_size=model.config.patch_size, ) return cls_embeddings.cpu().numpy(), reshaped_patch_embeddings.cpu().numpy() def _reshape_patch_embeddings( flat_tokens: torch.Tensor, *, image_size: int, patch_size: int, ) -> torch.Tensor: """Reshape flat list of patch tokens into a nice grid.""" embeddings_size = image_size // patch_size patches_grid = rearrange(flat_tokens, "b (h w) c -> b c h w", h=embeddings_size) return patches_grid @cache def _get_model_and_processor( model_name: str, device: torch.device, ) -> tuple[AutoModel, BitImageProcessor]: repo = f"microsoft/{model_name}" model = AutoModel.from_pretrained(repo).to(device).eval() processor = BitImageProcessor.from_pretrained(repo) return model, processor def _load_image(image_path: Path) -> Image.Image: return Image.open(image_path) def _get_device() -> torch.device: if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") return device class BarlessProgress(Progress): def __init__(self, *args, **kwargs): columns = [ SpinnerColumn(), TextColumn("[progress.description]{task.description}"), TimeElapsedColumn(), ] super().__init__(*columns, *args, **kwargs) class BarProgress(Progress): def __init__(self, *args, **kwargs): columns = [ TextColumn("[progress.description]{task.description}"), BarColumn(), MofNCompleteColumn(), TimeElapsedColumn(), ] super().__init__(*columns, *args, **kwargs) if __name__ == "__main__": app()