Skip to content

Instantly share code, notes, and snippets.

@ssingh1187
Forked from colejhudson/gradient_descent.py
Created February 28, 2019 19:02
Show Gist options
  • Save ssingh1187/9cfa9ebeff3a62f2a3014a4c701398b0 to your computer and use it in GitHub Desktop.
Save ssingh1187/9cfa9ebeff3a62f2a3014a4c701398b0 to your computer and use it in GitHub Desktop.
Gradient descent implemented from scratch over R^N to R^M
import numpy as np
import collections
# These functions have been designed with numpy in mind. To
# that end they take arrays as inputs and provide arrays as output
def derivative(f):
def df(x, h=0.1e-10):
return np.array((f(x+h)-f(x-h))/(2*h))
return df
def partial_derivative(f, nth=0, h=0.1e-10):
def partial(*xs):
dxs = np.array(xs, dtype=np.float64)
dxs[nth] = xs[nth] + h
return np.array((f(*dxs)-f(*xs))/h)
return partial
def gradient(f):
def grad(*xs):
if len(xs) == 1:
# If you don't watch for this special case,
# the result degerates and throws errors as
# numpy.ndarrays continue to nest
return np.array(derivative(f)(*xs))
else:
return np.array([partial_derivative(f, n)(*xs) for n in range(len(xs))])
return grad
def gradient_descent(func, learning_rate=0.1, convergence_threshold=0.05, max_steps=None):
assert 0 < learning_rate, "Learning rate must be greater than 0."
assert 0 < convergence_threshold, "Convergence threshold must be greater than zero."
Mapping = collections.namedtuple("Mapping", ("inputs", "outputs"))
def optimized_func(*xs):
n = 0
grad = gradient(func)
distance = float('inf')
y = func(*xs)
steps = np.array([[*xs, y]])
while convergence_threshold < distance:
if max_steps != None and max_steps <= n:
return steps
previous_step = steps[-1][:-1]
change = grad(*previous_step)
step = previous_step - learning_rate*change
y = func(*step)
n += 1
steps = np.vstack((steps, np.array([*step, y])))
# For the normal reasons, the higher dimensionality of the output
# the less meaningful this measure of distance becomes
distance = np.linalg.norm(steps[-2][:-1] - steps[-1][:-1])
return steps
return optimized_func
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment