diff --git a/.flake8 b/.flake8 index 229cf1d6c..dd9239b2d 100644 --- a/.flake8 +++ b/.flake8 @@ -13,4 +13,5 @@ per-file-ignores = exclude = .git, **/data/**, - icefall/shared/make_kn_lm.py + icefall/shared/make_kn_lm.py, + icefall/__init__.py diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 2a743705a..6b3d856df 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -45,7 +45,9 @@ jobs: - name: Install Python dependencies run: | - python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 + python3 -m pip install --upgrade pip black==21.6b0 flake8==3.9.2 click==8.0.4 + # See https://github.com/psf/black/issues/2964 + # The version of click should be selected from 8.0.0, 8.0.1, 8.0.2, 8.0.3, and 8.0.4 - name: Run flake8 shell: bash diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b59784dbf..446ba0fe7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,8 @@ repos: hooks: - id: black args: [--line-length=80] + additional_dependencies: ['click==8.0.1'] + exclude: icefall\/__init__\.py - repo: https://github.com/PyCQA/flake8 rev: 3.9.2 diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index a8c3b6865..5d364dbc0 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -27,9 +27,21 @@ Installation ``icefall`` depends on `k2 `_ and `lhotse `_. -We recommend you to install ``k2`` first, as ``k2`` is bound to -a specific version of PyTorch after compilation. Install ``k2`` also -installs its dependency PyTorch, which can be reused by ``lhotse``. +We recommend you to use the following steps to install the dependencies. + +- (0) Install PyTorch and torchaudio +- (1) Install k2 +- (2) Install lhotse + +.. caution:: + + Installation order matters. + +(0) Install PyTorch and torchaudio +---------------------------------- + +Please refer ``_ to install PyTorch +and torchaudio. (1) Install k2 @@ -54,14 +66,15 @@ to install ``k2``. Please refer to ``_ to install ``lhotse``. -.. HINT:: - Install ``lhotse`` also installs its dependency `torchaudio `_. +.. hint:: -.. CAUTION:: + We strongly recommend you to use:: + + pip install git+https://github.com/lhotse-speech/lhotse + + to install the latest version of lhotse. - If you have installed ``torchaudio``, please consider uninstalling it before - installing ``lhotse``. Otherwise, it may update your already installed PyTorch. (3) Download icefall -------------------- diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py deleted file mode 100644 index cdc85ce9a..000000000 --- a/egs/aishell/ASR/conformer_ctc/label_smoothing.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -class LabelSmoothingLoss(torch.nn.Module): - """ - Implement the LabelSmoothingLoss proposed in the following paper - https://arxiv.org/pdf/1512.00567.pdf - (Rethinking the Inception Architecture for Computer Vision) - - """ - - def __init__( - self, - ignore_index: int = -1, - label_smoothing: float = 0.1, - reduction: str = "sum", - ) -> None: - """ - Args: - ignore_index: - ignored class id - label_smoothing: - smoothing rate (0.0 means the conventional cross entropy loss) - reduction: - It has the same meaning as the reduction in - `torch.nn.CrossEntropyLoss`. It can be one of the following three - values: (1) "none": No reduction will be applied. (2) "mean": the - mean of the output is taken. (3) "sum": the output will be summed. - """ - super().__init__() - assert 0.0 <= label_smoothing < 1.0 - self.ignore_index = ignore_index - self.label_smoothing = label_smoothing - self.reduction = reduction - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.ignore_index of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.ndim == 3 - assert target.ndim == 2 - assert x.shape[:2] == target.shape - num_classes = x.size(-1) - x = x.reshape(-1, num_classes) - # Now x is of shape (N*T, C) - - # We don't want to change target in-place below, - # so we make a copy of it here - target = target.clone().reshape(-1) - - ignored = target == self.ignore_index - target[ignored] = 0 - - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) - - true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes - ) - # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 - - loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) - if self.reduction == "sum": - return loss.sum() - elif self.reduction == "mean": - return loss.sum() / (~ignored).sum() - else: - return loss.sum(dim=-1) diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_mmi/label_smoothing.py b/egs/aishell/ASR/conformer_mmi/label_smoothing.py deleted file mode 100644 index cdc85ce9a..000000000 --- a/egs/aishell/ASR/conformer_mmi/label_smoothing.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -class LabelSmoothingLoss(torch.nn.Module): - """ - Implement the LabelSmoothingLoss proposed in the following paper - https://arxiv.org/pdf/1512.00567.pdf - (Rethinking the Inception Architecture for Computer Vision) - - """ - - def __init__( - self, - ignore_index: int = -1, - label_smoothing: float = 0.1, - reduction: str = "sum", - ) -> None: - """ - Args: - ignore_index: - ignored class id - label_smoothing: - smoothing rate (0.0 means the conventional cross entropy loss) - reduction: - It has the same meaning as the reduction in - `torch.nn.CrossEntropyLoss`. It can be one of the following three - values: (1) "none": No reduction will be applied. (2) "mean": the - mean of the output is taken. (3) "sum": the output will be summed. - """ - super().__init__() - assert 0.0 <= label_smoothing < 1.0 - self.ignore_index = ignore_index - self.label_smoothing = label_smoothing - self.reduction = reduction - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.ignore_index of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.ndim == 3 - assert target.ndim == 2 - assert x.shape[:2] == target.shape - num_classes = x.size(-1) - x = x.reshape(-1, num_classes) - # Now x is of shape (N*T, C) - - # We don't want to change target in-place below, - # so we make a copy of it here - target = target.clone().reshape(-1) - - ignored = target == self.ignore_index - target[ignored] = 0 - - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) - - true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes - ) - # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 - - loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) - if self.reduction == "sum": - return loss.sum() - elif self.reduction == "mean": - return loss.sum() / (~ignored).sum() - else: - return loss.sum(dim=-1) diff --git a/egs/aishell/ASR/conformer_mmi/label_smoothing.py b/egs/aishell/ASR/conformer_mmi/label_smoothing.py new file mode 120000 index 000000000..08734abd7 --- /dev/null +++ b/egs/aishell/ASR/conformer_mmi/label_smoothing.py @@ -0,0 +1 @@ +../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 68f5c54d3..26324b0af 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -70,7 +70,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # |-- lexicon.txt # `-- speaker.info - if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then + if [ ! -d $dl_dir/aishell/data_aishell/wav/train ]; then lhotse download aishell $dl_dir fi diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py index 31bab122c..9e6ed96b1 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/pretrained.py @@ -55,18 +55,17 @@ from typing import List import kaldifeat import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model -from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict def get_parser(): @@ -111,6 +110,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -137,70 +143,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - "sample_rate": 16000, - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -225,6 +167,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -249,7 +192,7 @@ def main(): model = get_transducer_model(params) checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) + model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() model.device = device @@ -279,12 +222,22 @@ def main(): features, batch_first=True, padding_value=math.log(1e-10) ) - hyps = [] - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + hyp_list = [] + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, ) - + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: for i in range(encoder_out.size(0)): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] @@ -301,17 +254,15 @@ def main(): encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.method}" ) - hyps.append([lexicon.token_table[i] for i in hyp]) + hyp_list.append(hyp) + + hyps = [] + for hyp in hyp_list: + hyps.append([lexicon.token_table[i] for i in hyp]) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py index 698594e92..f7c5b24ba 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py @@ -55,18 +55,17 @@ from typing import List import kaldifeat import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model -from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict def get_parser(): @@ -111,6 +110,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -137,70 +143,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - "sample_rate": 16000, - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -225,6 +167,7 @@ def read_sound_files( return ans +@torch.no_grad() def main(): parser = get_parser() args = parser.parse_args() @@ -279,12 +222,22 @@ def main(): features, batch_first=True, padding_value=math.log(1e-10) ) - hyps = [] - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lens + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + hyp_list = [] + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, ) - + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: for i in range(encoder_out.size(0)): # fmt: off encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] @@ -301,17 +254,15 @@ def main(): encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.method}" ) - hyps.append([lexicon.token_table[i] for i in hyp]) + hyp_list.append(hyp) + + hyps = [] + for hyp in hyp_list: + hyps.append([lexicon.token_table[i] for i in hyp]) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..1f2f3b137 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -76,7 +76,11 @@ class LabelSmoothingLoss(torch.nn.Module): target = target.clone().reshape(-1) ignored = target == self.ignore_index - target[ignored] = 0 + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) true_dist = torch.nn.functional.one_hot( target, num_classes=num_classes @@ -86,8 +90,17 @@ class LabelSmoothingLoss(torch.nn.Module): true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) + # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) if self.reduction == "sum": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 651854999..815e1c02a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -106,7 +106,7 @@ def fast_beam_search( def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int ) -> List[int]: - """ + """Greedy search for a single utterance. Args: model: An instance of `Transducer`. @@ -178,6 +178,68 @@ def greedy_search( return hyp +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_out: (batch_size, 1, decoder_out_dim) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + + ans = [h[context_size:] for h in hyps] + return ans + + @dataclass class Hypothesis: # The predicted tokens so far. @@ -304,13 +366,156 @@ class HypothesisList(object): return ", ".join(s) +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + # decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + + Args: model: An instance of `Transducer`. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ad76411c0..49b1308b0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -71,6 +71,7 @@ from beam_search import ( beam_search, fast_beam_search, greedy_search, + greedy_search_batch, modified_beam_search, ) from train import get_params, get_transducer_model @@ -97,27 +98,28 @@ def get_parser(): "--epoch", type=int, default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--avg", type=int, default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, + "'--epoch' and '--iter'", ) parser.add_argument( @@ -191,7 +193,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -261,6 +263,24 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -280,12 +300,6 @@ def decode_one_batch( encoder_out=encoder_out_i, beam=params.beam_size, ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -440,13 +454,19 @@ def main(): ) params.res_dir = params.exp_dir / params.decoding_method - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -472,8 +492,20 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index e6528b8d7..b0eb4d749 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -50,7 +50,12 @@ import kaldifeat import sentencepiece as spm import torch import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model @@ -122,7 +127,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --method is greedy_search. """, @@ -224,28 +229,43 @@ def main(): if params.method == "beam_search": msg += f" with beam size {params.beam_size}" logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") + if params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) - hyps.append(sp.decode(hyp).split()) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index e71f0d1c6..e743106ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -33,6 +33,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple @@ -392,12 +393,16 @@ def load_checkpoint_if_available( "batch_idx_train", "best_train_loss", "best_valid_loss", - "cur_batch_idx", ] for k in keys: params[k] = saved_params[k] - params["start_epoch"] = saved_params["cur_epoch"] + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] return saved_params @@ -492,7 +497,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -600,21 +609,6 @@ def train_one_epoch( global_step=params.batch_idx_train, ) - def maybe_log_param_relative_changes(): - if ( - params.log_diagnostics - and tb_writer is not None - and params.batch_idx_train % (params.log_interval * 5) == 0 - ): - deltas = optim_step_and_measure_param_change(model, optimizer) - tb_writer.add_scalars( - "train/relative_param_change_per_minibatch", - deltas, - global_step=params.batch_idx_train, - ) - else: - optimizer.step() - cur_batch_idx = params.get("cur_batch_idx", 0) for batch_idx, batch in enumerate(train_dl): @@ -642,7 +636,26 @@ def train_one_epoch( maybe_log_weights("train/param_norms") maybe_log_gradients("train/grad_norms") - maybe_log_param_relative_changes() + + old_parameters = None + if ( + params.log_diagnostics + and tb_writer is not None + and params.batch_idx_train % (params.log_interval * 5) == 0 + ): + old_parameters = { + n: p.detach().clone() for n, p in model.named_parameters() + } + + optimizer.step() + + if old_parameters is not None: + deltas = optim_step_and_measure_param_change(model, old_parameters) + tb_writer.add_scalars( + "train/relative_param_change_per_minibatch", + deltas, + global_step=params.batch_idx_train, + ) optimizer.zero_grad() @@ -783,6 +796,13 @@ def run(rank, world_size, args): def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold return 1.0 <= c.duration <= 20.0 num_in_total = len(train_cuts) @@ -797,7 +817,9 @@ def run(rank, world_size, args): logging.info(f"After removing short and long utterances: {num_left}") logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") - if checkpoints and "sampler" in checkpoints: + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch sampler_state_dict = checkpoints["sampler"] else: sampler_state_dict = None diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index a460c8eb8..8dd1459ca 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -23,6 +23,7 @@ from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional +import torch from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, @@ -34,11 +35,20 @@ from lhotse.dataset import ( SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. @@ -301,12 +311,18 @@ class LibriSpeechAsrDataModule: logging.info("Loading sampler state dict") train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index a6ce79520..cbd9259e0 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -393,7 +394,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 9f06ed512..eef4d3430 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -35,6 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -397,7 +398,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index c5efb733d..7b4fac31d 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from typing import Dict, List, Optional +import k2 import torch from model import Transducer @@ -24,7 +25,7 @@ from model import Transducer def greedy_search( model: Transducer, encoder_out: torch.Tensor, max_sym_per_frame: int ) -> List[int]: - """ + """Greedy search for a single utterance. Args: model: An instance of `Transducer`. @@ -80,7 +81,7 @@ def greedy_search( logits = model.joiner( current_encoder_out, decoder_out, encoder_out_len, decoder_out_len ) - # logits is (1, 1, 1, vocab_size) + # logits is (1, vocab_size) y = logits.argmax().item() if y != blank_id: @@ -101,6 +102,75 @@ def greedy_search( return hyp +def greedy_search_batch( + model: Transducer, encoder_out: torch.Tensor +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + Returns: + Return a list-of-list of token IDs containing the decoded results. + len(ans) equals to encoder_out.size(0). + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + device = model.device + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + + hyps = [[blank_id] * context_size for _ in range(batch_size)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_out: (batch_size, 1, decoder_out_dim) + + encoder_out_len = torch.ones(batch_size, dtype=torch.int32) + decoder_out_len = torch.ones(batch_size, dtype=torch.int32) + + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + logits = model.joiner( + current_encoder_out, decoder_out, encoder_out_len, decoder_out_len + ) # (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) # (batch_size, context_size) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) # (batch_size, 1, decoder_out_dim) + + ans = [h[context_size:] for h in hyps] + return ans + + @dataclass class Hypothesis: # The predicted tokens so far. @@ -252,9 +322,11 @@ def run_decoder( device = model.device - decoder_input = torch.tensor([ys[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [ys[-context_size:]], + device=device, + dtype=torch.int64, + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_cache[key] = decoder_out @@ -314,13 +386,158 @@ def run_joiner( return log_prob +def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, beam: int = 4, +) -> List[List[int]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcodded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + beam: + Number of active paths during the beam search. + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + + batch_size = encoder_out.size(0) + T = encoder_out.size(1) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + B = [HypothesisList() for _ in range(batch_size)] + for i in range(batch_size): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + encoder_out_len = torch.tensor([1]) + decoder_out_len = torch.tensor([1]) + for t in range(T): + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + # current_encoder_out's shape is: (batch_size, 1, encoder_out_dim) + + hyps_shape = _get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + # decoder_output is of shape (num_hyps, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + encoder_out_len.expand(decoder_out.size(0)), + decoder_out_len.expand(decoder_out.size(0)), + ) + # logits is of shape (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + ans = [h.ys[context_size:] for h in best_hyps] + + return ans + + +def _deprecated_modified_beam_search( + model: Transducer, + encoder_out: torch.Tensor, + beam: int = 4, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. + It decodes only one utterance at a time. We keep it only for reference. + The function :func:`modified_beam_search` should be preferred as it + supports batch decoding. + Args: model: An instance of `Transducer`. @@ -341,12 +558,6 @@ def modified_beam_search( device = model.device - decoder_input = torch.tensor( - [blank_id] * context_size, device=device - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - T = encoder_out.size(1) B = HypothesisList() diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index fc838f75b..488c82386 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -109,8 +109,11 @@ class Conformer(Transformer): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Caution: We assume the subsampling factor is 4! + lengths = ((x_lens - 1) // 2 - 1) // 2 + assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index f23a3a300..ac66c9b49 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -55,14 +55,15 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, @@ -135,7 +136,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -143,70 +144,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -251,32 +188,47 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) - hyps = [] - batch_size = encoder_out.size(0) + hyp_list: List[List[int]] = [] - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -487,8 +439,5 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index ad8d89918..4fb5d92c5 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -59,17 +59,15 @@ from typing import List import kaldifeat import sentencepiece as spm import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence - -from icefall.env import get_env_info -from icefall.utils import AttributeDict +from train import get_params, get_transducer_model def get_parser(): @@ -115,6 +113,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -132,7 +137,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --method is greedy_search. """, @@ -141,70 +146,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -294,33 +235,45 @@ def main(): ) num_waves = encoder_out.size(0) - hyps = [] + hyp_list = [] msg = f"Using {params.method}" if params.method == "beam_search": msg += f" with beam size {params.beam_size}" logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 2cc6480d5..d6827c17c 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -34,6 +34,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -419,7 +420,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py index 669ad1d1b..c6cf739fb 100644 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/asr_datamodule.py @@ -22,6 +22,7 @@ import logging from pathlib import Path from typing import Optional +import torch from lhotse import CutSet, Fbank, FbankConfig from lhotse.dataset import ( BucketingSampler, @@ -34,11 +35,20 @@ from lhotse.dataset.input_strategies import ( OnTheFlyFeatures, PrecomputedFeatures, ) +from lhotse.utils import fix_random_seed from torch.utils.data import DataLoader from icefall.utils import str2bool +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + class AsrDataModule: def __init__(self, args: argparse.Namespace): self.args = args @@ -253,12 +263,19 @@ class AsrDataModule: ) logging.info("About to create train dataloader") + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, num_workers=self.args.num_workers, persistent_workers=False, + worker_init_fn=worker_init_fn, ) return train_dl diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 136afe9c0..22f137d36 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -46,15 +46,16 @@ import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import AsrDataModule -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from librispeech import LibriSpeech -from model import Transducer +from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info from icefall.utils import ( AttributeDict, setup_logger, @@ -127,7 +128,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --decoding_method is greedy_search""", ) @@ -135,71 +136,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict): - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict): - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict): - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict): - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - - return model - - def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -244,32 +180,47 @@ def decode_one_batch( encoder_out, encoder_out_lens = model.encoder( x=feature, x_lens=feature_lens ) - hyps = [] + hyp_list = [] batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.decoding_method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) + if ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + else: + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_list.append(sp.decode(hyp).split()) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] if params.decoding_method == "greedy_search": return {"greedy_search": hyps} @@ -483,8 +434,5 @@ def main(): logging.info("Done!") -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 5ba3acea1..df9c3186f 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -59,17 +59,15 @@ from typing import List import kaldifeat import sentencepiece as spm import torch -import torch.nn as nn import torchaudio -from beam_search import beam_search, greedy_search, modified_beam_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer +from beam_search import ( + beam_search, + greedy_search, + greedy_search_batch, + modified_beam_search, +) from torch.nn.utils.rnn import pad_sequence - -from icefall.env import get_env_info -from icefall.utils import AttributeDict +from train import get_params, get_transducer_model def get_parser(): @@ -115,6 +113,13 @@ def get_parser(): "The sample rate has to be 16kHz.", ) + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + parser.add_argument( "--beam-size", type=int, @@ -132,7 +137,7 @@ def get_parser(): parser.add_argument( "--max-sym-per-frame", type=int, - default=3, + default=1, help="""Maximum number of symbols per frame. Used only when --method is greedy_search. """, @@ -141,70 +146,6 @@ def get_parser(): return parser -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - def read_sound_files( filenames: List[str], expected_sample_rate: float ) -> List[torch.Tensor]: @@ -294,33 +235,46 @@ def main(): ) num_waves = encoder_out.size(0) - hyps = [] + hyp_list = [] msg = f"Using {params.method}" if params.method == "beam_search": msg += f" with beam size {params.beam_size}" logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - elif params.method == "modified_beam_search": - hyp = modified_beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + if params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + ) + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + beam=params.beam_size, + ) + + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + hyp_list.append(hyp) + + hyps = [sp.decode(hyp).split() for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 105f82417..5572d3f4c 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -58,6 +58,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" import argparse import logging import random +import warnings from pathlib import Path from shutil import copyfile from typing import Optional, Tuple @@ -466,7 +467,11 @@ def compute_loss( assert loss.requires_grad == is_training info = MetricsTracker() - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() diff --git a/icefall/__init__.py b/icefall/__init__.py index e69de29bb..ec77e89b5 100644 --- a/icefall/__init__.py +++ b/icefall/__init__.py @@ -0,0 +1,65 @@ +# isort:skip_file + +from . import ( + checkpoint, + decode, + dist, + env, + utils +) + +from .checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, + remove_checkpoints, + save_checkpoint, + save_checkpoint_with_global_batch_idx, +) + +from .decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) + +from .dist import ( + cleanup_dist, + setup_dist, +) + +from .env import ( + get_env_info, + get_git_branch_name, + get_git_date, + get_git_sha1, +) + +from .utils import ( + AttributeDict, + MetricsTracker, + add_eos, + add_sos, + concat, + encode_supervisions, + get_alignments, + get_executor, + get_texts, + l1_norm, + l2_norm, + linf_norm, + load_alignments, + make_pad_mask, + measure_gradient_norms, + measure_weight_norms, + optim_step_and_measure_param_change, + save_alignments, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 251456c95..1ef05d964 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -216,27 +216,62 @@ def save_checkpoint_with_global_batch_idx( ) -def find_checkpoints(out_dir: Path) -> List[str]: +def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: """Find all available checkpoints in a directory. The checkpoint filenames have the form: `checkpoint-xxx.pt` where xxx is a numerical value. + Assume you have the following checkpoints in the folder `foo`: + + - checkpoint-1.pt + - checkpoint-20.pt + - checkpoint-300.pt + - checkpoint-4000.pt + + Case 1 (Return all checkpoints):: + + find_checkpoints(out_dir='foo') + + Case 2 (Return checkpoints newer than checkpoint-20.pt, i.e., + checkpoint-4000.pt, checkpoint-300.pt, and checkpoint-20.pt) + + find_checkpoints(out_dir='foo', iteration=20) + + Case 3 (Return checkpoints older than checkpoint-20.pt, i.e., + checkpoint-20.pt, checkpoint-1.pt):: + + find_checkpoints(out_dir='foo', iteration=-20) + Args: out_dir: The directory where to search for checkpoints. + iteration: + If it is 0, return all available checkpoints. + If it is positive, return the checkpoints whose iteration number is + greater than or equal to `iteration`. + If it is negative, return the checkpoints whose iteration number is + less than or equal to `-iteration`. Returns: Return a list of checkpoint filenames, sorted in descending order by the numerical value in the filename. """ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) pattern = re.compile(r"checkpoint-([0-9]+).pt") - idx_checkpoints = [ + iter_checkpoints = [ (int(pattern.search(c).group(1)), c) for c in checkpoints ] + # iter_checkpoints is a list of tuples. Each tuple contains + # two elements: (iteration_number, checkpoint-iteration_number.pt) + + iter_checkpoints = sorted( + iter_checkpoints, reverse=True, key=lambda x: x[0] + ) + if iteration >= 0: + ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration] + else: + ans = [ic[1] for ic in iter_checkpoints if ic[0] <= -iteration] - idx_checkpoints = sorted(idx_checkpoints, reverse=True, key=lambda x: x[0]) - ans = [ic[1] for ic in idx_checkpoints] return ans diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index fa9b98fa0..ce4ac1464 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -135,8 +135,13 @@ def get_diagnostics_for_dim( return "" count = sum(counts) stats = stats / count - stats, _ = torch.symeig(stats) - stats = stats.abs().sqrt() + try: + eigs, _ = torch.symeig(stats) + stats = eigs.abs().sqrt() + except: # noqa + print("Error getting eigenvalues, trying another method.") + eigs = torch.linalg.eigvals(stats) + stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance elif sizes_same: stats = torch.stack(stats).sum(dim=0) diff --git a/icefall/env.py b/icefall/env.py index 0684c4bf1..c29cbb078 100644 --- a/icefall/env.py +++ b/icefall/env.py @@ -95,6 +95,7 @@ def get_env_info() -> Dict[str, Any]: "k2-git-sha1": k2.version.__git_sha1__, "k2-git-date": k2.version.__git_date__, "lhotse-version": lhotse.__version__, + "torch-version": torch.__version__, "torch-cuda-available": torch.cuda.is_available(), "torch-cuda-version": torch.version.cuda, "python-version": sys.version[:3], diff --git a/icefall/utils.py b/icefall/utils.py index c231dbbe4..daccd4346 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -25,15 +25,14 @@ from collections import defaultdict from contextlib import contextmanager from datetime import datetime from pathlib import Path -from typing import Dict, Iterable, List, TextIO, Optional, Tuple, Union +from typing import Dict, Iterable, List, TextIO, Tuple, Union import k2 import k2.version import kaldialign import torch -import torch.nn as nn import torch.distributed as dist -from torch.cuda.amp import GradScaler +import torch.nn as nn from torch.utils.tensorboard import SummaryWriter Pathlike = Union[str, Path] @@ -758,11 +757,10 @@ def measure_gradient_norms( def optim_step_and_measure_param_change( model: nn.Module, - optimizer: torch.optim.Optimizer, - scaler: Optional[GradScaler] = None, + old_parameters: Dict[str, nn.parameter.Parameter], ) -> Dict[str, float]: """ - Perform model weight update and measure the "relative change in parameters per minibatch." + Measure the "relative change in parameters per minibatch." It is understood as a ratio between the L2 norm of the difference between original and updates parameters, and the L2 norm of the original parameter. It is given by the formula: @@ -770,16 +768,31 @@ def optim_step_and_measure_param_change( \begin{aligned} \delta = \frac{\Vert\theta - \theta_{new}\Vert^2}{\Vert\theta\Vert^2} \end{aligned} - """ - param_copy = {n: p.detach().clone() for n, p in model.named_parameters()} - if scaler: - scaler.step(optimizer) - else: + + This function is supposed to be used as follows: + + .. code-block:: python + + old_parameters = { + n: p.detach().clone() for n, p in model.named_parameters() + } + optimizer.step() + + deltas = optim_step_and_measure_param_change(old_parameters) + + Args: + model: A torch.nn.Module instance. + old_parameters: + A Dict of named_parameters before optimizer.step(). + + Return: + A Dict containing the relative change for each parameter. + """ relative_change = {} with torch.no_grad(): for n, p_new in model.named_parameters(): - p_orig = param_copy[n] + p_orig = old_parameters[n] delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) relative_change[n] = delta.item() return relative_change diff --git a/pyproject.toml b/pyproject.toml index 01ff869db..b4f8c3377 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [tool.isort] profile = "black" +skip = ["icefall/__init__.py"] [tool.black] line-length = 80 @@ -9,4 +10,5 @@ exclude = ''' | \.github )/ | make_kn_lm.py + | icefall\/__init__\.py ''' diff --git a/requirements-ci.txt b/requirements-ci.txt index b5ee6b51c..7fb4b1665 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -11,7 +11,7 @@ graphviz==0.19.1 -f https://download.pytorch.org/whl/cpu/torch_stable.html torch==1.10.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html torchaudio==0.10.0+cpu --f https://k2-fsa.org/nightly/ k2==1.9.dev20211101+cpu.torch1.10.0 +-f https://k2-fsa.org/nightly/ k2==1.14.dev20220316+cpu.torch1.10.0 git+https://github.com/lhotse-speech/lhotse kaldilm==1.11