diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py index 341fc6611..2397f0d41 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/beam_search.py @@ -588,18 +588,10 @@ def greedy_search( 1, context_size ) - c = torch.tensor([hyp[-context_size - 1 :]], device=device).reshape( - 1, context_size + 1 - ) - - k = torch.sum( - ( - c[:, -context_size - 1 : -1] - == c[:, -1].expand_as(c[:, -context_size - 1 : -1]) - ).int(), - dim=1, - keepdim=True, - ) + if hyp[-context_size - 1] == hyp[-1]: + k += 1 + else: + k[0, 0] = 0 decoder_out = model.decoder(decoder_input, k, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -712,21 +704,14 @@ def greedy_search_batch( # update decoder output decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - c = torch.tensor( - [h[-context_size - 1 :] for h in hyps[:batch_size]], - device=device, - dtype=torch.int64, - ) - - k = torch.sum( - ( - c[:, :context_size] - == c[:, context_size : context_size + 1].expand_as( - c[:, :context_size] - ) - ).int(), - dim=1, - keepdim=True, + k = torch.where( + torch.tensor( + [h[-context_size - 1] for h in hyps[:batch_size]], + device=device, + dtype=torch.int64, + ) == decoder_input[:, -1], + k + 1, + torch.zeros(N, 1, device=device, dtype=torch.int64), ) decoder_input = torch.tensor( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py index 89b21b0f9..bb460a85e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/decoder.py @@ -89,15 +89,17 @@ class Decoder(nn.Module): A tensor of shape (N, U, decoder_dim). k: A tensor of shape (N, U). - Should be (N, S + 1) during training. + Should be (N, S) during training. Should be (N, 1) during inference. + is_training: + Whether it is training. Returns: Return a tensor of shape (N, U, decoder_dim). """ - return embedding_out + torch.matmul( - (k / (1 + k)).unsqueeze(2), - self.repeat_param.unsqueeze(0), - ) + if is_training: + k = F.pad(k, (1, 0), mode="constant", value=self.blank_id) + + return embedding_out + (k / (1 + k)).unsqueeze(2) * self.repeat_param def forward( self, @@ -138,6 +140,7 @@ class Decoder(nn.Module): embedding_out = self._add_repeat_param( embedding_out=embedding_out, k=k, + is_training=need_pad, ) embedding_out = F.relu(embedding_out) return embedding_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/model.py b/egs/librispeech/ASR/pruned_transducer_stateless9/model.py index 56ffa0661..9e35692d6 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/model.py @@ -24,7 +24,7 @@ import torch.nn.functional as F from encoder_interface import EncoderInterface from scaling import penalize_abs_values_gt -from icefall.utils import add_sos, make_pad_mask +from icefall.utils import add_sos class Transducer(nn.Module): @@ -72,13 +72,35 @@ class Transducer(nn.Module): ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + def _compute_k( + self, + y: torch.Tensor, + context_size: int = 2, + ) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, U). + context_size: + Number of previous words to use to predict the next word. + 1 means bigram; 2 means trigram. n means (n+1)-gram. + Returns: + Return a tensor of shape (N, U). + """ + y_shift = F.pad(y, (context_size, 0), mode="constant", value=self.decoder.blank_id)[:, :-context_size] + mask = y_shift != y + + T_arange = torch.arange(y.size(1)).expand_as(y).to(device=y.device) + cummax_out = (T_arange * mask).cummax(dim=-1)[0] + k = T_arange - cummax_out + + return k + def forward( self, x: torch.Tensor, x_lens: torch.Tensor, - y: torch.Tensor, - y_lens: torch.Tensor, - k: torch.Tensor, + y: k2.RaggedTensor, prune_range: int = 5, am_scale: float = 0.0, lm_scale: float = 0.0, @@ -91,14 +113,8 @@ class Transducer(nn.Module): A 1-D tensor of shape (N,). It contains the number of frames in `x` before padding. y: - A 2-D tensor with 2 axes [utt][label]. It contains labels of each + A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. - y_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `y` - before padding. - k: - A statistic given the context_size with respect to utt. - A 2-D tensor of shape (N, U). prune_range: The prune range for rnnt loss, it means how many symbols(context) we are considering for each frame to compute the loss. @@ -119,24 +135,34 @@ class Transducer(nn.Module): """ assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape - assert len(y.shape) == 2, len(y.shape) + assert y.num_axes == 2, y.num_axes - assert x.size(0) == x_lens.size(0) == y.size(0) + assert x.size(0) == x_lens.size(0) == y.dim0 encoder_out, x_lens = self.encoder(x, x_lens) assert torch.all(x_lens > 0) + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = F.pad(y, (1, 0), mode="constant", value=blank_id) + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + # compute k + k = self._compute_k(sos_y_padded, context_size=self.decoder.context_size) + # decoder_out: [B, S + 1, decoder_dim] decoder_out = self.decoder(sos_y_padded, k) - # Note: y_padded does not start with SOS + # Note: y does not start with SOS # y_padded : [B, S] - y_padded = y.to(torch.int64) + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless9/train.py b/egs/librispeech/ASR/pruned_transducer_stateless9/train.py index 5c9ce28e4..86d79e48f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless9/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless9/train.py @@ -39,7 +39,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --use-fp16 1 \ --exp-dir pruned_transducer_stateless9/exp \ --full-libri 1 \ - --max-duration 550 + --max-duration 750 """ @@ -58,7 +58,6 @@ import sentencepiece as spm import torch import torch.multiprocessing as mp import torch.nn as nn -import torch.nn.functional as F from asr_datamodule import LibriSpeechAsrDataModule from decoder import Decoder from joiner import Joiner @@ -70,7 +69,6 @@ from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils.rnn import pad_sequence from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer @@ -237,7 +235,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless9/exp", + default="pruned_transducer_stateless7/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved @@ -627,41 +625,6 @@ def save_checkpoint( copyfile(src=filename, dst=best_valid_filename) -def compute_k( - y: torch.Tensor, - context_size: int = 2, - blank_id: int = 0, -) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - Returns: - Return a tensor of shape (N, U). - """ - y = F.pad(y, (1, 0), mode="constant", value=blank_id) # [B, S + 1], start with SOS. - k = torch.zeros_like(y) - - for i in range(2, y.size(1) - 1): - k[:, i : i + 1] = torch.where( - y[:, i : i + 1] != 0, - torch.sum( - ( - y[:, i - context_size : i] - == y[:, i : i + 1].expand_as(y[:, i - context_size : i]) - ).int(), - dim=1, - keepdim=True, - ), - y[:, i : i + 1], - ) - - return k - - def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -712,26 +675,13 @@ def compute_loss( texts = batch["supervisions"]["text"] y = sp.encode(texts, out_type=int) - y_lens = torch.tensor(list(map(len, y))).to(device) - y = list(map(torch.tensor, y)) - y = pad_sequence(y, batch_first=True) # [B, S] - - k = compute_k( - y, - params.context_size, - model.module.decoder.blank_id - if isinstance(model, DDP) - else model.decoder.blank_id, - ).to(device) - y = y.to(device) + y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, y=y, - y_lens=y_lens, - k=k, prune_range=params.prune_range, am_scale=params.am_scale, lm_scale=params.lm_scale, @@ -1093,10 +1043,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds