Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import torch.nn | |
import torch | |
def get_param(model) -> torch.nn.Parameter: | |
""" | |
Find the first parameter in a model or module. | |
""" | |
if hasattr(model, "model") and hasattr(model.model, "parameters"): | |
# Unpeel a model descriptor to get at the actual Torch module. | |
model = model.model | |
for param in model.parameters(): | |
return param | |
raise ValueError(f"No parameters found in model {model!r}") | |
def float64(t: torch.Tensor): | |
"""return torch.float64 if device is not mps or xpu, else return torch.float32""" | |
if t.device.type in ['mps', 'xpu']: | |
return torch.float32 | |
return torch.float64 | |