Skip to content

Instantly share code, notes, and snippets.

@YodaEmbedding
Last active January 24, 2025 04:29
Show Gist options
  • Save YodaEmbedding/8803d95de072f12b4ff14ffd2b5bd7e5 to your computer and use it in GitHub Desktop.
Save YodaEmbedding/8803d95de072f12b4ff14ffd2b5bd7e5 to your computer and use it in GitHub Desktop.
Pregenerated Vimeo90K NumPy memmap dataset
import argparse
from pathlib import Path
import numpy as np
import torch
from compressai.datasets import Vimeo90kDataset
from torch.utils.data import DataLoader
from torchvision import transforms
MESSAGE = """
Download and extract the Vimeo90k dataset first:
mkdir -p vimeo90k
cd vimeo90k
wget http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip
unzip vimeo_triplet.zip
cd ..
Then, run one of the following:
python generate_vimeo90k_npy_dataset.py --tuplet=3 --mode=image --indir="vimeo90k/vimeo_triplet" --outdir="vimeo90k/vimeo_triplet_npy"
python generate_vimeo90k_npy_dataset.py --tuplet=3 --mode=video --indir="vimeo90k/vimeo_triplet" --outdir="vimeo90k/vimeo_triplet_npy_video"
python generate_vimeo90k_npy_dataset.py --tuplet=7 --mode=image --indir="vimeo90k/vimeo_septuplet" --outdir="vimeo90k/vimeo_septuplet_npy"
python generate_vimeo90k_npy_dataset.py --tuplet=7 --mode=video --indir="vimeo90k/vimeo_septuplet" --outdir="vimeo90k/vimeo_septuplet_npy_video"
If the mode is "image", each frame is treated separately, and may
undergo different transformations.
If the mode is "video", all frames undergo the same transformation.
"""
PATCH_LENGTH = 256
PATCH_SIZE = (PATCH_LENGTH, PATCH_LENGTH)
FILENAMES = {
"train": "training",
"valid": "validation",
}
def get_dataset(dataset_path, split, tuplet, mode):
crop = (
transforms.RandomCrop(PATCH_SIZE)
if split == "train"
else transforms.CenterCrop(PATCH_SIZE)
)
transform = transforms.Compose(
[
crop,
lambda img: torch.from_numpy(np.array(img)),
# transforms.ToTensor(), # NOTE: Converts HWC -> CHW.
]
)
dataset = Vimeo90kDataset(
root=dataset_path,
transform=transform,
split=split,
tuplet=tuplet,
# mode is experimental. Old versions of CompressAI do not have
# this parameter, and behave as if mode="image".
mode=mode,
)
loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=8)
return dataset, loader
def generate_npy_dataset(indir, outdir, split, tuplet, mode, epochs):
dataset, loader = get_dataset(indir, split, tuplet, mode)
out_filepath = Path(f"{outdir}/{FILENAMES[split]}.npy")
out_filepath.parent.mkdir(exist_ok=True)
print(f"Writing to {out_filepath}...")
x_out = np.memmap(
out_filepath,
dtype="uint8",
mode="w+",
shape=(epochs * len(dataset), *PATCH_SIZE, 3),
)
offset = 0
for epoch in range(epochs):
for i, x in enumerate(loader):
x = (x * 255).to(torch.uint8)
print(
f"{split} | "
f"{epoch} / {epochs} epochs | "
f"{offset:6d} / {len(dataset)} items | "
f"{i:5d} / {len(loader)} batches | "
# For ensuring that random output is stable:
f"checksum: {x.min():3.0f} {x.max():3.0f} {x.to(float).mean():3.0f}"
)
x_out[offset : offset + len(x)] = x.numpy()
offset += len(x)
x_out.flush()
del x_out
def parse_args():
parser = argparse.ArgumentParser(description="Generate Vimeo90k dataset")
parser.add_argument("--indir", default="vimeo90k/vimeo_triplet")
parser.add_argument("--outdir", default="vimeo90k/vimeo_triplet_npy")
parser.add_argument("--tuplet", type=int, default=3)
parser.add_argument("--mode", default="image", choices=["image", "video"])
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--epochs", type=int, default=1)
return parser.parse_args()
def main():
print(MESSAGE)
args = parse_args()
torch.manual_seed(args.seed)
for split in ["train", "valid"]:
generate_npy_dataset(
indir=args.indir,
outdir=args.outdir,
split=split,
tuplet=args.tuplet,
mode=args.mode,
epochs=args.epochs,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment