Created
December 3, 2020 16:49
-
-
Save asifr/043dd6da6cb95b096aae921c6ee0d3c4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| def pad_sequences( | |
| sequences, maxlen=None, dtype="int32", padding="pre", truncating="pre", value=0.0 | |
| ): | |
| if not hasattr(sequences, "__len__"): | |
| raise ValueError("`sequences` must be iterable.") | |
| lengths = [] | |
| for x in sequences: | |
| if not hasattr(x, "__len__"): | |
| raise ValueError( | |
| "`sequences` must be a list of iterables. " | |
| "Found non-iterable: " + str(x) | |
| ) | |
| lengths.append(len(x)) | |
| num_samples = len(sequences) | |
| if maxlen is None: | |
| maxlen = np.max(lengths) | |
| # take the sample shape from the first non empty sequence | |
| # checking for consistency in the main loop below. | |
| sample_shape = tuple() | |
| for s in sequences: | |
| if len(s) > 0: | |
| sample_shape = np.asarray(s).shape[1:] | |
| break | |
| x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype) | |
| for idx, s in enumerate(sequences): | |
| if not len(s): | |
| continue # empty list/array was found | |
| if truncating == "pre": | |
| trunc = s[-maxlen:] | |
| elif truncating == "post": | |
| trunc = s[:maxlen] | |
| else: | |
| raise ValueError('Truncating type "%s" ' "not understood" % truncating) | |
| # check `trunc` has expected shape | |
| trunc = np.asarray(trunc, dtype=dtype) | |
| if trunc.shape[1:] != sample_shape: | |
| raise ValueError( | |
| "Shape of sample %s of sequence at position %s " | |
| "is different from expected shape %s" | |
| % (trunc.shape[1:], idx, sample_shape) | |
| ) | |
| if padding == "post": | |
| x[idx, : len(trunc)] = trunc | |
| elif padding == "pre": | |
| x[idx, -len(trunc) :] = trunc | |
| else: | |
| raise ValueError('Padding type "%s" not understood' % padding) | |
| return x | |
| def pad(x, maxpadlen, value=np.nan): | |
| return pad_sequences( | |
| x, dtype=float, maxlen=maxpadlen, value=value, padding="post", truncating="pre" | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment