import torch
from torchcubicspline import(natural_cubic_spline_coeffs, 
                             NaturalCubicSpline)

t = torch.linspace(0, 1, 7)
# (2, 1) are batch dimensions. 7 is the time dimension
# (of the same length as t). 3 is the channel dimension.
x = torch.rand(2, 7, 3)
coeffs = natural_cubic_spline_coeffs(t, x)

spline = NaturalCubicSpline(coeffs)

point = torch.linspace(0, 1, 40)
# will be a tensor of shape (2, 1, 3), corresponding to
# batch, batch, and channel dimensions
out = spline.evaluate(point)
from pdb import set_trace as pdb_;pdb_() 
