mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
FIx
This commit is contained in:
parent
d92b6781b9
commit
f44d1b00b1
@ -789,7 +789,7 @@ def compute_loss(
|
|||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
simple_loss, pruned_loss, ctc_loss = model(
|
||||||
|
@ -2138,7 +2138,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
|
|
||||||
x = self.in_proj(x) # (time, batch, 2*channels)
|
x = self.in_proj(x) # (time, batch, 2*channels)
|
||||||
|
|
||||||
x, s = x.chunk(2, dim=-1)
|
x, s = x.chunk(2, dim=2)
|
||||||
s = self.balancer1(s)
|
s = self.balancer1(s)
|
||||||
s = self.sigmoid(s)
|
s = self.sigmoid(s)
|
||||||
x = self.activation1(x) # identity.
|
x = self.activation1(x) # identity.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user