Skip to content

Instantly share code, notes, and snippets.

@Deepayan137
Last active September 16, 2020 06:36
Show Gist options
  • Save Deepayan137/2c6c546afb63cc0d95f9418051823937 to your computer and use it in GitHub Desktop.
Save Deepayan137/2c6c546afb63cc0d95f9418051823937 to your computer and use it in GitHub Desktop.

Revisions

  1. Deepayan137 revised this gist Sep 16, 2020. 1 changed file with 19 additions and 20 deletions.
    39 changes: 19 additions & 20 deletions dummyData.py
    Original file line number Diff line number Diff line change
    @@ -2,31 +2,30 @@
    from torch.utils.data import Dataset
    import numpy as np



    class GenerateDummyData(object):
    def __init__(self, prob):
    print('Generating Dummy Data')
    self.prob = prob
    class DummyDataset(Dataset):
    def __init__(self, **kwargs):
    self.prob = kwargs['prob']
    self.vocab_size = kwargs['vocab_size']
    self.nSamples = kwargs['nSamples']
    self.src_data = np.random.choice(self.vocab_size,
    self.nSamples)
    def __len__(self):
    return self.nSamples

    def func1(self, x):
    #TODO
    pass
    return x//2

    def func2(self, x):
    #TODO
    pass
    return 2*x + 1

    def __call__(self, x):
    def get_target(self, x):
    if np.random.random() > self.prob:
    return self.func1(x)
    return self.func2(x)

    src_vocab_size = 300
    nSamples = 1000
    prob = 0.9
    pairs = []
    genObj = GenerateDummyData(prob)
    for x in np.random.choice(src_vocab_size, nSamples):
    y = genObj(x)
    pairs.append(x, y)

    def __getitem__(self, index):
    assert index < self.nSamples
    x = self.src_data[index]
    y = self.get_target(x)
    return {'src':x, 'tgt':y}

  2. Deepayan137 created this gist Sep 16, 2020.
    32 changes: 32 additions & 0 deletions dummyData.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,32 @@
    import torch
    from torch.utils.data import Dataset
    import numpy as np



    class GenerateDummyData(object):
    def __init__(self, prob):
    print('Generating Dummy Data')
    self.prob = prob

    def func1(self, x):
    #TODO
    pass

    def func2(self, x):
    #TODO
    pass

    def __call__(self, x):
    if np.random.random() > self.prob:
    return self.func1(x)
    return self.func2(x)

    src_vocab_size = 300
    nSamples = 1000
    prob = 0.9
    pairs = []
    genObj = GenerateDummyData(prob)
    for x in np.random.choice(src_vocab_size, nSamples):
    y = genObj(x)
    pairs.append(x, y)