Created
March 24, 2021 06:28
-
-
Save konpatp/72f18cfb917e6864821f52f69a71f91c to your computer and use it in GitHub Desktop.
Revisions
-
konpatp created this gist
Mar 24, 2021 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,188 @@ import os import random import cv2 import pandas as pd import torch from torch.utils.data import Dataset from transformers import (BertTokenizerFast, PreTrainedTokenizerFast, RobertaTokenizerFast) from .base_cls import * from .cxr_augment import * from .mimic_cls_data import IGNORE_FILES, SPLITS, VIEWS from .nih14_data import cv2_loader cv2.setNumThreads(1) here = os.path.dirname(__file__) DATASET_PATH = f'{here}/datasets/mimic-cxr' REPORT_TYPES = { 'finding_only': f'{DATASET_PATH}/reports_findings_only.csv', 'finding_impression': f'{DATASET_PATH}/reports_findings_or_impressions.csv', 'abnormal': f'{DATASET_PATH}/reports_abnormal.csv', 'abnormal_nodiff': f'{DATASET_PATH}/reports_abnormal_nodiff.csv', } @dataclass class MimicTextDataConfig(DatasetConfig): dataset: str = 'mimic-text' tokenizer: PreTrainedTokenizerFast = None n_max_sentence_length: int = None lower_case: bool = True blank_line_for_empty_report: bool = True img_dir: str = f'{DATASET_PATH}/images512' split: str = 'v1' view: str = 'front' report: str = 'finding_impression' trans_conf: TransformConfig = TransformConfig() @property def name(self): name = f'{self.split}-{self.view}-{self.report}' if self.n_max_sentence_length is not None: name += f'-len{self.n_max_sentence_length}' if not self.lower_case: name += f'-upper' name += f'_{self.trans_conf.name}' return name class MimicTextCombinedDataset: def __init__(self, conf: MimicTextDataConfig): train_csv, val_csv, test_csv = SPLITS[conf.split] train_transform = make_transform('train', conf.trans_conf) eval_transform = make_transform('eval', conf.trans_conf) self.train_data = MimicTextDataset(f'{DATASET_PATH}/{train_csv}', conf, train_transform) self.val_data = MimicTextDataset(f'{DATASET_PATH}/{val_csv}', conf, eval_transform) self.test_data = MimicTextDataset(f'{DATASET_PATH}/{test_csv}', conf, eval_transform) class MimicTextDataset(Dataset): def __init__( self, split_csv, conf: MimicTextDataConfig, transform=None, ): self.conf = conf # make the df split_df = pd.read_csv(split_csv) report_df = pd.read_csv(REPORT_TYPES[conf.report]) df = pd.read_csv(VIEWS[conf.view]) # select only mentioned df = df[df['study_id'].isin(split_df['study_id'])] # select only those we have the reports report_study_id = set(report_df['study_id']) df = df[df['study_id'].isin(report_study_id)] # select only we have readable images df = df[~df['dicom_id'].isin(IGNORE_FILES)].reset_index(drop=True) self.report_df = report_df self.record_df = df self.transform = transform def __len__(self): return len(self.record_df) def __getitem__(self, idx): ############ # REPORT study_id = self.record_df.loc[idx, 'study_id'] text = self.report_df[self.report_df['study_id'] == study_id].iloc[0]['text'] if text != text: # nan = empty report text = '' if self.conf.lower_case: text = text.lower() if not self.conf.blank_line_for_empty_report and text == '': # empty report input_ids = [] else: lines = text.split('\n') # random then take the first line random.shuffle(lines) lines = lines[0] if isinstance(self.conf.tokenizer, BertTokenizerFast): # bert tokenizer bug # [''] => error # so we need to force add special tokens res = self.conf.tokenizer(lines, return_attention_mask=False, add_special_tokens=True) elif isinstance(self.conf.tokenizer, RobertaTokenizerFast): res = self.conf.tokenizer(lines, return_attention_mask=False, add_special_tokens=True) else: raise NotImplementedError() input_ids = res['input_ids'] # clip the length if self.conf.n_max_sentence_length is not None: input_ids = input_ids[:self.conf.n_max_sentence_length] ########### # IMAGE # we use the png files img_path = self.record_df.loc[idx, 'path'].replace('.dcm', '.png') # remove the prefix files/ img_path = img_path.replace('files/', '') img_path = f'{self.conf.img_dir}/{img_path}' img = cv2_loader(img_path) if self.transform: _res = self.transform(image=img, bboxes=[]) img = _res['image'] return { 'img': img, 'input_ids': input_ids, 'study_id': study_id, } def add_bos_token(sentences, bos_token_id): out = [] for each in sentences: out.append([bos_token_id] + each) return out def add_eos_token(sentences, eos_token_id): out = [] for each in sentences: if each[-1] != eos_token_id: each = each + [eos_token_id] out.append(each) return out class MimicTextCollator: def __init__(self, conf): self.conf = conf def __call__(self, data): out = {'img': [], 'input_ids': []} max_length = max(len(each['input_ids']) for each in data) for each in data: out['img'].append(each['img']) n_pad = max_length - len(each['input_ids']) pad = [self.conf.tokenizer.pad_token_id] * n_pad out['input_ids'].append(each['input_ids'] + pad) out['img'] = torch.stack(out['img']) out['input_ids'] = torch.LongTensor(out['input_ids']) return out