From 4aead988126ddd149bb5bcc4e6eb0347d4192b63 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 19 Feb 2022 23:09:12 +0800 Subject: [PATCH] Minor fixes. --- egs/aishell/ASR/RESULTS.md | 6 +- egs/aishell/ASR/transducer_stateless/model.py | 11 - .../transducer_stateless_modified/README.md | 4 +- .../transducer_stateless_modified/decode.py | 65 ++-- .../transducer_stateless_modified/export.py | 249 ------------- .../transducer_stateless_modified/model.py | 8 +- .../pretrained.py | 327 ------------------ .../transducer_stateless_modified/train.py | 51 ++- .../ASR/transducer_stateless/decode.py | 9 + 9 files changed, 106 insertions(+), 624 deletions(-) delete mode 100755 egs/aishell/ASR/transducer_stateless_modified/export.py delete mode 100755 egs/aishell/ASR/transducer_stateless_modified/pretrained.py diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index ceb63b4cf..23954e47a 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -23,7 +23,7 @@ python3 ./transducer_stateless/train.py \ lang_dir=data/lang_char dir=exp/transducer_stateless_context_size2 -python3 ./transducer_stateless/decode.py\ +python3 ./transducer_stateless/decode.py \ --epoch 59 \ --avg 10 \ --exp-dir $dir \ @@ -35,8 +35,8 @@ python3 ./transducer_stateless/decode.py\ lang_dir=data/lang_char dir=exp/transducer_stateless_context_size2 python3 ./transducer_stateless/decode.py \ - --epoch 59\ - --avg 10\ + --epoch 59 \ + --avg 10 \ --exp-dir $dir \ --lang-dir $lang_dir \ --decoding-method beam_search \ diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index 0322edeed..23eabdf49 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -14,15 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos @@ -115,11 +109,6 @@ class Transducer(nn.Module): boundary[:, 2] = y_lens boundary[:, 3] = x_lens - assert hasattr(torchaudio.functional, "rnnt_loss"), ( - f"Current torchaudio version: {torchaudio.__version__}\n" - "Please install a version >= 0.10.0" - ) - loss = k2.rnnt_loss(logits, y_padded, blank_id, boundary) return torch.sum(loss) diff --git a/egs/aishell/ASR/transducer_stateless_modified/README.md b/egs/aishell/ASR/transducer_stateless_modified/README.md index 622cb837c..9709eb9a0 100644 --- a/egs/aishell/ASR/transducer_stateless_modified/README.md +++ b/egs/aishell/ASR/transducer_stateless_modified/README.md @@ -11,11 +11,11 @@ cd egs/aishell/ASR export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" -./transducer_stateless/train.py \ +./transducer_stateless_modified/train.py \ --world-size 8 \ --num-epochs 30 \ --start-epoch 0 \ - --exp-dir transducer_stateless/exp \ + --exp-dir transducer_stateless_modified/exp \ --max-duration 250 \ --lr-factor 2.5 ``` diff --git a/egs/aishell/ASR/transducer_stateless_modified/decode.py b/egs/aishell/ASR/transducer_stateless_modified/decode.py index a7b030fa5..822f0a00f 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/decode.py +++ b/egs/aishell/ASR/transducer_stateless_modified/decode.py @@ -15,6 +15,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" +Usage: +(1) greedy search +./transducer_stateless_modified/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_stateless_modified/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./transducer_stateless_modified/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_stateless_modified/exp \ + --max-duration 100 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./transducer_stateless_modified/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_stateless_modified/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 +""" import argparse import logging @@ -25,7 +53,7 @@ from typing import Dict, List, Tuple import torch import torch.nn as nn from asr_datamodule import AishellAsrDataModule -from beam_search import beam_search, greedy_search +from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -39,7 +67,6 @@ from icefall.utils import ( setup_logger, store_transcripts, write_error_stats, - str2bool, ) @@ -67,7 +94,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/exp", + default="transducer_stateless_modified/exp", help="The experiment dir", ) @@ -85,6 +112,7 @@ def get_parser(): help="""Possible values are: - greedy_search - beam_search + - modified_beam_search """, ) @@ -108,17 +136,6 @@ def get_parser(): default=3, help="Maximum number of symbols per frame", ) - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - transducer_stateless/exp/pretrained.pt. Note: only model.state_dict() - is saved. pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - return parser @@ -247,6 +264,10 @@ def decode_one_batch( hyp = beam_search( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) + elif params.decoding_method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) else: raise ValueError( f"Unsupported decoding method: {params.decoding_method}" @@ -382,11 +403,15 @@ def main(): params = get_params() params.update(vars(args)) - assert params.decoding_method in ("greedy_search", "beam_search") + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + ) params.res_dir = params.exp_dir / params.decoding_method params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "beam_search": + if "beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam_size}" else: params.suffix += f"-context-{params.context_size}" @@ -403,7 +428,6 @@ def main(): lexicon = Lexicon(params.lang_dir) - # params.blank_id = graph_compiler.texts_to_ids("")[0][0] params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 @@ -424,13 +448,6 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) - return - model.to(device) model.eval() model.device = device diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py deleted file mode 100755 index 5687260df..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ /dev/null @@ -1,249 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: -./transducer_stateless/export.py \ - --exp-dir ./transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -It will generate a file exp_dir/pretrained.pt - -To use the generated file with `transducer_stateless/decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./transducer_stateless/decode.py \ - --exp-dir ./transducer_stateless/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 1 \ - --bpe-model data/lang_bpe_500/bpe.model -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=10, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="transducer_stateless/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - assert args.jit is False, "Support torchscript will be added later" - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.eval() - - model.to("cpu") - model.eval() - - if params.jit: - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/model.py b/egs/aishell/ASR/transducer_stateless_modified/model.py index 4edbeb119..8281e1fb5 100644 --- a/egs/aishell/ASR/transducer_stateless_modified/model.py +++ b/egs/aishell/ASR/transducer_stateless_modified/model.py @@ -14,15 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Note we use `rnnt_loss` from torchaudio, which exists only in -torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0 -""" +import random + import k2 import torch import torch.nn as nn -import torchaudio -import torchaudio.functional from encoder_interface import EncoderInterface from icefall.utils import add_sos diff --git a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py b/egs/aishell/ASR/transducer_stateless_modified/pretrained.py deleted file mode 100755 index db89c4d67..000000000 --- a/egs/aishell/ASR/transducer_stateless_modified/pretrained.py +++ /dev/null @@ -1,327 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -(1) greedy search -./transducer_stateless/pretrained.py \ - --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav \ - -(1) beam search -./transducer_stateless/pretrained.py \ - --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav \ - -You can also use `./transducer_stateless/exp/epoch-xx.pt`. - -Note: ./transducer_stateless/exp/pretrained.pt is generated by -./transducer_stateless/export.py -""" - - -import argparse -import logging -import math -from pathlib import Path -from typing import List - -import kaldifeat -import torch -import torch.nn as nn -import torchaudio -from beam_search import beam_search, greedy_search -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from model import Transducer -from torch.nn.utils.rnn import pad_sequence - -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="Used only when --method is beam_search", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=3, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "feature_dim": 80, - "encoder_out_dim": 512, - "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - "vgg_frontend": False, - "env_info": get_env_info(), - } - ) - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Conformer( - num_features=params.feature_dim, - output_dim=params.encoder_out_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.attention_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - vgg_frontend=params.vgg_frontend, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - input_dim=params.encoder_out_dim, - output_dim=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - ) - return model - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = graph_compiler.texts_to_ids("")[0][0] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info("Creating model") - model = get_transducer_model(params) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - with torch.no_grad(): - encoder_out, encoder_out_lens = model.encoder( - x=features, x_lens=feature_lengths - ) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, encoder_out=encoder_out_i, beam=params.beam_size - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append([lexicon.token_table[i] for i in hyp]) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/transducer_stateless_modified/train.py b/egs/aishell/ASR/transducer_stateless_modified/train.py index 9f90d0002..a0c2e0a47 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified/train.py @@ -623,8 +623,8 @@ def run(rank, world_size, args): train_cuts = aishell.train_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - return 1.0 <= c.duration <= 20.0 + # Keep only utterances with duration between 1 second and 12 seconds + return 1.0 <= c.duration <= 12.0 num_in_total = len(train_cuts) @@ -641,6 +641,14 @@ def run(rank, world_size, args): train_dl = aishell.train_dataloaders(train_cuts) valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) @@ -681,6 +689,45 @@ def run(rank, world_size, args): cleanup_dist() +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index c101d9397..f23a3a300 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -33,6 +33,15 @@ Usage: --max-duration 100 \ --decoding-method beam_search \ --beam-size 4 + +(3) modified beam search +./transducer_stateless/decode.py \ + --epoch 14 \ + --avg 7 \ + --exp-dir ./transducer_stateless/exp \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 """