Skip to content

Instantly share code, notes, and snippets.

@keunwoochoi
Last active May 17, 2021 23:05
Show Gist options
  • Save keunwoochoi/c9592922a17d71b745d47dc8eb7f0538 to your computer and use it in GitHub Desktop.
Save keunwoochoi/c9592922a17d71b745d47dc8eb7f0538 to your computer and use it in GitHub Desktop.

Revisions

  1. keunwoochoi revised this gist Sep 28, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion log_melspectrogram.py
    Original file line number Diff line number Diff line change
    @@ -50,7 +50,7 @@ def _tf_log10(x):
    )
    mag_stfts = tf.abs(stfts)

    melgrams = tf.tensordot(mag_stfts, self.lin_to_mel_matrix, axes=[2, 0])
    melgrams = tf.tensordot(tf.square(mag_stfts), self.lin_to_mel_matrix, axes=[2, 0])
    log_melgrams = _tf_log10(melgrams + EPS)
    return tf.expand_dims(log_melgrams, 3)

  2. keunwoochoi renamed this gist Sep 28, 2019. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  3. keunwoochoi created this gist Sep 28, 2019.
    81 changes: 81 additions & 0 deletions log_mel.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,81 @@
    # assuming num_fft = 512
    NUM_FFT = 512
    NUM_FREQS = 257
    # some tentative constants
    NUM_MEL = 60
    SAMPLE_RATE = 44100
    F_MIN = 0
    F_MAX = 12000


    class LogMelgramLayer(tf.keras.layers.Layer):
    def __init__(self, num_fft, hop_length, **kwargs):
    super(LogMelgramLayer, self).__init__(**kwargs)
    self.num_fft = num_fft
    self.hop_length = hop_length

    assert num_fft // 2 + 1 == NUM_FREQS
    lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix(
    num_mel_bins=NUM_MEL,
    num_spectrogram_bins=NUM_FREQS,
    sample_rate=SAMPLE_RATE,
    lower_edge_hertz=F_MIN,
    upper_edge_hertz=F_MAX,
    )

    self.lin_to_mel_matrix = lin_to_mel_matrix

    def build(self, input_shape):
    self.non_trainable_weights.append(self.lin_to_mel_matrix)
    super(LogMelgramLayer, self).build(input_shape)

    def call(self, input):
    """
    Args:
    input (tensor): Batch of mono waveform, shape: (None, N)
    Returns:
    log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1)
    """

    def _tf_log10(x):
    numerator = tf.math.log(x)
    denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
    return numerator / denominator

    # tf.signal.stft seems to be applied along the last axis
    stfts = tf.signal.stft(
    input, frame_length=self.num_fft, frame_step=self.hop_length
    )
    mag_stfts = tf.abs(stfts)

    melgrams = tf.tensordot(mag_stfts, self.lin_to_mel_matrix, axes=[2, 0])
    log_melgrams = _tf_log10(melgrams + EPS)
    return tf.expand_dims(log_melgrams, 3)

    def get_config(self):
    config = {'num_fft': self.num_fft, 'hop_length': self.hop_length}
    base_config = super(LogMelgramLayer, self).get_config()
    return dict(list(config.items()) + list(base_config.items()))



    # in the model
    def model():
    # ...
    input_shape = (44100 * 10, ) # 10-sec mono audio input
    inputs = Input(shape=input_shape, name='audio_waveform')

    log_melgram_layer = LogMelgramLayer(
    num_fft=NUM_FFT,
    hop_length=HOP_LENGTH,
    )

    log_melgrams = log_melgram_layer(inputs)

    some_network = get_your_network()
    out = some_network(log_melgrams)
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    return model