From 95af0397336ac840a5bfed1ae8de79dbddcdad71 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 7 Dec 2021 21:44:37 +0800 Subject: [PATCH] RNN-T training for yesno. (#141) * RNN-T training for yesno. * Rename Jointer to Joiner. --- .gitignore | 2 + egs/yesno/ASR/tdnn/train.py | 1 + egs/yesno/ASR/transducer/__init__.py | 0 egs/yesno/ASR/transducer/asr_datamodule.py | 1 + egs/yesno/ASR/transducer/beam_search.py | 69 +++ egs/yesno/ASR/transducer/decode.py | 310 +++++++++++ egs/yesno/ASR/transducer/decoder.py | 92 ++++ egs/yesno/ASR/transducer/encoder.py | 87 +++ egs/yesno/ASR/transducer/joiner.py | 55 ++ egs/yesno/ASR/transducer/model.py | 120 ++++ egs/yesno/ASR/transducer/test_decoder.py | 65 +++ egs/yesno/ASR/transducer/test_encoder.py | 47 ++ egs/yesno/ASR/transducer/test_joiner.py | 50 ++ egs/yesno/ASR/transducer/test_transducer.py | 77 +++ egs/yesno/ASR/transducer/train.py | 581 ++++++++++++++++++++ icefall/utils.py | 125 +++++ test/test_utils.py | 41 +- 17 files changed, 1722 insertions(+), 1 deletion(-) create mode 100644 egs/yesno/ASR/transducer/__init__.py create mode 120000 egs/yesno/ASR/transducer/asr_datamodule.py create mode 100644 egs/yesno/ASR/transducer/beam_search.py create mode 100755 egs/yesno/ASR/transducer/decode.py create mode 100644 egs/yesno/ASR/transducer/decoder.py create mode 100644 egs/yesno/ASR/transducer/encoder.py create mode 100644 egs/yesno/ASR/transducer/joiner.py create mode 100644 egs/yesno/ASR/transducer/model.py create mode 100755 egs/yesno/ASR/transducer/test_decoder.py create mode 100755 egs/yesno/ASR/transducer/test_encoder.py create mode 100755 egs/yesno/ASR/transducer/test_joiner.py create mode 100755 egs/yesno/ASR/transducer/test_transducer.py create mode 100755 egs/yesno/ASR/transducer/train.py diff --git a/.gitignore b/.gitignore index f4f703243..31da5ed3e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ exp exp*/ *.pt download +*.bak +*-bak diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 30d83666a..d8454b7c5 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -487,6 +487,7 @@ def run(rank, world_size, args): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", rank) + logging.info(f"device: {device}") graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) diff --git a/egs/yesno/ASR/transducer/__init__.py b/egs/yesno/ASR/transducer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/yesno/ASR/transducer/asr_datamodule.py b/egs/yesno/ASR/transducer/asr_datamodule.py new file mode 120000 index 000000000..c9c8adb57 --- /dev/null +++ b/egs/yesno/ASR/transducer/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn/asr_datamodule.py \ No newline at end of file diff --git a/egs/yesno/ASR/transducer/beam_search.py b/egs/yesno/ASR/transducer/beam_search.py new file mode 100644 index 000000000..ae0f39478 --- /dev/null +++ b/egs/yesno/ASR/transducer/beam_search.py @@ -0,0 +1,69 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch +from transducer.model import Transducer + + +def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: + """ + Args: + model: + An instance of `Transducer`. + encoder_out: + A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. + Returns: + Return the decoded result. + """ + assert encoder_out.ndim == 3 + + # support only batch_size == 1 for now + assert encoder_out.size(0) == 1, encoder_out.size(0) + blank_id = model.decoder.blank_id + device = model.device + + sos = torch.tensor([blank_id], device=device).reshape(1, 1) + decoder_out, (h, c) = model.decoder(sos) + T = encoder_out.size(1) + t = 0 + hyp = [] + max_u = 1000 # terminte after this number of steps + u = 0 + + while t < T and u < max_u: + # fmt: off + current_encoder_out = encoder_out[:, t:t+1, :] + # fmt: on + logits = model.joiner(current_encoder_out, decoder_out) + + log_prob = logits.log_softmax(dim=-1) + # log_prob is (N, 1, 1) + # TODO: Use logits.argmax() + y = log_prob.argmax() + if y != blank_id: + hyp.append(y.item()) + y = y.reshape(1, 1) + decoder_out, (h, c) = model.decoder(y, (h, c)) + u += 1 + else: + t += 1 + id2word = {1: "YES", 2: "NO"} + + hyp = [id2word[i] for i in hyp] + + return hyp diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py new file mode 100755 index 000000000..abb34da4c --- /dev/null +++ b/egs/yesno/ASR/transducer/decode.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import torch +import torch.nn as nn +from asr_datamodule import YesNoAsrDataModule +from transducer.beam_search import greedy_search +from transducer.decoder import Decoder +from transducer.encoder import Tdnn +from transducer.joiner import Joiner +from transducer.model import Transducer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=125, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + parser.add_argument( + "--exp-dir", + type=str, + default="transducer/exp", + help="Directory from which to load the checkpoints", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 23, + # encoder/decoder params + "vocab_size": 3, # blank, yes, no + "blank_id": 0, + "embedding_dim": 32, + "hidden_dim": 16, + "num_decoder_layers": 4, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, +) -> List[List[int]]: + """Decode one batch and return the result in a list-of-list. + Each sub list contains the word IDs for an utterance in the batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding. + - params.method is "nbest", it uses nbest decoding. + + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py) + Returns: + Return the decoding result. `len(ans)` == batch size. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + feature_lens = batch["supervisions"]["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + hyp = greedy_search(model=model, encoder_out=encoder_out_i) + hyps.append(hyp) + + # hyps = [[word_table[i] for i in ids] for ids in hyps] + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, +) -> List[Tuple[List[int], List[int]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a tuple contains two elements (ref_text, hyp_text): + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = [] + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps = decode_one_batch( + params=params, + model=model, + batch=batch, + ) + + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results.extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + exp_dir: Path, + test_set_name: str, + results: List[Tuple[List[int], List[int]]], +) -> None: + """Save results to `exp_dir`. + Args: + exp_dir: + The output directory. This function create the following files inside + this directory: + + - recogs-{test_set_name}.text + + It contains the reference and hypothesis results, like below:: + + ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] + hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES'] + ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] + hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES'] + + - errs-{test_set_name}.txt + + It contains the detailed WER. + test_set_name: + The name of the test set, which will be part of the result filename. + results: + A list of tuples, each of which contains (ref_words, hyp_words). + Returns: + Return None. + """ + recog_path = exp_dir / f"recogs-{test_set_name}.txt" + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = exp_dir / f"errs-{test_set_name}.txt" + with open(errs_filename, "w") as f: + write_error_stats(f, f"{test_set_name}", results) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + +def get_transducer_model(params: AttributeDict): + encoder = Tdnn( + num_features=params.feature_dim, + output_dim=params.hidden_dim, + ) + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.hidden_dim, + embedding_dropout=0.4, + rnn_dropout=0.4, + ) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + return transducer + + +@torch.no_grad() +def main(): + parser = get_parser() + YesNoAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + params["env_info"] = get_env_info() + + setup_logger(f"{params.exp_dir}/log/log-decode") + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = get_transducer_model(params) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to(device) + model.eval() + model.device = device + + yes_no = YesNoAsrDataModule(args) + test_dl = yes_no.test_dataloaders() + results = decode_dataset( + dl=test_dl, + params=params, + model=model, + ) + + save_results( + exp_dir=params.exp_dir, test_set_name="test_set", results=results + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/yesno/ASR/transducer/decoder.py b/egs/yesno/ASR/transducer/decoder.py new file mode 100644 index 000000000..aa8a16845 --- /dev/null +++ b/egs/yesno/ASR/transducer/decoder.py @@ -0,0 +1,92 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn + + +class Decoder(nn.Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + blank_id: int, + num_layers: int, + hidden_dim: int, + embedding_dropout: float = 0.0, + rnn_dropout: float = 0.0, + ): + """ + Args: + vocab_size: + Number of tokens of the modeling unit. + embedding_dim: + Dimension of the input embedding. + blank_id: + The ID of the blank symbol. + num_layers: + Number of RNN layers. + hidden_dim: + Hidden dimension of RNN layers. + embedding_dropout: + Dropout rate for the embedding layer. + rnn_dropout: + Dropout for RNN layers. + """ + super().__init__() + self.embedding = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + padding_idx=blank_id, + ) + self.embedding_dropout = nn.Dropout(embedding_dropout) + self.rnn = nn.LSTM( + input_size=embedding_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + dropout=rnn_dropout, + ) + self.blank_id = blank_id + self.output_linear = nn.Linear(hidden_dim, hidden_dim) + + def forward( + self, + y: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + y: + A 2-D tensor of shape (N, U). + states: + A tuple of two tensors containing the states information of + RNN layers in this decoder. + Returns: + Return a tuple containing: + + - rnn_output, a tensor of shape (N, U, C) + - (h, c), which contain the state information for RNN layers. + Both are of shape (num_layers, N, C) + """ + embeding_out = self.embedding(y) + embeding_out = self.embedding_dropout(embeding_out) + rnn_out, (h, c) = self.rnn(embeding_out, states) + out = self.output_linear(rnn_out) + + return out, (h, c) diff --git a/egs/yesno/ASR/transducer/encoder.py b/egs/yesno/ASR/transducer/encoder.py new file mode 100644 index 000000000..8c50df293 --- /dev/null +++ b/egs/yesno/ASR/transducer/encoder.py @@ -0,0 +1,87 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +# We use a TDNN model as encoder, as it works very well with CTC training +# for this tiny dataset. +class Tdnn(nn.Module): + def __init__(self, num_features: int, output_dim: int): + """ + Args: + num_features: + Model input dimension. + ouput_dim: + Model output dimension + """ + super().__init__() + + # Note: We don't use paddings inside conv layers + self.tdnn = nn.Sequential( + nn.Conv1d( + in_channels=num_features, + out_channels=32, + kernel_size=3, + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=32, affine=False), + nn.Conv1d( + in_channels=32, + out_channels=32, + kernel_size=5, + dilation=2, + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=32, affine=False), + nn.Conv1d( + in_channels=32, + out_channels=32, + kernel_size=5, + dilation=4, + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=32, affine=False), + ) + self.output_linear = nn.Linear(in_features=32, out_features=output_dim) + + def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The input tensor with shape (N, T, C) + x_lens: + It contains the number of frames in each utterance in x + before padding. + + Returns: + Return a tuple with 2 tensors: + + - logits, a tensor of shape (N, T, C) + - logit_lens, a tensor of shape (N,) + """ + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.tdnn(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + logits = self.output_linear(x) + + # the first conv layer reduces T by 3-1 frames + # the second layer reduces T by (5-1)*2 frames + # the second layer reduces T by (5-1)*4 frames + # Number of output frames is 2 + 4*2 + 4*4 = 2 + 8 + 16 = 26 + x_lens = x_lens - 26 + return logits, x_lens diff --git a/egs/yesno/ASR/transducer/joiner.py b/egs/yesno/ASR/transducer/joiner.py new file mode 100644 index 000000000..0422f8a6f --- /dev/null +++ b/egs/yesno/ASR/transducer/joiner.py @@ -0,0 +1,55 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Joiner(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + + self.output_linear = nn.Linear(input_dim, output_dim) + + def forward( + self, encoder_out: torch.Tensor, decoder_out: torch.Tensor + ) -> torch.Tensor: + """ + Args: + encoder_out: + Output from the encoder. Its shape is (N, T, C). + decoder_out: + Output from the decoder. Its shape is (N, U, C). + Returns: + Return a tensor of shape (N, T, U, C). + """ + assert encoder_out.ndim == decoder_out.ndim == 3 + assert encoder_out.size(0) == decoder_out.size(0) + assert encoder_out.size(2) == decoder_out.size(2) + + encoder_out = encoder_out.unsqueeze(2) + # Now encoder_out is (N, T, 1, C) + + decoder_out = decoder_out.unsqueeze(1) + # Now decoder_out is (N, 1, U, C) + + logit = encoder_out + decoder_out + logit = F.relu(logit) + + output = self.output_linear(logit) + + return output diff --git a/egs/yesno/ASR/transducer/model.py b/egs/yesno/ASR/transducer/model.py new file mode 100644 index 000000000..caf9bed37 --- /dev/null +++ b/egs/yesno/ASR/transducer/model.py @@ -0,0 +1,120 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Note we use `rnnt_loss` from torchaudio, which exists only in +torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 +""" +import k2 +import torch +import torch.nn as nn +import torchaudio +import torchaudio.functional + +from icefall.utils import add_sos + +assert hasattr(torchaudio.functional, "rnnt_loss"), ( + f"Current torchaudio version: {torchaudio.__version__}\n" + "Please install a version >= 0.10.0" +) + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + joiner: nn.Module, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, C) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, C) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, C). It should contain + one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, C) and (N, U, C). Its + output shape is (N, T, U, C). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + Returns: + Return the transducer loss. + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + decoder_out, _ = self.decoder(sos_y_padded) + + logits = self.joiner(encoder_out, decoder_out) + + # rnnt_loss requires 0 padded targets + y_padded = y.pad(mode="constant", padding_value=0) + + loss = torchaudio.functional.rnnt_loss( + logits=logits, + targets=y_padded, + logit_lengths=x_lens, + target_lengths=y_lens, + blank=blank_id, + reduction="mean", + ) + + return loss diff --git a/egs/yesno/ASR/transducer/test_decoder.py b/egs/yesno/ASR/transducer/test_decoder.py new file mode 100755 index 000000000..88c54f678 --- /dev/null +++ b/egs/yesno/ASR/transducer/test_decoder.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/yesno/ASR + python ./transducer/test_decoder.py +""" + +import torch +from transducer.decoder import Decoder + + +def test_decoder(): + vocab_size = 3 + blank_id = 0 + embedding_dim = 128 + num_layers = 2 + hidden_dim = 6 + N = 3 + U = 5 + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + num_layers=num_layers, + hidden_dim=hidden_dim, + embedding_dropout=0.0, + rnn_dropout=0.0, + ) + x = torch.randint(1, vocab_size, (N, U)) + rnn_out, (h, c) = decoder(x) + + assert rnn_out.shape == (N, U, hidden_dim) + assert h.shape == (num_layers, N, hidden_dim) + assert c.shape == (num_layers, N, hidden_dim) + + rnn_out, (h, c) = decoder(x, (h, c)) + assert rnn_out.shape == (N, U, hidden_dim) + assert h.shape == (num_layers, N, hidden_dim) + assert c.shape == (num_layers, N, hidden_dim) + + +def main(): + test_decoder() + + +if __name__ == "__main__": + main() diff --git a/egs/yesno/ASR/transducer/test_encoder.py b/egs/yesno/ASR/transducer/test_encoder.py new file mode 100755 index 000000000..481fb558b --- /dev/null +++ b/egs/yesno/ASR/transducer/test_encoder.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/yesno/ASR + python ./transducer/test_encoder.py +""" + +import torch +from transducer.encoder import Tdnn + + +def test_encoder(): + input_dim = 10 + output_dim = 20 + encoder = Tdnn(input_dim, output_dim) + N = 10 + T = 85 + x = torch.rand(N, T, input_dim) + x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32) + logits, logit_lens = encoder(x, x_lens) + assert logits.shape == (N, T - 26, output_dim) + assert torch.all(torch.eq(x_lens - 26, logit_lens)) + + +def main(): + test_encoder() + + +if __name__ == "__main__": + main() diff --git a/egs/yesno/ASR/transducer/test_joiner.py b/egs/yesno/ASR/transducer/test_joiner.py new file mode 100755 index 000000000..2773ca319 --- /dev/null +++ b/egs/yesno/ASR/transducer/test_joiner.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/yesno/ASR + python ./transducer/test_joiner.py +""" + + +import torch +from transducer.joiner import Joiner + + +def test_joiner(): + N = 2 + T = 3 + C = 4 + U = 5 + + joiner = Joiner(C, 10) + + encoder_out = torch.rand(N, T, C) + decoder_out = torch.rand(N, U, C) + + joint = joiner(encoder_out, decoder_out) + assert joint.shape == (N, T, U, 10) + + +def main(): + test_joiner() + + +if __name__ == "__main__": + main() diff --git a/egs/yesno/ASR/transducer/test_transducer.py b/egs/yesno/ASR/transducer/test_transducer.py new file mode 100755 index 000000000..db7bf9c68 --- /dev/null +++ b/egs/yesno/ASR/transducer/test_transducer.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +To run this file, do: + + cd icefall/egs/yesno/ASR + python ./transducer/test_transducer.py +""" + + +import k2 +import torch +from transducer.decoder import Decoder +from transducer.encoder import Tdnn +from transducer.joiner import Joiner +from transducer.model import Transducer + + +def test_transducer(): + # encoder params + input_dim = 10 + output_dim = 20 + + # decoder params + vocab_size = 3 + blank_id = 0 + embedding_dim = 128 + num_layers = 2 + + encoder = Tdnn(input_dim, output_dim) + + decoder = Decoder( + vocab_size=vocab_size, + embedding_dim=embedding_dim, + blank_id=blank_id, + num_layers=num_layers, + hidden_dim=output_dim, + embedding_dropout=0.0, + rnn_dropout=0.0, + ) + + joiner = Joiner(output_dim, vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + + y = k2.RaggedTensor([[1, 2, 1], [1, 1, 1, 2, 1]]) + N = y.dim0 + T = 50 + + x = torch.rand(N, T, input_dim) + x_lens = torch.randint(low=30, high=T, size=(N,), dtype=torch.int32) + x_lens[0] = T + + loss = transducer(x, x_lens, y) + print(loss) + + +def main(): + test_transducer() + + +if __name__ == "__main__": + main() diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py new file mode 100755 index 000000000..7d2d1edeb --- /dev/null +++ b/egs/yesno/ASR/transducer/train.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import List, Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from asr_datamodule import YesNoAsrDataModule +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transducer.decoder import Decoder +from transducer.encoder import Tdnn +from transducer.joiner import Joiner +from transducer.model import Transducer + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_labels(texts: List[str]) -> k2.RaggedTensor: + """ + Args: + texts: + A list of transcripts. Each transcript contains spaces separated + "NO" or "YES". + Returns: + Return a ragged tensor containing the corresponding word ID. + """ + # blank is 0 + word2id = {"YES": 1, "NO": 2} + word_ids = [] + for t in texts: + words = t.split() + ids = [word2id[w] for w in words] + word_ids.append(ids) + + return k2.RaggedTensor(word_ids) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=200, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + tdnn/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer/exp", + help="Directory to save results", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - weight_decay: The weight_decay for the optimizer. + + - subsampling_factor: The subsampling factor for the model. + + - start_epoch: If it is not zero, load checkpoint `start_epoch-1` + and continue training from that checkpoint. + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + + """ + params = AttributeDict( + { + "lr": 1e-3, + "feature_dim": 23, + "weight_decay": 1e-6, + "start_epoch": 0, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 10, + "reset_interval": 20, + "valid_interval": 10, + # encoder/decoder params + "vocab_size": 3, # blank, yes, no + "blank_id": 0, + "embedding_dim": 32, + "hidden_dim": 16, + "num_decoder_layers": 4, + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Tdnn in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + feature_lens = batch["supervisions"]["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + labels = get_labels(texts).to(device) + + with torch.set_grad_enabled(is_training): + loss = model(x=feature, x_lens=feature_lens, y=labels) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = feature.size(0) + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + is_training=True, + ) + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, + "train/valid_", + params.batch_idx_train, + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def get_transducer_model(params: AttributeDict): + encoder = Tdnn( + num_features=params.feature_dim, + output_dim=params.hidden_dim, + ) + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + blank_id=params.blank_id, + num_layers=params.num_decoder_layers, + hidden_dim=params.hidden_dim, + embedding_dropout=0.4, + rnn_dropout=0.4, + ) + joiner = Joiner(input_dim=params.hidden_dim, output_dim=params.vocab_size) + transducer = Transducer(encoder=encoder, decoder=decoder, joiner=joiner) + + return transducer + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + params["env_info"] = get_env_info() + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"device: {device}") + + model = get_transducer_model(params) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + model.device = device + + optimizer = optim.Adam( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + yes_no = YesNoAsrDataModule(args) + train_dl = yes_no.train_dataloaders() + + # There are only 60 waves: 30 files are used for training + # and the remaining 30 files are used for testing. + # We use test data as validation. + valid_dl = yes_no.test_dataloaders() + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + scheduler=None, + rank=rank, + ) + + logging.info("Done!") + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + YesNoAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + main() diff --git a/icefall/utils.py b/icefall/utils.py index 1b2f12184..7237c8d62 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -565,3 +565,128 @@ class MetricsTracker(collections.defaultdict): """ for k, v in self.norm_items(): tb_writer.add_scalar(prefix + k, v, batch_idx) + + +def concat( + ragged: k2.RaggedTensor, value: int, direction: str +) -> k2.RaggedTensor: + """Prepend a value to the beginning of each sublist or append a value. + to the end of each sublist. + + Args: + ragged: + A ragged tensor with two axes. + value: + The value to prepend or append. + direction: + It can be either "left" or "right". If it is "left", we + prepend the value to the beginning of each sublist; + if it is "right", we append the value to the end of each + sublist. + + Returns: + Return a new ragged tensor, whose sublists either start with + or end with the given value. + + >>> a = k2.RaggedTensor([[1, 3], [5]]) + >>> a + [ [ 1 3 ] [ 5 ] ] + >>> concat(a, value=0, direction="left") + [ [ 0 1 3 ] [ 0 5 ] ] + >>> concat(a, value=0, direction="right") + [ [ 1 3 0 ] [ 5 0 ] ] + + """ + dtype = ragged.dtype + device = ragged.device + + assert ragged.num_axes == 2, f"num_axes: {ragged.num_axes}" + pad_values = torch.full( + size=(ragged.tot_size(0), 1), + fill_value=value, + device=device, + dtype=dtype, + ) + pad = k2.RaggedTensor(pad_values) + + if direction == "left": + ans = k2.ragged.cat([pad, ragged], axis=1) + elif direction == "right": + ans = k2.ragged.cat([ragged, pad], axis=1) + else: + raise ValueError( + f'Unsupported direction: {direction}. " \ + "Expect either "left" or "right"' + ) + return ans + + +def add_sos(ragged: k2.RaggedTensor, sos_id: int) -> k2.RaggedTensor: + """Add SOS to each sublist. + + Args: + ragged: + A ragged tensor with two axes. + sos_id: + The ID of the SOS symbol. + + Returns: + Return a new ragged tensor, where each sublist starts with SOS. + + >>> a = k2.RaggedTensor([[1, 3], [5]]) + >>> a + [ [ 1 3 ] [ 5 ] ] + >>> add_sos(a, sos_id=0) + [ [ 0 1 3 ] [ 0 5 ] ] + + """ + return concat(ragged, sos_id, direction="left") + + +def add_eos(ragged: k2.RaggedTensor, eos_id: int) -> k2.RaggedTensor: + """Add EOS to each sublist. + + Args: + ragged: + A ragged tensor with two axes. + eos_id: + The ID of the EOS symbol. + + Returns: + Return a new ragged tensor, where each sublist ends with EOS. + + >>> a = k2.RaggedTensor([[1, 3], [5]]) + >>> a + [ [ 1 3 ] [ 5 ] ] + >>> add_eos(a, eos_id=0) + [ [ 1 3 0 ] [ 5 0 ] ] + + """ + return concat(ragged, eos_id, direction="right") + + +def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + assert lengths.ndim == 1, lengths.ndim + + max_len = lengths.max() + n = lengths.size(0) + + expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) + + return expaned_lengths >= lengths.unsqueeze(1) diff --git a/test/test_utils.py b/test/test_utils.py index 01916bc59..6a9ce7853 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -21,7 +21,14 @@ import pytest import torch from icefall.env import get_env_info -from icefall.utils import AttributeDict, encode_supervisions, get_texts +from icefall.utils import ( + AttributeDict, + add_eos, + add_sos, + encode_supervisions, + get_texts, + make_pad_mask, +) @pytest.fixture @@ -126,3 +133,35 @@ def test_attribute_dict(): def test_get_env_info(): s = get_env_info() print(s) + + +def test_makd_pad_mask(): + lengths = torch.tensor([1, 3, 2]) + mask = make_pad_mask(lengths) + expected = torch.tensor( + [ + [False, True, True], + [False, False, False], + [False, False, True], + ] + ) + assert torch.all(torch.eq(mask, expected)) + assert (~expected).sum() == lengths.sum() + + +def test_add_sos(): + sos_id = 100 + ragged = k2.RaggedTensor([[1, 2], [3], [0]]) + sos_ragged = add_sos(ragged, sos_id) + expected = k2.RaggedTensor([[sos_id, 1, 2], [sos_id, 3], [sos_id, 0]]) + assert str(sos_ragged) == str(expected) + + +def test_add_eos(): + eos_id = 30 + ragged = k2.RaggedTensor([[1, 2], [3], [], [5, 8, 9]]) + ragged_eos = add_eos(ragged, eos_id) + expected = k2.RaggedTensor( + [[1, 2, eos_id], [3, eos_id], [eos_id], [5, 8, 9, eos_id]] + ) + assert str(ragged_eos) == str(expected)