diff --git a/egs/librispeech/ASR/zipformer/attention_decoder.py b/egs/librispeech/ASR/zipformer/attention_decoder.py new file mode 100644 index 000000000..0ee24f9a8 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/attention_decoder.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The model structure is modified from Daniel Povey's Zipformer +# https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py + +import math +from typing import List, Tuple + +import k2 +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss + +from icefall.utils import add_eos, add_sos, make_pad_mask + + +class AttentionDecoderModel(nn.Module): + """ + Args: + vocab_size (int): Number of classes. + decoder_dim: (int,int): embedding dimension of 2 encoder stacks + attention_dim: (int,int): attention dimension of 2 encoder stacks + nhead (int, int): number of heads + dim_feedforward (int, int): feedforward dimension in 2 encoder stacks + num_encoder_layers (int): number of encoder layers + dropout (float): dropout rate + """ + + def __init__( + self, + vocab_size: int, + decoder_dim: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + nhead: int = 8, + feedforward_dim: int = 2048, + sos_id: int = 1, + eos_id: int = 1, + dropout: float = 0.1, + ignore_id: int = -1, + label_smoothing: float = 0.1, + ): + super().__init__() + self.eos_id = eos_id + self.sos_id = sos_id + self.ignore_id = ignore_id + + # For the segment of the warmup period, we let the Embedding + # layer learn something. Then we start to warm up the other encoders. + self.decoder = TransformerDecoder( + vocab_size=vocab_size, + d_model=decoder_dim, + num_decoder_layers=num_decoder_layers, + attention_dim=attention_dim, + nhead=nhead, + feedforward_dim=feedforward_dim, + dropout=dropout, + ) + + # Used to calculate attention-decoder loss + self.loss_fun = LabelSmoothingLoss( + ignore_index=ignore_id, label_smoothing=label_smoothing, reduction="sum" + ) + + def _pre_ys_in_out(self, ys: k2.RaggedTensor, ys_lens: torch.Tensor): + """Prepare ys_in_pad and ys_out_pad.""" + ys_in = add_sos(ys, sos_id=self.sos_id) + # [B, S+1], start with SOS + ys_in_pad = ys_in.pad(mode="constant", padding_value=self.eos_id) + ys_in_lens = ys_lens + 1 + + ys_out = add_eos(ys, eos_id=self.eos_id) + # [B, S+1], end with EOS + ys_out_pad = ys_out.pad(mode="constant", padding_value=self.ignore_id) + + return ys_in_pad.to(torch.int64), ys_in_lens, ys_out_pad.to(torch.int64) + + def calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys: k2.RaggedTensor, + ys_lens: torch.Tensor, + ) -> torch.Tensor: + """Calculate attention-decoder loss. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: The attention-decoder loss. + """ + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens) + + loss = self.loss_fun(x=decoder_out, target=ys_out_pad) + return loss + + def nll( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + token_ids: List[List[int]], + ) -> torch.Tensor: + """Compute negative log likelihood(nll) from attention-decoder. + Args: + encoder_out: (batch, num_frames, encoder_dim) + encoder_out_lens: (batch,) + token_ids: A list of token id list. + + Return: A tensor of shape (batch, num_tokens). + """ + ys = k2.RaggedTensor(token_ids).to(device=encoder_out.device) + row_splits = ys.shape.row_splits(1) + ys_lens = row_splits[1:] - row_splits[:-1] + + ys_in_pad, ys_in_lens, ys_out_pad = self._pre_ys_in_out(ys, ys_lens) + + # decoder forward + decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens) + + batch_size, _, num_classes = decoder_out.size() + nll = nn.functional.cross_entropy( + decoder_out.view(-1, num_classes), + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction="none", + ) + nll = nll.view(batch_size, -1) + return nll + + +class TransformerDecoder(nn.Module): + """Transfomer decoder module. + It is modified from https://github.com/espnet/espnet/blob/master/espnet2/asr/decoder/transformer_decoder.py. + + Args: + vocab_size: output dim + d_model: decoder dimension + num_decoder_layers: number of decoder layers + attention_dim: total dimension of multi head attention + n_head: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + vocab_size: int, + d_model: int = 512, + num_decoder_layers: int = 6, + attention_dim: int = 512, + nhead: int = 8, + feedforward_dim: int = 2048, + dropout: float = 0.1, + ): + super().__init__() + self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model) + + # Using absolute positional encoding + self.pos = PositionalEncoding(d_model, dropout_rate=0.1) + + self.num_layers = num_decoder_layers + self.layers = nn.ModuleList( + [ + DecoderLayer(d_model, attention_dim, nhead, feedforward_dim, dropout) + for _ in range(num_decoder_layers) + ] + ) + + self.output_layer = nn.Linear(d_model, vocab_size) + + def forward( + self, + memory: torch.Tensor, + memory_lens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + memory: encoded memory, (batch, maxlen_in, feat) + memory_lens: (batch,) + ys_in_pad: input token ids, (batch, maxlen_out) + ys_in_lens: (batch, ) + + Returns: + tgt: decoded token score before softmax (batch, maxlen_out, vocab_size) + """ + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = make_pad_mask(ys_in_lens)[:, None, :].to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask | (~m) + + memory_mask = make_pad_mask(memory_lens)[:, None, :].to(memory.device) + + tgt = self.embed(tgt) + tgt = self.pos(tgt) + + for i, mod in enumerate(self.layers): + tgt = mod(tgt, tgt_mask, memory, memory_mask) + + tgt = self.output_layer(tgt) + return tgt + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + d_model: equal to encoder_dim + attention_dim: total dimension of multi head attention + n_head: number of attention heads + feedforward_dim: hidden dimension of feed_forward module + dropout: dropout rate + """ + + def __init__( + self, + d_model: int = 512, + attention_dim: int = 512, + nhead: int = 8, + feedforward_dim: int = 2048, + dropout: float = 0.1, + ): + """Construct an DecoderLayer object.""" + super(DecoderLayer, self).__init__() + + self.norm_self_attn = nn.LayerNorm(d_model) + self.self_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0) + + self.norm_src_attn = nn.LayerNorm(d_model) + self.src_attn = MultiHeadedAttention(d_model, attention_dim, nhead, dropout=0.0) + + self.norm_ff = nn.LayerNorm(d_model) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, feedforward_dim), + Swish(), + nn.Dropout(dropout), + nn.Linear(feedforward_dim, d_model), + ) + + self.dropout = nn.Dropout(dropout) + + def forward(self, tgt, tgt_mask, memory, memory_mask): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + """ + # self-attn module + tgt_norm = self.norm_self_attn(tgt) + tgt = tgt + self.dropout(self.self_attn(tgt_norm, tgt_norm, tgt_norm, tgt_mask)) + + # cross-attn module + tgt = tgt + self.dropout(self.src_attn(self.norm_src_attn(tgt), memory, memory, memory_mask)) + + # feed-forward module + tgt = tgt + self.dropout(self.feed_forward(self.norm_ff(tgt))) + + return tgt + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + embed_dim: total dimension of the model. + attention_dim: dimension in the attention module, may be less or more than embed_dim + but must be a multiple of num_heads. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + """ + + def __init__( + self, embed_dim: int, attention_dim: int, num_heads: int, dropout: float = 0.0 + ): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttention, self).__init__() + self.embed_dim = embed_dim + self.attention_dim = attention_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = attention_dim // num_heads + assert self.head_dim * num_heads == attention_dim, ( + self.head_dim, + num_heads, + attention_dim, + ) + + self.linear_q = nn.Linear(embed_dim, attention_dim, bias=True) + self.linear_k = nn.Linear(embed_dim, attention_dim, bias=True) + self.linear_v = nn.Linear(embed_dim, attention_dim, bias=True) + self.scale = math.sqrt(self.head_dim) + + self.out_proj = nn.Linear(attention_dim, embed_dim, bias=True) + + def forward(self, query, key, value, mask): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + bsz, tgt_len, _ = query.size() + src_len = key.size(1) + num_heads = self.num_heads + head_dim = self.head_dim + + q = self.linear_q(query) + k = self.linear_k(key) + v = self.linear_v(value) + + q = q.reshape(bsz, tgt_len, num_heads, head_dim) + q = q.transpose(1, 2) # (batch, head, time1, head_dim) + k = k.reshape(bsz, src_len, num_heads, head_dim) + k = k.permute(0, 2, 3, 1) # (batch, head, head_dim, time2) + v = v.reshape(bsz, src_len, num_heads, head_dim) + v = v.transpose(1, 2).reshape(bsz * num_heads, src_len, head_dim) + + # (batch, head, time1, time2) + attn_output_weights = torch.matmul(q, k) / self.scale + + if mask is not None: + attn_output_weights = attn_output_weights.masked_fill( + mask.unsqueeze(1), float("-inf") + ) + + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=self.dropout, training=self.training + ) + + # (bsz * head, time1, head_dim_v) + attn_output = torch.bmm(attn_output_weights, v) + assert attn_output.shape == (bsz * num_heads, tgt_len, head_dim) + attn_output = ( + attn_output.reshape(bsz, num_heads, tgt_len, head_dim) + .transpose(1, 2) + .reshape(bsz, tgt_len, self.attention_dim) + ) + attn_output = self.out_proj(attn_output) + + return attn_output + + +class PositionalEncoding(nn.Module): + """Positional encoding. + Copied from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py#L35. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def subsequent_mask(size, device="cpu", dtype=torch.bool): + """Create mask for subsequent steps (size, size). + + :param int size: size of mask + :param str device: "cpu" or "cuda" or torch.Tensor.device + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + ret = torch.ones(size, size, device=device, dtype=dtype) + return torch.tril(ret, out=ret) + + +def _test_attention_decoder_model(): + m = AttentionDecoderModel( + vocab_size=500, + decoder_dim=512, + num_decoder_layers=6, + attention_dim=512, + nhead=8, + feedforward_dim=2048, + dropout=0.1, + sos_id=1, + eos_id=1, + ignore_id=-1, + ) + + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + m.eval() + encoder_out = torch.randn(2, 50, 512) + encoder_out_lens = torch.full((2,), 50) + token_ids = [[1, 2, 3, 4], [2, 3, 10]] + + nll = m.nll(encoder_out, encoder_out_lens, token_ids) + print(nll) + + +if __name__ == "__main__": + _test_attention_decoder_model() diff --git a/egs/librispeech/ASR/zipformer/label_smoothing.py b/egs/librispeech/ASR/zipformer/label_smoothing.py new file mode 100644 index 000000000..52d2eda3b --- /dev/null +++ b/egs/librispeech/ASR/zipformer/label_smoothing.py @@ -0,0 +1,109 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) + + true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + ) + + # Set the value of ignored indexes to 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index f2f86af47..a0202b34d 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -34,11 +34,13 @@ class AsrModel(nn.Module): encoder: EncoderInterface, decoder: Optional[nn.Module] = None, joiner: Optional[nn.Module] = None, + attention_decoder: Optional[nn.Module] = None, encoder_dim: int = 384, decoder_dim: int = 512, vocab_size: int = 500, use_transducer: bool = True, use_ctc: bool = False, + use_attention_decoder: bool = False, ): """A joint CTC & Transducer ASR model. @@ -111,6 +113,12 @@ class AsrModel(nn.Module): nn.LogSoftmax(dim=-1), ) + self.use_attention_decoder = use_attention_decoder + if use_attention_decoder: + self.attention_decoder = attention_decoder + else: + assert attention_decoder is None + def forward_encoder( self, x: torch.Tensor, x_lens: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -286,7 +294,7 @@ class AsrModel(nn.Module): prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -308,7 +316,7 @@ class AsrModel(nn.Module): part Returns: Return the transducer losses and CTC loss, - in form of (simple_loss, pruned_loss, ctc_loss) + in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) Note: Regarding am_scale & lm_scale, it will make the loss-function one of @@ -322,6 +330,8 @@ class AsrModel(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) + device = x.device + # Compute encoder outputs encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) @@ -333,7 +343,7 @@ class AsrModel(nn.Module): simple_loss, pruned_loss = self.forward_transducer( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - y=y.to(x.device), + y=y.to(device), y_lens=y_lens, prune_range=prune_range, am_scale=am_scale, @@ -355,4 +365,14 @@ class AsrModel(nn.Module): else: ctc_loss = torch.empty(0) - return simple_loss, pruned_loss, ctc_loss + if self.use_attention_decoder: + attention_decoder_loss = self.attention_decoder.calc_att_loss( + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ys=y.to(device), + ys_lens=y_lens.to(device), + ) + else: + attention_decoder_loss = torch.empty(0) + + return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 7009f3346..b5bb789d2 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -66,6 +66,7 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule +from attention_decoder import AttentionDecoderModel from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut @@ -220,6 +221,41 @@ def add_model_arguments(parser: argparse.ArgumentParser): """, ) + parser.add_argument( + "--attention-decoder-dim", + type=int, + default=512, + help="""Dimension used in the attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-layers", + type=int, + default=6, + help="""Number of transformer layers used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-attention-dim", + type=int, + default=512, + help="""Attention dimension used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-num-heads", + type=int, + default=8, + help="""Number of attention heads used in attention decoder""", + ) + + parser.add_argument( + "--attention-decoder-feedforward-dim", + type=int, + default=2048, + help="""Feedforward dimension used in attention decoder""", + ) + parser.add_argument( "--causal", type=str2bool, @@ -258,6 +294,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="If True, use CTC head.", ) + parser.add_argument( + "--use-attention-decoder", + type=str2bool, + default=False, + help="If True, use attention-decoder head.", + ) + def get_parser(): parser = argparse.ArgumentParser( @@ -403,6 +446,21 @@ def get_parser(): help="Scale for CTC loss.", ) + parser.add_argument( + "--attention-decoder-loss-scale", + type=float, + default=0.8, + help="Scale for attention-decoder loss.", + ) + + parser.add_argument( + "--label-smoothing", + type=float, + default=0.1, + help="""Label smoothing rate used in attention decoder, + (0.0 means the conventional cross entropy loss)""", + ) + parser.add_argument( "--seed", type=int, @@ -531,6 +589,8 @@ def get_params() -> AttributeDict: # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for attention-decoder + "ignore_id": -1, "warm_step": 2000, "env_info": get_env_info(), } @@ -603,6 +663,25 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: return joiner +def get_attention_decoder_model(params: AttributeDict) -> nn.Module: + encoder_dim = max(_to_int_tuple(params.encoder_dim)) + assert params.attention_decoder_dim == encoder_dim, (params.attention_decoder_dim, encoder_dim) + + decoder = AttentionDecoderModel( + vocab_size=params.vocab_size, + decoder_dim=params.attention_decoder_dim, + num_decoder_layers=params.attention_decoder_num_layers, + attention_dim=params.attention_decoder_attention_dim, + nhead=params.attention_decoder_num_heads, + feedforward_dim=params.attention_decoder_feedforward_dim, + sos_id=params.sos_id, + eos_id=params.eos_id, + ignore_id=params.ignore_id, + label_smoothing=params.label_smoothing, + ) + return decoder + + def get_model(params: AttributeDict) -> nn.Module: assert params.use_transducer or params.use_ctc, ( f"At least one of them should be True, " @@ -620,16 +699,23 @@ def get_model(params: AttributeDict) -> nn.Module: decoder = None joiner = None + if params.use_attention_decoder: + attention_decoder = get_attention_decoder_model(params) + else: + attention_decoder = None + model = AsrModel( encoder_embed=encoder_embed, encoder=encoder, decoder=decoder, joiner=joiner, + attention_decoder=attention_decoder, encoder_dim=max(_to_int_tuple(params.encoder_dim)), decoder_dim=params.decoder_dim, vocab_size=params.vocab_size, use_transducer=params.use_transducer, use_ctc=params.use_ctc, + use_attention_decoder=params.use_attention_decoder, ) return model @@ -792,7 +878,7 @@ def compute_loss( y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( + simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -822,6 +908,9 @@ def compute_loss( if params.use_ctc: loss += params.ctc_loss_scale * ctc_loss + if params.use_attention_decoder: + loss += params.attention_decoder_loss_scale * attention_decoder_loss + assert loss.requires_grad == is_training info = MetricsTracker() @@ -836,6 +925,8 @@ def compute_loss( info["pruned_loss"] = pruned_loss.detach().cpu().item() if params.use_ctc: info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.use_attention_decoder: + info["attn_deocder_loss"] = attention_decoder_loss.detach().cpu().item() return loss, info @@ -1114,10 +1205,17 @@ def run(rank, world_size, args): # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") + params.eos_id = sp.piece_to_id("") + params.sos_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() if not params.use_transducer: - params.ctc_loss_scale = 1.0 + if not params.use_attention_decoder: + params.ctc_loss_scale = 1.0 + else: + assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, ( + params.ctc_loss_scale, params.attention_decoder_loss_scale + ) logging.info(params)