Fix bugs in backward decoder

This commit is contained in:
pkufool 2022-01-27 17:34:44 +08:00
parent 8b7f43a027
commit 18f997fe51
4 changed files with 55 additions and 22 deletions

View File

@ -297,7 +297,7 @@ def beam_search(
current_encoder_out, decoder_out.unsqueeze(1) current_encoder_out, decoder_out.unsqueeze(1)
) )
# TODO(fangjun): Cache the blank posterior # TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1) log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size) # log_prob is (1, 1, 1, vocab_size)

View File

@ -84,24 +84,30 @@ class Decoder(nn.Module):
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
# If the input is [sos, a, b, c, d] and output is
# [a, b, c, d, eos], padding left and using kernel-size 2,
# it uses left context.
# If the input is [a, b, c, d, eos] and output is
# [sos, a, b, c, d], padding right and using kernel-size 2,
# it uses right context.
if self.backward: if self.backward:
assert self.context_size == 2 assert self.context_size == 2
embeding_out = F.pad( embedding_out = F.pad(
embeding_out, pad=(0, self.context_size - 1) embedding_out, pad=(0, self.context_size - 1)
) )
else: else:
embeding_out = F.pad( embedding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0) embedding_out, pad=(self.context_size - 1, 0)
) )
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
assert embeding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
assert self.backward is False assert self.backward is False
embeding_out = self.conv(embeding_out) embedding_out = self.conv(embedding_out)
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
return embeding_out return embedding_out

View File

@ -20,7 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_eos, add_sos
class Transducer(nn.Module): class Transducer(nn.Module):
@ -50,10 +50,30 @@ class Transducer(nn.Module):
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`. one attribute: `blank_id`.
backward_decoder:
Almost the same as decoder, except that it uses right context and
the decoder uses left context.
joiner: joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax. unnormalized probs, i.e., not processed by log-softmax.
backward_joiner:
The same as joiner, it intends for backward_decoder.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
""" """
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder) assert isinstance(encoder, EncoderInterface), type(encoder)
@ -110,7 +130,6 @@ class Transducer(nn.Module):
# Note: y does not start with SOS # Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros( boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device (x.size(0), 4), dtype=torch.int64, device=x.device
) )
@ -121,7 +140,7 @@ class Transducer(nn.Module):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
decoder_out, decoder_out,
encoder_out, encoder_out,
y_padded, y_padded.to(torch.int64),
blank_id, blank_id,
lm_only_scale=self.lm_scale, lm_only_scale=self.lm_scale,
am_only_scale=self.am_scale, am_only_scale=self.am_scale,
@ -138,13 +157,15 @@ class Transducer(nn.Module):
) )
logits = self.joiner(am_pruned, lm_pruned) logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits, y_padded, ranges, blank_id, boundary logits, y_padded.to(torch.int64), ranges, blank_id, boundary
) )
eos_y = add_eos(y, eos_id=blank_id)
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
# backward loss # backward loss
assert self.backward_decoder is not None assert self.backward_decoder is not None
assert self.backward_joiner is not None assert self.backward_joiner is not None
backward_decoder_out = self.backward_decoder(sos_y_padded) backward_decoder_out = self.backward_decoder(eos_y_padded)
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning( backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
encoder_out, backward_decoder_out, ranges encoder_out, backward_decoder_out, ranges
) )
@ -152,7 +173,11 @@ class Transducer(nn.Module):
backward_am_pruned, backward_lm_pruned backward_am_pruned, backward_lm_pruned
) )
backward_pruned_loss = k2.rnnt_loss_pruned( backward_pruned_loss = k2.rnnt_loss_pruned(
backward_logits, y_padded, ranges, blank_id, boundary backward_logits,
sos_y_padded.to(torch.int64),
ranges,
blank_id,
boundary,
) )
return ( return (

View File

@ -227,7 +227,7 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
"log_diagnostics": False, "log_diagnostics": True,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "encoder_out_dim": 512,
@ -246,7 +246,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
@ -261,7 +261,9 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict, backward: bool = False): def get_decoder_model(
params: AttributeDict, backward: bool = False
) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -272,7 +274,7 @@ def get_decoder_model(params: AttributeDict, backward: bool = False):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -280,7 +282,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)