diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py index 7388af389..2bc4c1788 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py @@ -542,7 +542,9 @@ def greedy_search( [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 ).reshape(1, context_size) - decoder_out = model.decoder(decoder_input, need_pad=False) + k = torch.zeros(1, 1, device=device, dtype=torch.int64) + + decoder_out = model.decoder(decoder_input, k, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -586,7 +588,20 @@ def greedy_search( 1, context_size ) - decoder_out = model.decoder(decoder_input, need_pad=False) + c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape( + 1, context_size + 1 + ) + + k[:, 0] = torch.sum( + ( + c[:, -context_size - 1 : -1] + == c[:, -1].expand_as(c[:, -context_size - 1 : -1]) + ).int(), + dim=1, + keepdim=True, + ) + + decoder_out = model.decoder(decoder_input, k, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) sym_per_utt += 1 @@ -594,7 +609,7 @@ def greedy_search( else: sym_per_frame = 0 t += 1 - hyp = hyp[context_size:] # remove blanks + hyp = hyp[context_size :] # remove blanks if not return_timestamps: return hyp diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless9/decode.py index b9bce465f..98794c6db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/decode.py @@ -20,36 +20,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless7/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ +./pruned_transducer_stateless9/decode.py \ + --epoch 30 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless9/exp \ --max-duration 600 \ --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless9/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless9/exp \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless9/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless9/exp \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless9/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless9/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ @@ -57,10 +57,10 @@ Usage: --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless9/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless9/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -70,10 +70,10 @@ Usage: --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless9/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless9/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -83,10 +83,10 @@ Usage: --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless9/decode.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless9/exp \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ @@ -223,7 +223,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless7/exp", + default="pruned_transducer_stateless9/exp", help="The experiment dir", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py index 5f90e6375..aead79de5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py @@ -73,14 +73,50 @@ class Decoder(nn.Module): bias=False, ) - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: + self.repeat_param = nn.Parameter(torch.randn(decoder_dim)) + + def _add_repeat_param( + self, + embedding_out: torch.Tensor, + k: torch.Tensor, + is_training: bool = True, + ) -> torch.Tensor: + """ + Add the repeat parameter to the embedding_out tensor. + + Args: + embedding_out: + A tensor of shape (N, U, decoder_dim). + k: + A tensor of shape (N, U). + Should be (N, S + 1) during training. + Should be (N, 1) during inference. + Returns: + Return a tensor of shape (N, U, decoder_dim). + """ + return embedding_out + torch.matmul( + (k / (1 + k)).unsqueeze(2), + self.repeat_param.unsqueeze(0), + ) + + def forward( + self, + y: torch.Tensor, + k: torch.Tensor, + need_pad: bool = True, + ) -> torch.Tensor: """ Args: y: A 2-D tensor of shape (N, U). + k: + A 2-D tensor, statistic given the context_size with respect to utt. + Should be (N, S + 1) during training. + Should be (N, 1) during inference. need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. + Whether to left pad the input. + Should be True during training. + Should be False during inference. Returns: Return a tensor of shape (N, U, decoder_dim). """ @@ -90,7 +126,7 @@ class Decoder(nn.Module): embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: + if need_pad is True: embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) else: # During inference time, there is no need to do extra padding @@ -98,5 +134,10 @@ class Decoder(nn.Module): assert embedding_out.size(-1) == self.context_size embedding_out = self.conv(embedding_out) embedding_out = embedding_out.permute(0, 2, 1) + + embedding_out = self._add_repeat_param( + embedding_out=embedding_out, + k=k, + ) embedding_out = F.relu(embedding_out) return embedding_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/model.py b/egs/librispeech/ASR/pruned_transducer_stateless9/model.py index 0e59b0f2f..56ffa0661 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/model.py @@ -20,10 +20,11 @@ import random import k2 import torch import torch.nn as nn +import torch.nn.functional as F from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos +from icefall.utils import add_sos, make_pad_mask class Transducer(nn.Module): @@ -75,7 +76,9 @@ class Transducer(nn.Module): self, x: torch.Tensor, x_lens: torch.Tensor, - y: k2.RaggedTensor, + y: torch.Tensor, + y_lens: torch.Tensor, + k: torch.Tensor, prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, @@ -88,8 +91,14 @@ class Transducer(nn.Module): A 1-D tensor of shape (N,). It contains the number of frames in `x` before padding. y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each + A 2-D tensor with 2 axes [utt][label]. It contains labels of each utterance. + y_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `y` + before padding. + k: + A statistic given the context_size with respect to utt. + A 2-D tensor of shape (N, U). prune_range: The prune range for rnnt loss, it means how many symbols(context) we are considering for each frame to compute the loss. @@ -110,31 +119,24 @@ class Transducer(nn.Module): """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes + assert len(y.shape) == 2, len(y.shape) - assert x.size(0) == x_lens.size(0) == y.dim0 + assert x.size(0) == x_lens.size(0) == y.size(0) encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id) # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) + decoder_out = self.decoder(sos_y_padded, k) - # Note: y does not start with SOS + # Note: y_padded does not start with SOS # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) + y_padded = y.to(torch.int64) boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/train.py b/egs/librispeech/ASR/pruned_transducer_stateless9/train.py index 6022406eb..5c9ce28e4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/train.py @@ -22,22 +22,22 @@ Usage: export CUDA_VISIBLE_DEVICES="0,1,2,3" -./pruned_transducer_stateless7/train.py \ +./pruned_transducer_stateless9/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ + --exp-dir pruned_transducer_stateless9/exp \ --full-libri 1 \ --max-duration 300 # For mix precision training: -./pruned_transducer_stateless7/train.py \ +./pruned_transducer_stateless9/train.py \ --world-size 4 \ --num-epochs 30 \ --start-epoch 1 \ --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ + --exp-dir pruned_transducer_stateless9/exp \ --full-libri 1 \ --max-duration 550 @@ -58,6 +58,7 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn +import torch.nn.functional as F from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder from joiner import Joiner @@ -69,6 +70,7 @@ from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils.rnn import pad_sequence from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -235,7 +237,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless7/exp", + default="pruned_transducer_stateless9/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -625,6 +627,41 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) +def compute_k( + y: torch.Tensor, + context_size: int = 2, + blank_id: int = 0, +) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + Returns: + Return a tensor of shape (N, U). + """ + y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS. + k = torch.zeros_like(y) + + for i in range(2, y.size(1) - 1): + k[:, i : i + 1] = torch.where( + y[:, i : i + 1] != 0, + torch.sum( + ( + y[:, i - context_size : i] + == y[:, i : i + 1].expand_as(y[:, i - context_size : i]) + ).int(), + dim=1, + keepdim=True, + ), + y[:, i : i + 1], + ) + + return k + + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -675,13 +712,26 @@ def compute_loss( texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) + y_lens = torch.tensor(list(map(len, y))).to(device) + y = list(map(torch.tensor, y)) + y = pad_sequence(y, batch_first=True) # [B, S] + + k = compute_k( + y, + params.context_size, + model.module.decoder.blank_id + if isinstance(model, DDP) + else model.decoder.blank_id, + ).to(device) + y = y.to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, y=y, + y_lens=y_lens, + k=k, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale,