Skip to content

Instantly share code, notes, and snippets.

@asifr
Created December 3, 2020 16:49
Show Gist options
  • Save asifr/043dd6da6cb95b096aae921c6ee0d3c4 to your computer and use it in GitHub Desktop.
Save asifr/043dd6da6cb95b096aae921c6ee0d3c4 to your computer and use it in GitHub Desktop.
import numpy as np
def pad_sequences(
sequences, maxlen=None, dtype="int32", padding="pre", truncating="pre", value=0.0
):
if not hasattr(sequences, "__len__"):
raise ValueError("`sequences` must be iterable.")
lengths = []
for x in sequences:
if not hasattr(x, "__len__"):
raise ValueError(
"`sequences` must be a list of iterables. "
"Found non-iterable: " + str(x)
)
lengths.append(len(x))
num_samples = len(sequences)
if maxlen is None:
maxlen = np.max(lengths)
# take the sample shape from the first non empty sequence
# checking for consistency in the main loop below.
sample_shape = tuple()
for s in sequences:
if len(s) > 0:
sample_shape = np.asarray(s).shape[1:]
break
x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
for idx, s in enumerate(sequences):
if not len(s):
continue # empty list/array was found
if truncating == "pre":
trunc = s[-maxlen:]
elif truncating == "post":
trunc = s[:maxlen]
else:
raise ValueError('Truncating type "%s" ' "not understood" % truncating)
# check `trunc` has expected shape
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError(
"Shape of sample %s of sequence at position %s "
"is different from expected shape %s"
% (trunc.shape[1:], idx, sample_shape)
)
if padding == "post":
x[idx, : len(trunc)] = trunc
elif padding == "pre":
x[idx, -len(trunc) :] = trunc
else:
raise ValueError('Padding type "%s" not understood' % padding)
return x
def pad(x, maxpadlen, value=np.nan):
return pad_sequences(
x, dtype=float, maxlen=maxpadlen, value=value, padding="post", truncating="pre"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment