Skip to content

Instantly share code, notes, and snippets.

@amroamroamro
Last active February 24, 2025 18:17
Show Gist options
  • Select an option

  • Save amroamroamro/1db8d69b4b65e8bc66a6 to your computer and use it in GitHub Desktop.

Select an option

Save amroamroamro/1db8d69b4b65e8bc66a6 to your computer and use it in GitHub Desktop.

Revisions

  1. amroamroamro revised this gist Jun 16, 2023. 3 changed files with 138 additions and 2 deletions.
    13 changes: 11 additions & 2 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -1,9 +1,18 @@
    Python version of the MATLAB code in this Stack Overflow post:
    http://stackoverflow.com/a/18648210/97160
    https://stackoverflow.com/a/18648210/97160

    The example shows how to determine the best-fit plane/surface
    (1st or higher order polynomial) over a set of three-dimensional points.

    Implemented in Python + NumPy + SciPy + matplotlib.

    ![quadratic_surface](http://i.imgur.com/hquieGA.png)
    ![quadratic_surface](https://i.imgur.com/hquieGA.png)

    ---

    ### EDIT (2023-06-16)

    I added a new example `fit.py` that shows polynomial fitting of any n-th order,
    as well as the same thing but using scikit-learn functions `fit-sklearn.py`.

    ![peaks](https://i.imgur.com/XMMbCxH.png)
    60 changes: 60 additions & 0 deletions fit-sklearn.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,60 @@
    import numpy as np
    from sklearn.preprocessing import PolynomialFeatures
    from sklearn.linear_model import LinearRegression
    from sklearn.pipeline import make_pipeline
    import matplotlib.pyplot as plt

    def generateData(n = 30):
    # similar to peaks() function in MATLAB
    g = np.linspace(-3.0, 3.0, n)
    X, Y = np.meshgrid(g, g)
    X, Y = X.reshape(-1,1), Y.reshape(-1,1)
    Z = 3 * (1 - X)**2 * np.exp(- X**2 - (Y+1)**2) \
    - 10 * (X/5 - X**3 - Y**5) * np.exp(- X**2 - Y**2) \
    - 1/3 * np.exp(- (X+1)**2 - Y**2)
    return X, Y, Z

    def names2model(names):
    # C[i] * X^n * Y^m
    return ' + '.join([
    f"C[{i}]*{n.replace(' ','*')}"
    for i,n in enumerate(names)])

    # generate some random 3-dim points
    X, Y, Z = generateData()

    # 1=linear, 2=quadratic, 3=cubic, ..., nth degree
    order = 11

    # best-fit polynomial surface
    model = make_pipeline(
    PolynomialFeatures(degree=order),
    LinearRegression(fit_intercept=False))
    model.fit(np.c_[X, Y], Z)

    m = names2model(model[0].get_feature_names_out(['X', 'Y']))
    C = model[1].coef_.T # coefficients
    r2 = model.score(np.c_[X, Y], Z) # R-squared

    # print summary
    print(f'data = {Z.size}x3')
    print(f'model = {m}')
    print(f'coefficients =\n{C}')
    print(f'R2 = {r2}')

    # uniform grid covering the domain of the data
    XX,YY = np.meshgrid(np.linspace(X.min(), X.max(), 20), np.linspace(Y.min(), Y.max(), 20))

    # evaluate model on grid
    ZZ = model.predict(np.c_[XX.flatten(), YY.flatten()]).reshape(XX.shape)

    # plot points and fitted surface
    ax = plt.figure().add_subplot(projection='3d')
    ax.scatter(X, Y, Z, c='r', s=2)
    ax.plot_surface(XX, YY, ZZ, rstride=1, cstride=1, alpha=0.2, linewidth=0.5, edgecolor='b')
    ax.axis('tight')
    ax.view_init(azim=-60.0, elev=30.0)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    plt.show()
    67 changes: 67 additions & 0 deletions fit.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,67 @@
    import numpy as np
    from scipy.linalg import lstsq
    import matplotlib.pyplot as plt

    def generateData(n = 30):
    # similar to peaks() function in MATLAB
    g = np.linspace(-3.0, 3.0, n)
    X, Y = np.meshgrid(g, g)
    X, Y = X.reshape(-1,1), Y.reshape(-1,1)
    Z = 3 * (1 - X)**2 * np.exp(- X**2 - (Y+1)**2) \
    - 10 * (X/5 - X**3 - Y**5) * np.exp(- X**2 - Y**2) \
    - 1/3 * np.exp(- (X+1)**2 - Y**2)
    return X, Y, Z

    def exp2model(e):
    # C[i] * X^n * Y^m
    return ' + '.join([
    f'C[{i}]' +
    ('*' if x>0 or y>0 else '') +
    (f'X^{x}' if x>1 else 'X' if x==1 else '') +
    ('*' if x>0 and y>0 else '') +
    (f'Y^{y}' if y>1 else 'Y' if y==1 else '')
    for i,(x,y) in enumerate(e)
    ])

    # generate some random 3-dim points
    X, Y, Z = generateData()

    # 1=linear, 2=quadratic, 3=cubic, ..., nth degree
    order = 11

    # calculate exponents of design matrix
    #e = [(x,y) for x in range(0,order+1) for y in range(0,order-x+1)]
    e = [(x,y) for n in range(0,order+1) for y in range(0,n+1) for x in range(0,n+1) if x+y==n]
    eX = np.asarray([[x] for x,_ in e]).T
    eY = np.asarray([[y] for _,y in e]).T

    # best-fit polynomial surface
    A = (X ** eX) * (Y ** eY)
    C,resid,_,_ = lstsq(A, Z) # coefficients

    # calculate R-squared from residual error
    r2 = 1 - resid[0] / (Z.size * Z.var())

    # print summary
    print(f'data = {Z.size}x3')
    print(f'model = {exp2model(e)}')
    print(f'coefficients =\n{C}')
    print(f'R2 = {r2}')

    # uniform grid covering the domain of the data
    XX,YY = np.meshgrid(np.linspace(X.min(), X.max(), 20), np.linspace(Y.min(), Y.max(), 20))

    # evaluate model on grid
    A = (XX.reshape(-1,1) ** eX) * (YY.reshape(-1,1) ** eY)
    ZZ = np.dot(A, C).reshape(XX.shape)

    # plot points and fitted surface
    ax = plt.figure().add_subplot(projection='3d')
    ax.scatter(X, Y, Z, c='r', s=2)
    ax.plot_surface(XX, YY, ZZ, rstride=1, cstride=1, alpha=0.2, linewidth=0.5, edgecolor='b')
    ax.axis('tight')
    ax.view_init(azim=-60.0, elev=30.0)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    plt.show()
  2. amroamroamro revised this gist Feb 20, 2015. 1 changed file with 2 additions and 0 deletions.
    2 changes: 2 additions & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -5,3 +5,5 @@ The example shows how to determine the best-fit plane/surface
    (1st or higher order polynomial) over a set of three-dimensional points.

    Implemented in Python + NumPy + SciPy + matplotlib.

    ![quadratic_surface](http://i.imgur.com/hquieGA.png)
  3. amroamroamro created this gist Feb 20, 2015.
    7 changes: 7 additions & 0 deletions README.md
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,7 @@
    Python version of the MATLAB code in this Stack Overflow post:
    http://stackoverflow.com/a/18648210/97160

    The example shows how to determine the best-fit plane/surface
    (1st or higher order polynomial) over a set of three-dimensional points.

    Implemented in Python + NumPy + SciPy + matplotlib.
    48 changes: 48 additions & 0 deletions curve_fitting.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,48 @@
    #!/usr/bin/evn python

    import numpy as np
    import scipy.linalg
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt

    # some 3-dim points
    mean = np.array([0.0,0.0,0.0])
    cov = np.array([[1.0,-0.5,0.8], [-0.5,1.1,0.0], [0.8,0.0,1.0]])
    data = np.random.multivariate_normal(mean, cov, 50)

    # regular grid covering the domain of the data
    X,Y = np.meshgrid(np.arange(-3.0, 3.0, 0.5), np.arange(-3.0, 3.0, 0.5))
    XX = X.flatten()
    YY = Y.flatten()

    order = 1 # 1: linear, 2: quadratic
    if order == 1:
    # best-fit linear plane
    A = np.c_[data[:,0], data[:,1], np.ones(data.shape[0])]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2]) # coefficients

    # evaluate it on grid
    Z = C[0]*X + C[1]*Y + C[2]

    # or expressed using matrix/vector product
    #Z = np.dot(np.c_[XX, YY, np.ones(XX.shape)], C).reshape(X.shape)

    elif order == 2:
    # best-fit quadratic curve
    A = np.c_[np.ones(data.shape[0]), data[:,:2], np.prod(data[:,:2], axis=1), data[:,:2]**2]
    C,_,_,_ = scipy.linalg.lstsq(A, data[:,2])

    # evaluate it on a grid
    Z = np.dot(np.c_[np.ones(XX.shape), XX, YY, XX*YY, XX**2, YY**2], C).reshape(X.shape)

    # plot points and fitted surface
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.2)
    ax.scatter(data[:,0], data[:,1], data[:,2], c='r', s=50)
    plt.xlabel('X')
    plt.ylabel('Y')
    ax.set_zlabel('Z')
    ax.axis('equal')
    ax.axis('tight')
    plt.show()