Fix bugs in backward decoder

This commit is contained in:
pkufool 2022-01-27 17:34:44 +08:00
parent 8b7f43a027
commit 18f997fe51
4 changed files with 55 additions and 22 deletions

View File

@ -297,7 +297,7 @@ def beam_search(
current_encoder_out, decoder_out.unsqueeze(1)
)
# TODO(fangjun): Cache the blank posterior
# TODO(fangjun): Scale the blank posterior
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)

View File

@ -84,24 +84,30 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embeding_out = self.embedding(y)
embedding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
# If the input is [sos, a, b, c, d] and output is
# [a, b, c, d, eos], padding left and using kernel-size 2,
# it uses left context.
# If the input is [a, b, c, d, eos] and output is
# [sos, a, b, c, d], padding right and using kernel-size 2,
# it uses right context.
if self.backward:
assert self.context_size == 2
embeding_out = F.pad(
embeding_out, pad=(0, self.context_size - 1)
embedding_out = F.pad(
embedding_out, pad=(0, self.context_size - 1)
)
else:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
assert embedding_out.size(-1) == self.context_size
assert self.backward is False
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out

View File

@ -20,7 +20,7 @@ import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
from icefall.utils import add_eos, add_sos
class Transducer(nn.Module):
@ -50,10 +50,30 @@ class Transducer(nn.Module):
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
backward_decoder:
Almost the same as decoder, except that it uses right context and
the decoder uses left context.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
backward_joiner:
The same as joiner, it intends for backward_decoder.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
@ -110,7 +130,6 @@ class Transducer(nn.Module):
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros(
(x.size(0), 4), dtype=torch.int64, device=x.device
)
@ -121,7 +140,7 @@ class Transducer(nn.Module):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
decoder_out,
encoder_out,
y_padded,
y_padded.to(torch.int64),
blank_id,
lm_only_scale=self.lm_scale,
am_only_scale=self.am_scale,
@ -138,13 +157,15 @@ class Transducer(nn.Module):
)
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
logits, y_padded, ranges, blank_id, boundary
logits, y_padded.to(torch.int64), ranges, blank_id, boundary
)
eos_y = add_eos(y, eos_id=blank_id)
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
# backward loss
assert self.backward_decoder is not None
assert self.backward_joiner is not None
backward_decoder_out = self.backward_decoder(sos_y_padded)
backward_decoder_out = self.backward_decoder(eos_y_padded)
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
encoder_out, backward_decoder_out, ranges
)
@ -152,7 +173,11 @@ class Transducer(nn.Module):
backward_am_pruned, backward_lm_pruned
)
backward_pruned_loss = k2.rnnt_loss_pruned(
backward_logits, y_padded, ranges, blank_id, boundary
backward_logits,
sos_y_padded.to(torch.int64),
ranges,
blank_id,
boundary,
)
return (

View File

@ -227,7 +227,7 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"log_diagnostics": False,
"log_diagnostics": True,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
@ -246,7 +246,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
@ -261,7 +261,9 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict, backward: bool = False):
def get_decoder_model(
params: AttributeDict, backward: bool = False
) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -272,7 +274,7 @@ def get_decoder_model(params: AttributeDict, backward: bool = False):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -280,7 +282,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)