Fix Zipformer (#1132)

* Update model.py

* Update train.py

* Update decoder.py
This commit is contained in:
Yifan Yang 2023-06-15 17:52:14 +08:00 committed by GitHub
parent 947f0614c9
commit 0a465794a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 2 additions and 3 deletions

View File

@ -58,7 +58,6 @@ class Decoder(nn.Module):
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
padding_idx=blank_id,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.

View File

@ -333,7 +333,7 @@ class AsrModel(nn.Module):
simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
y=y,
y=y.to(x.device),
y_lens=y_lens,
prune_range=prune_range,
am_scale=am_scale,

View File

@ -789,7 +789,7 @@ def compute_loss(
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model(