import multiprocessing manager = multiprocessing.Manager() all_hashes_set = manager.dict() def deduplicate(examples, all_hashes_set): print(len(all_hashes_set)) input_ids = examples['input_ids'] hashes = [ hash(tuple(input_ids[i])) for i in range(len(input_ids)) ] shou/ld_filter_ex = [] for val in hashes: if val in all_hashes_set: should_filter_ex.append(True) else: should_filter_ex.append(False) all_hashes_set[val] = 1 return should_filter_ex original_len = len(dataset) dataset = dataset.filter( deduplicate, batched=True, num_proc=os.cpu_count(), fn_kwargs = { "all_hashes_set": all_hashes_set } ) print(f"Removed {original_len - len(dataset)} duplicates")