Skip to content

Instantly share code, notes, and snippets.

@braman09
Forked from mblondel/check_convex.py
Created December 22, 2019 17:26
Show Gist options
  • Select an option

  • Save braman09/ac3fc8bc6e1e8e50eebacf7eef56eb0a to your computer and use it in GitHub Desktop.

Select an option

Save braman09/ac3fc8bc6e1e8e50eebacf7eef56eb0a to your computer and use it in GitHub Desktop.

Revisions

  1. @mblondel mblondel revised this gist Dec 21, 2019. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion check_convex.py
    Original file line number Diff line number Diff line change
    @@ -32,7 +32,7 @@ def check_convex(func, gen, max_iter=1000, max_inner=10,
    input function.
    If answers "not convex", a counter-example has been found and
    the function is guaranteed to be non-convex. Don't loose time proving its
    the function is guaranteed to be non-convex. Don't lose time proving its
    convexity!
    If answers "could be convex", you can't completely rule out the possibility
  2. @mblondel mblondel revised this gist Dec 21, 2019. 1 changed file with 0 additions and 1 deletion.
    1 change: 0 additions & 1 deletion check_convex.py
    Original file line number Diff line number Diff line change
    @@ -2,7 +2,6 @@
    # License: BSD 3 clause

    import numpy as np
    from scipy.linalg import eigh


    def _gen_pairs(gen, max_iter, max_inner, random_state, verbose):
  3. @mblondel mblondel revised this gist Dec 21, 2019. 1 changed file with 0 additions and 2 deletions.
    2 changes: 0 additions & 2 deletions check_convex.py
    Original file line number Diff line number Diff line change
    @@ -93,8 +93,6 @@ def check_convex(func, gen, max_iter=1000, max_inner=10,
    print("could be convex")

    if __name__ == "__main__":
    import sys

    def sqnorm(x):
    return np.dot(x, x)

  4. @mblondel mblondel revised this gist Dec 21, 2019. 1 changed file with 3 additions and 2 deletions.
    5 changes: 3 additions & 2 deletions check_convex.py
    Original file line number Diff line number Diff line change
    @@ -49,8 +49,9 @@ def check_convex(func, gen, max_iter=1000, max_inner=10,
    func:
    Function func(M) to be tested.
    shape: tuple
    Shape of the function argument M. Small arrays are recommended.
    gen: tuple or function
    If tuple, shape of the function argument M. Small arrays are recommended.
    If function, function for generating M.
    max_iter: int
    Max number of trials.
  5. @mblondel mblondel created this gist Dec 21, 2019.
    101 changes: 101 additions & 0 deletions check_convex.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,101 @@
    # Authors: Mathieu Blondel, Vlad Niculae
    # License: BSD 3 clause

    import numpy as np
    from scipy.linalg import eigh


    def _gen_pairs(gen, max_iter, max_inner, random_state, verbose):
    rng = np.random.RandomState(random_state)

    # if tuple, interpret as randn
    if isinstance(gen, tuple):
    shape = gen
    gen = lambda rng: rng.randn(*shape)

    for it in range(max_iter):
    if verbose:
    print("iter", it + 1)

    M1 = gen(rng)
    M2 = gen(rng)

    for t in np.linspace(0.01, 0.99, max_inner):
    M = t * M1 + (1 - t) * M2

    yield M, M1, M2, t


    def check_convex(func, gen, max_iter=1000, max_inner=10,
    quasi=False, random_state=None, eps=1e-9, verbose=0):
    """
    Numerically check whether the definition of a convex function holds for the
    input function.
    If answers "not convex", a counter-example has been found and
    the function is guaranteed to be non-convex. Don't loose time proving its
    convexity!
    If answers "could be convex", you can't completely rule out the possibility
    that the function is non-convex. To be completely sure, this needs to be
    proved analytically.
    This approach was explained by S. Boyd in his convex analysis lectures at
    Stanford.
    Parameters
    ----------
    func:
    Function func(M) to be tested.
    shape: tuple
    Shape of the function argument M. Small arrays are recommended.
    max_iter: int
    Max number of trials.
    max_inner: int
    Max number of values between [0, 1] to be tested for the definition of
    convexity.
    quasi: bool (default=False)
    If True, use quasi-convex definition instead of convex.
    random_state: None or int
    Random seed to be used.
    eps: float
    Tolerance.
    verbose: int
    Verbosity level.
    """

    for M, M1, M2, t in _gen_pairs(gen, max_iter, max_inner,
    random_state, verbose):
    if quasi:
    # quasi-convex if f(M) <= max(f(M1), f(M2))
    # not quasi convex if f(M) > max(f(M1), f(M2))
    diff = func(M) - max(func(M1), func(M2))
    else:
    # convex if f(M) <= t * f(M1) + (1 - t) * f(M2)
    # non-convex if f(M) > t * f(M1) + (1 - t) * f(M2)
    diff = func(M) - (t * func(M1) + (1 - t) * func(M2))

    if diff > eps:
    # We found a counter-example.
    print("not convex (diff=%f)" % diff)
    return

    # To be completely sure, this needs to be proved analytically.
    print("could be convex")

    if __name__ == "__main__":
    import sys

    def sqnorm(x):
    return np.dot(x, x)

    check_convex(sqnorm, gen=(5,), max_iter=10000, max_inner=10,
    random_state=0)