This commit is contained in:
Yifan Yang 2023-06-15 17:54:16 +08:00
parent 33a8bf54be
commit f6a18ec34d
6 changed files with 16 additions and 4 deletions

View File

@ -58,6 +58,7 @@ class Decoder(nn.Module):
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
num_embeddings=vocab_size, num_embeddings=vocab_size,
embedding_dim=decoder_dim, embedding_dim=decoder_dim,
padding_idx=blank_id,
) )
# the balancers are to avoid any drift in the magnitude of the # the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging. # embeddings, which would interact badly with parameter averaging.

View File

@ -333,7 +333,7 @@ class AsrModel(nn.Module):
simple_loss, pruned_loss = self.forward_transducer( simple_loss, pruned_loss = self.forward_transducer(
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
y=y.to(x.device), y=y,
y_lens=y_lens, y_lens=y_lens,
prune_range=prune_range, prune_range=prune_range,
am_scale=am_scale, am_scale=am_scale,

View File

@ -789,7 +789,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( simple_loss, pruned_loss, ctc_loss = model(

View File

@ -2190,7 +2190,7 @@ class ConvolutionModule(nn.Module):
x = self.in_proj(x) # (time, batch, 2*channels) x = self.in_proj(x) # (time, batch, 2*channels)
x, s = x.chunk(2, dim=2) x, s = x.chunk(2, dim=-1)
s = self.sigmoid(s) s = self.sigmoid(s)
x = x * s x = x * s
# (time, batch, channels) # (time, batch, channels)

View File

@ -66,6 +66,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then
cd data/fbank cd data/fbank
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz) . ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_dev-clean.jsonl.gz) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_dev-other.jsonl.gz) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_test-clean.jsonl.gz) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_test-other.jsonl.gz) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_train-clean-100) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_train-clean-360) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_train-other-500) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_dev-clean) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_dev-other) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_test-clean) .
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_test-other) .
cd ../.. cd ../..
else else
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3" log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3"

View File

@ -790,7 +790,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss = model( simple_loss, pruned_loss, ctc_loss = model(