Source code for elasticdeform.tf

import numpy
import tensorflow
import elasticdeform

[docs]def deform_grid(X, displacement, *args, **kwargs): """ Elastic deformation with a deformation grid, wrapped in a TensorFlow Op. This function wraps the ``elasticdeform.deform_grid`` function in a TensorFlow Op with a custom gradient. Parameters ---------- X : Tensor or list of Tensors input image or list of input images displacement : Tensor or numpy array displacement vectors for each control point Returns ------- Tensor the deformed image, or a list of deformed images See Also -------- elasticdeform.deform_grid : for the other parameters """ use_tf_v1 = hasattr(tensorflow, 'py_func') @tensorflow.custom_gradient def f(displacement, *xs): def fwd(displacement, *xs): if not use_tf_v1: xs = [x.numpy() for x in list(xs)] displacement = displacement.numpy() return elasticdeform.deform_grid(list(xs), displacement, *args, **kwargs) def bwd(*dys): def grad(*dys_disp_xs): dys = list(dys_disp_xs[:len(xs)]) displacement = dys_disp_xs[len(xs)] X_shape = [x.shape for x in dys_disp_xs[len(xs) + 1:]] if not use_tf_v1: dys = [dy.numpy() for dy in dys] displacement = displacement.numpy() dXs = elasticdeform.deform_grid_gradient(dys, displacement, *args, X_shape=X_shape, **kwargs) return [numpy.nan * displacement] + dXs grad_inputs = dys + (displacement,) + xs grad_output_dtypes = [displacement.dtype] + [x.dtype for x in xs] if use_tf_v1: # TensorFlow 1 return tensorflow.py_func(grad, grad_inputs, grad_output_dtypes, stateful=False, name='DeformGridGrad') else: # TensorFlow 2 return tensorflow.py_function(grad, grad_inputs, grad_output_dtypes, name='DeformGridGrad') inputs = (displacement,) + xs output_dtypes = [x.dtype for x in xs] if use_tf_v1: # TensorFlow 1 y = tensorflow.py_func(fwd, inputs, output_dtypes, stateful=False, name='DeformGrid') else: # TensorFlow 2 y = tensorflow.py_function(fwd, inputs, output_dtypes, name='DeformGrid') return y, bwd if isinstance(X, (list, tuple)): return f(displacement, *X) else: return f(displacement, X)[0]