from local

This commit is contained in:
dohe0342 2022-12-09 17:22:49 +09:00
parent ab49e98a36
commit fac31c76e9
2 changed files with 1 additions and 1 deletions

View File

@ -672,7 +672,7 @@ def compute_loss(
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
assert feature.ndim == 2 or feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]