mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
Fix bugs in backward decoder
This commit is contained in:
parent
8b7f43a027
commit
18f997fe51
@ -297,7 +297,7 @@ def beam_search(
|
|||||||
current_encoder_out, decoder_out.unsqueeze(1)
|
current_encoder_out, decoder_out.unsqueeze(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(fangjun): Cache the blank posterior
|
# TODO(fangjun): Scale the blank posterior
|
||||||
|
|
||||||
log_prob = logits.log_softmax(dim=-1)
|
log_prob = logits.log_softmax(dim=-1)
|
||||||
# log_prob is (1, 1, 1, vocab_size)
|
# log_prob is (1, 1, 1, vocab_size)
|
||||||
|
@ -84,24 +84,30 @@ class Decoder(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, embedding_dim).
|
||||||
"""
|
"""
|
||||||
embeding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
|
# If the input is [sos, a, b, c, d] and output is
|
||||||
|
# [a, b, c, d, eos], padding left and using kernel-size 2,
|
||||||
|
# it uses left context.
|
||||||
|
# If the input is [a, b, c, d, eos] and output is
|
||||||
|
# [sos, a, b, c, d], padding right and using kernel-size 2,
|
||||||
|
# it uses right context.
|
||||||
if self.backward:
|
if self.backward:
|
||||||
assert self.context_size == 2
|
assert self.context_size == 2
|
||||||
embeding_out = F.pad(
|
embedding_out = F.pad(
|
||||||
embeding_out, pad=(0, self.context_size - 1)
|
embedding_out, pad=(0, self.context_size - 1)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
embeding_out = F.pad(
|
embedding_out = F.pad(
|
||||||
embeding_out, pad=(self.context_size - 1, 0)
|
embedding_out, pad=(self.context_size - 1, 0)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
assert embeding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
assert self.backward is False
|
assert self.backward is False
|
||||||
embeding_out = self.conv(embeding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embeding_out = embeding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
return embeding_out
|
return embedding_out
|
||||||
|
@ -20,7 +20,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_eos, add_sos
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -50,10 +50,30 @@ class Transducer(nn.Module):
|
|||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
is (N, U) and its output shape is (N, U, C). It should contain
|
is (N, U) and its output shape is (N, U, C). It should contain
|
||||||
one attribute: `blank_id`.
|
one attribute: `blank_id`.
|
||||||
|
backward_decoder:
|
||||||
|
Almost the same as decoder, except that it uses right context and
|
||||||
|
the decoder uses left context.
|
||||||
joiner:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||||
output shape is (N, T, U, C). Note that its output contains
|
output shape is (N, T, U, C). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
|
backward_joiner:
|
||||||
|
The same as joiner, it intends for backward_decoder.
|
||||||
|
prune_range:
|
||||||
|
The prune range for rnnt loss, it means how many symbols(context)
|
||||||
|
we are considering for each frame to compute the loss.
|
||||||
|
am_scale:
|
||||||
|
The scale to smooth the loss with am (output of encoder network)
|
||||||
|
part
|
||||||
|
lm_scale:
|
||||||
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
|
part
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||||
|
the form:
|
||||||
|
lm_scale * lm_probs + am_scale * am_probs +
|
||||||
|
(1-lm_scale-am_scale) * combined_probs
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
@ -110,7 +130,6 @@ class Transducer(nn.Module):
|
|||||||
# Note: y does not start with SOS
|
# Note: y does not start with SOS
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
y_padded = y_padded.to(torch.int64)
|
|
||||||
boundary = torch.zeros(
|
boundary = torch.zeros(
|
||||||
(x.size(0), 4), dtype=torch.int64, device=x.device
|
(x.size(0), 4), dtype=torch.int64, device=x.device
|
||||||
)
|
)
|
||||||
@ -121,7 +140,7 @@ class Transducer(nn.Module):
|
|||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
decoder_out,
|
decoder_out,
|
||||||
encoder_out,
|
encoder_out,
|
||||||
y_padded,
|
y_padded.to(torch.int64),
|
||||||
blank_id,
|
blank_id,
|
||||||
lm_only_scale=self.lm_scale,
|
lm_only_scale=self.lm_scale,
|
||||||
am_only_scale=self.am_scale,
|
am_only_scale=self.am_scale,
|
||||||
@ -138,13 +157,15 @@ class Transducer(nn.Module):
|
|||||||
)
|
)
|
||||||
logits = self.joiner(am_pruned, lm_pruned)
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits, y_padded, ranges, blank_id, boundary
|
logits, y_padded.to(torch.int64), ranges, blank_id, boundary
|
||||||
)
|
)
|
||||||
|
|
||||||
|
eos_y = add_eos(y, eos_id=blank_id)
|
||||||
|
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
# backward loss
|
# backward loss
|
||||||
assert self.backward_decoder is not None
|
assert self.backward_decoder is not None
|
||||||
assert self.backward_joiner is not None
|
assert self.backward_joiner is not None
|
||||||
backward_decoder_out = self.backward_decoder(sos_y_padded)
|
backward_decoder_out = self.backward_decoder(eos_y_padded)
|
||||||
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
|
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
|
||||||
encoder_out, backward_decoder_out, ranges
|
encoder_out, backward_decoder_out, ranges
|
||||||
)
|
)
|
||||||
@ -152,7 +173,11 @@ class Transducer(nn.Module):
|
|||||||
backward_am_pruned, backward_lm_pruned
|
backward_am_pruned, backward_lm_pruned
|
||||||
)
|
)
|
||||||
backward_pruned_loss = k2.rnnt_loss_pruned(
|
backward_pruned_loss = k2.rnnt_loss_pruned(
|
||||||
backward_logits, y_padded, ranges, blank_id, boundary
|
backward_logits,
|
||||||
|
sos_y_padded.to(torch.int64),
|
||||||
|
ranges,
|
||||||
|
blank_id,
|
||||||
|
boundary,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -227,7 +227,7 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
"log_diagnostics": False,
|
"log_diagnostics": True,
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"encoder_out_dim": 512,
|
||||||
@ -246,7 +246,7 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict):
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
@ -261,7 +261,9 @@ def get_encoder_model(params: AttributeDict):
|
|||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict, backward: bool = False):
|
def get_decoder_model(
|
||||||
|
params: AttributeDict, backward: bool = False
|
||||||
|
) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
embedding_dim=params.encoder_out_dim,
|
embedding_dim=params.encoder_out_dim,
|
||||||
@ -272,7 +274,7 @@ def get_decoder_model(params: AttributeDict, backward: bool = False):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict):
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.encoder_out_dim,
|
input_dim=params.encoder_out_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
@ -280,7 +282,7 @@ def get_joiner_model(params: AttributeDict):
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict):
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user