mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Add backward decoder
This commit is contained in:
parent
725f6ddb9b
commit
8b7f43a027
@ -38,6 +38,7 @@ class Decoder(nn.Module):
|
|||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
context_size: int,
|
context_size: int,
|
||||||
|
backward: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -61,6 +62,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
|
self.backward = backward
|
||||||
if context_size > 1:
|
if context_size > 1:
|
||||||
self.conv = nn.Conv1d(
|
self.conv = nn.Conv1d(
|
||||||
in_channels=embedding_dim,
|
in_channels=embedding_dim,
|
||||||
@ -86,13 +88,20 @@ class Decoder(nn.Module):
|
|||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embeding_out = embeding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
embeding_out = F.pad(
|
if self.backward:
|
||||||
embeding_out, pad=(self.context_size - 1, 0)
|
assert self.context_size == 2
|
||||||
)
|
embeding_out = F.pad(
|
||||||
|
embeding_out, pad=(0, self.context_size - 1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
embeding_out = F.pad(
|
||||||
|
embeding_out, pad=(self.context_size - 1, 0)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
assert embeding_out.size(-1) == self.context_size
|
assert embeding_out.size(-1) == self.context_size
|
||||||
|
assert self.backward is False
|
||||||
embeding_out = self.conv(embeding_out)
|
embeding_out = self.conv(embeding_out)
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embeding_out = embeding_out.permute(0, 2, 1)
|
||||||
return embeding_out
|
return embeding_out
|
||||||
|
@ -32,8 +32,12 @@ class Transducer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
|
backward_decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
|
backward_joiner: nn.Module,
|
||||||
prune_range: int = 3,
|
prune_range: int = 3,
|
||||||
|
lm_scale: float = 0.0,
|
||||||
|
am_scale: float = 0.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -57,8 +61,12 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
|
self.backward_decoder = backward_decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
self.backward_joiner = backward_joiner
|
||||||
self.prune_range = prune_range
|
self.prune_range = prune_range
|
||||||
|
self.lm_scale = lm_scale
|
||||||
|
self.am_scale = am_scale
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -109,22 +117,46 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
|
# calculate prune ranges
|
||||||
decoder_out, encoder_out, y_padded, blank_id, boundary, True
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
|
decoder_out,
|
||||||
|
encoder_out,
|
||||||
|
y_padded,
|
||||||
|
blank_id,
|
||||||
|
lm_only_scale=self.lm_scale,
|
||||||
|
am_only_scale=self.am_scale,
|
||||||
|
boundary=boundary,
|
||||||
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
px_grad, py_grad, boundary, self.prune_range
|
px_grad, py_grad, boundary, self.prune_range
|
||||||
)
|
)
|
||||||
|
|
||||||
am_pruning, lm_pruning = k2.do_rnnt_pruning(
|
# forward loss
|
||||||
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
encoder_out, decoder_out, ranges
|
encoder_out, decoder_out, ranges
|
||||||
)
|
)
|
||||||
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
logits = self.joiner(am_pruning, lm_pruning)
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
|
||||||
pruning_loss = k2.rnnt_loss_pruned(
|
|
||||||
logits, y_padded, ranges, blank_id, boundary
|
logits, y_padded, ranges, blank_id, boundary
|
||||||
)
|
)
|
||||||
|
|
||||||
return (-torch.sum(simple_loss), -torch.sum(pruning_loss))
|
# backward loss
|
||||||
|
assert self.backward_decoder is not None
|
||||||
|
assert self.backward_joiner is not None
|
||||||
|
backward_decoder_out = self.backward_decoder(sos_y_padded)
|
||||||
|
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
|
||||||
|
encoder_out, backward_decoder_out, ranges
|
||||||
|
)
|
||||||
|
backward_logits = self.backward_joiner(
|
||||||
|
backward_am_pruned, backward_lm_pruned
|
||||||
|
)
|
||||||
|
backward_pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
backward_logits, y_padded, ranges, blank_id, boundary
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
-torch.sum(simple_loss),
|
||||||
|
-torch.sum(pruned_loss),
|
||||||
|
-torch.sum(backward_pruned_loss),
|
||||||
|
)
|
||||||
|
@ -154,6 +154,22 @@ def get_parser():
|
|||||||
"we are using to compute the loss",
|
"we are using to compute the loss",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="The scale to smooth the loss with lm "
|
||||||
|
"(output of prediction network) part.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--am-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="The scale to smooth the loss with am (output of encoder network)"
|
||||||
|
"part.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -245,12 +261,13 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict):
|
def get_decoder_model(params: AttributeDict, backward: bool = False):
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
blank_id=params.blank_id,
|
blank_id=params.blank_id,
|
||||||
context_size=params.context_size,
|
context_size=params.context_size,
|
||||||
|
backward=backward,
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
@ -267,11 +284,18 @@ def get_transducer_model(params: AttributeDict):
|
|||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
backward_decoder = get_decoder_model(params, backward=True)
|
||||||
|
backward_joiner = get_joiner_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = Transducer(
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
|
backward_decoder=backward_decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
backward_joiner=backward_joiner,
|
||||||
|
prune_range=params.prune_range,
|
||||||
|
lm_scale=params.lm_scale,
|
||||||
|
am_scale=params.am_scale,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -400,8 +424,10 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
|
simple_loss, pruned_loss, backward_pruned_loss = model(
|
||||||
loss = simple_loss + pruned_loss
|
x=feature, x_lens=feature_lens, y=y
|
||||||
|
)
|
||||||
|
loss = simple_loss + pruned_loss + backward_pruned_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -412,6 +438,7 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
info["backward_pruned_loss"] = backward_pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user