Skip to content

Instantly share code, notes, and snippets.

@faustomilletari
Last active March 2, 2021 18:15
Show Gist options
  • Save faustomilletari/a17c1c251bfc65b49f47b85f6028e250 to your computer and use it in GitHub Desktop.
Save faustomilletari/a17c1c251bfc65b49f47b85f6028e250 to your computer and use it in GitHub Desktop.

Revisions

  1. faustomilletari revised this gist Jun 8, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion brats_challenge_starterkit.py
    Original file line number Diff line number Diff line change
    @@ -56,7 +56,7 @@

    os.makedirs(PATH_ARTIFACTS, exist_ok=True)

    USE_GPU = False
    USE_GPU = True

    TRAINING = True

  2. faustomilletari created this gist Jun 8, 2020.
    206 changes: 206 additions & 0 deletions brats_challenge_starterkit.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,206 @@
    """
    Eisen BraTS2020 challenge starter kit
    NOTE: you need to register to the challenge, download and unpack the data in
    order to be able to run the following example.
    Find more info here: https://www.med.upenn.edu/cbica/brats2020/data.html
    Information about Eisen can be found at http://eisen.ai -- Join the community on Slack https://bit.ly/2L7i6OL
    This is released under MIT license. Do what you want with this code.
    """
    import os

    from eisen.datasets import Brats2020
    from eisen.models.segmentation import VNet
    from eisen.io import LoadNiftiFromFilename
    from eisen.transforms import (
    ResampleNiftiVolumes,
    NiftiToNumpy,
    CropCenteredSubVolumes,
    StackImagesChannelwise,
    MapValues,
    FilterFields,
    LabelMapToOneHot
    )
    from eisen.ops.losses import DiceLoss
    from eisen.ops.metrics import DiceMetric
    from eisen.utils import EisenModuleWrapper
    from eisen.utils.workflows import Training
    from eisen.utils.logging import LoggingHook
    from eisen.utils.logging import TensorboardSummaryHook
    from eisen.utils.artifacts import SaveTorchModelHook

    from torchvision.transforms import Compose
    from torch.utils.data import DataLoader
    from torch.optim import Adam

    """
    <<< SEGMENTATION TASK >>>
    This code is meant to provide an example on how to train a DL network on BraTS2020 data.
    Its results won't be optimal.
    """

    """
    Constants defining important parameters of the algorithm.
    CHANGE HERE WHAT SHOULD BE CHANGED TO FIT YOUR CONFIG.
    This code will save Tensorboard summaries, model snapshots and print output on the console.
    You can watch the progress of your training job by pointing a tensorboard process to the output folder.
    """

    # Defining some constants
    PATH_DATA = './MICCAI_BraTS2020_TrainingData' # path of data as unpacked from the challenge files
    PATH_ARTIFACTS = './results' # path for model results

    os.makedirs(PATH_ARTIFACTS, exist_ok=True)

    USE_GPU = False

    TRAINING = True

    NUM_EPOCHS = 100
    BATCH_SIZE = 2

    VOLUMES_RESOLUTION = [2, 2, 2]
    VOLUMES_PIXEL_SIZE = [128, 128, 128]

    LABELS = [1, 2, 4]

    INPUT_CHANNELS = 4 # T1, T1ce, T2, FLAIR
    OUTPUT_CHANNELS = len(LABELS) # different label set can be achieved by transforming the labels

    """
    Define Readers and Transforms
    In order to load data and prepare it for being used by the network, we need to operate
    I/O operations and define transforms to standardize data.
    You can add transforms or change the existing ones by editing this
    """

    # readers: for images and labels
    read_tform = LoadNiftiFromFilename(['t1', 't1ce', 't2', 'flair', 'label'], PATH_DATA)

    # Image manipulation transforms. Here we declare components of the transform chain

    # we want to resample images to a common resolution so that they are all comparable and each pixel has
    # the same physical meaning in terms of millimeters
    resample_tform_img = ResampleNiftiVolumes(
    ['t1', 't1ce', 't2', 'flair'],
    VOLUMES_RESOLUTION,
    'linear'
    )

    # the labels are interpolated with 'nearest' because they are discrete
    # and we should not create weird interpolation artifacts
    resample_tform_lbl = ResampleNiftiVolumes(
    ['label'],
    VOLUMES_RESOLUTION,
    'nearest'
    )

    # We bring the data from Nifti to numpy so we can work further
    to_numpy_tform = NiftiToNumpy(['t1', 't1ce', 't2', 'flair', 'label'])

    # Cropping the resampled images to have the same pixel size
    crop = CropCenteredSubVolumes(fields=['t1', 't1ce', 't2', 'flair', 'label'], size=VOLUMES_PIXEL_SIZE)

    # normalization of intensities. here there might be more than one valid choice on the method to accomplish this
    map_intensities = MapValues(['t1', 't1ce', 't2', 'flair'], min_value=0.0, max_value=1.0)

    # labels are integers, but can be mapped to a 1-hot-encoding to be used during learning
    map_labels = LabelMapToOneHot(['label'], LABELS)

    # we compose a multi channel image from the t1, t1ce, t2 and flair volumes. we call this new data 'image'
    stack_modalities = StackImagesChannelwise(['t1', 't1ce', 't2', 'flair'], 'image')

    # various transforms have created a lot of information. we keep only 'image' and 'label' because in this
    # case they are the only thing we need to train
    preserve_only_fields = FilterFields(['image', 'label'])

    # create a transform to manipulate and load data
    tform = Compose([
    read_tform,
    resample_tform_img,
    resample_tform_lbl,
    to_numpy_tform,
    crop,
    map_intensities,
    map_labels,
    stack_modalities,
    preserve_only_fields
    ])

    # create a dataset from the training set of the ABC dataset
    dataset = Brats2020(
    PATH_DATA,
    training=True,
    transform=tform
    )

    # Data loader: a pytorch DataLoader is used here to loop through the data as provided by the dataset.
    data_loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4
    )

    """
    Building blocks: we define here:
    * model
    * loss
    * metric
    * optimizer
    These components are used during training.
    These blocks will be joined together in a workflow (Eg. training workflow).
    """

    # specify model and loss (building blocks)

    model = EisenModuleWrapper(
    module=VNet(input_channels=INPUT_CHANNELS, output_channels=OUTPUT_CHANNELS),
    input_names=['image'],
    output_names=['predictions']
    )

    loss = EisenModuleWrapper(
    module=DiceLoss(dim=[2, 3, 4]),
    input_names=['predictions', 'label'],
    output_names=['dice_loss']
    )

    metric = EisenModuleWrapper(
    module=DiceMetric(dim=[2, 3, 4]),
    input_names=['predictions', 'label'],
    output_names=['dice_metric']
    )

    optimizer = Adam(model.parameters(), 0.001)

    # join all blocks into a workflow (training workflow)
    training_workflow = Training(
    model=model,
    losses=[loss],
    data_loader=data_loader,
    optimizer=optimizer,
    metrics=[metric],
    gpu=USE_GPU
    )

    # create Hook to monitor training and save models
    training_loggin_hook = LoggingHook(training_workflow.id, 'Training', PATH_ARTIFACTS)

    training_summary_hook = TensorboardSummaryHook(training_workflow.id, 'Training', PATH_ARTIFACTS)

    save_model_hook = SaveTorchModelHook(training_workflow.id, 'Training', PATH_ARTIFACTS)

    # run optimization for NUM_EPOCHS
    for i in range(NUM_EPOCHS):
    training_workflow.run()

    # todo: VALIDATION and INFERENCE code