mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
Merge branch 'k2-fsa:master' into multi
This commit is contained in:
commit
33a8bf54be
@ -58,7 +58,6 @@ class Decoder(nn.Module):
|
||||
self.embedding = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=decoder_dim,
|
||||
padding_idx=blank_id,
|
||||
)
|
||||
# the balancers are to avoid any drift in the magnitude of the
|
||||
# embeddings, which would interact badly with parameter averaging.
|
||||
|
@ -333,7 +333,7 @@ class AsrModel(nn.Module):
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y,
|
||||
y=y.to(x.device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
|
@ -789,7 +789,7 @@ def compute_loss(
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
|
@ -2190,7 +2190,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||
|
||||
x, s = x.chunk(2, dim=-1)
|
||||
x, s = x.chunk(2, dim=2)
|
||||
s = self.sigmoid(s)
|
||||
x = x * s
|
||||
# (time, batch, channels)
|
||||
|
Loading…
x
Reference in New Issue
Block a user