import argparse from pathlib import Path import numpy as np import torch from compressai.datasets import Vimeo90kDataset from torch.utils.data import DataLoader from torchvision import transforms MESSAGE = """ Download and extract the Vimeo90k dataset first: mkdir -p vimeo90k cd vimeo90k wget http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip unzip vimeo_triplet.zip cd .. Then, run one of the following: python generate_vimeo90k_npy_dataset.py --tuplet=3 --mode=image --indir="vimeo90k/vimeo_triplet" --outdir="vimeo90k/vimeo_triplet_npy" python generate_vimeo90k_npy_dataset.py --tuplet=3 --mode=video --indir="vimeo90k/vimeo_triplet" --outdir="vimeo90k/vimeo_triplet_npy_video" python generate_vimeo90k_npy_dataset.py --tuplet=7 --mode=image --indir="vimeo90k/vimeo_septuplet" --outdir="vimeo90k/vimeo_septuplet_npy" python generate_vimeo90k_npy_dataset.py --tuplet=7 --mode=video --indir="vimeo90k/vimeo_septuplet" --outdir="vimeo90k/vimeo_septuplet_npy_video" If the mode is "image", each frame is treated separately, and may undergo different transformations. If the mode is "video", all frames undergo the same transformation. """ PATCH_LENGTH = 256 PATCH_SIZE = (PATCH_LENGTH, PATCH_LENGTH) FILENAMES = { "train": "training", "valid": "validation", } def get_dataset(dataset_path, split, tuplet, mode): crop = ( transforms.RandomCrop(PATCH_SIZE) if split == "train" else transforms.CenterCrop(PATCH_SIZE) ) transform = transforms.Compose( [ crop, lambda img: torch.from_numpy(np.array(img)), # transforms.ToTensor(), # NOTE: Converts HWC -> CHW. ] ) dataset = Vimeo90kDataset( root=dataset_path, transform=transform, split=split, tuplet=tuplet, # mode is experimental. Old versions of CompressAI do not have # this parameter, and behave as if mode="image". mode=mode, ) loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=8) return dataset, loader def generate_npy_dataset(indir, outdir, split, tuplet, mode, epochs): dataset, loader = get_dataset(indir, split, tuplet, mode) out_filepath = Path(f"{outdir}/{FILENAMES[split]}.npy") out_filepath.parent.mkdir(exist_ok=True) print(f"Writing to {out_filepath}...") x_out = np.memmap( out_filepath, dtype="uint8", mode="w+", shape=(epochs * len(dataset), *PATCH_SIZE, 3), ) offset = 0 for epoch in range(epochs): for i, x in enumerate(loader): x = (x * 255).to(torch.uint8) print( f"{split} | " f"{epoch} / {epochs} epochs | " f"{offset:6d} / {len(dataset)} items | " f"{i:5d} / {len(loader)} batches | " # For ensuring that random output is stable: f"checksum: {x.min():3.0f} {x.max():3.0f} {x.to(float).mean():3.0f}" ) x_out[offset : offset + len(x)] = x.numpy() offset += len(x) x_out.flush() del x_out def parse_args(): parser = argparse.ArgumentParser(description="Generate Vimeo90k dataset") parser.add_argument("--indir", default="vimeo90k/vimeo_triplet") parser.add_argument("--outdir", default="vimeo90k/vimeo_triplet_npy") parser.add_argument("--tuplet", type=int, default=3) parser.add_argument("--mode", default="image", choices=["image", "video"]) parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--epochs", type=int, default=1) return parser.parse_args() def main(): print(MESSAGE) args = parse_args() torch.manual_seed(args.seed) for split in ["train", "valid"]: generate_npy_dataset( indir=args.indir, outdir=args.outdir, split=split, tuplet=args.tuplet, mode=args.mode, epochs=args.epochs, ) if __name__ == "__main__": main()