import os import gzip import struct import numpy as np import pandas as pd import torch import torchvision.transforms as TF import torch.nn.functional as F from tqdm import tqdm from torch.utils.data import Dataset from typing import Tuple from PIL import Image from skimage.io import imread def log_standardize(x): log_x = torch.log(x.clamp(min=1e-12)) return (log_x - log_x.mean()) / log_x.std().clamp(min=1e-12) # mean=0, std=1 def normalize(x, x_min=None, x_max=None, zero_one=False): if x_min is None: x_min = x.min() if x_max is None: x_max = x.max() print(f"max: {x_max}, min: {x_min}") x = (x - x_min) / (x_max - x_min) # [0,1] return x if zero_one else 2 * x - 1 # else [-1,1] class UKBBDataset(Dataset): def __init__( self, root, csv_file, transform=None, columns=None, norm=None, concat_pa=True ): super().__init__() self.root = root self.transform = transform self.concat_pa = concat_pa # return concatenated parents print(f"\nLoading csv data: {csv_file}") self.df = pd.read_csv(csv_file) self.columns = columns if self.columns is None: # ['eid', 'sex', 'age', 'brain_volume', 'ventricle_volume', 'mri_seq'] self.columns = list(self.df.columns) # return all self.columns.pop(0) # remove redundant 'index' column print(f"columns: {self.columns}") self.samples = {i: torch.as_tensor(self.df[i]).float() for i in self.columns} for k in ["age", "brain_volume", "ventricle_volume"]: print(f"{k} normalization: {norm}") if k in self.columns: if norm == "[-1,1]": self.samples[k] = normalize(self.samples[k]) elif norm == "[0,1]": self.samples[k] = normalize(self.samples[k], zero_one=True) elif norm == "log_standard": self.samples[k] = log_standardize(self.samples[k]) elif norm == None: pass else: NotImplementedError(f"{norm} not implemented.") print(f"#samples: {len(self.df)}") self.return_x = True if "eid" in self.columns else False def __len__(self): return len(self.df) def __getitem__(self, idx): sample = {k: v[idx] for k, v in self.samples.items()} if self.return_x: mri_seq = "T1" if sample["mri_seq"] == 0.0 else "T2_FLAIR" # Load scan filename = ( f'{int(sample["eid"])}_' + mri_seq + "_unbiased_brain_rigid_to_mni.png" ) x = Image.open(os.path.join(self.root, "thumbs_192x192", filename)) if self.transform is not None: sample["x"] = self.transform(x) sample.pop("eid", None) if self.concat_pa: sample["pa"] = torch.cat( [torch.tensor([sample[k]]) for k in self.columns if k != "eid"], dim=0 ) return sample def get_attr_max_min(attr): # some ukbb dataset (max, min) stats if attr == "age": return 73, 44 elif attr == "brain_volume": return 1629520, 841919 elif attr == "ventricle_volume": return 157075, 7613.27001953125 else: NotImplementedError def ukbb(args): csv_dir = args.data_dir augmentation = { "train": TF.Compose( [ TF.Resize((args.input_res, args.input_res), antialias=None), TF.RandomCrop( size=(args.input_res, args.input_res), padding=[2 * args.pad, args.pad], ), TF.RandomHorizontalFlip(p=args.hflip), TF.PILToTensor(), ] ), "eval": TF.Compose( [ TF.Resize((args.input_res, args.input_res), antialias=None), TF.PILToTensor(), ] ), } datasets = {} # for split in ['train', 'valid', 'test']: for split in ["test"]: datasets[split] = UKBBDataset( root=args.data_dir, csv_file=os.path.join(csv_dir, split + ".csv"), transform=augmentation[("eval" if split != "train" else split)], columns=(None if not args.parents_x else ["eid"] + args.parents_x), norm=(None if not hasattr(args, "context_norm") else args.context_norm), concat_pa=False, ) return datasets def _load_uint8(f): idx_dtype, ndim = struct.unpack("BBBB", f.read(4))[2:] shape = struct.unpack(">" + "I" * ndim, f.read(4 * ndim)) buffer_length = int(np.prod(shape)) data = np.frombuffer(f.read(buffer_length), dtype=np.uint8).reshape(shape) return data def load_idx(path: str) -> np.ndarray: """Reads an array in IDX format from disk. Parameters ---------- path : str Path of the input file. Will uncompress with `gzip` if path ends in '.gz'. Returns ------- np.ndarray Output array of dtype ``uint8``. References ---------- http://yann.lecun.com/exdb/mnist/ """ open_fcn = gzip.open if path.endswith(".gz") else open with open_fcn(path, "rb") as f: return _load_uint8(f) def _get_paths(root_dir, train): prefix = "train" if train else "t10k" images_filename = prefix + "-images-idx3-ubyte.gz" labels_filename = prefix + "-labels-idx1-ubyte.gz" metrics_filename = prefix + "-morpho.csv" images_path = os.path.join(root_dir, images_filename) labels_path = os.path.join(root_dir, labels_filename) metrics_path = os.path.join(root_dir, metrics_filename) return images_path, labels_path, metrics_path def load_morphomnist_like( root_dir, train: bool = True, columns=None ) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]: """ Args: root_dir: path to data directory train: whether to load the training subset (``True``, ``'train-*'`` files) or the test subset (``False``, ``'t10k-*'`` files) columns: list of morphometrics to load; by default (``None``) loads the image index and all available metrics: area, length, thickness, slant, width, and height Returns: images, labels, metrics """ images_path, labels_path, metrics_path = _get_paths(root_dir, train) images = load_idx(images_path) labels = load_idx(labels_path) if columns is not None and "index" not in columns: usecols = ["index"] + list(columns) else: usecols = columns metrics = pd.read_csv(metrics_path, usecols=usecols, index_col="index") return images, labels, metrics class MorphoMNIST(Dataset): def __init__( self, root_dir, train=True, transform=None, columns=None, norm=None, concat_pa=True, ): self.train = train self.transform = transform self.columns = columns self.concat_pa = concat_pa self.norm = norm cols_not_digit = [c for c in self.columns if c != "digit"] images, labels, metrics_df = load_morphomnist_like( root_dir, train, cols_not_digit ) self.images = torch.from_numpy(np.array(images)).unsqueeze(1) self.labels = F.one_hot( torch.from_numpy(np.array(labels)).long(), num_classes=10 ) if self.columns is None: self.columns = metrics_df.columns self.samples = {k: torch.tensor(metrics_df[k]) for k in cols_not_digit} self.min_max = { "thickness": [0.87598526, 6.255515], "intensity": [66.601204, 254.90317], } for k, v in self.samples.items(): # optional preprocessing print(f"{k} normalization: {norm}") if norm == "[-1,1]": self.samples[k] = normalize( v, x_min=self.min_max[k][0], x_max=self.min_max[k][1] ) elif norm == "[0,1]": self.samples[k] = normalize( v, x_min=self.min_max[k][0], x_max=self.min_max[k][1], zero_one=True ) elif norm == None: pass else: NotImplementedError(f"{norm} not implemented.") print(f"#samples: {len(metrics_df)}\n") self.samples.update({"digit": self.labels}) def __len__(self): return len(self.images) def __getitem__(self, idx): sample = {} sample["x"] = self.images[idx] if self.transform is not None: sample["x"] = self.transform(sample["x"]) if self.concat_pa: sample["pa"] = torch.cat( [ v[idx] if k == "digit" else torch.tensor([v[idx]]) for k, v in self.samples.items() ], dim=0, ) else: sample.update({k: v[idx] for k, v in self.samples.items()}) return sample def morphomnist(args): # Load data augmentation = { "train": TF.Compose( [ TF.RandomCrop((args.input_res, args.input_res), padding=args.pad), ] ), "eval": TF.Compose( [ TF.Pad(padding=2), # (32, 32) ] ), } datasets = {} # for split in ['train', 'valid', 'test']: for split in ["test"]: datasets[split] = MorphoMNIST( root_dir=args.data_dir, train=(split == "train"), # test set is valid set transform=augmentation[("eval" if split != "train" else split)], columns=args.parents_x, norm=args.context_norm, concat_pa=False, ) return datasets def preproc_mimic(batch): for k, v in batch.items(): if k == "x": batch["x"] = (batch["x"].float() - 127.5) / 127.5 # [-1,1] elif k in ["age"]: batch[k] = batch[k].float().unsqueeze(-1) batch[k] = batch[k] / 100.0 batch[k] = batch[k] * 2 - 1 # [-1,1] elif k in ["race"]: batch[k] = F.one_hot(batch[k], num_classes=3).squeeze().float() elif k in ["finding"]: batch[k] = batch[k].unsqueeze(-1).float() else: batch[k] = batch[k].float().unsqueeze(-1) return batch class MIMICDataset(Dataset): def __init__( self, root, csv_file, transform=None, columns=None, concat_pa=True, only_pleural_eff=True, ): self.data = pd.read_csv(csv_file) self.transform = transform self.disease_labels = [ "No Finding", "Other", "Pleural Effusion", # "Lung Opacity", ] self.samples = { "age": [], "sex": [], "finding": [], "x": [], "race": [], # "lung_opacity": [], # "pleural_effusion": [], } for idx, _ in enumerate(tqdm(range(len(self.data)), desc="Loading MIMIC Data")): if only_pleural_eff and self.data.loc[idx, "disease"] == "Other": continue img_path = os.path.join(root, self.data.loc[idx, "path_preproc"]) # lung_opacity = self.data.loc[idx, "Lung Opacity"] # self.samples["lung_opacity"].append(lung_opacity) # pleural_effusion = self.data.loc[idx, "Pleural Effusion"] # self.samples["pleural_effusion"].append(pleural_effusion) disease = self.data.loc[idx, "disease"] finding = 0 if disease == "No Finding" else 1 self.samples["x"].append(img_path) self.samples["finding"].append(finding) self.samples["age"].append(self.data.loc[idx, "age"]) self.samples["race"].append(self.data.loc[idx, "race_label"]) self.samples["sex"].append(self.data.loc[idx, "sex_label"]) self.columns = columns if self.columns is None: # ['age', 'race', 'sex'] self.columns = list(self.data.columns) # return all self.columns.pop(0) # remove redundant 'index' column self.concat_pa = concat_pa def __len__(self): return len(self.samples["x"]) def __getitem__(self, idx): sample = {k: v[idx] for k, v in self.samples.items()} sample["x"] = imread(sample["x"]).astype(np.float32)[None, ...] for k, v in sample.items(): sample[k] = torch.tensor(v) if self.transform: sample["x"] = self.transform(sample["x"]) sample = preproc_mimic(sample) if self.concat_pa: sample["pa"] = torch.cat([sample[k] for k in self.columns], dim=0) return sample def mimic(args): args.csv_dir = args.data_dir datasets = {} datasets["test"] = MIMICDataset( root=args.data_dir, csv_file=os.path.join(args.csv_dir, "mimic.sample.test.csv"), columns=args.parents_x, transform=TF.Compose( [ TF.Resize((args.input_res, args.input_res), antialias=None), ] ), concat_pa=False, ) return datasets