Skip to content

Instantly share code, notes, and snippets.

@cadrev
Forked from fmder/elastic_transform.py
Created July 11, 2016 07:58
Show Gist options
  • Select an option

  • Save cadrev/5d8403f23b5728d73278e4b28032eabb to your computer and use it in GitHub Desktop.

Select an option

Save cadrev/5d8403f23b5728d73278e4b28032eabb to your computer and use it in GitHub Desktop.

Revisions

  1. @fmder fmder revised this gist Aug 21, 2015. No changes.
  2. @fmder fmder created this gist Aug 21, 2015.
    24 changes: 24 additions & 0 deletions elastic_transform.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,24 @@
    import numpy
    from scipy.ndimage.interpolation import map_coordinates
    from scipy.ndimage.filters import gaussian_filter

    def elastic_transform(image, alpha, sigma, random_state=None):
    """Elastic deformation of images as described in [Simard2003]_.
    .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
    Convolutional Neural Networks applied to Visual Document Analysis", in
    Proc. of the International Conference on Document Analysis and
    Recognition, 2003.
    """
    if random_state is None:
    random_state = numpy.random.RandomState(None)

    shape = image.shape
    dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
    dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha

    x, y = numpy.meshgrid(numpy.arange(shape[0]), numpy.arange(shape[1]))
    indices = numpy.reshape(y+dy, (-1, 1)), numpy.reshape(x+dx, (-1, 1))

    return map_coordinates(image, indices, order=1).reshape(shape)