Skip to content

Instantly share code, notes, and snippets.

@konpatp
Created March 24, 2021 06:28
Show Gist options
  • Select an option

  • Save konpatp/72f18cfb917e6864821f52f69a71f91c to your computer and use it in GitHub Desktop.

Select an option

Save konpatp/72f18cfb917e6864821f52f69a71f91c to your computer and use it in GitHub Desktop.

Revisions

  1. konpatp created this gist Mar 24, 2021.
    188 changes: 188 additions & 0 deletions mimic_img_text.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,188 @@
    import os
    import random

    import cv2
    import pandas as pd
    import torch
    from torch.utils.data import Dataset
    from transformers import (BertTokenizerFast, PreTrainedTokenizerFast,
    RobertaTokenizerFast)

    from .base_cls import *
    from .cxr_augment import *
    from .mimic_cls_data import IGNORE_FILES, SPLITS, VIEWS
    from .nih14_data import cv2_loader

    cv2.setNumThreads(1)
    here = os.path.dirname(__file__)
    DATASET_PATH = f'{here}/datasets/mimic-cxr'

    REPORT_TYPES = {
    'finding_only': f'{DATASET_PATH}/reports_findings_only.csv',
    'finding_impression':
    f'{DATASET_PATH}/reports_findings_or_impressions.csv',
    'abnormal': f'{DATASET_PATH}/reports_abnormal.csv',
    'abnormal_nodiff': f'{DATASET_PATH}/reports_abnormal_nodiff.csv',
    }


    @dataclass
    class MimicTextDataConfig(DatasetConfig):
    dataset: str = 'mimic-text'
    tokenizer: PreTrainedTokenizerFast = None
    n_max_sentence_length: int = None
    lower_case: bool = True
    blank_line_for_empty_report: bool = True
    img_dir: str = f'{DATASET_PATH}/images512'
    split: str = 'v1'
    view: str = 'front'
    report: str = 'finding_impression'
    trans_conf: TransformConfig = TransformConfig()

    @property
    def name(self):
    name = f'{self.split}-{self.view}-{self.report}'
    if self.n_max_sentence_length is not None:
    name += f'-len{self.n_max_sentence_length}'
    if not self.lower_case:
    name += f'-upper'
    name += f'_{self.trans_conf.name}'
    return name


    class MimicTextCombinedDataset:
    def __init__(self, conf: MimicTextDataConfig):
    train_csv, val_csv, test_csv = SPLITS[conf.split]
    train_transform = make_transform('train', conf.trans_conf)
    eval_transform = make_transform('eval', conf.trans_conf)
    self.train_data = MimicTextDataset(f'{DATASET_PATH}/{train_csv}', conf,
    train_transform)
    self.val_data = MimicTextDataset(f'{DATASET_PATH}/{val_csv}', conf,
    eval_transform)
    self.test_data = MimicTextDataset(f'{DATASET_PATH}/{test_csv}', conf,
    eval_transform)


    class MimicTextDataset(Dataset):
    def __init__(
    self,
    split_csv,
    conf: MimicTextDataConfig,
    transform=None,
    ):
    self.conf = conf
    # make the df
    split_df = pd.read_csv(split_csv)
    report_df = pd.read_csv(REPORT_TYPES[conf.report])
    df = pd.read_csv(VIEWS[conf.view])

    # select only mentioned
    df = df[df['study_id'].isin(split_df['study_id'])]
    # select only those we have the reports
    report_study_id = set(report_df['study_id'])
    df = df[df['study_id'].isin(report_study_id)]
    # select only we have readable images
    df = df[~df['dicom_id'].isin(IGNORE_FILES)].reset_index(drop=True)

    self.report_df = report_df
    self.record_df = df

    self.transform = transform

    def __len__(self):
    return len(self.record_df)

    def __getitem__(self, idx):
    ############
    # REPORT
    study_id = self.record_df.loc[idx, 'study_id']
    text = self.report_df[self.report_df['study_id'] ==
    study_id].iloc[0]['text']
    if text != text:
    # nan = empty report
    text = ''

    if self.conf.lower_case:
    text = text.lower()

    if not self.conf.blank_line_for_empty_report and text == '':
    # empty report
    input_ids = []
    else:
    lines = text.split('\n')
    # random then take the first line
    random.shuffle(lines)
    lines = lines[0]

    if isinstance(self.conf.tokenizer, BertTokenizerFast):
    # bert tokenizer bug
    # [''] => error
    # so we need to force add special tokens
    res = self.conf.tokenizer(lines,
    return_attention_mask=False,
    add_special_tokens=True)
    elif isinstance(self.conf.tokenizer, RobertaTokenizerFast):
    res = self.conf.tokenizer(lines,
    return_attention_mask=False,
    add_special_tokens=True)
    else:
    raise NotImplementedError()

    input_ids = res['input_ids']

    # clip the length
    if self.conf.n_max_sentence_length is not None:
    input_ids = input_ids[:self.conf.n_max_sentence_length]

    ###########
    # IMAGE
    # we use the png files
    img_path = self.record_df.loc[idx, 'path'].replace('.dcm', '.png')
    # remove the prefix files/
    img_path = img_path.replace('files/', '')
    img_path = f'{self.conf.img_dir}/{img_path}'
    img = cv2_loader(img_path)

    if self.transform:
    _res = self.transform(image=img, bboxes=[])
    img = _res['image']

    return {
    'img': img,
    'input_ids': input_ids,
    'study_id': study_id,
    }


    def add_bos_token(sentences, bos_token_id):
    out = []
    for each in sentences:
    out.append([bos_token_id] + each)
    return out


    def add_eos_token(sentences, eos_token_id):
    out = []
    for each in sentences:
    if each[-1] != eos_token_id:
    each = each + [eos_token_id]
    out.append(each)
    return out


    class MimicTextCollator:
    def __init__(self, conf):
    self.conf = conf

    def __call__(self, data):
    out = {'img': [], 'input_ids': []}
    max_length = max(len(each['input_ids']) for each in data)
    for each in data:
    out['img'].append(each['img'])
    n_pad = max_length - len(each['input_ids'])
    pad = [self.conf.tokenizer.pad_token_id] * n_pad
    out['input_ids'].append(each['input_ids'] + pad)

    out['img'] = torch.stack(out['img'])
    out['input_ids'] = torch.LongTensor(out['input_ids'])
    return out