-
-
Save gtLara/99e98a7ff6065efe2f217d719c99059f to your computer and use it in GitHub Desktop.
Revisions
-
endolith revised this gist
Apr 9, 2022 . 1 changed file with 2 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -4,12 +4,10 @@ My first experiment with Keras.  Which is the same structure as a neural net (assuming no activation function):  So it should be able to learn the FIR coefficients by learning from chunks of signal before and after filtering, right? **Conclusion:** Yep, it works great. -
endolith revised this gist
Apr 9, 2022 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,13 +2,13 @@ My first experiment with Keras. **Hypothesis:** Each output sample of an FIR filter is just a sum of weighted input samples taken from a small chunk of the input:  Which is the same structure as a neural net (assuming no activation function),  so it should be able to learn the FIR coefficients by learning from chunks of signal before and after filtering, right? -
endolith revised this gist
Apr 9, 2022 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -106,7 +106,7 @@ def rolling_window(a, window): Node-level graph """ # Working version: https://github.com/endolith/ann-visualizer # from ann_visualizer.visualize import ann_viz # ann_viz(model, title="Learned FIR filter") -
endolith revised this gist
Mar 30, 2022 . 1 changed file with 15 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1 +1,15 @@ My first experiment with Keras. **Hypothesis:** Each output sample of an FIR filter is just a sum of weighted input samples taken from a small chunk of the input:  Which is the same structure as a neural net (assuming no activation function),  so it should be able to learn the FIR coefficients by learning from chunks of signal before and after filtering, right? **Conclusion:** Yep, it works great. -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -52,7 +52,7 @@ def rolling_window(a, window): X = rolling_window(sig, numtaps) Y = filtered # Filter outputs 1 sample for each chunk of input samples # plt.plot(X[0]) # plt.plot(Y[0]) -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 2 additions and 9 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -124,21 +124,14 @@ def rolling_window(a, window): class LossHistory(Callback): def on_train_begin(self, logs={}): self.losses = [] # Could plot the convergence here def on_batch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) def on_epoch_end(self, batch, logs={}): pass # Could plot the convergence here history = LossHistory() -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 0 additions and 5 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -85,11 +85,6 @@ def rolling_window(a, window): """ Make block diagram of network (not from tutorial) """ from tensorflow.keras.utils import plot_model plot_model(model, to_file='model.png', show_shapes=True) -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 3 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -15,8 +15,10 @@ Generate or load a signal to use as input data """ # Only learns at the frequencies present in the signal # https://en.wikipedia.org/wiki/File:Short-beaked_Echidna.ogg sig, fs = read('echidna.wav') # Learns at all frequencies with white noise # sig, fs = np.random.randn(10000), 10000 """ -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 5 additions and 4 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -109,13 +109,14 @@ def rolling_window(a, window): Node-level graph """ # Broken, requires old keras. # from ann_visualizer.visualize import ann_viz # ann_viz(model, title="Learned FIR filter") # https://github.com/Dicksonchin93/keras-architecture-visualizer/ # from keras_architecture_visualizer import KerasArchitectureVisualizer # vis = KerasArchitectureVisualizer() # vis.visualize(model) # Compile model model.compile(loss='mean_squared_error', -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 3 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -25,8 +25,9 @@ numtaps = 51 # b = signal.firwin(numtaps, 1, fs=fs) # b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris', # pass_zero=False) b = signal.firwin(numtaps, cutoff=[6000, 11000], fs=fs, window='blackmanharris', pass_zero=False) # TODO: Use an IIR filter and have ANN approximate it as best it can -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 10 additions and 12 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -11,18 +11,6 @@ import matplotlib.pyplot as plt from soundfile import read """ Generate or load a signal to use as input data """ @@ -31,6 +19,16 @@ sig, fs = read('echidna.wav') # sig, fs = np.random.randn(10000), 10000 """ Create the FIR filter for the ANN to copy """ numtaps = 51 # b = signal.firwin(numtaps, 1, fs=fs) # b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris', # pass_zero=False) b = signal.firwin(numtaps, [500, 1000], pass_zero=True, width=100, fs=fs) # TODO: Use an IIR filter and have ANN approximate it as best it can """ Training data is chunks of input and output of FIR filter -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 11 additions and 9 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -158,15 +158,17 @@ def on_epoch_end(self, batch, logs={}): final = model.get_weights() fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, num='kernel', sharex=True) ax1.plot(b, '.-', label='Filter', alpha=0.5, c='gray') ax0.plot(initial[0], '.-', label='Initial') ax1.plot(final[0], '.-', label='Learned') ax0.grid(True, color='0.7', linestyle='-', which='major') ax0.grid(True, color='0.9', linestyle='-', which='minor') ax1.grid(True, color='0.7', linestyle='-', which='major') ax1.grid(True, color='0.9', linestyle='-', which='minor') ax0.set_title('Kernel') ax0.legend() ax1.legend() plt.figure('frequency response') w, h = signal.freqz(b, [1.0]) -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -152,9 +152,9 @@ def on_epoch_end(self, batch, logs={}): model.fit(X, Y, epochs=35, batch_size=100, callbacks=[history]) # evaluate the model print("Evaluating...") scores = model.evaluate(X, Y) print(scores*100) # percent?? final = model.get_weights() -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 10 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -59,9 +59,18 @@ def rolling_window(a, window): """ Create model Initializer matters because signal might have missing areas of spectrum, and model will not learn there. So if the initial guess is all zeros, those areas of the spectrum will stay silenced, while the passband is "built up" for frequencies that are present. """ model = Sequential([ Dense(1, input_dim=numtaps, use_bias=False, # kernel_initializer='random_normal', # typical usage # kernel_initializer='ones', # boxcar window = running average kernel_initializer='zeros', # nothing (good for non-white input) ) ]) model.summary() -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 9 additions and 8 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -28,8 +28,8 @@ """ # Only learns at the frequencies present in the signal # Learns at all frequencies with white noise sig, fs = read('echidna.wav') # sig, fs = np.random.randn(10000), 10000 """ @@ -161,16 +161,17 @@ def on_epoch_end(self, batch, logs={}): plt.figure('frequency response') w, h = signal.freqz(b, [1.0]) plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Filter', alpha=0.5, c='gray') w, h = signal.freqz(initial[0], [1.0]) plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Initial') w, h = signal.freqz(final[0], [1.0]) plt.semilogx(w*fs/(2*np.pi), 20*np.log10(abs(h)), label='Learned', alpha=0.5) plt.grid(True, color='0.7', linestyle='-', which='major') plt.grid(True, color='0.9', linestyle='-', which='minor') plt.xlabel('Frequency [Hz]') plt.ylabel('Response [dB]') plt.xlim(None, fs/2) plt.title('Frequency response') plt.legend() -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 0 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -21,9 +21,6 @@ pass_zero=False) #b = signal.firwin(numtaps, [500, 1000], pass_zero=True, width=100, fs=fs) """ -
endolith revised this gist
Mar 26, 2022 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -15,7 +15,7 @@ Create the FIR filter for the ANN to copy """ # TODO: Use an IIR filter and have ANN approximate it as best it can numtaps = 51 #b = signal.firwin(numtaps, 1, fs=fs) b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris', pass_zero=False) -
endolith revised this gist
Mar 25, 2022 . 1 changed file with 6 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -152,16 +152,17 @@ def on_epoch_end(self, batch, logs={}): final = model.get_weights() plt.figure('kernel') plt.plot(b, '.-', label='Filter') plt.plot(initial[0], '.-', label='Initial') plt.plot(final[0], '.-', label='Learned') plt.grid(True, color='0.7', linestyle='-', which='major') plt.grid(True, color='0.9', linestyle='-', which='minor') plt.title('Kernel') plt.legend() plt.figure('frequency response') w, h = signal.freqz(b, [1.0]) plt.semilogx(w, 20*np.log10(abs(h)), label='Filter') @@ -173,10 +174,12 @@ def on_epoch_end(self, batch, logs={}): plt.grid(True, color='0.7', linestyle='-', which='major') plt.grid(True, color='0.9', linestyle='-', which='minor') plt.title('Frequency response') plt.legend() plt.figure('loss') plt.semilogy(history.losses) plt.xlabel('Batch') plt.ylabel('Loss') plt.grid(True, which="both") plt.title('Loss') -
endolith revised this gist
Mar 24, 2022 . 1 changed file with 10 additions and 7 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -3,9 +3,9 @@ Created on Fri Aug 3 15:00:40 2018 """ from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense from tensorflow.keras.callbacks import Callback import numpy as np from scipy import signal import matplotlib.pyplot as plt @@ -83,7 +83,7 @@ def rolling_window(a, window): import os os.environ["PATH"] += os.pathsep + r'C:\Anaconda3\Library\bin\graphviz' from tensorflow.keras.utils import plot_model plot_model(model, to_file='model.png', show_shapes=True) @@ -94,8 +94,9 @@ def rolling_window(a, window): """ import tensorflow as tf # TODO: This isn't working like it used to. Replace with TF2.0 conventions. with tf.compat.v1.Session() as sess: writer = tf.compat.v1.summary.FileWriter('logs', sess.graph) writer.close() @@ -105,7 +106,9 @@ def rolling_window(a, window): from ann_visualizer.visualize import ann_viz # Broken, requires old keras. # Replace with https://github.com/Dicksonchin93/keras-architecture-visualizer/ # ann_viz(model, title="Learned FIR filter") -
endolith revised this gist
Mar 24, 2022 . 1 changed file with 179 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,179 @@ """ Train a neural network to learn an FIR filter. Created on Fri Aug 3 15:00:40 2018 """ from keras.models import Sequential from keras.layers import Dense from keras.callbacks import Callback import numpy as np from scipy import signal import matplotlib.pyplot as plt from soundfile import read """ Create the FIR filter for the ANN to copy """ # TODO: Use an IIR filter and have ANN approximate it as best it can numtaps = 1001 #b = signal.firwin(numtaps, 1, fs=fs) b = signal.firwin(numtaps, cutoff=[0.3, 0.5], window='blackmanharris', pass_zero=False) #b = signal.firwin(numtaps, [500, 1000], pass_zero=True, width=100, fs=fs) # Plot filter response #w, h = signal.freqz(b, [1.0], 2000) #plt.plot(w/np.pi*fs/2, 20*np.log10(abs(h)), label='Filter') """ Generate or load a signal to use as input data """ # Only learns at the frequencies present in the signal # Learns at all frequencies with white noise #sig, fs = read('echidna.wav') sig = np.random.randn(10000) """ Training data is chunks of input and output of FIR filter """ # filtered = signal.lfilter(b, 1.0, sig) filtered = signal.convolve(sig, b, mode='valid') def rolling_window(a, window): """ Return chunks of signal `a` of size `window`, incremented by 1 each time. https://gist.github.com/codehacken/708f19ae746784cef6e68b037af65788 """ shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) strides = a.strides + (a.strides[-1],) return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) X = rolling_window(sig, numtaps) Y = filtered # plt.plot(X[0]) # plt.plot(Y[0]) """ Create model """ model = Sequential([ Dense(1, input_dim=numtaps, use_bias=False) ]) model.summary() initial = model.get_weights() print('Initial weights:') print(initial) """ Make block diagram of network (not from tutorial) """ # Workaround for some packaging bug import os os.environ["PATH"] += os.pathsep + r'C:\Anaconda3\Library\bin\graphviz' from keras.utils import plot_model plot_model(model, to_file='model.png', show_shapes=True) """ Make graph diagram of network (not from tutorial) Viewable with `tensorboard --logdir="logs"` """ import tensorflow as tf with tf.Session() as sess: writer = tf.summary.FileWriter('logs', sess.graph) writer.close() """ Node-level graph """ from ann_visualizer.visualize import ann_viz #ann_viz(model, title="Learned FIR filter") # Compile model model.compile(loss='mean_squared_error', optimizer='adam', ) class LossHistory(Callback): def on_train_begin(self, logs={}): self.losses = [] # self.n = 0 # maxx = np.amax(np.abs(model.get_weights()[0])) # plt.imsave(str(self.n)+'.png', model.get_weights()[0], # cmap=cmap, vmin=-maxx, vmax=maxx) # self.n+=1 def on_batch_end(self, batch, logs={}): self.losses.append(logs.get('loss')) def on_epoch_end(self, batch, logs={}): pass # maxx = np.amax(np.abs(model.get_weights()[0])) # plt.imsave(str(self.n)+'.png', model.get_weights()[0], # cmap=cmap, vmin=-maxx, vmax=maxx) # self.n+=1 history = LossHistory() # Fit the model print("Fitting...") model.fit(X, Y, epochs=35, batch_size=100, callbacks=[history]) # evaluate the model #print("Evaluating...") #scores = model.evaluate(X, Y) #print(scores[1]*100) final = model.get_weights() plt.figure() plt.plot(b, '.-', label='Filter') plt.plot(initial[0], '.-', label='Initial') plt.plot(final[0], '.-', label='Learned') plt.grid(True, color='0.7', linestyle='-', which='major') plt.grid(True, color='0.9', linestyle='-', which='minor') plt.legend() plt.figure() w, h = signal.freqz(b, [1.0]) plt.semilogx(w, 20*np.log10(abs(h)), label='Filter') w, h = signal.freqz(initial[0], [1.0]) plt.semilogx(w, 20*np.log10(abs(h)), label='Initial') w, h = signal.freqz(final[0], [1.0]) plt.semilogx(w, 20*np.log10(abs(h)), label='Learned') plt.grid(True, color='0.7', linestyle='-', which='major') plt.grid(True, color='0.9', linestyle='-', which='minor') plt.legend() plt.figure() plt.semilogy(history.losses) plt.xlabel('Batch') plt.ylabel('Loss') plt.grid(True, which="both") -
endolith created this gist
Mar 27, 2022 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1 @@ My first experiment with keras.