Last active
September 16, 2020 06:36
-
-
Save Deepayan137/2c6c546afb63cc0d95f9418051823937 to your computer and use it in GitHub Desktop.
Revisions
-
Deepayan137 revised this gist
Sep 16, 2020 . 1 changed file with 19 additions and 20 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,31 +2,30 @@ from torch.utils.data import Dataset import numpy as np class DummyDataset(Dataset): def __init__(self, **kwargs): self.prob = kwargs['prob'] self.vocab_size = kwargs['vocab_size'] self.nSamples = kwargs['nSamples'] self.src_data = np.random.choice(self.vocab_size, self.nSamples) def __len__(self): return self.nSamples def func1(self, x): return x//2 def func2(self, x): return 2*x + 1 def get_target(self, x): if np.random.random() > self.prob: return self.func1(x) return self.func2(x) def __getitem__(self, index): assert index < self.nSamples x = self.src_data[index] y = self.get_target(x) return {'src':x, 'tgt':y}
-
Deepayan137 created this gist
Sep 16, 2020 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,32 @@ import torch from torch.utils.data import Dataset import numpy as np class GenerateDummyData(object): def __init__(self, prob): print('Generating Dummy Data') self.prob = prob def func1(self, x): #TODO pass def func2(self, x): #TODO pass def __call__(self, x): if np.random.random() > self.prob: return self.func1(x) return self.func2(x) src_vocab_size = 300 nSamples = 1000 prob = 0.9 pairs = [] genObj = GenerateDummyData(prob) for x in np.random.choice(src_vocab_size, nSamples): y = genObj(x) pairs.append(x, y)