diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py deleted file mode 100644 index 26839c61c..000000000 --- a/icefall/lm_wrapper.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging - -import torch - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.rnn_lm.model import RnnLmModel -from icefall.transformer_lm.model import TransformerLM -from icefall.utils import AttributeDict, str2bool - - -class LmScorer(torch.nn.Module): - """This is a wrapper for NN LMs - The language models supported include: - RNN, - Transformer - """ - - def __init__( - self, - lm_type: str, - params: AttributeDict, - device, - lm_scale: float = 0.3, - ): - super(LmScorer, self).__init__() - assert lm_type in ["rnn", "transformer"], f"{lm_type} is not supported" - self.lm_type = lm_type - self.lm = self.get_lm(lm_type, device, params) - self.lm_scale = lm_scale - self.params = params - - @classmethod - def add_arguments(cls, parser): - # LM general arguments - parser.add_argument( - "--lm-vocab-size", - type=int, - default=500, - ) - - parser.add_argument( - "--lm-epoch", - type=int, - default=7, - help="""Which epoch to be used - """, - ) - - parser.add_argument( - "--lm-avg", - type=int, - default=1, - help="""Number of checkpoints to be averaged - """, - ) - - parser.add_argument("--lm-exp-dir", type=str, help="Path to LM experiments") - - # Now RNNLM related arguments - parser.add_argument( - "--rnn-lm-embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--rnn-lm-hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--rnn-lm-num-layers", - type=int, - default=3, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--rnn-lm-tie-weights", - type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - # Now transformers - parser.add_argument( - "--transformer-lm-exp-dir", type=str, help="Directory of transformer LM exp" - ) - - parser.add_argument( - "--transformer-lm-dim-feedforward", - type=int, - default=2048, - help="Dimension of FFW module in transformer", - ) - - parser.add_argument( - "--transformer-lm-encoder-dim", - type=int, - default=768, - help="Encoder dimension of transformer", - ) - - parser.add_argument( - "--transformer-lm-embedding-dim", - type=int, - default=768, - help="Input embedding dimension of transformer", - ) - - parser.add_argument( - "--transformer-lm-nhead", - type=int, - default=8, - help="Number of attention heads in transformer", - ) - - parser.add_argument( - "--transformer-lm-num-layers", - type=int, - default=16, - help="Number of encoder layers in transformer", - ) - - parser.add_argument( - "--transformer-lm-tie-weights", - type=str2bool, - default=True, - help="If tie weights in transformer LM", - ) - - def get_lm(self, lm_type: str, device, params: AttributeDict) -> torch.nn.Module: - """Return the neural network LM - - Args: - lm_type (str): Type name of NN LM - """ - if lm_type == "rnn": - model = RnnLmModel( - vocab_size=params.lm_vocab_size, - embedding_dim=params.rnn_lm_embedding_dim, - hidden_dim=params.rnn_lm_hidden_dim, - num_layers=params.rnn_lm_num_layers, - tie_weights=params.rnn_lm_tie_weights, - ) - - if params.lm_avg == 1: - load_checkpoint( - f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model - ) - model.to(device) - else: - start = params.lm_epoch - params.lm_avg + 1 - filenames = [] - for i in range(start, params.lm_epoch + 1): - if start >= 0: - filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - elif lm_type == "transformer": - model = TransformerLM( - vocab_size=params.lm_vocab_size, - d_model=params.transformer_lm_encoder_dim, - embedding_dim=params.transformer_lm_embedding_dim, - dim_feedforward=params.transformer_lm_dim_feedforward, - nhead=params.transformer_lm_nhead, - num_layers=params.transformer_lm_num_layers, - tie_weights=params.transformer_lm_tie_weights, - params=params, - ) - - if params.lm_avg == 1: - load_checkpoint( - f"{params.lm_exp_dir}/epoch-{params.lm_epoch}.pt", model - ) - model.to(device) - else: - start = params.lm_epoch - params.lm_avg + 1 - filenames = [] - for i in range(start, params.lm_epoch + 1): - if start >= 0: - filenames.append(f"{params.lm_exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - raise NotImplementedError() - - return model - - def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): - """Score the input and return the prediction - This requires the lm to have the method `score_token` - Args: - x (torch.Tensor): Input tokens - x_lens (torch.Tensor): Length of the input tokens - state (optional): LM states - - """ - return self.lm.score_token(x, x_lens, state) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - LmScorer.add_arguments(parser) - args = parser.parse_args() - - params = AttributeDict() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - Scorer = LmScorer(params=params, device=device) - Scorer.eval() - - x = ( - torch.tensor([[1, 4, 19, 256, 77], [1, 4, 19, 256, 77]]) - .to(device) - .to(torch.int64) - ) - x_lens = torch.tensor([5, 5]).to(device) - - state = None - - score, state = Scorer.score(x, x_lens) - print(score.shape) - print(score[0]) - print(score[1]) diff --git a/icefall/rnn_lm/.gitignore b/icefall/rnn_lm/.gitignore deleted file mode 100644 index 877fb1e18..000000000 --- a/icefall/rnn_lm/.gitignore +++ /dev/null @@ -1 +0,0 @@ -icefall-librispeech-rnn-lm diff --git a/icefall/rnn_lm/__init__.py b/icefall/rnn_lm/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py deleted file mode 100755 index 28b908f82..000000000 --- a/icefall/rnn_lm/check-onnx-streaming.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation - -""" -Usage: - -./check-onnx-streaming.py \ - --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \ - --onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx - -Note: You can download pre-trained models from -https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm - -""" - -import argparse -import logging -from typing import Tuple - -import onnxruntime as ort -import torch - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx", - required=True, - type=str, - help="Path to the onnx model", - ) - - return parser - - -class OnnxModel: - def __init__(self, filename: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.model = ort.InferenceSession( - filename, - sess_options=session_opts, - ) - - meta_data = self.model.get_modelmeta().custom_metadata_map - self.sos_id = int(meta_data["sos_id"]) - self.eos_id = int(meta_data["eos_id"]) - self.vocab_size = int(meta_data["vocab_size"]) - self.num_layers = int(meta_data["num_layers"]) - self.hidden_size = int(meta_data["hidden_size"]) - print(meta_data) - - def __call__( - self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - out = self.model.run( - [ - self.model.get_outputs()[0].name, - self.model.get_outputs()[1].name, - self.model.get_outputs()[2].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: y.numpy(), - self.model.get_inputs()[2].name: h0.numpy(), - self.model.get_inputs()[3].name: c0.numpy(), - }, - ) - return ( - torch.from_numpy(out[0]), - torch.from_numpy(out[1]), - torch.from_numpy(out[2]), - ) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - torch_model = torch.jit.load(args.jit).cpu() - onnx_model = OnnxModel(args.onnx) - N = torch.arange(1, 5).tolist() - - num_layers = onnx_model.num_layers - hidden_size = onnx_model.hidden_size - - for n in N: - L = torch.randint(low=1, high=100, size=(1,)).item() - x = torch.randint( - low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 - ) - h0 = torch.rand(num_layers, n, hidden_size) - c0 = torch.rand(num_layers, n, hidden_size) - - torch_log_prob, torch_h0, torch_c0 = torch_model.score_token_onnx(x, h0, c0) - onnx_log_prob, onnx_h0, onnx_c0 = onnx_model(x, h0, c0) - - for torch_v, onnx_v in zip( - (torch_log_prob, torch_h0, torch_c0), (onnx_log_prob, onnx_h0, onnx_c0) - ): - assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( - torch_v.shape, - onnx_v.shape, - (torch_v - onnx_v).abs().max(), - ) - print(n, L, torch_v.sum(), onnx_v.sum()) - - -if __name__ == "__main__": - torch.manual_seed(20230423) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/icefall/rnn_lm/check-onnx.py b/icefall/rnn_lm/check-onnx.py deleted file mode 100755 index 24c5395f8..000000000 --- a/icefall/rnn_lm/check-onnx.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation - -""" -Usage: - -./check-onnx.py \ - --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \ - --onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx - -Note: You can download pre-trained models from -https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm - -""" - -import argparse -import logging - -import onnxruntime as ort -import torch - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx", - required=True, - type=str, - help="Path to the onnx model", - ) - - return parser - - -class OnnxModel: - def __init__(self, filename: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.model = ort.InferenceSession( - filename, - sess_options=session_opts, - ) - - meta_data = self.model.get_modelmeta().custom_metadata_map - self.sos_id = int(meta_data["sos_id"]) - self.eos_id = int(meta_data["eos_id"]) - self.vocab_size = int(meta_data["vocab_size"]) - print(meta_data) - - def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - torch_model = torch.jit.load(args.jit).cpu() - onnx_model = OnnxModel(args.onnx) - N = torch.arange(1, 5).tolist() - - for n in N: - L = torch.randint(low=1, high=100, size=(1,)).item() - x = torch.randint( - low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 - ) - x_lens = torch.full((n,), fill_value=L, dtype=torch.int64) - if n > 1: - x_lens[0] = L // 2 + 1 - - sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1) - sos_x = torch.cat([sos, x], dim=1) - - pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1) - x_eos = torch.cat([x, pad_col], dim=1) - - row_index = torch.arange(0, n, dtype=x.dtype) - x_eos[row_index, x_lens] = onnx_model.eos_id - - torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1) - onnx_nll = onnx_model(x, x_lens) - # Note: For int8 models, the differences may be quite large, - # e.g., within 0.9 - assert torch.allclose(torch_nll, onnx_nll), ( - torch_nll, - onnx_nll, - ) - print(n, L, torch_nll, onnx_nll) - - -if __name__ == "__main__": - torch.manual_seed(20230420) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py deleted file mode 100755 index cc566bd92..000000000 --- a/icefall/rnn_lm/compute_perplexity.py +++ /dev/null @@ -1,275 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - ./rnn_lm/compute_perplexity.py \ - --epoch 4 \ - --avg 2 \ - --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt - -""" - -import argparse -import logging -import math -from pathlib import Path - -import torch -from dataset import get_dataloader -from model import RnnLmModel - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import AttributeDict, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=49, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="rnn_lm/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lm-data", - type=str, - help="Path to the LM test data for computing perplexity", - ) - - parser.add_argument( - "--vocab-size", - type=int, - default=500, - help="Vocabulary size of the model", - ) - - parser.add_argument( - "--embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--num-layers", - type=int, - default=3, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--tie-weights", - type=str2bool, - default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--max-sent-len", - type=int, - default=100, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--sos-id", - type=int, - default=1, - help="SOS ID", - ) - - parser.add_argument( - "--eos-id", - type=int, - default=1, - help="EOS ID", - ) - - parser.add_argument( - "--blank-id", - type=int, - default=0, - help="Blank ID", - ) - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lm_data = Path(args.lm_data) - - params = AttributeDict(vars(args)) - - if params.iter > 0: - setup_logger( - f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}" - ) - else: - setup_logger( - f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}" - ) - logging.info("Computing perplexity started") - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - logging.info("About to create model") - model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - hidden_dim=params.hidden_dim, - num_layers=params.num_layers, - tie_weights=params.tie_weights, - ) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - num_param_requires_grad = sum( - [p.numel() for p in model.parameters() if p.requires_grad] - ) - - logging.info(f"Number of model parameters: {num_param}") - logging.info( - f"Number of model parameters (requires_grad): " - f"{num_param_requires_grad} " - f"({num_param_requires_grad/num_param_requires_grad*100}%)" - ) - - logging.info(f"Loading LM test data from {params.lm_data}") - test_dl = get_dataloader( - filename=params.lm_data, - is_distributed=False, - params=params, - ) - - tot_loss = 0.0 - num_tokens = 0 - num_sentences = 0 - for batch_idx, batch in enumerate(test_dl): - x, y, sentence_lengths = batch - x = x.to(device) - y = y.to(device) - sentence_lengths = sentence_lengths.to(device) - - nll = model(x, y, sentence_lengths) - loss = nll.sum().cpu().item() - - tot_loss += loss - num_tokens += sentence_lengths.sum().cpu().item() - num_sentences += x.size(0) - - ppl = math.exp(tot_loss / num_tokens) - logging.info( - f"total nll: {tot_loss}, num tokens: {num_tokens}, " - f"num sentences: {num_sentences}, ppl: {ppl:.3f}" - ) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -if __name__ == "__main__": - main() diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py deleted file mode 100644 index 53be53f64..000000000 --- a/icefall/rnn_lm/dataset.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple - -import k2 -import torch -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler - -from icefall.utils import AttributeDict, add_eos, add_sos - - -class LmDataset(torch.utils.data.Dataset): - def __init__( - self, - sentences: k2.RaggedTensor, - words: k2.RaggedTensor, - sentence_lengths: torch.Tensor, - max_sent_len: int, - batch_size: int, - ): - """ - Args: - sentences: - A ragged tensor of dtype torch.int32 with 2 axes [sentence][word]. - words: - A ragged tensor of dtype torch.int32 with 2 axes [word][token]. - sentence_lengths: - A 1-D tensor of dtype torch.int32 containing number of tokens - of each sentence. - max_sent_len: - Maximum sentence length. It is used to change the batch size - dynamically. In general, we try to keep the product of - "max_sent_len in a batch" and "num_of_sent in a batch" being - a constant. - batch_size: - The expected batch size. It is changed dynamically according - to the "max_sent_len". - - See `../local/prepare_lm_training_data.py` for how `sentences` and - `words` are generated. We assume that `sentences` are sorted by length. - See `../local/sort_lm_training_data.py`. - """ - super().__init__() - self.sentences = sentences - self.words = words - - sentence_lengths = sentence_lengths.tolist() - - assert batch_size > 0, batch_size - assert max_sent_len > 1, max_sent_len - batch_indexes = [] - num_sentences = sentences.dim0 - cur = 0 - while cur < num_sentences: - sz = sentence_lengths[cur] // max_sent_len + 1 - # Assume the current sentence has 3 * max_sent_len tokens, - # in the worst case, the subsequent sentences also have - # this number of tokens, we should reduce the batch size - # so that this batch will not contain too many tokens - actual_batch_size = batch_size // sz + 1 - actual_batch_size = min(actual_batch_size, batch_size) - end = cur + actual_batch_size - end = min(end, num_sentences) - this_batch_indexes = torch.arange(cur, end).tolist() - batch_indexes.append(this_batch_indexes) - cur = end - assert batch_indexes[-1][-1] == num_sentences - 1 - - self.batch_indexes = k2.RaggedTensor(batch_indexes) - - def __len__(self) -> int: - """Return number of batches in this dataset""" - return self.batch_indexes.dim0 - - def __getitem__(self, i: int) -> k2.RaggedTensor: - """Get the i'th batch in this dataset - Return a ragged tensor with 2 axes [sentence][token]. - """ - assert 0 <= i < len(self), i - - # indexes is a 1-D tensor containing sentence indexes - indexes = self.batch_indexes[i] - - # sentence_words is a ragged tensor with 2 axes - # [sentence][word] - sentence_words = self.sentences[indexes] - - # in case indexes contains only 1 entry, the returned - # sentence_words is a 1-D tensor, we have to convert - # it to a ragged tensor - if isinstance(sentence_words, torch.Tensor): - sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0)) - - # sentence_word_tokens is a ragged tensor with 3 axes - # [sentence][word][token] - sentence_word_tokens = self.words.index(sentence_words) - assert sentence_word_tokens.num_axes == 3 - - sentence_tokens = sentence_word_tokens.remove_axis(1) - return sentence_tokens - - -class LmDatasetCollate: - def __init__(self, sos_id: int, eos_id: int, blank_id: int): - """ - Args: - sos_id: - Token ID of the SOS symbol. - eos_id: - Token ID of the EOS symbol. - blank_id: - Token ID of the blank symbol. - """ - self.sos_id = sos_id - self.eos_id = eos_id - self.blank_id = blank_id - - def __call__( - self, batch: List[k2.RaggedTensor] - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Return a tuple containing 3 tensors: - - - x, a 2-D tensor of dtype torch.int32; each row contains tokens - for a sentence starting with `self.sos_id`. It is padded to - the max sentence length with `self.blank_id`. - - - y, a 2-D tensor of dtype torch.int32; each row contains tokens - for a sentence ending with `self.eos_id` before padding. - Then it is padded to the max sentence length with - `self.blank_id`. - - - lengths, a 2-D tensor of dtype torch.int32, containing the number of - tokens of each sentence before padding. - """ - # The batching stuff has already been done in LmDataset - assert len(batch) == 1 - sentence_tokens = batch[0] - row_splits = sentence_tokens.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id) - sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id) - - x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id) - y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id) - sentence_token_lengths += 1 # plus 1 since we added a SOS - - return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths - - -def get_dataloader( - filename: str, - is_distributed: bool, - params: AttributeDict, -) -> torch.utils.data.DataLoader: - """Get dataloader for LM training. - - Args: - filename: - Path to the file containing LM data. The file is assumed to - be generated by `../local/sort_lm_training_data.py`. - is_distributed: - True if using DDP training. False otherwise. - params: - Set `get_params()` from `rnn_lm/train.py` - Returns: - Return a dataloader containing the LM data. - """ - lm_data = torch.load(filename) - - words = lm_data["words"] - sentences = lm_data["sentences"] - sentence_lengths = lm_data["sentence_lengths"] - - dataset = LmDataset( - sentences=sentences, - words=words, - sentence_lengths=sentence_lengths, - max_sent_len=params.max_sent_len, - batch_size=params.batch_size, - ) - if is_distributed: - sampler = DistributedSampler(dataset, shuffle=True, drop_last=True) - else: - sampler = None - - collate_fn = LmDatasetCollate( - sos_id=params.sos_id, - eos_id=params.eos_id, - blank_id=params.blank_id, - ) - - dataloader = DataLoader( - dataset, - batch_size=1, - collate_fn=collate_fn, - sampler=sampler, - shuffle=sampler is None, - ) - return dataloader diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py deleted file mode 100755 index 1070d443a..000000000 --- a/icefall/rnn_lm/export-onnx.py +++ /dev/null @@ -1,395 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation - -import argparse -import logging -from pathlib import Path -from typing import Dict - -import onnx -import torch -from model import RnnLmModel -from onnxruntime.quantization import QuantType, quantize_dynamic -from train import get_params - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import AttributeDict, str2bool - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -# A wrapper for RnnLm model to simpily the C++ calling code -# when exporting the model to ONNX. -# -# TODO(fangjun): The current wrapper works only for non-streaming ASR -# since we don't expose the LM state and it is used to score -# a complete sentence at once. -class RnnLmModelWrapper(torch.nn.Module): - def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int): - super().__init__() - self.model = model - self.sos_id = sos_id - self.eos_id = eos_id - - def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - A 2-D tensor of shape (N, L) with dtype torch.int64. - It does not contain SOS or EOS. We will add SOS and EOS inside - this function. - x_lens: - A 1-D tensor of shape (N,) with dtype torch.int64. It contains - number of valid tokens in ``x`` before padding. - Returns: - Return a 1-D tensor of shape (N,) containing negative loglikelihood. - Its dtype is torch.float32 - """ - N = x.size(0) - - sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand( - N, 1 - ) - sos_x = torch.cat([sos_tensor, x], dim=1) - - pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1) - x_eos = torch.cat([x, pad_col], dim=1) - - row_index = torch.arange(0, N, dtype=x.dtype) - x_eos[row_index, x_lens] = self.eos_id - - # use x_lens + 1 here since we prepended x with sos - return ( - self.model(x=sos_x, y=x_eos, lengths=x_lens + 1) - .to(torch.float32) - .sum(dim=1) - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=29, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--vocab-size", - type=int, - default=500, - help="Vocabulary size of the model", - ) - - parser.add_argument( - "--embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--num-layers", - type=int, - default=3, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--tie-weights", - type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="rnn_lm/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - return parser - - -def export_without_state( - model: RnnLmModel, - filename: str, - params: AttributeDict, - opset_version: int, -): - model_wrapper = RnnLmModelWrapper( - model, - sos_id=params.sos_id, - eos_id=params.eos_id, - ) - - N = 1 - L = 20 - x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) - x_lens = torch.full((N,), fill_value=L, dtype=torch.int64) - - # Note(fangjun): The following warnings can be ignored. - # We can use ./check-onnx.py to validate the exported model with batch_size > 1 - """ - torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX - with a batch_size other than 1, with a variable length with LSTM can cause - an error when running the ONNX model with a different batch size. Make sure - to save the model with a batch size of 1, or define the initial states - (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX - with a batch_size other than 1, " + - """ - - torch.onnx.export( - model_wrapper, - (x, x_lens), - filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["nll"], - dynamic_axes={ - "x": {0: "N", 1: "L"}, - "x_lens": {0: "N"}, - "nll": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "rnnlm", - "version": "1", - "model_author": "k2-fsa", - "comment": "rnnlm without state", - "sos_id": str(params.sos_id), - "eos_id": str(params.eos_id), - "vocab_size": str(params.vocab_size), - "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=filename, meta_data=meta_data) - - -def export_with_state( - model: RnnLmModel, - filename: str, - params: AttributeDict, - opset_version: int, -): - N = 1 - L = 20 - num_layers = model.rnn.num_layers - hidden_size = model.rnn.hidden_size - embedding_dim = model.embedding_dim - - x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) - h0 = torch.zeros(num_layers, N, hidden_size) - c0 = torch.zeros(num_layers, N, hidden_size) - - # Note(fangjun): The following warnings can be ignored. - # We can use ./check-onnx.py to validate the exported model with batch_size > 1 - """ - torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX - with a batch_size other than 1, with a variable length with LSTM can cause - an error when running the ONNX model with a different batch size. Make sure - to save the model with a batch size of 1, or define the initial states - (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX - with a batch_size other than 1, " + - """ - - torch.onnx.export( - model, - (x, h0, c0), - filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "h0", "c0"], - output_names=["log_softmax", "next_h0", "next_c0"], - dynamic_axes={ - "x": {0: "N", 1: "L"}, - "h0": {1: "N"}, - "c0": {1: "N"}, - "log_softmax": {0: "N"}, - "next_h0": {1: "N"}, - "next_c0": {1: "N"}, - }, - ) - - meta_data = { - "model_type": "rnnlm", - "version": "1", - "model_author": "k2-fsa", - "comment": "rnnlm state", - "sos_id": str(params.sos_id), - "eos_id": str(params.eos_id), - "vocab_size": str(params.vocab_size), - "num_layers": str(num_layers), - "hidden_size": str(hidden_size), - "embedding_dim": str(embedding_dim), - "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - logging.info(params) - - device = torch.device("cpu") - logging.info(f"device: {device}") - - model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - hidden_dim=params.hidden_dim, - num_layers=params.num_layers, - tie_weights=params.tie_weights, - ) - - model.to(device) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to("cpu") - model.eval() - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting model without state") - filename = params.exp_dir / f"no-state-{suffix}.onnx" - export_without_state( - model=model, - filename=filename, - params=params, - opset_version=opset_version, - ) - - filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx" - quantize_dynamic( - model_input=filename, - model_output=filename_int8, - weight_type=QuantType.QInt8, - ) - - # now for streaming export - saved_forward = model.__class__.forward - model.__class__.forward = model.__class__.score_token_onnx - streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx" - export_with_state( - model=model, - filename=streaming_filename, - params=params, - opset_version=opset_version, - ) - model.__class__.forward = saved_forward - - streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx" - quantize_dynamic( - model_input=streaming_filename, - model_output=streaming_filename_int8, - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/icefall/rnn_lm/export-onnx.sh b/icefall/rnn_lm/export-onnx.sh deleted file mode 100755 index 6e3262b5e..000000000 --- a/icefall/rnn_lm/export-onnx.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash - -# We use the model from -# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main -# as an example - -export CUDA_VISIBLE_DEVICES= - -if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm - pushd icefall-librispeech-rnn-lm/exp - git lfs pull --include "pretrained.pt" - ln -s pretrained.pt epoch-99.pt - popd -fi - -python3 ./export-onnx.py \ - --exp-dir ./icefall-librispeech-rnn-lm/exp \ - --epoch 99 \ - --avg 1 \ - --vocab-size 500 \ - --embedding-dim 2048 \ - --hidden-dim 2048 \ - --num-layers 3 \ - --tie-weights 1 - diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py deleted file mode 100644 index dadf23009..000000000 --- a/icefall/rnn_lm/export.py +++ /dev/null @@ -1,205 +0,0 @@ -#!/usr/bin/env python3 -# -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. - -import argparse -import logging -from pathlib import Path - -import torch -from model import RnnLmModel - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import AttributeDict, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=29, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--vocab-size", - type=int, - default=500, - help="Vocabulary size of the model", - ) - - parser.add_argument( - "--embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--num-layers", - type=int, - default=3, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--tie-weights", - type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="rnn_lm/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=True, - help="""True to save a model after applying torch.jit.script. - """, - ) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = AttributeDict({}) - params.update(vars(args)) - - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - hidden_dim=params.hidden_dim, - num_layers=params.num_layers, - tie_weights=params.tie_weights, - ) - - model.to(device) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to("cpu") - model.eval() - - if params.jit: - logging.info("Using torch.jit.script") - - model.__class__.score_token_onnx = torch.jit.export( - model.__class__.score_token_onnx - ) - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/icefall/rnn_lm/export.sh b/icefall/rnn_lm/export.sh deleted file mode 100755 index 678bc294e..000000000 --- a/icefall/rnn_lm/export.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env bash - -# We use the model from -# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main -# as an example - -export CUDA_VISIBLE_DEVICES= - -if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then - GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm - pushd icefall-librispeech-rnn-lm/exp - git lfs pull --include "pretrained.pt" - ln -s pretrained.pt epoch-99.pt - popd -fi - -python3 ./export.py \ - --exp-dir ./icefall-librispeech-rnn-lm/exp \ - --epoch 99 \ - --avg 1 \ - --vocab-size 500 \ - --embedding-dim 2048 \ - --hidden-dim 2048 \ - --num-layers 3 \ - --tie-weights 1 \ - --jit 1 - diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py deleted file mode 100644 index 5eacf5d40..000000000 --- a/icefall/rnn_lm/model.py +++ /dev/null @@ -1,295 +0,0 @@ -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Tuple - -import torch -import torch.nn.functional as F - -from icefall.utils import add_eos, add_sos, make_pad_mask - - -class RnnLmModel(torch.nn.Module): - def __init__( - self, - vocab_size: int, - embedding_dim: int, - hidden_dim: int, - num_layers: int, - tie_weights: bool = False, - ): - """ - Args: - vocab_size: - Vocabulary size of BPE model. - embedding_dim: - Input embedding dimension. - hidden_dim: - Hidden dimension of RNN layers. - num_layers: - Number of RNN layers. - tie_weights: - True to share the weights between the input embedding layer and the - last output linear layer. See https://arxiv.org/abs/1608.05859 - and https://arxiv.org/abs/1611.01462 - """ - super().__init__() - self.vocab_size = vocab_size - self.embedding_dim = embedding_dim - self.hidden_dim = hidden_dim - self.num_layers = num_layers - self.tie_weights = tie_weights - - self.input_embedding = torch.nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - ) - - self.rnn = torch.nn.LSTM( - input_size=embedding_dim, - hidden_size=hidden_dim, - num_layers=num_layers, - batch_first=True, - ) - - self.output_linear = torch.nn.Linear( - in_features=hidden_dim, out_features=vocab_size - ) - - self.vocab_size = vocab_size - if tie_weights: - logging.info("Tying weights") - assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim) - self.output_linear.weight = self.input_embedding.weight - else: - logging.info("Not tying weights") - - self.cache = {} - - def streaming_forward( - self, - x: torch.Tensor, - y: torch.Tensor, - h0: torch.Tensor, - c0: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 2-D tensor of shape (N, L). We won't prepend it with SOS. - y: - A 2-D tensor of shape (N, L). We won't append it with EOS. - h0: - A 3-D tensor of shape (num_layers, N, hidden_size). - (If proj_size > 0, then it is (num_layers, N, proj_size)) - c0: - A 3-D tensor of shape (num_layers, N, hidden_size). - Returns: - Return a tuple containing 3 tensors: - - negative loglike (nll), a 1-D tensor of shape (N,) - - next_h0, a 3-D tensor with the same shape as h0 - - next_c0, a 3-D tensor with the same shape as c0 - """ - assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) - assert x.shape == y.shape, (x.shape, y.shape) - - # embedding is of shape (N, L, embedding_dim) - embedding = self.input_embedding(x) - # Note: We use batch_first==True - rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0)) - logits = self.output_linear(rnn_out) - nll_loss = F.cross_entropy( - logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" - ) - - batch_size = x.size(0) - nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1) - return nll_loss, next_h0, next_c0 - - def forward( - self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor - ) -> torch.Tensor: - """ - Args: - x: - A 2-D tensor with shape (N, L). Each row - contains token IDs for a sentence and starts with the SOS token. - y: - A shifted version of `x` and with EOS appended. - lengths: - A 1-D tensor of shape (N,). It contains the sentence lengths - before padding. - Returns: - Return a 2-D tensor of shape (N, L) containing negative log-likelihood - loss values. Note: Loss values for padding positions are set to 0. - """ - assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) - assert lengths.ndim == 1, lengths.ndim - assert x.shape == y.shape, (x.shape, y.shape) - - batch_size = x.size(0) - assert lengths.size(0) == batch_size, (lengths.size(0), batch_size) - - # embedding is of shape (N, L, embedding_dim) - embedding = self.input_embedding(x) - - # Note: We use batch_first==True - rnn_out, _ = self.rnn(embedding) - logits = self.output_linear(rnn_out) - - # Note: No need to use `log_softmax()` here - # since F.cross_entropy() expects unnormalized probabilities - - # nll_loss is of shape (N*L,) - # nll -> negative log-likelihood - nll_loss = F.cross_entropy( - logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" - ) - # Set loss values for padding positions to 0 - mask = make_pad_mask(lengths).reshape(-1) - nll_loss.masked_fill_(mask, 0) - - nll_loss = nll_loss.reshape(batch_size, -1) - - return nll_loss - - def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): - device = next(self.parameters()).device - batch_size = len(token_lens) - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64).to(device) - y_tokens = y_tokens.to(torch.int64).to(device) - sentence_lengths = sentence_lengths.to(torch.int64).to(device) - - embedding = self.input_embedding(x_tokens) - - # Note: We use batch_first==True - rnn_out, states = self.rnn(embedding) - logits = self.output_linear(rnn_out) - mask = torch.zeros(logits.shape).bool().to(device) - for i in range(batch_size): - mask[i, token_lens[i], :] = True - logits = logits[mask].reshape(batch_size, -1) - - return logits[:, :].log_softmax(-1), states - - def clean_cache(self): - self.cache = {} - - def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): - """Score a batch of tokens, i.e each sample in the batch should be a - single token. For example, x = torch.tensor([[5],[10],[20]]) - - - Args: - x (torch.Tensor): - A batch of tokens - x_lens (torch.Tensor): - The length of tokens in the batch before padding - state (optional): - Either None or a tuple of two torch.Tensor. Each tensor has - the shape of (num_layers, bs, hidden_dim) - - Returns: - _type_: _description_ - """ - device = next(self.parameters()).device - batch_size = x.size(0) - if state: - h, c = state - else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( - device - ) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( - device - ) - - embedding = self.input_embedding(x) - rnn_out, states = self.rnn(embedding, (h, c)) - logits = self.output_linear(rnn_out) - - return logits[:, 0].log_softmax(-1), states - - def score_token_onnx( - self, - x: torch.Tensor, - state_h: torch.Tensor, - state_c: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Score a batch of tokens, i.e each sample in the batch should be a - single token. For example, x = torch.tensor([[5],[10],[20]]) - - - Args: - x (torch.Tensor): - A batch of tokens - state_h: - state h of RNN has the shape of (num_layers, bs, hidden_dim) - state_c: - state c of RNN has the shape of (num_layers, bs, hidden_dim) - - Returns: - _type_: _description_ - """ - embedding = self.input_embedding(x) - rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c)) - logits = self.output_linear(rnn_out) - - return logits[:, 0].log_softmax(-1), next_h0, next_c0 - - def forward_with_state( - self, tokens, token_lens, sos_id, eos_id, blank_id, state=None - ): - batch_size = len(token_lens) - if state: - h, c = state - else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) - - device = next(self.parameters()).device - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64).to(device) - y_tokens = y_tokens.to(torch.int64).to(device) - sentence_lengths = sentence_lengths.to(torch.int64).to(device) - - embedding = self.input_embedding(x_tokens) - - # Note: We use batch_first==True - rnn_out, states = self.rnn(embedding, (h, c)) - logits = self.output_linear(rnn_out) - - return logits, states diff --git a/icefall/rnn_lm/test_dataset.py b/icefall/rnn_lm/test_dataset.py deleted file mode 100755 index bf961f54b..000000000 --- a/icefall/rnn_lm/test_dataset.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import k2 -import torch -from rnn_lm.dataset import LmDataset, LmDatasetCollate - - -def main(): - sentences = k2.RaggedTensor( - [[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]] - ) - words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]]) - - num_sentences = sentences.dim0 - - sentence_lengths = [0] * num_sentences - for i in range(num_sentences): - word_ids = sentences[i] - - # NOTE: If word_ids is a tensor with only 1 entry, - # token_ids is a torch.Tensor - token_ids = words[word_ids] - if isinstance(token_ids, k2.RaggedTensor): - token_ids = token_ids.values - - # token_ids is a 1-D tensor containing the BPE tokens - # of the current sentence - - sentence_lengths[i] = token_ids.numel() - - sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32) - - indices = torch.argsort(sentence_lengths, descending=True) - sentences = sentences[indices.to(torch.int32)] - sentence_lengths = sentence_lengths[indices] - - dataset = LmDataset( - sentences=sentences, - words=words, - sentence_lengths=sentence_lengths, - max_sent_len=3, - batch_size=4, - ) - - collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0) - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=1, collate_fn=collate_fn - ) - - for i in dataloader: - print(i) - # I've checked the output manually; the output is as expected. - - -if __name__ == "__main__": - main() diff --git a/icefall/rnn_lm/test_dataset_ddp.py b/icefall/rnn_lm/test_dataset_ddp.py deleted file mode 100755 index 48fbb19cb..000000000 --- a/icefall/rnn_lm/test_dataset_ddp.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import k2 -import torch -import torch.multiprocessing as mp -from rnn_lm.dataset import LmDataset, LmDatasetCollate -from torch import distributed as dist - - -def generate_data(): - sentences = k2.RaggedTensor( - [[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]] - ) - words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]]) - - num_sentences = sentences.dim0 - - sentence_lengths = [0] * num_sentences - for i in range(num_sentences): - word_ids = sentences[i] - - # NOTE: If word_ids is a tensor with only 1 entry, - # token_ids is a torch.Tensor - token_ids = words[word_ids] - if isinstance(token_ids, k2.RaggedTensor): - token_ids = token_ids.values - - # token_ids is a 1-D tensor containing the BPE tokens - # of the current sentence - - sentence_lengths[i] = token_ids.numel() - - sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32) - - indices = torch.argsort(sentence_lengths, descending=True) - sentences = sentences[indices.to(torch.int32)] - sentence_lengths = sentence_lengths[indices] - - return sentences, words, sentence_lengths - - -def run(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12352" - - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) - - sentences, words, sentence_lengths = generate_data() - - dataset = LmDataset( - sentences=sentences, - words=words, - sentence_lengths=sentence_lengths, - max_sent_len=3, - batch_size=4, - ) - sampler = torch.utils.data.distributed.DistributedSampler( - dataset, shuffle=True, drop_last=False - ) - - collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0) - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=1, - collate_fn=collate_fn, - sampler=sampler, - shuffle=False, - ) - - for i in dataloader: - print(f"rank: {rank}", i) - - dist.destroy_process_group() - - -def main(): - world_size = 2 - mp.spawn(run, args=(world_size,), nprocs=world_size, join=True) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/icefall/rnn_lm/test_model.py b/icefall/rnn_lm/test_model.py deleted file mode 100755 index 5a216a3fb..000000000 --- a/icefall/rnn_lm/test_model.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from rnn_lm.model import RnnLmModel - - -def test_rnn_lm_model(): - vocab_size = 4 - model = RnnLmModel( - vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2 - ) - x = torch.tensor( - [ - [1, 3, 2, 2], - [1, 2, 2, 0], - [1, 2, 0, 0], - ] - ) - y = torch.tensor( - [ - [3, 2, 2, 1], - [2, 2, 1, 0], - [2, 1, 0, 0], - ] - ) - lengths = torch.tensor([4, 3, 2]) - nll_loss = model(x, y, lengths) - print(nll_loss) - """ - tensor([[1.1180, 1.3059, 1.2426, 1.7773], - [1.4231, 1.2783, 1.7321, 0.0000], - [1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=) - """ - - -def test_rnn_lm_model_tie_weights(): - model = RnnLmModel( - vocab_size=10, - embedding_dim=10, - hidden_dim=10, - num_layers=2, - tie_weights=True, - ) - assert model.input_embedding.weight is model.output_linear.weight - - -def main(): - test_rnn_lm_model() - test_rnn_lm_model_tie_weights() - - -if __name__ == "__main__": - torch.manual_seed(20211122) - main() diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py deleted file mode 100755 index 0178b80bf..000000000 --- a/icefall/rnn_lm/train.py +++ /dev/null @@ -1,689 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - ./rnn_lm/train.py \ - --start-epoch 0 \ - --world-size 2 \ - --num-epochs 1 \ - --use-fp16 0 \ - --tie-weights 0 \ - --embedding-dim 800 \ - --hidden-dim 200 \ - --num-layers 2 \ - --batch-size 400 - -""" - -import argparse -import logging -import math -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from dataset import get_dataloader -from lhotse.utils import fix_random_seed -from model import RnnLmModel -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - exp_dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="rnn_lm/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, logs, etc, are saved - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--batch-size", - type=int, - default=400, - ) - - parser.add_argument( - "--lm-data", - type=str, - default="data/lm_training_bpe_500/sorted_lm_data.pt", - help="LM training data", - ) - - parser.add_argument( - "--lm-data-valid", - type=str, - default="data/lm_training_bpe_500/sorted_lm_data-valid.pt", - help="LM validation data", - ) - - parser.add_argument( - "--vocab-size", - type=int, - default=500, - help="Vocabulary size of the model", - ) - - parser.add_argument( - "--embedding-dim", - type=int, - default=2048, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--hidden-dim", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--num-layers", - type=int, - default=3, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--tie-weights", - type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--lr", - type=float, - default=1e-3, - ) - - parser.add_argument( - "--max-sent-len", - type=int, - default=200, - help="""Maximum number of tokens in a sentence. This is used - to adjust batch-size dynamically""", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters.""" - - params = AttributeDict( - { - "max_sent_len": 200, - "sos_id": 1, - "eos_id": 1, - "blank_id": 0, - "weight_decay": 1e-6, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 100, - "reset_interval": 2000, - "valid_interval": 200, - "env_info": get_env_info(), - } - ) - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - logging.info(f"Loading checkpoint: {filename}") - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - if "cur_batch_idx" in saved_params: - params["cur_batch_idx"] = saved_params["cur_batch_idx"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - model: nn.Module, - x: torch.Tensor, - y: torch.Tensor, - sentence_lengths: torch.Tensor, - is_training: bool, -) -> Tuple[torch.Tensor, MetricsTracker]: - """Compute the negative log-likelihood loss given a model and its input. - Args: - model: - The NN model, e.g., RnnLmModel. - x: - A 2-D tensor. Each row contains BPE token IDs for a sentence. Also, - each row starts with SOS ID. - y: - A 2-D tensor. Each row is a shifted version of the corresponding row - in `x` but ends with an EOS ID (before padding). - sentence_lengths: - A 1-D tensor containing number of tokens of each sentence - before padding. - is_training: - True for training. False for validation. - """ - with torch.set_grad_enabled(is_training): - device = model.device - x = x.to(device) - y = y.to(device) - sentence_lengths = sentence_lengths.to(device) - - nll = model(x, y, sentence_lengths) - loss = nll.sum() - - num_tokens = sentence_lengths.sum().item() - - loss_info = MetricsTracker() - # Note: Due to how MetricsTracker() is designed, - # we use "frames" instead of "num_tokens" as a key here - loss_info["frames"] = num_tokens - loss_info["loss"] = loss.detach().item() - return loss, loss_info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - model=model, - x=x, - y=y, - sentence_lengths=sentence_lengths, - is_training=False, - ) - - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all sentences is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - cur_batch_idx = params.get("cur_batch_idx", 0) - - for batch_idx, batch in enumerate(train_dl): - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx - - params.batch_idx_train += 1 - x, y, sentence_lengths = batch - batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - model=model, - x=x, - y=y, - sentence_lengths=sentence_lengths, - is_training=True, - ) - - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - rank=rank, - ) - del params.cur_batch_idx - - if batch_idx % params.log_interval == 0: - # Note: "frames" here means "num_tokens" - this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) - tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] " - f"tot_loss[{tot_loss}, ppl: {tot_ppl}], " - f"batch size: {batch_size}" - ) - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - tb_writer.add_scalar( - "train/current_ppl", this_batch_ppl, params.batch_idx_train - ) - - tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - - valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"]) - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}, " - f"ppl: {valid_ppl}" - ) - - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - tb_writer.add_scalar( - "train/valid_ppl", valid_ppl, params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - is_distributed = world_size > 1 - - fix_random_seed(params.seed) - if is_distributed: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - logging.info(f"Device: {device}") - - logging.info("About to create model") - model = RnnLmModel( - vocab_size=params.vocab_size, - embedding_dim=params.embedding_dim, - hidden_dim=params.hidden_dim, - num_layers=params.num_layers, - tie_weights=params.tie_weights, - ) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if is_distributed: - model = DDP(model, device_ids=[rank]) - - model.device = device - - optimizer = optim.Adam( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - if checkpoints: - logging.info("Load optimizer state_dict from checkpoint") - optimizer.load_state_dict(checkpoints["optimizer"]) - - logging.info(f"Loading LM training data from {params.lm_data}") - train_dl = get_dataloader( - filename=params.lm_data, - is_distributed=is_distributed, - params=params, - ) - - logging.info(f"Loading LM validation data from {params.lm_data_valid}") - valid_dl = get_dataloader( - filename=params.lm_data_valid, - is_distributed=is_distributed, - params=params, - ) - - # Note: No learning rate scheduler is used here - for epoch in range(params.start_epoch, params.num_epochs): - if is_distributed: - train_dl.sampler.set_epoch(epoch) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if is_distributed: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/icefall/transformer_lm/__init__.py b/icefall/transformer_lm/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/icefall/transformer_lm/attention.py b/icefall/transformer_lm/attention.py deleted file mode 100644 index 5ce83b15e..000000000 --- a/icefall/transformer_lm/attention.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import List, Optional, Tuple - -import torch -from torch import Tensor, nn - -from icefall.transformer_lm.scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv1d, - ScaledConv2d, - ScaledLinear, -) -from icefall.utils import is_jit_tracing - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear( - embed_dim, embed_dim, bias=True, initial_scale=0.25 - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() - - def _pos_bias_u(self): - return self.pos_bias_u * self.pos_bias_u_scale.exp() - - def _pos_bias_v(self): - return self.pos_bias_v * self.pos_bias_v_scale.exp() - - def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.01) - nn.init.normal_(self.pos_bias_v, std=0.01) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = False, - attn_mask: Optional[Tensor] = None, - left_context: int = 0, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), - self.dropout, - self.out_proj.get_weight(), - self.out_proj.get_bias(), - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - left_context=left_context, - ) - - def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1+left_context). - time1 means the length of query vector. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - - time2 = time1 + left_context - if not is_jit_tracing(): - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" - - if is_jit_tracing(): - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(time2) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - - x = x.reshape(-1, n) - x = torch.gather(x, dim=1, index=indexes) - x = x.reshape(batch_size, num_heads, time1, time2) - return x - else: - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = False, - attn_mask: Optional[Tensor] = None, - left_context: int = 0, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - if not is_jit_tracing(): - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - if not is_jit_tracing(): - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None and not is_jit_tracing(): - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - if not is_jit_tracing(): - assert pos_emb_bsz in (1, bsz) # actually it is 1 - - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) - p = p.permute(0, 2, 3, 1) - - q_with_bias_u = (q + self._pos_bias_u()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self._pos_bias_v()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd, left_context) - - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - - if not is_jit_tracing(): - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - - # If we are using dynamic_chunk_training and setting a limited - # num_left_chunks, the attention may only see the padding values which - # will also be masked out by `key_padding_mask`, at this circumstances, - # the whole column of `attn_output_weights` will be `-inf` - # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking - # positions to avoid invalid loss value below. - if ( - attn_mask is not None - and attn_mask.dtype == torch.bool - and key_padding_mask is not None - ): - if attn_mask.size(0) != 1: - attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) - - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - - if not is_jit_tracing(): - assert list(attn_output.size()) == [ - bsz * num_heads, - tgt_len, - head_dim, - ] - - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None diff --git a/icefall/transformer_lm/compute_perplexity.py b/icefall/transformer_lm/compute_perplexity.py deleted file mode 100644 index 72d7c477b..000000000 --- a/icefall/transformer_lm/compute_perplexity.py +++ /dev/null @@ -1,195 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import logging -import math -from pathlib import Path - -import torch -from dataset import get_dataloader -from train import get_params - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.transformer_lm.model import TransformerLM -from icefall.utils import AttributeDict, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=7, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transformer_lm/exp_full_libri_16layer_maxlen200_8gpu", - ) - - parser.add_argument( - "--lm-data", - type=str, - help="Path to the LM test data for computing perplexity", - default="transformer_lm/libri_lm_training_bpe500/sorted_lm_data-test.pt", - ) - - parser.add_argument( - "--vocab-size", - type=int, - default=500, - help="Vocabulary size of the model", - ) - - parser.add_argument( - "--num-layers", - type=int, - default=16, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--tie-weights", - type=str2bool, - default=False, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="Number of RNN layers the model", - ) - - parser.add_argument( - "--max-sent-len", - type=int, - default=100, - help="Number of RNN layers the model", - ) - - return parser - - -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lm_data = Path(args.lm_data) - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-ppl/") - logging.info("Computing perplexity started") - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - logging.info("About to create model") - model = TransformerLM( - vocab_size=params.vocab_size, - d_model=params.encoder_dim, - embedding_dim=params.embedding_dim, - dim_feedforward=params.dim_feedforward, - nhead=params.nhead, - num_layers=params.num_layers, - tie_weights=params.tie_weights, - params=params, - ) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - num_param_requires_grad = sum( - [p.numel() for p in model.parameters() if p.requires_grad] - ) - - logging.info(f"Number of model parameters: {num_param}") - logging.info( - f"Number of model parameters (requires_grad): " - f"{num_param_requires_grad} " - f"({num_param_requires_grad/num_param_requires_grad*100}%)" - ) - - logging.info(f"Loading LM test data from {params.lm_data}") - test_dl = get_dataloader( - filename=params.lm_data, - is_distributed=False, - params=params, - ) - - tot_loss = 0.0 - num_tokens = 0 - num_sentences = 0 - for batch_idx, batch in enumerate(test_dl): - x, y, sentence_lengths = batch - x = x.to(device) - y = y.to(device) - sentence_lengths = sentence_lengths.to(device) - - nll = model(x, y, sentence_lengths) - loss = nll.sum().cpu().item() - - tot_loss += loss - num_tokens += sentence_lengths.sum().cpu().item() - num_sentences += x.size(0) - - ppl = math.exp(tot_loss / num_tokens) - logging.info( - f"total nll: {tot_loss}, num tokens: {num_tokens}, " - f"num sentences: {num_sentences}, ppl: {ppl:.3f}" - ) - - -if __name__ == "__main__": - main() diff --git a/icefall/transformer_lm/dataset.py b/icefall/transformer_lm/dataset.py deleted file mode 120000 index 5792a6cf0..000000000 --- a/icefall/transformer_lm/dataset.py +++ /dev/null @@ -1 +0,0 @@ -../rnn_lm/dataset.py \ No newline at end of file diff --git a/icefall/transformer_lm/encoder.py b/icefall/transformer_lm/encoder.py deleted file mode 100644 index 4357b83d7..000000000 --- a/icefall/transformer_lm/encoder.py +++ /dev/null @@ -1,329 +0,0 @@ -# Copyright (c) 2021 Xiaomi Corporation (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -from typing import List, Optional, Tuple - -import torch -import torch.nn.functional as F -from torch import Tensor, nn - -from icefall.transformer_lm.attention import RelPositionMultiheadAttention -from icefall.transformer_lm.scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv1d, - ScaledConv2d, - ScaledLinear, -) -from icefall.utils import is_jit_tracing, make_pad_mask - - -class Transformer(torch.nn.Module): - """_summary_ - - Args: - input_dim (int): Input feature dimension - d_mode (int): The dimension of the transformer - dim_feedforward (int ): The dimension of the ffw module - nhead (int): The number of attention heads - dropout_rate (float): dropout rate - att_dropout (float): dropout rate in attention module - """ - - def __init__( - self, - input_dim: int, - d_model: int, - dim_feedforward: int, - nhead: int = 4, - num_layers: int = 6, - dropout_rate: float = 0.1, - att_dropout: float = 0.0, - ): - super().__init__() - - self.encoder_layers = num_layers - self.d_model = d_model - - self.embed = ScaledLinear(input_dim, d_model) - self.norm_before = BasicNorm(d_model, learn_eps=False) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout_rate) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - dim_feedforward=dim_feedforward, - nhead=nhead, - dropout_rate=dropout_rate, - ) - - self.encoder = TransformerEncoder(encoder_layer, num_layers) - - def _create_attention_mask(self, x_lens: torch.Tensor): - # create a 2D attention mask to mask out - # the upper right half of the attention matrix - max_len = max(x_lens) - ones = torch.ones(max_len, max_len, device=x_lens.device, dtype=torch.bool) - return torch.triu(ones, diagonal=1) - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Transformer forward - - Args: - x (torch.Tensor): Input tensor (B,T,input_dim) - x_lens (torch.Tensor): The length of input tensors before padding (B,) - - Returns: - Return a tuple of 2 tensors: - - x: output feature of the transformer (B,T,d_model) - - x_lens: output feature lens of the transformer - """ - - attention_mask = self._create_attention_mask(x_lens) - src_key_padding_mask = make_pad_mask(x_lens) - - x = self.norm_before(self.embed(x)) - - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) - - x = self.encoder( - x, - pos_emb, - mask=attention_mask, # pass the attention mast - src_key_padding_mask=src_key_padding_mask, - ) # (T, N, C) - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x, x_lens - - -class TransformerEncoder(torch.nn.Module): - def __init__(self, encoder_layer: torch.nn.Module, num_layers: int) -> None: - """TransformerEncoder is a stack of N encoder layers - - Args: - encoder_layer (torch.nn.Module): an instance of the TransformerEncoderLayer() - num_layers (int): Number of layers to be stacked - """ - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - def forward( - self, - src: torch.Tensor, - pos_emb: torch.Tensor, - src_key_padding_mask: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """_summary_ - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Returns: - output: transformer encoded features - """ - output = src - - for layer_index, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - src_key_padding_mask=src_key_padding_mask, - src_mask=mask, - ) - - return output - - -class TransformerEncoderLayer(torch.nn.Module): - def __init__( - self, - d_model: int, - dim_feedforward: int, - nhead: int, - dropout_rate: float, - ): - """TransformerEncoderLayer is made up of self-attn and feedforward module - - Args: - d_model (int): The model size - dim_feedforward (int): Dimension of ffw module - nhead (int): Number of heads - dropout_rate (float): Dropout rate - """ - super().__init__() - - self.d_model = d_model - - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout_rate), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.norm_final = BasicNorm(d_model) - - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - - self.dropout = nn.Dropout(dropout_rate) - - def forward( - self, - src: torch.Tensor, - pos_emb: torch.Tensor, - src_key_padding_mask: Optional[torch.Tensor] = None, - src_mask: Optional[torch.Tensor] = None, - cache=None, - ): - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_key_padding_mask: the mask for the src keys per batch (optional). - src_mask: the mask for the src sequence (optional). - """ - src_orig = src - - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - - src = src + self.dropout(src_att) - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - return src - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - if is_jit_tracing(): - # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., - # It assumes that the maximum input won't have more than - # 10k frames. - # - # TODO(fangjun): Use torch.jit.script() for this module - max_len = 10000 - - self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: - """Reset the positional encodings.""" - x_size_1 = x.size(1) + left_context - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_1 * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - - """ - self.extend_pe(x, left_context) - x_size_1 = x.size(1) + left_context - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_1 - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) diff --git a/icefall/transformer_lm/export.py b/icefall/transformer_lm/export.py deleted file mode 100644 index c08982e37..000000000 --- a/icefall/transformer_lm/export.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. - -import argparse -import logging -from pathlib import Path - -import torch -from model import TransformerLM - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict, load_averaged_model, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=11, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--vocab-size", - type=int, - default=500, - help="Vocabulary size of the model", - ) - - parser.add_argument( - "--embedding-dim", - type=int, - default=768, - help="Embedding dim of the model", - ) - - parser.add_argument( - "--encoder-dim", - type=int, - default=768, - help="Encoder dim of the model", - ) - - parser.add_argument( - "--dim_feedforward", - type=int, - default=2048, - help="Hidden dim of the model", - ) - - parser.add_argument( - "--nhead", - type=int, - default=8, - help="Number of attention heads", - ) - - parser.add_argument( - "--num-layers", - type=int, - default=16, - help="Number of Transformer layers", - ) - - parser.add_argument( - "--tie-weights", - type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="rnn_lm/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=True, - help="""True to save a model after applying torch.jit.script. - """, - ) - - return parser - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = AttributeDict({}) - params.update(vars(args)) - - logging.info(params) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("About to create model") - model = TransformerLM( - vocab_size=params.vocab_size, - d_model=params.encoder_dim, - embedding_dim=params.embedding_dim, - dim_feedforward=params.dim_feedforward, - nhead=params.nhead, - num_layers=params.num_layers, - tie_weights=params.tie_weights, - params=params, - ) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - model = load_averaged_model( - params.exp_dir, model, params.epoch, params.avg, device - ) - - model.to("cpu") - model.eval() - - if params.jit: - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/icefall/transformer_lm/model.py b/icefall/transformer_lm/model.py deleted file mode 100644 index c78cf1821..000000000 --- a/icefall/transformer_lm/model.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2022 Xiaomi Corporation (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F - -from icefall.transformer_lm.encoder import Transformer -from icefall.utils import AttributeDict, add_eos, add_sos, make_pad_mask - - -class TransformerLM(torch.nn.Module): - def __init__( - self, - vocab_size: int, - embedding_dim: int, - d_model: int, - dim_feedforward: int, - nhead: int = 8, - num_layers: int = 16, - tie_weights: bool = True, - dropout: float = 0.1, - emb_dropout_rate: float = 0.0, - params: AttributeDict = None, - ): - super().__init__() - - self.vocab_size = vocab_size - self.params = params - - self.input_embedding = torch.nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=embedding_dim, - ) - - self.encoder = Transformer( - input_dim=embedding_dim, - d_model=d_model, - dim_feedforward=dim_feedforward, - nhead=nhead, - num_layers=num_layers, - dropout_rate=dropout, - ) - - self.output_linear = torch.nn.Linear( - in_features=d_model, out_features=vocab_size - ) - if tie_weights: - logging.info("Tying weights") - assert d_model == embedding_dim, (d_model, embedding_dim) - self.output_linear.weight = self.input_embedding.weight - else: - logging.info("Not tying weights") - - def forward( - self, - x: torch.Tensor, - y: torch.Tensor, - x_lens: torch.Tensor, - return_logits: bool = False, - ): - """Forward transformer language model - - Args: - x (torch.Tensor): Input tokens (B,L) - y (torch.Tensor): Output tokens (with EOS appended) (B,L) - x_lens (torch.Tensor): Length of input tokens before padding (B,) - return_logits (bool, optional): Return logits instead of NLL - - """ - - x = self.input_embedding(x) - - x, x_lens = self.encoder(x, x_lens) - - logits = self.output_linear(x) - - if return_logits: - return logits - - nll_loss = F.cross_entropy( - logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" - ) - - mask = make_pad_mask(x_lens).reshape(-1) - nll_loss.masked_fill_(mask, 0) - - return nll_loss - - def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): - bs = x.size(0) - - state = None - logits = self.forward(x, x, x_lens, return_logits=True) - index = torch.arange(bs) - - last_logits = logits[index, x_lens - 1, :] - - return last_logits.log_softmax(-1), state diff --git a/icefall/transformer_lm/scaling.py b/icefall/transformer_lm/scaling.py deleted file mode 120000 index 0876c0704..000000000 --- a/icefall/transformer_lm/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/icefall/transformer_lm/train.py b/icefall/transformer_lm/train.py deleted file mode 100644 index c36abfcdf..000000000 --- a/icefall/transformer_lm/train.py +++ /dev/null @@ -1,609 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -Usage: - ./transformer_lm/train.py \ - --start-epoch 0 \ - --world-size 2 \ - --num-epochs 1 \ - --use-fp16 0 \ - --num-layers 12 \ - --batch-size 400 - -""" - -import argparse -import logging -import math -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from dataset import get_dataloader -from lhotse.utils import fix_random_seed -from model import TransformerLM -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - exp_dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transformer_lm/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, logs, etc, are saved - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=True, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--batch-size", - type=int, - default=400, - ) - - parser.add_argument( - "--lm-data", - type=str, - default="data/lm_training_bpe_500/sorted_lm_data.pt", - help="LM training data", - ) - - parser.add_argument( - "--lm-data-valid", - type=str, - default="data/lm_training_bpe_500/sorted_lm_data-valid.pt", - help="LM validation data", - ) - - parser.add_argument( - "--vocab-size", - type=int, - default=500, - help="Vocabulary size of the model", - ) - - parser.add_argument( - "--num-layers", - type=int, - default=12, - help="Number of Transformer layers in the model", - ) - - parser.add_argument( - "--tie-weights", - type=str2bool, - default=True, - help="""True to share the weights between the input embedding layer and the - last output linear layer - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters.""" - - params = AttributeDict( - { - "max_sent_len": 200, - "sos_id": 1, - "eos_id": 1, - "blank_id": 0, - "lr": 1e-3, - "weight_decay": 1e-6, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 200, - "reset_interval": 2000, - "valid_interval": 1000, - "nhead": 8, - "embedding_dim": 768, - "encoder_dim": 768, - "dim_feedforward": 2048, - "dropout": 0.1, - "env_info": get_env_info(), - } - ) - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - logging.info(f"Loading checkpoint: {filename}") - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - model: nn.Module, - x: torch.Tensor, - y: torch.Tensor, - sentence_lengths: torch.Tensor, - is_training: bool, -) -> Tuple[torch.Tensor, MetricsTracker]: - """Compute the negative log-likelihood loss given a model and its input. - Args: - model: - The NN model, - x: - A 2-D tensor. Each row contains BPE token IDs for a sentence. Also, - each row starts with SOS ID. - y: - A 2-D tensor. Each row is a shifted version of the corresponding row - in `x` but ends with an EOS ID (before padding). - sentence_lengths: - A 1-D tensor containing number of tokens of each sentence - before padding. - is_training: - True for training. False for validation. - """ - with torch.set_grad_enabled(is_training): - device = model.device - x = x.to(device) - y = y.to(device) - sentence_lengths = sentence_lengths.to(device) - - nll = model(x, y, sentence_lengths) - loss = nll.sum() - - num_tokens = sentence_lengths.sum().item() - - loss_info = MetricsTracker() - # Note: Due to how MetricsTracker() is designed, - # we use "frames" instead of "num_tokens" as a key here - loss_info["frames"] = num_tokens - loss_info["loss"] = loss.detach().item() - return loss, loss_info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - x, y, sentence_lengths = batch - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - model=model, - x=x, - y=y, - sentence_lengths=sentence_lengths, - is_training=False, - ) - - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all sentences is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - x, y, sentence_lengths = batch - batch_size = x.size(0) - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - model=model, - x=x, - y=y, - sentence_lengths=sentence_lengths, - is_training=True, - ) - - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - # Note: "frames" here means "num_tokens" - this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) - tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] " - f"tot_loss[{tot_loss}, ppl: {tot_ppl}], " - f"batch size: {batch_size}" - ) - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - tb_writer.add_scalar( - "train/current_ppl", this_batch_ppl, params.batch_idx_train - ) - - tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - - valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"]) - logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}, " - f"ppl: {valid_ppl}" - ) - - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - tb_writer.add_scalar( - "train/valid_ppl", valid_ppl, params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - is_distributed = world_size > 1 - - fix_random_seed(params.seed) - if is_distributed: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - logging.info(f"Device: {device}") - - logging.info("About to create model") - model = TransformerLM( - vocab_size=params.vocab_size, - d_model=params.encoder_dim, - embedding_dim=params.embedding_dim, - dim_feedforward=params.dim_feedforward, - nhead=params.nhead, - num_layers=params.num_layers, - tie_weights=params.tie_weights, - params=params, - ) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if is_distributed: - model = DDP(model, device_ids=[rank]) - - model.device = device - - optimizer = optim.Adam( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - if checkpoints: - logging.info("Load optimizer state_dict from checkpoint") - optimizer.load_state_dict(checkpoints["optimizer"]) - - logging.info(f"Loading LM training data from {params.lm_data}") - train_dl = get_dataloader( - filename=params.lm_data, - is_distributed=is_distributed, - params=params, - ) - - logging.info(f"Loading LM validation data from {params.lm_data_valid}") - valid_dl = get_dataloader( - filename=params.lm_data_valid, - is_distributed=is_distributed, - params=params, - ) - - # Note: No learning rate scheduler is used here - for epoch in range(params.start_epoch, params.num_epochs): - if is_distributed: - train_dl.sampler.set_epoch(epoch) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if is_distributed: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main()