Skip to content

Instantly share code, notes, and snippets.

@gtLara
Forked from endolith/FIR_filter_NN.py
Created May 14, 2022 07:24
Show Gist options
  • Select an option

  • Save gtLara/99e98a7ff6065efe2f217d719c99059f to your computer and use it in GitHub Desktop.

Select an option

Save gtLara/99e98a7ff6065efe2f217d719c99059f to your computer and use it in GitHub Desktop.

Revisions

  1. @endolith endolith revised this gist Apr 9, 2022. 1 changed file with 2 additions and 4 deletions.
    6 changes: 2 additions & 4 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -4,12 +4,10 @@ My first experiment with Keras.

    ![](https://gist.github.com/endolith/1de9a8700f72b97974a2e93b0fba316a/raw/08d6ddf4ce234c7c6a559a628b009fe5c4bce6ea/FIR.png)

    Which is the same structure as a neural net (assuming no activation function),


    Which is the same structure as a neural net (assuming no activation function):

    ![](https://gist.github.com/endolith/1de9a8700f72b97974a2e93b0fba316a/raw/08d6ddf4ce234c7c6a559a628b009fe5c4bce6ea/ANN.png)

    so it should be able to learn the FIR coefficients by learning from chunks of signal before and after filtering, right?
    So it should be able to learn the FIR coefficients by learning from chunks of signal before and after filtering, right?

    **Conclusion:** Yep, it works great.
  2. @endolith endolith revised this gist Apr 9, 2022. 1 changed file with 2 additions and 2 deletions.
    4 changes: 2 additions & 2 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -2,13 +2,13 @@ My first experiment with Keras.

    **Hypothesis:** Each output sample of an FIR filter is just a sum of weighted input samples taken from a small chunk of the input:

    ![](https://upload.wikimedia.org/wikipedia/commons/thumb/9/9b/FIR_Filter.svg/320px-FIR_Filter.svg.png)
    ![](https://gist.github.com/endolith/1de9a8700f72b97974a2e93b0fba316a/raw/08d6ddf4ce234c7c6a559a628b009fe5c4bce6ea/FIR.png)

    Which is the same structure as a neural net (assuming no activation function),



    ![](https://upload.wikimedia.org/wikipedia/commons/thumb/6/6a/Perceptron-unit.svg/266px-Perceptron-unit.svg.png)
    ![](https://gist.github.com/endolith/1de9a8700f72b97974a2e93b0fba316a/raw/08d6ddf4ce234c7c6a559a628b009fe5c4bce6ea/ANN.png)

    so it should be able to learn the FIR coefficients by learning from chunks of signal before and after filtering, right?

  3. @endolith endolith revised this gist Apr 9, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -106,7 +106,7 @@ def rolling_window(a, window):
    Node-level graph
    """

    # Broken, requires old keras.
    # Working version: https://github.com/endolith/ann-visualizer
    # from ann_visualizer.visualize import ann_viz
    # ann_viz(model, title="Learned FIR filter")

  4. @endolith endolith revised this gist Mar 30, 2022. 1 changed file with 15 additions and 1 deletion.
    16 changes: 15 additions & 1 deletion README.md
    Original file line number Diff line number Diff line change
    @@ -1 +1,15 @@
    My first experiment with keras.
    My first experiment with Keras.

    **Hypothesis:** Each output sample of an FIR filter is just a sum of weighted input samples taken from a small chunk of the input:

    ![](https://upload.wikimedia.org/wikipedia/commons/thumb/9/9b/FIR_Filter.svg/320px-FIR_Filter.svg.png)

    Which is the same structure as a neural net (assuming no activation function),



    ![](https://upload.wikimedia.org/wikipedia/commons/thumb/6/6a/Perceptron-unit.svg/266px-Perceptron-unit.svg.png)

    so it should be able to learn the FIR coefficients by learning from chunks of signal before and after filtering, right?

    **Conclusion:** Yep, it works great.
  5. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -52,7 +52,7 @@ def rolling_window(a, window):


    X = rolling_window(sig, numtaps)
    Y = filtered
    Y = filtered # Filter outputs 1 sample for each chunk of input samples

    # plt.plot(X[0])
    # plt.plot(Y[0])
  6. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 2 additions and 9 deletions.
    11 changes: 2 additions & 9 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -124,21 +124,14 @@ def rolling_window(a, window):
    class LossHistory(Callback):
    def on_train_begin(self, logs={}):
    self.losses = []
    # self.n = 0
    # maxx = np.amax(np.abs(model.get_weights()[0]))
    # plt.imsave(str(self.n)+'.png', model.get_weights()[0],
    # cmap=cmap, vmin=-maxx, vmax=maxx)
    # self.n+=1
    # Could plot the convergence here

    def on_batch_end(self, batch, logs={}):
    self.losses.append(logs.get('loss'))

    def on_epoch_end(self, batch, logs={}):
    pass
    # maxx = np.amax(np.abs(model.get_weights()[0]))
    # plt.imsave(str(self.n)+'.png', model.get_weights()[0],
    # cmap=cmap, vmin=-maxx, vmax=maxx)
    # self.n+=1
    # Could plot the convergence here


    history = LossHistory()
  7. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 0 additions and 5 deletions.
    5 changes: 0 additions & 5 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -85,11 +85,6 @@ def rolling_window(a, window):
    """
    Make block diagram of network (not from tutorial)
    """

    # Workaround for some packaging bug
    import os
    os.environ["PATH"] += os.pathsep + r'C:\Anaconda3\Library\bin\graphviz'

    from tensorflow.keras.utils import plot_model
    plot_model(model, to_file='model.png', show_shapes=True)

  8. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 3 additions and 1 deletion.
    4 changes: 3 additions & 1 deletion FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -15,8 +15,10 @@
    Generate or load a signal to use as input data
    """
    # Only learns at the frequencies present in the signal
    # Learns at all frequencies with white noise
    # https://en.wikipedia.org/wiki/File:Short-beaked_Echidna.ogg
    sig, fs = read('echidna.wav')

    # Learns at all frequencies with white noise
    # sig, fs = np.random.randn(10000), 10000

    """
  9. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 5 additions and 4 deletions.
    9 changes: 5 additions & 4 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -109,13 +109,14 @@ def rolling_window(a, window):
    Node-level graph
    """

    from ann_visualizer.visualize import ann_viz

    # Broken, requires old keras.
    # Replace with https://github.com/Dicksonchin93/keras-architecture-visualizer/
    # from ann_visualizer.visualize import ann_viz
    # ann_viz(model, title="Learned FIR filter")


    # https://github.com/Dicksonchin93/keras-architecture-visualizer/
    # from keras_architecture_visualizer import KerasArchitectureVisualizer
    # vis = KerasArchitectureVisualizer()
    # vis.visualize(model)

    # Compile model
    model.compile(loss='mean_squared_error',
  10. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 3 additions and 2 deletions.
    5 changes: 3 additions & 2 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -25,8 +25,9 @@
    numtaps = 51
    # b = signal.firwin(numtaps, 1, fs=fs)
    # b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris',
    # pass_zero=False)
    b = signal.firwin(numtaps, [500, 1000], pass_zero=True, width=100, fs=fs)
    # pass_zero=False)
    b = signal.firwin(numtaps, cutoff=[6000, 11000], fs=fs,
    window='blackmanharris', pass_zero=False)

    # TODO: Use an IIR filter and have ANN approximate it as best it can

  11. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 10 additions and 12 deletions.
    22 changes: 10 additions & 12 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -11,18 +11,6 @@
    import matplotlib.pyplot as plt
    from soundfile import read

    """
    Create the FIR filter for the ANN to copy
    """
    # TODO: Use an IIR filter and have ANN approximate it as best it can
    numtaps = 51
    #b = signal.firwin(numtaps, 1, fs=fs)
    b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris',
    pass_zero=False)
    #b = signal.firwin(numtaps, [500, 1000], pass_zero=True, width=100, fs=fs)



    """
    Generate or load a signal to use as input data
    """
    @@ -31,6 +19,16 @@
    sig, fs = read('echidna.wav')
    # sig, fs = np.random.randn(10000), 10000

    """
    Create the FIR filter for the ANN to copy
    """
    numtaps = 51
    # b = signal.firwin(numtaps, 1, fs=fs)
    # b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris',
    # pass_zero=False)
    b = signal.firwin(numtaps, [500, 1000], pass_zero=True, width=100, fs=fs)

    # TODO: Use an IIR filter and have ANN approximate it as best it can

    """
    Training data is chunks of input and output of FIR filter
  12. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 11 additions and 9 deletions.
    20 changes: 11 additions & 9 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -158,15 +158,17 @@ def on_epoch_end(self, batch, logs={}):

    final = model.get_weights()

    plt.figure('kernel')
    plt.plot(b, '.-', label='Filter')
    plt.plot(initial[0], '.-', label='Initial')
    plt.plot(final[0], '.-', label='Learned')
    plt.grid(True, color='0.7', linestyle='-', which='major')
    plt.grid(True, color='0.9', linestyle='-', which='minor')
    plt.title('Kernel')
    plt.legend()

    fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, num='kernel', sharex=True)
    ax1.plot(b, '.-', label='Filter', alpha=0.5, c='gray')
    ax0.plot(initial[0], '.-', label='Initial')
    ax1.plot(final[0], '.-', label='Learned')
    ax0.grid(True, color='0.7', linestyle='-', which='major')
    ax0.grid(True, color='0.9', linestyle='-', which='minor')
    ax1.grid(True, color='0.7', linestyle='-', which='major')
    ax1.grid(True, color='0.9', linestyle='-', which='minor')
    ax0.set_title('Kernel')
    ax0.legend()
    ax1.legend()

    plt.figure('frequency response')
    w, h = signal.freqz(b, [1.0])
  13. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 3 additions and 3 deletions.
    6 changes: 3 additions & 3 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -152,9 +152,9 @@ def on_epoch_end(self, batch, logs={}):
    model.fit(X, Y, epochs=35, batch_size=100, callbacks=[history])

    # evaluate the model
    #print("Evaluating...")
    #scores = model.evaluate(X, Y)
    #print(scores[1]*100)
    print("Evaluating...")
    scores = model.evaluate(X, Y)
    print(scores*100) # percent??

    final = model.get_weights()

  14. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 10 additions and 1 deletion.
    11 changes: 10 additions & 1 deletion FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -59,9 +59,18 @@ def rolling_window(a, window):

    """
    Create model
    Initializer matters because signal might have missing areas of spectrum, and
    model will not learn there. So if the initial guess is all zeros, those areas
    of the spectrum will stay silenced, while the passband is "built up" for
    frequencies that are present.
    """
    model = Sequential([
    Dense(1, input_dim=numtaps, use_bias=False)
    Dense(1, input_dim=numtaps, use_bias=False,
    # kernel_initializer='random_normal', # typical usage
    # kernel_initializer='ones', # boxcar window = running average
    kernel_initializer='zeros', # nothing (good for non-white input)
    )
    ])

    model.summary()
  15. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 9 additions and 8 deletions.
    17 changes: 9 additions & 8 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -28,8 +28,8 @@
    """
    # Only learns at the frequencies present in the signal
    # Learns at all frequencies with white noise
    #sig, fs = read('echidna.wav')
    sig = np.random.randn(10000)
    sig, fs = read('echidna.wav')
    # sig, fs = np.random.randn(10000), 10000


    """
    @@ -161,16 +161,17 @@ def on_epoch_end(self, batch, logs={}):

    plt.figure('frequency response')
    w, h = signal.freqz(b, [1.0])
    plt.semilogx(w, 20*np.log10(abs(h)), label='Filter')

    plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Filter',
    alpha=0.5, c='gray')
    w, h = signal.freqz(initial[0], [1.0])
    plt.semilogx(w, 20*np.log10(abs(h)), label='Initial')

    plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Initial')
    w, h = signal.freqz(final[0], [1.0])
    plt.semilogx(w, 20*np.log10(abs(h)), label='Learned')

    plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Learned', alpha=0.5)
    plt.grid(True, color='0.7', linestyle='-', which='major')
    plt.grid(True, color='0.9', linestyle='-', which='minor')
    plt.xlabel('Frequency [Hz]')
    plt.ylabel('Response [dB]')
    plt.xlim(None, fs/2)
    plt.title('Frequency response')
    plt.legend()

  16. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 0 additions and 3 deletions.
    3 changes: 0 additions & 3 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -21,9 +21,6 @@
    pass_zero=False)
    #b = signal.firwin(numtaps, [500, 1000], pass_zero=True, width=100, fs=fs)

    # Plot filter response
    #w, h = signal.freqz(b, [1.0], 2000)
    #plt.plot(w/np.pi*fs/2, 20*np.log10(abs(h)), label='Filter')


    """
  17. @endolith endolith revised this gist Mar 26, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -15,7 +15,7 @@
    Create the FIR filter for the ANN to copy
    """
    # TODO: Use an IIR filter and have ANN approximate it as best it can
    numtaps = 1001
    numtaps = 51
    #b = signal.firwin(numtaps, 1, fs=fs)
    b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris',
    pass_zero=False)
  18. @endolith endolith revised this gist Mar 25, 2022. 1 changed file with 6 additions and 3 deletions.
    9 changes: 6 additions & 3 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -152,16 +152,17 @@ def on_epoch_end(self, batch, logs={}):

    final = model.get_weights()

    plt.figure()
    plt.figure('kernel')
    plt.plot(b, '.-', label='Filter')
    plt.plot(initial[0], '.-', label='Initial')
    plt.plot(final[0], '.-', label='Learned')
    plt.grid(True, color='0.7', linestyle='-', which='major')
    plt.grid(True, color='0.9', linestyle='-', which='minor')
    plt.title('Kernel')
    plt.legend()


    plt.figure()
    plt.figure('frequency response')
    w, h = signal.freqz(b, [1.0])
    plt.semilogx(w, 20*np.log10(abs(h)), label='Filter')

    @@ -173,10 +174,12 @@ def on_epoch_end(self, batch, logs={}):

    plt.grid(True, color='0.7', linestyle='-', which='major')
    plt.grid(True, color='0.9', linestyle='-', which='minor')
    plt.title('Frequency response')
    plt.legend()

    plt.figure()
    plt.figure('loss')
    plt.semilogy(history.losses)
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.grid(True, which="both")
    plt.title('Loss')
  19. @endolith endolith revised this gist Mar 24, 2022. 1 changed file with 10 additions and 7 deletions.
    17 changes: 10 additions & 7 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -3,9 +3,9 @@
    Created on Fri Aug 3 15:00:40 2018
    """
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.callbacks import Callback
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense
    from tensorflow.keras.callbacks import Callback
    import numpy as np
    from scipy import signal
    import matplotlib.pyplot as plt
    @@ -83,7 +83,7 @@ def rolling_window(a, window):
    import os
    os.environ["PATH"] += os.pathsep + r'C:\Anaconda3\Library\bin\graphviz'

    from keras.utils import plot_model
    from tensorflow.keras.utils import plot_model
    plot_model(model, to_file='model.png', show_shapes=True)


    @@ -94,8 +94,9 @@ def rolling_window(a, window):
    """
    import tensorflow as tf

    with tf.Session() as sess:
    writer = tf.summary.FileWriter('logs', sess.graph)
    # TODO: This isn't working like it used to. Replace with TF2.0 conventions.
    with tf.compat.v1.Session() as sess:
    writer = tf.compat.v1.summary.FileWriter('logs', sess.graph)
    writer.close()


    @@ -105,7 +106,9 @@ def rolling_window(a, window):

    from ann_visualizer.visualize import ann_viz

    #ann_viz(model, title="Learned FIR filter")
    # Broken, requires old keras.
    # Replace with https://github.com/Dicksonchin93/keras-architecture-visualizer/
    # ann_viz(model, title="Learned FIR filter")



  20. @endolith endolith revised this gist Mar 24, 2022. 1 changed file with 179 additions and 0 deletions.
    179 changes: 179 additions & 0 deletions FIR_filter_NN.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,179 @@
    """
    Train a neural network to learn an FIR filter.
    Created on Fri Aug 3 15:00:40 2018
    """
    from keras.models import Sequential
    from keras.layers import Dense
    from keras.callbacks import Callback
    import numpy as np
    from scipy import signal
    import matplotlib.pyplot as plt
    from soundfile import read

    """
    Create the FIR filter for the ANN to copy
    """
    # TODO: Use an IIR filter and have ANN approximate it as best it can
    numtaps = 1001
    #b = signal.firwin(numtaps, 1, fs=fs)
    b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris',
    pass_zero=False)
    #b = signal.firwin(numtaps, [500, 1000], pass_zero=True, width=100, fs=fs)

    # Plot filter response
    #w, h = signal.freqz(b, [1.0], 2000)
    #plt.plot(w/np.pi*fs/2, 20*np.log10(abs(h)), label='Filter')


    """
    Generate or load a signal to use as input data
    """
    # Only learns at the frequencies present in the signal
    # Learns at all frequencies with white noise
    #sig, fs = read('echidna.wav')
    sig = np.random.randn(10000)


    """
    Training data is chunks of input and output of FIR filter
    """
    # filtered = signal.lfilter(b, 1.0, sig)
    filtered = signal.convolve(sig, b, mode='valid')


    def rolling_window(a, window):
    """
    Return chunks of signal `a` of size `window`, incremented by 1 each time.
    https://gist.github.com/codehacken/708f19ae746784cef6e68b037af65788
    """
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)


    X = rolling_window(sig, numtaps)
    Y = filtered

    # plt.plot(X[0])
    # plt.plot(Y[0])


    """
    Create model
    """
    model = Sequential([
    Dense(1, input_dim=numtaps, use_bias=False)
    ])

    model.summary()

    initial = model.get_weights()

    print('Initial weights:')
    print(initial)


    """
    Make block diagram of network (not from tutorial)
    """

    # Workaround for some packaging bug
    import os
    os.environ["PATH"] += os.pathsep + r'C:\Anaconda3\Library\bin\graphviz'

    from keras.utils import plot_model
    plot_model(model, to_file='model.png', show_shapes=True)


    """
    Make graph diagram of network (not from tutorial)
    Viewable with `tensorboard --logdir="logs"`
    """
    import tensorflow as tf

    with tf.Session() as sess:
    writer = tf.summary.FileWriter('logs', sess.graph)
    writer.close()


    """
    Node-level graph
    """

    from ann_visualizer.visualize import ann_viz

    #ann_viz(model, title="Learned FIR filter")



    # Compile model
    model.compile(loss='mean_squared_error',
    optimizer='adam',
    )


    class LossHistory(Callback):
    def on_train_begin(self, logs={}):
    self.losses = []
    # self.n = 0
    # maxx = np.amax(np.abs(model.get_weights()[0]))
    # plt.imsave(str(self.n)+'.png', model.get_weights()[0],
    # cmap=cmap, vmin=-maxx, vmax=maxx)
    # self.n+=1

    def on_batch_end(self, batch, logs={}):
    self.losses.append(logs.get('loss'))

    def on_epoch_end(self, batch, logs={}):
    pass
    # maxx = np.amax(np.abs(model.get_weights()[0]))
    # plt.imsave(str(self.n)+'.png', model.get_weights()[0],
    # cmap=cmap, vmin=-maxx, vmax=maxx)
    # self.n+=1


    history = LossHistory()


    # Fit the model
    print("Fitting...")
    model.fit(X, Y, epochs=35, batch_size=100, callbacks=[history])

    # evaluate the model
    #print("Evaluating...")
    #scores = model.evaluate(X, Y)
    #print(scores[1]*100)

    final = model.get_weights()

    plt.figure()
    plt.plot(b, '.-', label='Filter')
    plt.plot(initial[0], '.-', label='Initial')
    plt.plot(final[0], '.-', label='Learned')
    plt.grid(True, color='0.7', linestyle='-', which='major')
    plt.grid(True, color='0.9', linestyle='-', which='minor')
    plt.legend()


    plt.figure()
    w, h = signal.freqz(b, [1.0])
    plt.semilogx(w, 20*np.log10(abs(h)), label='Filter')

    w, h = signal.freqz(initial[0], [1.0])
    plt.semilogx(w, 20*np.log10(abs(h)), label='Initial')

    w, h = signal.freqz(final[0], [1.0])
    plt.semilogx(w, 20*np.log10(abs(h)), label='Learned')

    plt.grid(True, color='0.7', linestyle='-', which='major')
    plt.grid(True, color='0.9', linestyle='-', which='minor')
    plt.legend()

    plt.figure()
    plt.semilogy(history.losses)
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.grid(True, which="both")
  21. @endolith endolith created this gist Mar 27, 2022.
    1 change: 1 addition & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1 @@
    My first experiment with keras.