Source code for elasticdeform.torch
import numpy
import torch
import elasticdeform
class ElasticDeform(torch.autograd.Function):
@staticmethod
def forward(ctx, displacement, deform_args, deform_kwargs, *xs):
ctx.save_for_backward(displacement)
ctx.deform_args = deform_args
ctx.deform_kwargs = deform_kwargs
ctx.x_shapes = [x.shape for x in xs]
xs_numpy = [x.detach().cpu().numpy() for x in xs]
displacement = displacement.detach().cpu().numpy()
ys = elasticdeform.deform_grid(xs_numpy, displacement, *deform_args, **deform_kwargs)
return tuple(torch.tensor(y, device=x.device) for x, y in zip(xs, ys))
@staticmethod
def backward(ctx, *dys):
displacement, = ctx.saved_tensors
deform_args = ctx.deform_args
deform_kwargs = ctx.deform_kwargs
x_shapes = ctx.x_shapes
dys_numpy = [dy.detach().cpu().numpy() for dy in dys]
displacement = displacement.detach().cpu().numpy()
dxs = elasticdeform.deform_grid_gradient(dys_numpy, displacement,
*deform_args, X_shape=x_shapes, **deform_kwargs)
return (None, None, None) + tuple(torch.tensor(dx, device=dy.device) for dx, dy in zip(dxs, dys))
[docs]def deform_grid(X, displacement, *args, **kwargs):
"""
Elastic deformation with a deformation grid, wrapped for PyTorch.
This function wraps the ``elasticdeform.deform_grid`` function in a PyTorch function
with a custom gradient.
Parameters
----------
X : torch.Tensor or list of torch.Tensors
input image or list of input images
displacement : torch.Tensor
displacement vectors for each control point
Returns
-------
torch.Tensor
the deformed image, or a list of deformed images
See Also
--------
elasticdeform.deform_grid : for the other parameters
"""
if not isinstance(X, (list, tuple)):
X_list = [X]
else:
X_list = X
displacement = torch.as_tensor(displacement)
y = ElasticDeform.apply(displacement, args, kwargs, *X_list)
if isinstance(X, (list, tuple)):
return y
else:
return y[0]