mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Revert model.py so there are no constraints on the output.
This commit is contained in:
parent
45c38dec61
commit
d37c159174
@ -19,12 +19,10 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import random_clamp
|
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||||
"Sequence Transduction with Recurrent Neural Networks"
|
"Sequence Transduction with Recurrent Neural Networks"
|
||||||
@ -142,12 +140,6 @@ class Transducer(nn.Module):
|
|||||||
lm = self.simple_lm_proj(decoder_out)
|
lm = self.simple_lm_proj(decoder_out)
|
||||||
am = self.simple_am_proj(encoder_out)
|
am = self.simple_am_proj(encoder_out)
|
||||||
|
|
||||||
if self.training:
|
|
||||||
lm = random_clamp(lm, min=-8.0, max=2.0, prob=0.5,
|
|
||||||
reflect=0.1)
|
|
||||||
am = random_clamp(am, min=-5.0, max=5.0, prob=0.5,
|
|
||||||
reflect=0.1)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=lm.float(),
|
lm=lm.float(),
|
||||||
@ -183,10 +175,6 @@ class Transducer(nn.Module):
|
|||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||||
|
|
||||||
if self.training:
|
|
||||||
logits = random_clamp(logits, -8.0, 2.0, prob=0.5,
|
|
||||||
reflect=0.1)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits=logits.float(),
|
logits=logits.float(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user