Last active
May 17, 2021 23:05
-
-
Save keunwoochoi/c9592922a17d71b745d47dc8eb7f0538 to your computer and use it in GitHub Desktop.
Revisions
-
keunwoochoi revised this gist
Sep 28, 2019 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -50,7 +50,7 @@ def _tf_log10(x): ) mag_stfts = tf.abs(stfts) melgrams = tf.tensordot(tf.square(mag_stfts), self.lin_to_mel_matrix, axes=[2, 0]) log_melgrams = _tf_log10(melgrams + EPS) return tf.expand_dims(log_melgrams, 3) -
keunwoochoi renamed this gist
Sep 28, 2019 . 1 changed file with 0 additions and 0 deletions.There are no files selected for viewing
File renamed without changes. -
keunwoochoi created this gist
Sep 28, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,81 @@ # assuming num_fft = 512 NUM_FFT = 512 NUM_FREQS = 257 # some tentative constants NUM_MEL = 60 SAMPLE_RATE = 44100 F_MIN = 0 F_MAX = 12000 class LogMelgramLayer(tf.keras.layers.Layer): def __init__(self, num_fft, hop_length, **kwargs): super(LogMelgramLayer, self).__init__(**kwargs) self.num_fft = num_fft self.hop_length = hop_length assert num_fft // 2 + 1 == NUM_FREQS lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix( num_mel_bins=NUM_MEL, num_spectrogram_bins=NUM_FREQS, sample_rate=SAMPLE_RATE, lower_edge_hertz=F_MIN, upper_edge_hertz=F_MAX, ) self.lin_to_mel_matrix = lin_to_mel_matrix def build(self, input_shape): self.non_trainable_weights.append(self.lin_to_mel_matrix) super(LogMelgramLayer, self).build(input_shape) def call(self, input): """ Args: input (tensor): Batch of mono waveform, shape: (None, N) Returns: log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1) """ def _tf_log10(x): numerator = tf.math.log(x) denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) return numerator / denominator # tf.signal.stft seems to be applied along the last axis stfts = tf.signal.stft( input, frame_length=self.num_fft, frame_step=self.hop_length ) mag_stfts = tf.abs(stfts) melgrams = tf.tensordot(mag_stfts, self.lin_to_mel_matrix, axes=[2, 0]) log_melgrams = _tf_log10(melgrams + EPS) return tf.expand_dims(log_melgrams, 3) def get_config(self): config = {'num_fft': self.num_fft, 'hop_length': self.hop_length} base_config = super(LogMelgramLayer, self).get_config() return dict(list(config.items()) + list(base_config.items())) # in the model def model(): # ... input_shape = (44100 * 10, ) # 10-sec mono audio input inputs = Input(shape=input_shape, name='audio_waveform') log_melgram_layer = LogMelgramLayer( num_fft=NUM_FFT, hop_length=HOP_LENGTH, ) log_melgrams = log_melgram_layer(inputs) some_network = get_your_network() out = some_network(log_melgrams) model = tf.keras.models.Model(inputs=inputs, outputs=outputs) return model