Update aishell with k2 pruned rnnt loss

This commit is contained in:
pkufool 2022-01-20 11:42:02 +08:00
parent f94ff19bfe
commit e46409e90f
6 changed files with 90 additions and 51 deletions

View File

@ -73,9 +73,9 @@ def greedy_search(
continue continue
# fmt: off # fmt: off
current_encoder_out = encoder_out[:, t:t+1, :] current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on # fmt: on
logits = model.joiner(current_encoder_out, decoder_out) logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1))
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()

View File

@ -128,7 +128,7 @@ def get_params() -> AttributeDict:
{ {
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "embedding_dim": 256,
"subsampling_factor": 4, "subsampling_factor": 4,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
@ -145,7 +145,7 @@ def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim, d_model=params.attention_dim,
nhead=params.nhead, nhead=params.nhead,
@ -159,7 +159,7 @@ def get_encoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict):
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size, context_size=params.context_size,
) )
@ -168,8 +168,9 @@ def get_decoder_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict):
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.vocab_size,
output_dim=params.vocab_size, output_dim=params.vocab_size,
inner_dim=params.embedding_dim,
) )
return joiner return joiner
@ -408,7 +409,7 @@ def main():
device=device, device=device,
) )
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0] params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
logging.info(params) logging.info(params)

View File

@ -70,6 +70,7 @@ class Decoder(nn.Module):
groups=embedding_dim, groups=embedding_dim,
bias=False, bias=False,
) )
self.output_linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
@ -80,7 +81,7 @@ class Decoder(nn.Module):
True to left pad the input. Should be True during training. True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference. False to not pad the input. Should be False during inference.
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, vocab_size).
""" """
embeding_out = self.embedding(y) embeding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
@ -95,4 +96,5 @@ class Decoder(nn.Module):
assert embeding_out.size(-1) == self.context_size assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out) embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1) embeding_out = embeding_out.permute(0, 2, 1)
embeding_out = self.output_linear(embeding_out)
return embeding_out return embeding_out

View File

@ -19,10 +19,12 @@ import torch.nn as nn
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int): def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
super().__init__() super().__init__()
self.output_linear = nn.Linear(input_dim, output_dim) self.output_linear = nn.Sequential(
nn.Linear(input_dim, inner_dim), nn.Linear(inner_dim, output_dim)
)
def forward( def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
@ -36,15 +38,8 @@ class Joiner(nn.Module):
Returns: Returns:
Return a tensor of shape (N, T, U, C). Return a tensor of shape (N, T, U, C).
""" """
assert encoder_out.ndim == decoder_out.ndim == 3 assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.size(0) == decoder_out.size(0) assert encoder_out.shape == decoder_out.shape
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = torch.tanh(logit) logit = torch.tanh(logit)

View File

@ -1,4 +1,4 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -14,15 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
@ -38,6 +32,9 @@ class Transducer(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
prune_range: int = 5,
lm_scale: float = 0.0,
am_scale: float = 0.0,
): ):
""" """
Args: Args:
@ -62,6 +59,9 @@ class Transducer(nn.Module):
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
self.prune_range = prune_range
self.lm_scale = lm_scale
self.am_scale = am_scale
def forward( def forward(
self, self,
@ -102,24 +102,38 @@ class Transducer(nn.Module):
decoder_out = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS # Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0) y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), ( y_padded = y_padded.to(torch.int64)
f"Current torchaudio version: {torchaudio.__version__}\n" boundary = torch.zeros(
"Please install a version >= 0.10.0" (x.size(0), 4), dtype=torch.int64, device=x.device
)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
decoder_out,
encoder_out,
y_padded,
blank_id,
lm_only_scale=self.lm_scale,
am_only_scale=self.am_scale,
boundary=boundary,
return_grad=True,
) )
loss = torchaudio.functional.rnnt_loss( ranges = k2.get_rnnt_prune_ranges(
logits=logits, px_grad, py_grad, boundary, self.prune_range
targets=y_padded, )
logit_lengths=x_lens, am_pruned, lm_pruned = k2.do_rnnt_pruning(
target_lengths=y_lens, encoder_out, decoder_out, ranges
blank=blank_id,
reduction="sum",
) )
return loss logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
logits, y_padded, ranges, blank_id, boundary
)
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))

View File

@ -38,7 +38,6 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -129,6 +128,28 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with lm (output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network) part.",
)
return parser return parser
@ -185,18 +206,19 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 800,
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
# parameters for decoder
"embedding_dim": 256,
# parameters for Noam # parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k "warm_step": 30000,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -208,7 +230,7 @@ def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim, d_model=params.attention_dim,
nhead=params.nhead, nhead=params.nhead,
@ -222,7 +244,7 @@ def get_encoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict):
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size, context_size=params.context_size,
) )
@ -231,7 +253,8 @@ def get_decoder_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict):
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
) )
return joiner return joiner
@ -246,6 +269,9 @@ def get_transducer_model(params: AttributeDict):
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
prune_range=params.prune_range,
lm_scale=params.lm_scale,
am_scale=params.am_scale,
) )
return model return model
@ -374,7 +400,8 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y) simple_loss, pruned_loss = model(x=feature, x_lens=feature_lens, y=y)
loss = simple_loss + pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -383,6 +410,8 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info return loss, info
@ -476,7 +505,6 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
@ -555,10 +583,9 @@ def run(rank, world_size, args):
graph_compiler = CharCtcTrainingGraphCompiler( graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon, lexicon=lexicon,
device=device, device=device,
oov="<unk>",
) )
params.blank_id = graph_compiler.texts_to_ids("<blk>")[0][0] params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = max(lexicon.tokens) + 1
logging.info(params) logging.info(params)