import json import random from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from collections import Counter from typing import NoReturn, List from .distribution import NERDistribution, DatasetSplits, PrintDistribution random.seed(41) class DatasetSplitter(object): """ Prepare dataset splits - training, validation & testing splits Compute ner distributions in our dataset. Compute ner distributions based on which we create and store a dictionary which will contain information about which notes (in a dataset) belong to which split. Based on this distribution and whether we want to keep certain notes grouped (e.g by patient) we assign notes to a split, such that the final ner type distribution in each split is similar. """ def __init__( self, train_proportion: int = 70, validation_proportion: int = 15, test_proportion: int = 15 ) -> NoReturn: """ Initialize the proportions of the splits. Args: train_proportion (int): Ratio of train dataset validation_proportion (int): Ratio of validation dataset test_proportion (int): Ratio of test dataset """ self._train_proportion = train_proportion self._validation_proportion = validation_proportion self._test_proportion = test_proportion self._split = None self._lookup_split = dict() def get_split(self, split: str) -> List[str]: return [key for key in self._lookup_split[split].keys()] def set_split(self, split: str) -> NoReturn: """ Set the split that you are currently checking/processing. Based on the split you can perform certain checks and computation. Once the split is set, read the information present in the split_info_path. Extract only the information belonging to the split. Create a hash map where we have the keys as the note_ids/patient ids that belong to the split. This hashmap can then be used to check if a particular note belongs to this split. Args: split (str): The split - train, test etc (depends on how you named it) """ if split not in ['train', 'validation', 'test']: raise ValueError('Invalid split') self._split = split def __update_split(self, key: str) -> NoReturn: """ Update the hash map where we have the keys (e.g note_id) that belong to the split. This hashmap can then be used to check if a particular note belongs to this split. Args: key (str): The key that identify the note belonging to the split """ self._lookup_split[self._split][key] = 1 def check_note(self, key: str) -> bool: """ Use the hash map created in the __get_i2b2_filter_map function to check if the note (note_info) belongs to this split (train, val, test etc). If it does, return true, else false Args: key (str): The key that identify the note belonging to the split Returns: (bool): True if the note belongs to the split, false otherwise """ if self._split is None: raise ValueError('Split not set') if self._lookup_split[self._split].get(key, False): return True else: return False def assign_splits( self, input_file: str, spans_key: str = 'spans', metadata_key: str = 'meta', group_key: str = 'note_id', margin: float = 0.3 ) -> NoReturn: """ Get the dataset splits - training, validation & testing splits Based on the NER distribution and whether we want to keep certain notes grouped (e.g by patient). Return an iterable that contains a tuple that contains the note_id and the split. This can be used to filter notes based on the splits. Args: input_file (str): The input file spans_key (str): The key where the note spans are present metadata_key (str): The key where the note metadata is present group_key (str): The key where the note group (e.g note_id or patient id etc) is present. This field is what the notes will be grouped by, and all notes belonging to this grouping will be in the same split margin (float): Margin of error when maintaining proportions in the splits """ # Compute the distribution of NER types in the grouped notes. # For example the distribution of NER types in all notes belonging to a # particular patient self._lookup_split = { 'train': dict(), 'validation': dict(), 'test': dict() } ner_distribution = NERDistribution() for line in open(input_file, 'r'): note = json.loads(line) key = note[metadata_key][group_key] ner_distribution.update_distribution(spans=note[spans_key], key=key) # Initialize the dataset splits object dataset_splits = DatasetSplits( ner_distribution=ner_distribution, train_proportion=self._train_proportion, validation_proportion=self._validation_proportion, test_proportion=self._test_proportion, margin=margin ) # Check the note and assign it to a split for line in open(input_file, 'r'): note = json.loads(line) key = note[metadata_key][group_key] split = dataset_splits.get_split(key=key) self.set_split(split) self.__update_split(key) return None def main() -> NoReturn: """ Prepare dataset splits - training, validation & testing splits Compute ner distributions in our dataset. Based on this distribution and whether we want to keep certain notes grouped (e.g by patient) we assign notes to a split, such that the final ner type distribution in each split is similar. """ # Compute the distribution of NER types in the grouped notes. # For example the distribution of NER types in all notes belonging to a # particular patient # The following code sets up the arguments to be passed via CLI or via a JSON file cli_parser = ArgumentParser( description='configuration arguments provided at run time from the CLI', formatter_class=ArgumentDefaultsHelpFormatter ) cli_parser.add_argument( '--input_file', type=str, required=True, help='the the jsonl file that contains the notes' ) cli_parser.add_argument( '--spans_key', type=str, default='spans', help='the key where the note spans is present in the json object' ) cli_parser.add_argument( '--metadata_key', type=str, default='meta', help='the key where the note metadata is present in the json object' ) cli_parser.add_argument( '--group_key', type=str, default='note_id', help='the key to group notes by in the json object' ) cli_parser.add_argument( '--train_proportion', type=int, default=70, help='ratio of train dataset' ) cli_parser.add_argument( '--train_file', type=str, default=None, help='The file to store the train data' ) cli_parser.add_argument( '--validation_proportion', type=int, default=15, help='ratio of validation dataset' ) cli_parser.add_argument( '--validation_file', type=str, default=None, help='The file to store the validation data' ) cli_parser.add_argument( '--test_proportion', type=int, default=15, help='ratio of test dataset' ) cli_parser.add_argument( '--test_file', type=str, default=None, help='The file to store the test data' ) cli_parser.add_argument( '--margin', type=float, default=0.3, help='margin of error when maintaining proportions in the splits' ) cli_parser.add_argument( '--print_dist', action='store_true', help='whether to print the label distribution in the splits' ) args = cli_parser.parse_args() dataset_splitter = DatasetSplitter( train_proportion=args.train_proportion, validation_proportion=args.validation_proportion, test_proportion=args.test_proportion ) dataset_splitter.assign_splits( input_file=args.input_file, spans_key=args.spans_key, metadata_key=args.metadata_key, group_key=args.group_key, margin=args.margin ) if args.train_proportion > 0: with open(args.train_file, 'w') as file: for line in open(args.input_file, 'r'): note = json.loads(line) key = note[args.metadata_key][args.group_key] dataset_splitter.set_split('train') if dataset_splitter.check_note(key): file.write(json.dumps(note) + '\n') if args.validation_proportion > 0: with open(args.validation_file, 'w') as file: for line in open(args.input_file, 'r'): note = json.loads(line) key = note[args.metadata_key][args.group_key] dataset_splitter.set_split('validation') if dataset_splitter.check_note(key): file.write(json.dumps(note) + '\n') if args.test_proportion > 0: with open(args.test_file, 'w') as file: for line in open(args.input_file, 'r'): note = json.loads(line) key = note[args.metadata_key][args.group_key] dataset_splitter.set_split('test') if dataset_splitter.check_note(key): file.write(json.dumps(note) + '\n') if args.print_dist: # Read the dataset splits file and compute the NER type distribution key_counts = Counter() ner_distribution = NERDistribution() for line in open(args.input_file, 'r'): note = json.loads(line) key = note[args.metadata_key][args.group_key] key_counts[key] += 1 ner_distribution.update_distribution(spans=note[args.spans_key], key=key) print_distribution = PrintDistribution(ner_distribution=ner_distribution, key_counts=key_counts) train_splits = dataset_splitter.get_split('train') validation_splits = dataset_splitter.get_split('validation') test_splits = dataset_splitter.get_split('test') all_splits = train_splits + validation_splits + test_splits # Print distribution for each split print_distribution.split_distribution(split='total', split_info=all_splits) print_distribution.split_distribution(split='train', split_info=train_splits) print_distribution.split_distribution(split='validation', split_info=validation_splits) print_distribution.split_distribution(split='test', split_info=test_splits) if __name__ == "__main__": main()