mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
update
This commit is contained in:
parent
33a8bf54be
commit
f6a18ec34d
@ -58,6 +58,7 @@ class Decoder(nn.Module):
|
|||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
num_embeddings=vocab_size,
|
num_embeddings=vocab_size,
|
||||||
embedding_dim=decoder_dim,
|
embedding_dim=decoder_dim,
|
||||||
|
padding_idx=blank_id,
|
||||||
)
|
)
|
||||||
# the balancers are to avoid any drift in the magnitude of the
|
# the balancers are to avoid any drift in the magnitude of the
|
||||||
# embeddings, which would interact badly with parameter averaging.
|
# embeddings, which would interact badly with parameter averaging.
|
||||||
|
|||||||
@ -333,7 +333,7 @@ class AsrModel(nn.Module):
|
|||||||
simple_loss, pruned_loss = self.forward_transducer(
|
simple_loss, pruned_loss = self.forward_transducer(
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
y=y.to(x.device),
|
y=y,
|
||||||
y_lens=y_lens,
|
y_lens=y_lens,
|
||||||
prune_range=prune_range,
|
prune_range=prune_range,
|
||||||
am_scale=am_scale,
|
am_scale=am_scale,
|
||||||
|
|||||||
@ -789,7 +789,7 @@ def compute_loss(
|
|||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
simple_loss, pruned_loss, ctc_loss = model(
|
||||||
|
|||||||
@ -2190,7 +2190,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||||
|
|
||||||
x, s = x.chunk(2, dim=2)
|
x, s = x.chunk(2, dim=-1)
|
||||||
s = self.sigmoid(s)
|
s = self.sigmoid(s)
|
||||||
x = x * s
|
x = x * s
|
||||||
# (time, batch, channels)
|
# (time, batch, channels)
|
||||||
|
|||||||
@ -66,6 +66,17 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then
|
if [ -e ../../librispeech/ASR/data/fbank/.librispeech.done ]; then
|
||||||
cd data/fbank
|
cd data/fbank
|
||||||
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz) .
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_dev-clean.jsonl.gz) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_dev-other.jsonl.gz) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_test-clean.jsonl.gz) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_cuts_test-other.jsonl.gz) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_train-clean-100) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_train-clean-360) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_train-other-500) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_dev-clean) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_dev-other) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_test-clean) .
|
||||||
|
ln -svf $(realpath ../../../../librispeech/ASR/data/fbank/librispeech_feats_test-other) .
|
||||||
cd ../..
|
cd ../..
|
||||||
else
|
else
|
||||||
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3"
|
log "Abort! Please run ../../librispeech/ASR/prepare.sh --stage 3 --stop-stage 3"
|
||||||
|
|||||||
@ -790,7 +790,7 @@ def compute_loss(
|
|||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
simple_loss, pruned_loss, ctc_loss = model(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user