diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 137273c56..aaa6add0e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -24,7 +24,7 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos, make_pad_mask from scaling import penalize_abs_values_gt, ScaledLinear - +from torch import Tensor class PromptedTransducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf @@ -90,8 +90,8 @@ class PromptedTransducer(nn.Module): x: torch.Tensor, x_lens: torch.Tensor, text: torch.Tensor, - style_lens: torch.Tensor, text_lens: torch.Tensor, + style_lens: torch.Tensor, y: k2.RaggedTensor, prune_range: int = 5, am_scale: float = 0.0, @@ -111,14 +111,14 @@ class PromptedTransducer(nn.Module): A 2-D tensor of integer dtype containing prompt text, of shape (N, T). It is exptected to contain the style prompt (first) and then the content prompt. - style_lens: - A 1-D tensor of shape (N,), containing the number of elements (bytes) - within each row of `text` that correspond to the style prompt (these - are expected to come first). text_lens: A 1-D tensor of shape (N,). It contains the number of elements (bytes) in `text` before padding, which will include the lengths of the style plus the content prompt. + style_lens: + A 1-D tensor of shape (N,), containing the number of elements (bytes) + within each row of `text` that correspond to the style prompt (these + are expected to come first). y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 822723ea5..a2f392355 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -51,9 +51,10 @@ import random import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import k2 +import numpy import optim import sentencepiece as spm import torch @@ -68,7 +69,7 @@ from subsampling import Conv2dSubsampling from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model import Transducer +from model import PromptedTransducer from optim import Eden, ScaledAdam from torch import Tensor from torch import nn @@ -635,7 +636,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder = get_decoder_model(params) joiner = get_joiner_model(params) - model = Transducer( + model = PromptedTransducer( encoder_embed=encoder_embed, encoder=encoder, text_embed=text_embed, @@ -768,7 +769,36 @@ def save_checkpoint( best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) +def _encode_texts_as_bytes(texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor, Tensor]: + """ + Encode texts as bytes and then integer tensors. + Args: + texts: the texts to encode, as a list of strings + device: the PyTorch device we want the texts on + Returns: + (text, text_lens, style_lens), where: + text: a torch.Tensor of shape (batch_size, text_len) containing integers + 0 <= i < 256 + text_lens: a torch.Tensor of shape (batch_size,), giving the length of each byt + sequence + style_lens: a torch.Tensor of shape (batch_size,), giving the length of each + style prompt (style prompts are supposed to come first). Since there is no + style prompt here, this is just all zeros. + """ + texts = [ bytes(s, 'UTF-8') for s in texts ] + N = len(texts) + lengths = [ len(s) for s in texts ] + max_len = max(lengths) + texts = [ s + (b'\0' * (max_len - len(s))) for s in texts ] + text = b''.join(texts) # bytes array containing all of the texts + text = torch.Tensor(numpy.frombuffer(text, dtype=numpy.uint8)).to(device) + text = text.to(dtype=torch.long) + text = text.reshape(N, max_len) + text_lens = torch.tensor(lengths).to(device) + style_lens = torch.zeros(N, dtype=torch.long, device=device) + # print(f"text={text}, text_lens={text_lens}, style_lens={style_lens}") + return text, text_lens, style_lens def compute_loss( params: AttributeDict, @@ -815,10 +845,15 @@ def compute_loss( y = sp.encode(texts, out_type=int) y = k2.RaggedTensor(y).to(device) + text, text_lens, style_lens = _encode_texts_as_bytes(texts, device) + with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, x_lens=feature_lens, + text=text, + text_lens=text_lens, + style_lens=style_lens, y=y, prune_range=params.prune_range, am_scale=params.am_scale, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index a55b4bd57..8538a3cfe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -316,7 +316,8 @@ class Zipformer2(EncoderInterface): # setting memory to zero should be equivalent to not using the # memory input at all, since the Attention module has no biases. memory_dropout_rate = 0.05 - memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate) + memory = memory * (torch.rand(batch_size, 1, device=memory.device) > + memory_dropout_rate) for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] @@ -354,11 +355,9 @@ class Zipformer2(EncoderInterface): # most recent output that has it present. x = get_full_dim_output() x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2 - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - lengths = (x_lens + 1) // 2 + + d = self.output_downsampling_factor + lengths = (x_lens + d - 1) // d return x, lengths