Skip to content

Instantly share code, notes, and snippets.

@fepegar
Last active May 25, 2025 12:52
Show Gist options
  • Select an option

  • Save fepegar/8c5f5444dbd6b29a44c6fa14c070c4c9 to your computer and use it in GitHub Desktop.

Select an option

Save fepegar/8c5f5444dbd6b29a44c6fa14c070c4c9 to your computer and use it in GitHub Desktop.

Revisions

  1. fepegar revised this gist Feb 18, 2025. 1 changed file with 120 additions and 50 deletions.
    170 changes: 120 additions & 50 deletions run_rad_dino.py
    Original file line number Diff line number Diff line change
    @@ -26,11 +26,14 @@

    import numpy as np
    import numpy.typing as npt
    import SimpleITK as sitk
    import torch
    import typer
    from einops import rearrange
    from loguru import logger
    from PIL import Image
    from procex import functional as F
    from procex.imgio import read_image
    from rich import print
    from rich.progress import (
    BarColumn,
    @@ -44,7 +47,9 @@
    from tqdm.auto import tqdm
    from transformers import AutoModel, BitImageProcessor

    app = typer.Typer()
    app = typer.Typer(
    no_args_is_help=True,
    )


    @enum.unique
    @@ -55,33 +60,35 @@ class Model(str, enum.Enum):

    @app.command()
    def main(
    image_path: Annotated[
    Path | None,
    typer.Option(
    "--image",
    "-i",
    input: Annotated[
    Path,
    typer.Argument(
    help=(
    "Input image. If it is a DICOM file, it will be temporarily"
    " converted to PNG."
    "Input image(s). If it is a DICOM file, it will be temporarily"
    " converted to PNG. If it is a .txt file, it must contain paths"
    "to images, one per line."
    ),
    show_default=False,
    exists=True,
    file_okay=True,
    dir_okay=False,
    writable=False,
    readable=True,
    resolve_path=True,
    ),
    ] = None,
    images_path: Annotated[
    Path | None,
    typer.Option(
    "--images",
    help="Text file with paths to images, one per line.",
    show_default=False,
    ),
    ] = None,
    ],
    features_path: Annotated[
    Path | None,
    typer.Option(
    "--features",
    "-f",
    help="Output features file.",
    show_default=False,
    exists=True,
    file_okay=True,
    dir_okay=False,
    writable=True,
    readable=False,
    ),
    ] = None,
    out_dir: Annotated[
    @@ -90,6 +97,28 @@ def main(
    "--out-dir",
    help="Output directory for features files.",
    show_default=False,
    exists=False,
    file_okay=False,
    dir_okay=True,
    writable=True,
    readable=False,
    ),
    ] = None,
    in_dir: Annotated[
    Path | None,
    typer.Option(
    "--in-dir",
    help=(
    "If passed, the path of the output relative to the output"
    " directory will be the same as the input path relative to"
    " this directory."
    ),
    show_default=False,
    exists=True,
    file_okay=False,
    dir_okay=True,
    writable=False,
    readable=True,
    ),
    ] = None,
    model_name: Annotated[
    @@ -125,26 +154,11 @@ def main(
    ),
    ] = True,
    ) -> None:
    if image_path is not None and images_path is not None:
    message = "You can only provide one of --image or --images"
    logger.error(message)
    raise typer.Abort

    if image_path is not None:
    input_paths = [image_path]
    elif images_path is not None:
    with images_path.open() as f:
    input_paths = [Path(line.strip()) for line in f]
    input_paths = _get_input_paths(input)
    output_paths = _get_output_paths(input_paths, features_path, out_dir, in_dir)
    import sys

    if features_path is not None and out_dir is not None:
    message = "You can only provide one of --features or --out-dir"
    logger.error(message)
    raise typer.Abort
    elif features_path is not None:
    output_paths = [features_path]
    elif out_dir is not None:
    out_dir.mkdir(parents=True, exist_ok=True)
    output_paths = [out_dir / p.with_suffix(".npz").name for p in input_paths]
    sys.exit()

    device = _get_device()

    @@ -158,22 +172,65 @@ def main(
    input_batches = batched(input_paths, batch_size)
    output_batches = batched(output_paths, batch_size)
    iterable = list(zip(input_batches, output_batches))
    num_batches = len(iterable)
    if num_batches == 1:
    _process_batch(
    *iterable[0],
    model,
    processor,
    device,
    save_cls=cls,
    save_patch=patch,
    )
    raise typer.Exit()

    _process_batches(
    iterable,
    model,
    processor,
    device,
    cls=cls,
    patch=patch,
    )


    def _get_input_paths(
    input_path: Path,
    ) -> list[Path]:
    if input_path.suffix == ".txt":
    with input_path.open() as f:
    input_paths = [Path(line.strip()) for line in f]
    else:
    input_paths = [input_path]

    return input_paths


    def _get_output_paths(
    input_paths: list[Path],
    features_path: Path | None,
    out_dir: Path | None,
    in_dir: Path | None,
    ) -> list[Path]:
    if features_path is not None and out_dir is not None:
    message = "You can only provide one of --features or --out-dir"
    logger.error(message)
    raise typer.Abort
    elif features_path is not None:
    output_paths = [features_path]
    elif out_dir is not None:
    if in_dir is None:
    output_paths = [out_dir / p.with_suffix(".npz").name for p in input_paths]
    else:
    output_paths = [
    out_dir / p.relative_to(in_dir).with_suffix(".npz") for p in input_paths
    ]
    print(output_paths)
    return output_paths


    def _process_batches(
    in_paths_out_paths: list[tuple[tuple[Path, ...], tuple[Path, ...]]],
    model: AutoModel,
    processor: BitImageProcessor,
    device: torch.device,
    *,
    cls: bool,
    patch: bool,
    ):
    message = "Processing batches..."
    with BarProgress(transient=True) as progress:
    task = progress.add_task(message, total=num_batches)
    for inputs_batch, outputs_batch in iterable:
    task = progress.add_task(message, total=len(in_paths_out_paths))
    for inputs_batch, outputs_batch in in_paths_out_paths:
    _process_batch(
    inputs_batch,
    outputs_batch,
    @@ -204,6 +261,7 @@ def _process_batch(
    cls_embeddings, patch_embeddings = _infer(images, model, processor, device)
    zipped = zip(output_paths, cls_embeddings, patch_embeddings)
    for features_path, cls_embedding, patch_embedding in zipped:
    features_path.parent.mkdir(parents=True, exist_ok=True)
    kwargs = {}
    if save_cls:
    kwargs["cls_embeddings"] = cls_embedding
    @@ -261,7 +319,19 @@ def _get_model_and_processor(


    def _load_image(image_path: Path) -> Image.Image:
    return Image.open(image_path)
    if image_path.suffix == ".dcm":
    image = _load_dicom(image_path)
    else:
    image = Image.open(image_path)
    return image


    def _load_dicom(image_path: Path) -> Image.Image:
    image = read_image(image_path)
    image = F.enhance_contrast(image, num_bits=8)
    array = sitk.GetArrayFromImage(image)
    array = np.squeeze(array)
    return Image.fromarray(array)


    def _get_device() -> torch.device:
  2. fepegar created this gist Feb 18, 2025.
    299 changes: 299 additions & 0 deletions run_rad_dino.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,299 @@
    # /// script
    # requires-python = ">=3.12"
    # dependencies = [
    # "einops",
    # "loguru",
    # "numpy",
    # "pillow",
    # "procex",
    # "torch",
    # "transformers",
    # "typer",
    # ]
    #
    # [tool.uv.sources]
    # procex = { git = "https://github.com/fepegar/procex" }
    # [tool.ruff.lint.isort]
    # force-single-line = true
    # ///
    from __future__ import annotations

    import enum
    from functools import cache
    from itertools import batched
    from pathlib import Path
    from typing import Annotated, Sequence

    import numpy as np
    import numpy.typing as npt
    import torch
    import typer
    from einops import rearrange
    from loguru import logger
    from PIL import Image
    from rich import print
    from rich.progress import (
    BarColumn,
    MofNCompleteColumn,
    Progress,
    SpinnerColumn,
    TextColumn,
    TimeElapsedColumn,
    )
    from torch import nn
    from tqdm.auto import tqdm
    from transformers import AutoModel, BitImageProcessor

    app = typer.Typer()


    @enum.unique
    class Model(str, enum.Enum):
    RAD_DINO = "rad-dino"
    RAD_DINO_MAIRA_2 = "rad-dino-maira-2"


    @app.command()
    def main(
    image_path: Annotated[
    Path | None,
    typer.Option(
    "--image",
    "-i",
    help=(
    "Input image. If it is a DICOM file, it will be temporarily"
    " converted to PNG."
    ),
    show_default=False,
    ),
    ] = None,
    images_path: Annotated[
    Path | None,
    typer.Option(
    "--images",
    help="Text file with paths to images, one per line.",
    show_default=False,
    ),
    ] = None,
    features_path: Annotated[
    Path | None,
    typer.Option(
    "--features",
    "-f",
    help="Output features file.",
    show_default=False,
    ),
    ] = None,
    out_dir: Annotated[
    Path | None,
    typer.Option(
    "--out-dir",
    help="Output directory for features files.",
    show_default=False,
    ),
    ] = None,
    model_name: Annotated[
    Model,
    typer.Option(
    "--model",
    "-m",
    help="Model to use.",
    show_default=False,
    ),
    ] = Model.RAD_DINO,
    batch_size: Annotated[
    int,
    typer.Option(
    "--batch-size",
    "-b",
    help="Batch size.",
    show_default=False,
    ),
    ] = 1,
    cls: Annotated[
    bool,
    typer.Option(
    help="Whether to include the CLS token.",
    show_default=False,
    ),
    ] = True,
    patch: Annotated[
    bool,
    typer.Option(
    help="Whether to include the patch tokens.",
    show_default=False,
    ),
    ] = True,
    ) -> None:
    if image_path is not None and images_path is not None:
    message = "You can only provide one of --image or --images"
    logger.error(message)
    raise typer.Abort

    if image_path is not None:
    input_paths = [image_path]
    elif images_path is not None:
    with images_path.open() as f:
    input_paths = [Path(line.strip()) for line in f]

    if features_path is not None and out_dir is not None:
    message = "You can only provide one of --features or --out-dir"
    logger.error(message)
    raise typer.Abort
    elif features_path is not None:
    output_paths = [features_path]
    elif out_dir is not None:
    out_dir.mkdir(parents=True, exist_ok=True)
    output_paths = [out_dir / p.with_suffix(".npz").name for p in input_paths]

    device = _get_device()

    with BarlessProgress() as progress:
    task = progress.add_task("Loading model...", total=1)
    model, processor = _get_model_and_processor(model_name.value, device)
    progress.update(task, advance=1)

    print(f'Running inference on device: "{device}"')

    input_batches = batched(input_paths, batch_size)
    output_batches = batched(output_paths, batch_size)
    iterable = list(zip(input_batches, output_batches))
    num_batches = len(iterable)
    if num_batches == 1:
    _process_batch(
    *iterable[0],
    model,
    processor,
    device,
    save_cls=cls,
    save_patch=patch,
    )
    raise typer.Exit()

    message = "Processing batches..."
    with BarProgress(transient=True) as progress:
    task = progress.add_task(message, total=num_batches)
    for inputs_batch, outputs_batch in iterable:
    _process_batch(
    inputs_batch,
    outputs_batch,
    model,
    processor,
    device,
    save_cls=cls,
    save_patch=patch,
    )
    progress.update(task, advance=1)


    def _process_batch(
    input_paths: Sequence[Path],
    output_paths: Sequence[Path],
    model: AutoModel,
    processor: BitImageProcessor,
    device: torch.device,
    *,
    save_cls: bool,
    save_patch: bool,
    ):
    if not save_cls and not save_patch:
    message = "You must save at least one of the CLS token or the patch tokens."
    logger.error(message)
    raise typer.Abort
    images = [_load_image(p) for p in input_paths]
    cls_embeddings, patch_embeddings = _infer(images, model, processor, device)
    zipped = zip(output_paths, cls_embeddings, patch_embeddings)
    for features_path, cls_embedding, patch_embedding in zipped:
    kwargs = {}
    if save_cls:
    kwargs["cls_embeddings"] = cls_embedding
    if save_patch:
    kwargs["patch_embeddings"] = patch_embedding
    with features_path.open("wb") as f:
    np.savez(f, **kwargs)


    def _infer(
    images: list[Image.Image],
    model: nn.Module,
    processor: BitImageProcessor,
    device: torch.device,
    ) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
    processed = processor(images, return_tensors="pt").to(device)

    with torch.no_grad():
    outputs = model(**processed)

    cls_embeddings = outputs.pooler_output

    flat_patch_embeddings = outputs.last_hidden_state[:, 1:] # first token is CLS
    reshaped_patch_embeddings = _reshape_patch_embeddings(
    flat_patch_embeddings,
    image_size=processor.crop_size["height"],
    patch_size=model.config.patch_size,
    )

    return cls_embeddings.cpu().numpy(), reshaped_patch_embeddings.cpu().numpy()


    def _reshape_patch_embeddings(
    flat_tokens: torch.Tensor,
    *,
    image_size: int,
    patch_size: int,
    ) -> torch.Tensor:
    """Reshape flat list of patch tokens into a nice grid."""
    embeddings_size = image_size // patch_size
    patches_grid = rearrange(flat_tokens, "b (h w) c -> b c h w", h=embeddings_size)
    return patches_grid


    @cache
    def _get_model_and_processor(
    model_name: str,
    device: torch.device,
    ) -> tuple[AutoModel, BitImageProcessor]:
    repo = f"microsoft/{model_name}"
    model = AutoModel.from_pretrained(repo).to(device).eval()
    processor = BitImageProcessor.from_pretrained(repo)

    return model, processor


    def _load_image(image_path: Path) -> Image.Image:
    return Image.open(image_path)


    def _get_device() -> torch.device:
    if torch.cuda.is_available():
    device = torch.device("cuda")
    elif torch.backends.mps.is_available():
    device = torch.device("mps")
    else:
    device = torch.device("cpu")
    return device


    class BarlessProgress(Progress):
    def __init__(self, *args, **kwargs):
    columns = [
    SpinnerColumn(),
    TextColumn("[progress.description]{task.description}"),
    TimeElapsedColumn(),
    ]
    super().__init__(*columns, *args, **kwargs)


    class BarProgress(Progress):
    def __init__(self, *args, **kwargs):
    columns = [
    TextColumn("[progress.description]{task.description}"),
    BarColumn(),
    MofNCompleteColumn(),
    TimeElapsedColumn(),
    ]
    super().__init__(*columns, *args, **kwargs)


    if __name__ == "__main__":
    app()