From 0a465794a806ca42f43fb626ae8300b878b7ec43 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Thu, 15 Jun 2023 17:52:14 +0800 Subject: [PATCH] Fix Zipformer (#1132) * Update model.py * Update train.py * Update decoder.py --- egs/librispeech/ASR/zipformer/decoder.py | 1 - egs/librispeech/ASR/zipformer/model.py | 2 +- egs/librispeech/ASR/zipformer/train.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index 45432d570..e8db988f6 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -58,7 +58,6 @@ class Decoder(nn.Module): self.embedding = nn.Embedding( num_embeddings=vocab_size, embedding_dim=decoder_dim, - padding_idx=blank_id, ) # the balancers are to avoid any drift in the magnitude of the # embeddings, which would interact badly with parameter averaging. diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 9b7494972..0c3ea6a86 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -333,7 +333,7 @@ class AsrModel(nn.Module): simple_loss, pruned_loss = self.forward_transducer( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - y=y, + y=y.to(x.device), y_lens=y_lens, prune_range=prune_range, am_scale=am_scale, diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 1d1bee947..bc3e9c1ba 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -789,7 +789,7 @@ def compute_loss( texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) + y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss, ctc_loss = model(