Add backward decoder

This commit is contained in:
pkufool 2022-01-26 14:27:41 +08:00
parent 725f6ddb9b
commit 8b7f43a027
3 changed files with 83 additions and 15 deletions

View File

@ -38,6 +38,7 @@ class Decoder(nn.Module):
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
context_size: int, context_size: int,
backward: bool = False,
): ):
""" """
Args: Args:
@ -61,6 +62,7 @@ class Decoder(nn.Module):
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size
self.backward = backward
if context_size > 1: if context_size > 1:
self.conv = nn.Conv1d( self.conv = nn.Conv1d(
in_channels=embedding_dim, in_channels=embedding_dim,
@ -86,13 +88,20 @@ class Decoder(nn.Module):
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embeding_out = embeding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embeding_out = F.pad( if self.backward:
embeding_out, pad=(self.context_size - 1, 0) assert self.context_size == 2
) embeding_out = F.pad(
embeding_out, pad=(0, self.context_size - 1)
)
else:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
)
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
assert embeding_out.size(-1) == self.context_size assert embeding_out.size(-1) == self.context_size
assert self.backward is False
embeding_out = self.conv(embeding_out) embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1) embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out return embeding_out

View File

@ -32,8 +32,12 @@ class Transducer(nn.Module):
self, self,
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
backward_decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
backward_joiner: nn.Module,
prune_range: int = 3, prune_range: int = 3,
lm_scale: float = 0.0,
am_scale: float = 0.0,
): ):
""" """
Args: Args:
@ -57,8 +61,12 @@ class Transducer(nn.Module):
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.backward_decoder = backward_decoder
self.joiner = joiner self.joiner = joiner
self.backward_joiner = backward_joiner
self.prune_range = prune_range self.prune_range = prune_range
self.lm_scale = lm_scale
self.am_scale = am_scale
def forward( def forward(
self, self,
@ -109,22 +117,46 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple( # calculate prune ranges
decoder_out, encoder_out, y_padded, blank_id, boundary, True simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
decoder_out,
encoder_out,
y_padded,
blank_id,
lm_only_scale=self.lm_scale,
am_only_scale=self.am_scale,
boundary=boundary,
return_grad=True,
) )
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
px_grad, py_grad, boundary, self.prune_range px_grad, py_grad, boundary, self.prune_range
) )
am_pruning, lm_pruning = k2.do_rnnt_pruning( # forward loss
am_pruned, lm_pruned = k2.do_rnnt_pruning(
encoder_out, decoder_out, ranges encoder_out, decoder_out, ranges
) )
logits = self.joiner(am_pruned, lm_pruned)
logits = self.joiner(am_pruning, lm_pruning) pruned_loss = k2.rnnt_loss_pruned(
pruning_loss = k2.rnnt_loss_pruned(
logits, y_padded, ranges, blank_id, boundary logits, y_padded, ranges, blank_id, boundary
) )
return (-torch.sum(simple_loss), -torch.sum(pruning_loss)) # backward loss
assert self.backward_decoder is not None
assert self.backward_joiner is not None
backward_decoder_out = self.backward_decoder(sos_y_padded)
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
encoder_out, backward_decoder_out, ranges
)
backward_logits = self.backward_joiner(
backward_am_pruned, backward_lm_pruned
)
backward_pruned_loss = k2.rnnt_loss_pruned(
backward_logits, y_padded, ranges, blank_id, boundary
)
return (
-torch.sum(simple_loss),
-torch.sum(pruned_loss),
-torch.sum(backward_pruned_loss),
)

View File

@ -154,6 +154,22 @@ def get_parser():
"we are using to compute the loss", "we are using to compute the loss",
) )
parser.add_argument(
"--lm-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
return parser return parser
@ -245,12 +261,13 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict, backward: bool = False):
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size, context_size=params.context_size,
backward=backward,
) )
return decoder return decoder
@ -267,11 +284,18 @@ def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
backward_decoder = get_decoder_model(params, backward=True)
backward_joiner = get_joiner_model(params)
model = Transducer( model = Transducer(
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
backward_decoder=backward_decoder,
joiner=joiner, joiner=joiner,
backward_joiner=backward_joiner,
prune_range=params.prune_range,
lm_scale=params.lm_scale,
am_scale=params.am_scale,
) )
return model return model
@ -400,8 +424,10 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y) simple_loss, pruned_loss, backward_pruned_loss = model(
loss = simple_loss + pruned_loss x=feature, x_lens=feature_lens, y=y
)
loss = simple_loss + pruned_loss + backward_pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -412,6 +438,7 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
info["backward_pruned_loss"] = backward_pruned_loss.detach().cpu().item()
return loss, info return loss, info