Skip to content

Instantly share code, notes, and snippets.

@abhishekkrthakur
Last active June 11, 2019 20:24
Show Gist options
  • Select an option

  • Save abhishekkrthakur/3346a3e1f47bc8a11b73078d4eca54e9 to your computer and use it in GitHub Desktop.

Select an option

Save abhishekkrthakur/3346a3e1f47bc8a11b73078d4eca54e9 to your computer and use it in GitHub Desktop.

Revisions

  1. abhishekkrthakur revised this gist Jun 11, 2019. 1 changed file with 5 additions and 0 deletions.
    5 changes: 5 additions & 0 deletions training_1.py
    Original file line number Diff line number Diff line change
    @@ -1,13 +1,15 @@
    import torch
    from torchvision import transforms

    # define some re-usable stuff
    IMAGE_SIZE = 224
    NUM_CLASSES = 1103
    BATCH_SIZE = 32
    device = torch.device("cuda:0")
    IMG_MEAN = model_ft.mean
    IMG_STD = model_ft.std

    # make some augmentations on training data
    train_transform = transforms.Compose([
    transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    @@ -16,13 +18,16 @@
    ])


    # use the collections dataset class we created earlier
    train_dataset = CollectionsDataset(csv_file='../input/folds.csv',
    root_dir='../input/train/',
    num_classes=NUM_CLASSES,
    transform=train_transform)

    # create the pytorch data loader
    train_dataset_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4)
    # push model to device
    model_ft = model_ft.to(device)
  2. abhishekkrthakur created this gist Jun 11, 2019.
    28 changes: 28 additions & 0 deletions training_1.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,28 @@
    import torch
    from torchvision import transforms

    IMAGE_SIZE = 224
    NUM_CLASSES = 1103
    BATCH_SIZE = 32
    device = torch.device("cuda:0")
    IMG_MEAN = model_ft.mean
    IMG_STD = model_ft.std

    train_transform = transforms.Compose([
    transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(IMG_MEAN, IMG_STD)
    ])


    train_dataset = CollectionsDataset(csv_file='../input/folds.csv',
    root_dir='../input/train/',
    num_classes=NUM_CLASSES,
    transform=train_transform)

    train_dataset_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4)
    model_ft = model_ft.to(device)