Sophia Tang
Initial commit
b55bace
raw
history blame contribute delete
292 Bytes
import sys
import torch
class flow_model_torch_wrapper(torch.nn.Module):
"""Wraps model to torchdyn compatible format."""
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, t, x, *args, **kwargs):
return self.model(t, x)