From 8c3ea93fc8a9be5dcac2bd0ad0d8e34cc13f6dd3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 27 Mar 2023 11:39:29 +0800 Subject: [PATCH 01/17] Save meta data to exported ONNX models (#968) --- .../ASR/pruned_transducer_stateless3/export-onnx.py | 10 ++++++++++ .../ASR/pruned_transducer_stateless5/export-onnx.py | 10 ++++++++++ .../ASR/pruned_transducer_stateless7/export-onnx.py | 10 ++++++++++ 3 files changed, 30 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index ca8be307c..36e57e946 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -273,6 +273,16 @@ def export_encoder_model_onnx( }, ) + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless3", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py index 743fe8a92..3d94760dc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export-onnx.py @@ -296,6 +296,16 @@ def export_encoder_model_onnx( }, ) + meta_data = { + "model_type": "conformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless5", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index f76915a74..93eb8df3d 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -291,6 +291,16 @@ def export_encoder_model_onnx( }, ) + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "comment": "stateless7", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + def export_decoder_model_onnx( decoder_model: OnnxDecoder, From 35e21a0d2ede47c632f97ae6d560194b66439472 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 27 Mar 2023 14:08:26 +0800 Subject: [PATCH 02/17] Fix torchscript export for aishell (#969) --- egs/aishell/ASR/pruned_transducer_stateless3/export.py | 2 ++ egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py | 1 + .../ASR/pruned_transducer_stateless3/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py index 7f10eb36e..723414167 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/export.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -48,6 +48,7 @@ import logging from pathlib import Path import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -244,6 +245,7 @@ def main(): model.eval() if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. # Otherwise, one of its arguments is a ragged tensor and is not diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py new file mode 120000 index 000000000..557e18aa1 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/lstmp.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file From 15d48e3a6af6cd1626ec842686ffe76f928e158f Mon Sep 17 00:00:00 2001 From: PF Luo Date: Tue, 28 Mar 2023 19:14:08 +0800 Subject: [PATCH 03/17] fix rnn_lm && transformer_lm import problem (#971) --- icefall/rnn_lm/__init__.py | 0 icefall/transformer_lm/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 icefall/rnn_lm/__init__.py create mode 100644 icefall/transformer_lm/__init__.py diff --git a/icefall/rnn_lm/__init__.py b/icefall/rnn_lm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/icefall/transformer_lm/__init__.py b/icefall/transformer_lm/__init__.py new file mode 100644 index 000000000..e69de29bb From bcc5923ab92c93bf38829f7d5de84d84c9050eb1 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Tue, 28 Mar 2023 23:24:24 +0800 Subject: [PATCH 04/17] Support batch-wise forced-alignment (#970) * support batch-wise forced-alignment based on beam search * add length_norm to HypothesisList.topk() * Use Hypothesis and HypothesisList instead --- .../beam_search.py | 17 +- .../pruned_transducer_stateless7/alignment.py | 206 +++++++++++ .../compute_ali.py | 345 ++++++++++++++++++ .../test_compute_ali.py | 130 +++++++ icefall/utils.py | 2 +- 5 files changed, 696 insertions(+), 4 deletions(-) create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index bd2d6e258..999d793a4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -829,11 +829,22 @@ class HypothesisList(object): ans.add(hyp) # shallow copy return ans - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": + """Return the top-k hypothesis. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + """ hyps = list(self._data.items()) - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] ans = HypothesisList(dict(hyps)) return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py new file mode 100644 index 000000000..76cd56bbb --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py @@ -0,0 +1,206 @@ +# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + +import k2 +import torch + +from beam_search import Hypothesis, HypothesisList, get_hyps_shape + +# The force alignment problem can be formulated as finding +# a path in a rectangular lattice, where the path starts +# from the lower left corner and ends at the upper right +# corner. The horizontal axis of the lattice is `t` (representing +# acoustic frame indexes) and the vertical axis is `u` (representing +# BPE tokens of the transcript). +# +# The notations `t` and `u` are from the paper +# https://arxiv.org/pdf/1211.3711.pdf +# +# Beam search is used to find the path with the highest log probabilities. +# +# It assumes the maximum number of symbols that can be +# emitted per frame is 1. + + +def batch_force_alignment( + model: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_list: List[List[int]], + beam_size: int = 4, +) -> List[int]: + """Compute the force alignment of a batch of utterances given their transcripts + in BPE tokens and the corresponding acoustic output from the encoder. + + Caution: + This function is modified from `modified_beam_search` in beam_search.py. + We assume that the maximum number of sybmols per frame is 1. + + Args: + model: + The transducer model. + encoder_out: + A tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + ys_list: + A list of BPE token IDs list. We require that for each utterance i, + len(ys_list[i]) <= encoder_out_lens[i]. + beam_size: + Size of the beam used in beam search. + + Returns: + Return a list of frame indexes list for each utterance i, + where len(ans[i]) == len(ys_list[i]). + """ + assert encoder_out.ndim == 3, encoder_out.ndim + assert encoder_out.size(0) == len(ys_list), (encoder_out.size(0), len(ys_list)) + assert encoder_out.size(0) > 0, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + sorted_indices = packed_encoder_out.sorted_indices.tolist() + encoder_out_lens = encoder_out_lens.tolist() + ys_lens = [len(ys) for ys in ys_list] + sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices] + sorted_ys_lens = [ys_lens[i] for i in sorted_indices] + sorted_ys_list = [ys_list[i] for i in sorted_indices] + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size] + sorted_ys_lens = sorted_ys_lens[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs.reshape(-1) + ) # [batch][num_hyps*vocab_size] + + for i in range(batch_size): + for h, hyp in enumerate(A[i]): + pos_u = len(hyp.timestamp) + idx_offset = h * vocab_size + if (sorted_encoder_out_lens[i] - 1 - t) >= (sorted_ys_lens[i] - pos_u): + # emit blank token + new_hyp = Hypothesis( + log_prob=ragged_log_probs[i][idx_offset + blank_id], + ys=hyp.ys[:], + timestamp=hyp.timestamp[:], + ) + B[i].add(new_hyp) + if pos_u < sorted_ys_lens[i]: + # emit non-blank token + new_token = sorted_ys_list[i][pos_u] + new_hyp = Hypothesis( + log_prob=ragged_log_probs[i][idx_offset + new_token], + ys=hyp.ys + [new_token], + timestamp=hyp.timestamp + [t], + ) + B[i].add(new_hyp) + + if len(B[i]) > beam_size: + B[i] = B[i].topk(beam_size, length_norm=True) + + B = B + finalized_B + sorted_hyps = [b.get_most_probable() for b in B] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + hyps = [sorted_hyps[i] for i in unsorted_indices] + ans = [] + for i, hyp in enumerate(hyps): + assert hyp.ys[context_size:] == ys_list[i], (hyp.ys[context_size:], ys_list[i]) + ans.append(hyp.timestamp) + + return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py new file mode 100755 index 000000000..8bcb56d62 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The script gets forced-alignments based on the modified_beam_search decoding method. +Both token-level alignments and word-level alignments are saved to the new cuts manifests. + +It loads a checkpoint and uses it to get the forced-alignments. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 + +Usage of this script: + +./pruned_transducer_stateless7/compute_ali.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --dataset test-clean \ + --max-duration 300 \ + --beam-size 4 \ + --cuts-out-dir data/fbank_ali_beam_search +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import sentencepiece as spm +import torch +import torch.nn as nn +from alignment import batch_force_alignment +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp +from lhotse import CutSet +from lhotse.serialization import SequentialJsonlWriter +from lhotse.supervision import AlignmentItem + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--cuts-out-dir", + type=str, + default="data/fbank_ali_beam_search", + help="The dir to save the new cuts manifests with alignments", + ) + + add_model_arguments(parser) + + return parser + + +def align_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, +) -> Tuple[List[List[str]], List[List[str]], List[List[float]], List[List[float]]]: + """Get forced-alignments for one batch. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + + Returns: + token_list: + A list of token list. + word_list: + A list of word list. + token_time_list: + A list of timestamps list for tokens. + word_time_list. + A list of timestamps list for words. + + where len(token_list) == len(word_list) == len(token_time_list) == len(word_time_list), + len(token_list[i]) == len(token_time_list[i]), + and len(word_list[i]) == len(word_time_list[i]) + + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + texts = supervisions["text"] + ys_list: List[List[int]] = sp.encode(texts, out_type=int) + + frame_indexes = batch_force_alignment( + model, encoder_out, encoder_out_lens, ys_list, params.beam_size + ) + + token_list = [] + word_list = [] + token_time_list = [] + word_time_list = [] + for i in range(encoder_out.size(0)): + tokens = sp.id_to_piece(ys_list[i]) + words = texts[i].split() + token_time = convert_timestamp( + frame_indexes[i], params.subsampling_factor, params.frame_shift_ms + ) + word_time = parse_timestamp(tokens, token_time) + assert len(word_time) == len(words), (len(word_time), len(words)) + + token_list.append(tokens) + word_list.append(words) + token_time_list.append(token_time) + word_time_list.append(word_time) + + return token_list, word_list, token_time_list, word_time_list + + +def align_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + writer: SequentialJsonlWriter, +) -> None: + """Get forced-alignments for the dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + writer: + Writer to save the cuts with alignments. + """ + log_interval = 20 + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + token_list, word_list, token_time_list, word_time_list = align_one_batch( + params=params, model=model, sp=sp, batch=batch + ) + + cut_list = batch["supervisions"]["cut"] + for cut, token, word, token_time, word_time in zip( + cut_list, token_list, word_list, token_time_list, word_time_list + ): + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" + token_ali = [ + AlignmentItem( + symbol=token[i], + start=round(token_time[i], ndigits=3), + duration=None, + ) + for i in range(len(token)) + ] + word_ali = [ + AlignmentItem( + symbol=word[i], start=round(word_time[i], ndigits=3), duration=None + ) + for i in range(len(word)) + ] + cut.supervisions[0].alignment = {"word": word_ali, "token": token_ali} + writer.write(cut, flush=True) + + num_cuts += len(cut_list) + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) + + cuts_out_dir = Path(params.cuts_out_dir) + cuts_out_dir.mkdir(parents=True, exist_ok=True) + cuts_out_path = cuts_out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + + with CutSet.open_writer(cuts_out_path) as writer: + align_dataset(dl=dl, params=params, model=model, sp=sp, writer=writer) + + logging.info( + f"For dataset {params.dataset}, the cut manifest with framewise token alignments " + f"and word alignments are saved to {cuts_out_path}" + ) + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py b/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py new file mode 100755 index 000000000..081f7ba1a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script compares the word-level alignments generated based on modified_beam_search decoding +(in ./pruned_transducer_stateless7/compute_ali.py) to the reference alignments generated +by torchaudio framework (in ./add_alignments.sh). + +Usage: + +./pruned_transducer_stateless7/compute_ali.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --dataset test-clean \ + --max-duration 300 \ + --beam-size 4 \ + --cuts-out-dir data/fbank_ali_beam_search + +And the you can run: + +./pruned_transducer_stateless7/test_compute_ali.py \ + --cuts-out-dir ./data/fbank_ali_test \ + --cuts-ref-dir ./data/fbank_ali_torch \ + --dataset train-clean-100 +""" +import argparse +import logging +from pathlib import Path + +import torch +from lhotse import load_manifest + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--cuts-out-dir", + type=Path, + default="./data/fbank_ali", + help="The dir that saves the generated cuts manifests with alignments", + ) + + parser.add_argument( + "--cuts-ref-dir", + type=Path, + default="./data/fbank_ali_torch", + help="The dir that saves the reference cuts manifests with alignments", + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset: + Possible values are: + - test-clean + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + + cuts_out_jsonl = args.cuts_out_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz" + cuts_ref_jsonl = args.cuts_ref_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz" + + logging.info(f"Loading {cuts_out_jsonl} and {cuts_ref_jsonl}") + cuts_out = load_manifest(cuts_out_jsonl) + cuts_ref = load_manifest(cuts_ref_jsonl) + cuts_ref = cuts_ref.sort_like(cuts_out) + + all_time_diffs = [] + for cut_out, cut_ref in zip(cuts_out, cuts_ref): + time_out = [ + ali.start + for ali in cut_out.supervisions[0].alignment["word"] + if ali.symbol != "" + ] + time_ref = [ + ali.start + for ali in cut_ref.supervisions[0].alignment["word"] + if ali.symbol != "" + ] + assert len(time_out) == len(time_ref), (len(time_out), len(time_ref)) + diff = [ + round(abs(out - ref), ndigits=3) for out, ref in zip(time_out, time_ref) + ] + all_time_diffs += diff + + all_time_diffs = torch.tensor(all_time_diffs) + logging.info( + f"For the word-level alignments abs difference on dataset {args.dataset}, " + f"mean: {'%.2f' % all_time_diffs.mean()}s, std: {'%.2f' % all_time_diffs.std()}s" + ) + logging.info("Done!") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/utils.py b/icefall/utils.py index 5d86472b5..1fd9156bd 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1378,7 +1378,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: List of timestamp of each word. """ start_token = b"\xe2\x96\x81".decode() # '_' - assert len(tokens) == len(timestamp) + assert len(tokens) == len(timestamp), (len(tokens), len(timestamp)) ans = [] for i in range(len(tokens)): flag = False From 2a5a75cb5655126c5b88a370f848cb87f2e491ac Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 30 Mar 2023 14:30:13 +0800 Subject: [PATCH 05/17] add option of using full attention for streaming model decoding (#975) --- .../ASR/pruned_transducer_stateless/decode.py | 13 ++++++++++++- .../pruned_transducer_stateless2/conformer.py | 5 +++++ .../ASR/pruned_transducer_stateless2/decode.py | 16 +++++++++------- .../ASR/pruned_transducer_stateless3/decode.py | 16 +++++++++------- .../ASR/pruned_transducer_stateless4/decode.py | 18 ++++++++++-------- .../ASR/pruned_transducer_stateless5/decode.py | 16 +++++++++------- .../ASR/transducer_stateless/conformer.py | 5 +++++ 7 files changed, 59 insertions(+), 30 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 6dfe11cee..3c4500087 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -107,6 +107,7 @@ Usage: import argparse import logging +import math from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -138,6 +139,8 @@ from icefall.utils import ( write_error_stats, ) +LOG_EPS = math.log(1e-10) + def get_parser(): parser = argparse.ArgumentParser( @@ -288,7 +291,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( "--left-context", @@ -370,6 +373,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index f94ffef59..9bac46004 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -375,6 +375,11 @@ class Conformer(EncoderInterface): assert x.size(0) == lengths.max().item() + if chunk_size < 0: + # use full attention + chunk_size = x.size(0) + left_context = -1 + num_left_chunks = -1 if left_context >= 0: assert left_context % chunk_size == 0 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 172c9ab7c..c57514193 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -295,7 +295,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -378,12 +378,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index aa055049e..b39007dfc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -344,7 +344,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -508,12 +508,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 5ec3d3b45..79d919ab1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -326,14 +326,14 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( "--left-context", type=int, default=64, - help="left context can be seen during decoding (in frames after subsampling)", # noqa + help="""Left context can be seen during decoding (in frames after subsampling). """, # noqa ) parser.add_argument( @@ -409,12 +409,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index 2be895feb..af0b2d9fc 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -291,7 +291,7 @@ def get_parser(): "--decode-chunk-size", type=int, default=16, - help="The chunk size for decoding (in frames after subsampling)", + help="The chunk size for decoding (in frames after subsampling). Set -1 to use full attention.", ) parser.add_argument( @@ -470,12 +470,14 @@ def decode_one_batch( feature_lens = supervisions["num_frames"].to(device) if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) + if params.decode_chunk_size > 0: + # except the case of using full attention + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens, diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 01e8c5b21..94d0393c2 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -358,6 +358,11 @@ class Conformer(Transformer): assert x.size(0) == lengths.max().item() + if chunk_size < 0: + # use full attention + chunk_size = x.size(0) + left_context = -1 + num_left_chunks = -1 if left_context >= 0: assert left_context % chunk_size == 0 From c21b6a208b7919ae867d2c833f8be4ff9386e47b Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 30 Mar 2023 17:08:46 +0800 Subject: [PATCH 06/17] Add finetuning script for aishell (#974) * add aishell finetune scripts * add an example bash script --- egs/wenetspeech/ASR/finetune.sh | 82 ++ .../pruned_transducer_stateless2/aishell.py | 1 + .../decode_aishell.py | 547 +++++++++ .../pruned_transducer_stateless2/finetune.py | 1050 +++++++++++++++++ 4 files changed, 1680 insertions(+) create mode 100755 egs/wenetspeech/ASR/finetune.sh create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py create mode 100755 egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py create mode 100755 egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py diff --git a/egs/wenetspeech/ASR/finetune.sh b/egs/wenetspeech/ASR/finetune.sh new file mode 100755 index 000000000..8559780e9 --- /dev/null +++ b/egs/wenetspeech/ASR/finetune.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=-1 +stop_stage=100 + +# This is an example script for fine-tuning. Here, we fine-tune a model trained +# on WenetSpeech on Aishell. The model used for fine-tuning is +# pruned_transducer_stateless2 (zipformer). If you want to fine-tune model +# from another recipe, you can adapt ./pruned_transducer_stateless2/finetune.py +# for that recipe. If you have any problem, please open up an issue in https://github.com/k2-fsa/icefall/issues. + +# We assume that you have already prepared the Aishell manfiest&features under ./data. +# If you haven't done that, please see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/prepare.sh. + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download Pre-trained model" + + # clone from huggingface + git lfs install + git clone https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 + +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Start fine-tuning" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + initial_lr=0.0001 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=icefall_asr_wenetspeech_pruned_transducer_stateless2/exp/pretrained_epoch_10_avg_2.pt + lang_dir=icefall_asr_wenetspeech_pruned_transducer_stateless2/data/lang_char + export CUDA_VISIBLE_DEVICES="0,1" + + ./pruned_transducer_stateless2/finetune.py \ + --world-size 2 \ + --master-port 18180 \ + --num-epochs 15 \ + --context-size 2 \ + --exp-dir pruned_transducer_stateless2/exp_aishell_finetune \ + --initial-lr $initial_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --lang-dir $lang_dir \ + --do-finetune True \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 200 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decoding" + + epoch=4 + avg=4 + + for m in greedy_search modified_beam_search; do + python pruned_transducer_stateless2/decode_aishell.py \ + --epoch $epoch \ + --avg $avg \ + --context-size 2 \ + --beam-size 4 \ + --exp-dir pruned_transducer_stateless2/exp_aishell_finetune \ + --max-duration 400 \ + --decoding-method $m + done +fi diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py new file mode 120000 index 000000000..f7321272b --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/aishell.py @@ -0,0 +1 @@ +../../../aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py new file mode 100755 index 000000000..2e644ec2f --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode_aishell.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless2/decode.py \ + --epoch 84 \ + --avg 25 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless2/decode.py \ + --epoch 84 \ + --avg 25 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless2/decode.py \ + --epoch 84 \ + --avg 25 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless2/decode.py \ + --epoch 84 \ + --avg 25 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from aishell import AishellAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from finetune import get_params, get_transducer_model + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + token_table: k2.SymbolTable, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + token_table: + It maps token ID to a string. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + else: + hyp_tokens = [] + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_tokens.append(hyp) + + hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + token_table: k2.SymbolTable, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + token_table: + It maps a token ID to a string. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + token_table=token_table, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to(device) + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + test_cuts = aishell.test_cuts() + dev_cuts = aishell.valid_cuts() + test_dl = aishell.test_dataloaders(test_cuts) + dev_dl = aishell.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + token_table=lexicon.token_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py new file mode 100755 index 000000000..e703100a9 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/finetune.py @@ -0,0 +1,1050 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang, +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless2/finetune.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless2/exp \ + --full-libri 1 \ + --do-finetune 1 \ + --max-duration 100 + +""" + + +import argparse +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from aishell import AishellAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--do-finetune", type=str2bool, default=False) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.0001, + help="The initial learning rate. This value should not need to be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=100000, + help="""Number of steps that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=100, + help="""Number of epochs that affects how rapidly the learning rate + decreases. During fine-tuning, we set this very large so that the + learning rate slowly decays with number of batches. You may tune + its value by yourself. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--valid-interval", + type=int, + default=3000, + help="""When training_subset is L, set the valid_interval to 3000. + When training_subset is M, set the valid_interval to 1000. + When training_subset is S, set the valid_interval to 400. + """, + ) + + parser.add_argument( + "--model-warm-step", + type=int, + default=3000, + help="""When training_subset is L, set the model_warm_step to 3000. + When training_subset is M, set the model_warm_step to 500. + When training_subset is S, set the model_warm_step to 100. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + Explanation of options saved in `params`: + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + - best_train_epoch: It is the epoch that has the best training loss. + - best_valid_epoch: It is the epoch that has the best validation loss. + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + - log_interval: Print training loss if batch_idx % log_interval` is 0 + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + - feature_dim: The model input dim. It has to match the one used + in computing features. + - subsampling_factor: The subsampling factor for the model. + - encoder_dim: Hidden dim for multi-head attention model. + - num_decoder_layers: Number of decoder layer of transformer decoder. + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 0: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [k for k in src_state_dict.keys() if k.startswith(module)] + dst_keys = [k for k in dst_state_dict.keys() if k.startswith(module)] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + + y = graph_compiler.texts_to_ids(texts) + if isinstance(y, list): + y = k2.RaggedTensor(y).to(device) + else: + y = y.to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # load model parameters for model fine-tuning + if params.do_finetune: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + else: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + aishell = AishellAsrDataModule(args) + train_dl = aishell.train_dataloaders(aishell.train_cuts()) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments( + parser + ) # you may replace this with your own dataset + add_finetune_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From a632b24c353be0ac113a17a756156387fda5c7e0 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 31 Mar 2023 22:46:19 +0800 Subject: [PATCH 07/17] Export int8 quantized models for non-streaming Zipformer. (#977) * Export int8 quantized models for non-streaming Zipformer. * Delete export-onnx.py * Export int8 models for other folders --- .../export-onnx-zh.py | 632 ++++++++++++++++ .../lstm_transducer_stateless2/export-onnx.py | 35 +- .../export-onnx.py | 30 + .../export-onnx.py | 30 + .../export-onnx-zh.py | 678 ++++++++++++++++++ .../export-onnx.py | 30 + 6 files changed, 1434 insertions(+), 1 deletion(-) create mode 100755 egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py new file mode 100755 index 000000000..f068f6a0f --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx-zh.py @@ -0,0 +1,632 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14 +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14 +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lexicon.txt" +git lfs pull --include "data/L.pt" +git lfs pull --include "exp/epoch-11.pt" +git lfs pull --include "exp/epoch-10.pt" + +popd + +2. Export the model to ONNX + +./lstm_transducer_stateless2/export-onnx-zh.py \ + --lang-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/data/lang_char \ + --use-averaged-model 1 \ + --epoch 11 \ + --avg 1 \ + --exp-dir ./icefall-asr-wenetspeech-lstm-transducer-stateless-2022-10-14/exp \ + --num-encoder-layers 12 \ + --encoder-dim 512 \ + --rnn-hidden-size 1024 + +It will generate the following files inside $repo/exp: + + - encoder-epoch-11-avg-1.onnx + - decoder-epoch-11-avg-1.onnx + - joiner-epoch-11-avg-1.onnx + - encoder-epoch-11-avg-1.int8.onnx + - decoder-epoch-11-avg-1.int8.onnx + - joiner-epoch-11-avg-1.int8.onnx + +See ./onnx_pretrained.py and ./onnx_check.py for how to +use the exported ONNX models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, Optional, Tuple + +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from lstm import RNN +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +class OnnxEncoder(nn.Module): + """A wrapper for RNN and the encoder_proj from the joiner""" + + def __init__(self, encoder: RNN, encoder_proj: nn.Linear): + """ + Args: + encoder: + An RNN encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward( + self, + x: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Please see the help information of RNN.forward + + Args: + x: + A 3-D tensor of shape (N, T, C) + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + Returns: + Return a tuple containing: + - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) + - updated states, whose shape is the same as the input states. + """ + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, dtype=torch.int64, device=x.device) + encoder_out, _, next_states = self.encoder(x, x_lens, states) + + encoder_out = self.encoder_proj(encoder_out) + # Now encoder_out is of shape (N, T, joiner_dim) + + return encoder_out, next_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has the following inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + and it has 3 outputs: + + - encoder_out, a tensor of shape (N, T', joiner_dim) + - new_state0, a tensor of shape (num_encoder_layers, batch_size, d_model) + - new_state1, a tensor of shape (num_encoder_layers, batch_size, rnn_hidden_size) + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + num_encoder_layers = encoder_model.encoder.num_encoder_layers + d_model = encoder_model.encoder.d_model + rnn_hidden_size = encoder_model.encoder.rnn_hidden_size + + decode_chunk_len = 4 + T = 9 + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.encoder.get_init_states() + # state0: (num_encoder_layers, batch_size, d_model) + # state1: (num_encoder_layers, batch_size, rnn_hidden_size) + + torch.onnx.export( + encoder_model, + (x, states), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "state0", "state1"], + output_names=["encoder_out", "new_state0", "new_state1"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "state0": {1: "N"}, + "state1": {1: "N"}, + "encoder_out": {0: "N"}, + "new_state0": {1: "N"}, + "new_state1": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "lstm", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": str(num_encoder_layers), + "d_model": str(d_model), + "rnn_hidden_size": str(rnn_hidden_size), + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: OnnxDecoder, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params, enable_giga=False) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py index 46873ebf9..acaff8540 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export-onnx.py @@ -34,11 +34,14 @@ popd --avg 1 \ --exp-dir $repo/exp -It will generate the following 3 files inside $repo/exp: +It will generate the following files inside $repo/exp: - encoder-epoch-99-avg-1.onnx - decoder-epoch-99-avg-1.onnx - joiner-epoch-99-avg-1.onnx + - encoder-epoch-99-avg-1.int8.onnx + - decoder-epoch-99-avg-1.int8.onnx + - joiner-epoch-99-avg-1.int8.onnx See ./onnx_pretrained.py and ./onnx_check.py for how to use the exported ONNX models. @@ -55,6 +58,7 @@ import torch import torch.nn as nn from decoder import Decoder from lstm import RNN +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -586,6 +590,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py index 36e57e946..9645b7801 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export-onnx.py @@ -54,6 +54,7 @@ import torch import torch.nn as nn from conformer import Conformer from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model @@ -500,6 +501,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py index 93eb8df3d..2f5d9e338 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export-onnx.py @@ -55,6 +55,7 @@ import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from zipformer import Zipformer @@ -563,6 +564,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py new file mode 100755 index 000000000..04d97808d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py @@ -0,0 +1,678 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) + +""" +This script exports a transducer model from PyTorch to ONNX. + +We use the pre-trained model from +https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +as an example to show how to use this file. + +1. Download the pre-trained model + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/pfluo/k2fsa-zipformer-chinese-english-mixed +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +pushd $repo +git lfs pull --include "data/lang_char_bpe/L.pt" +git lfs pull --include "data/lang_char_bpe/Linv.pt" +git lfs pull --include "data/lang_char_bpe/L_disambig.pt" +git lfs pull --include "exp/pretrained.pt" +cd exp +ln -s pretrained.pt epoch-99.pt +popd + +2. Export the model to ONNX + +./pruned_transducer_stateless7_streaming/export-onnx-zh.py \ + --lang-dir $repo/data/lang_char_bpe \ + --use-averaged-model 0 \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp/ \ + --decode-chunk-len 32 \ + --num-encoder-layers "2,4,3,2,4" \ + --feedforward-dims "1024,1024,1536,1536,1024" \ + --nhead "8,8,8,8,8" \ + --encoder-dims "384,384,384,384,384" \ + --attention-dims "192,192,192,192,192" \ + --encoder-unmasked-dims "256,256,256,256,256" \ + --zipformer-downsampling-factors "1,2,4,8,2" \ + --cnn-module-kernels "31,31,31,31,31" \ + --decoder-dim 512 \ + --joiner-dim 512 + +It will generate the following 3 files in $repo/exp + + - encoder-epoch-99-avg-1.onnx + - decoder-epoch-99-avg-1.onnx + - joiner-epoch-99-avg-1.onnx + +See ./onnx_pretrained.py for how to use the exported models. +""" + +import argparse +import logging +from pathlib import Path +from typing import Dict, List, Tuple + +import onnx +import torch +import torch.nn as nn +from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor +from train import add_model_arguments, get_params, get_transducer_model +from zipformer import Zipformer + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_streaming/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class OnnxEncoder(nn.Module): + """A wrapper for Zipformer and the encoder_proj from the joiner""" + + def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): + """ + Args: + encoder: + A Zipformer encoder. + encoder_proj: + The projection layer for encoder from the joiner. + """ + super().__init__() + self.encoder = encoder + self.encoder_proj = encoder_proj + + def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + """Please see the help information of Zipformer.streaming_forward""" + N = x.size(0) + T = x.size(1) + x_lens = torch.tensor([T] * N, device=x.device) + + output, _, new_states = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=states, + ) + + output = self.encoder_proj(output) + # Now output is of shape (N, T, joiner_dim) + + return output, new_states + + +class OnnxDecoder(nn.Module): + """A wrapper for Decoder and the decoder_proj from the joiner""" + + def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): + super().__init__() + self.decoder = decoder + self.decoder_proj = decoder_proj + + def forward(self, y: torch.Tensor) -> torch.Tensor: + """ + Args: + y: + A 2-D tensor of shape (N, context_size). + Returns + Return a 2-D tensor of shape (N, joiner_dim) + """ + need_pad = False + decoder_output = self.decoder(y, need_pad=need_pad) + decoder_output = decoder_output.squeeze(1) + output = self.decoder_proj(decoder_output) + + return output + + +class OnnxJoiner(nn.Module): + """A wrapper for the joiner""" + + def __init__(self, output_linear: nn.Linear): + super().__init__() + self.output_linear = output_linear + + def forward( + self, + encoder_out: torch.Tensor, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A 2-D tensor of shape (N, joiner_dim) + decoder_out: + A 2-D tensor of shape (N, joiner_dim) + Returns: + Return a 2-D tensor of shape (N, vocab_size) + """ + logit = encoder_out + decoder_out + logit = self.output_linear(torch.tanh(logit)) + return logit + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +def export_encoder_model_onnx( + encoder_model: OnnxEncoder, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """ + Onnx model inputs: + - 0: src + - many state tensors (the exact number depending on the actual model) + + Onnx model outputs: + - 0: output, its shape is (N, T, joiner_dim) + - many state tensors (the exact number depending on the actual model) + + Args: + encoder_model: + The model to be exported + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + + encoder_model.encoder.__class__.forward = ( + encoder_model.encoder.__class__.streaming_forward + ) + + decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 + pad_length = 7 + T = decode_chunk_len + pad_length + logging.info(f"decode_chunk_len: {decode_chunk_len}") + logging.info(f"pad_length: {pad_length}") + logging.info(f"T: {T}") + + x = torch.rand(1, T, 80, dtype=torch.float32) + + init_state = encoder_model.encoder.get_init_state() + + num_encoders = encoder_model.encoder.num_encoders + logging.info(f"num_encoders: {num_encoders}") + logging.info(f"len(init_state): {len(init_state)}") + + inputs = {} + input_names = ["x"] + + outputs = {} + output_names = ["encoder_out"] + + def build_inputs_outputs(tensors, name, N): + for i, s in enumerate(tensors): + logging.info(f"{name}_{i}.shape: {s.shape}") + inputs[f"{name}_{i}"] = {N: "N"} + outputs[f"new_{name}_{i}"] = {N: "N"} + input_names.append(f"{name}_{i}") + output_names.append(f"new_{name}_{i}") + + num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) + encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) + attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) + cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) + ds = encoder_model.encoder.zipformer_downsampling_factors + left_context_len = encoder_model.encoder.left_context_len + left_context_len = [left_context_len // k for k in ds] + left_context_len = ",".join(map(str, left_context_len)) + + meta_data = { + "model_type": "zipformer", + "version": "1", + "model_author": "k2-fsa", + "decode_chunk_len": str(decode_chunk_len), # 32 + "T": str(T), # 39 + "num_encoder_layers": num_encoder_layers, + "encoder_dims": encoder_dims, + "attention_dims": attention_dims, + "cnn_module_kernels": cnn_module_kernels, + "left_context_len": left_context_len, + } + logging.info(f"meta_data: {meta_data}") + + # (num_encoder_layers, 1) + cached_len = init_state[num_encoders * 0 : num_encoders * 1] + + # (num_encoder_layers, 1, encoder_dim) + cached_avg = init_state[num_encoders * 1 : num_encoders * 2] + + # (num_encoder_layers, left_context_len, 1, attention_dim) + cached_key = init_state[num_encoders * 2 : num_encoders * 3] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val = init_state[num_encoders * 3 : num_encoders * 4] + + # (num_encoder_layers, left_context_len, 1, attention_dim//2) + cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] + + # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) + cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] + + build_inputs_outputs(cached_len, "cached_len", 1) + build_inputs_outputs(cached_avg, "cached_avg", 1) + build_inputs_outputs(cached_key, "cached_key", 2) + build_inputs_outputs(cached_val, "cached_val", 2) + build_inputs_outputs(cached_val2, "cached_val2", 2) + build_inputs_outputs(cached_conv1, "cached_conv1", 1) + build_inputs_outputs(cached_conv2, "cached_conv2", 1) + + logging.info(inputs) + logging.info(outputs) + logging.info(input_names) + logging.info(output_names) + + torch.onnx.export( + encoder_model, + (x, init_state), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=input_names, + output_names=output_names, + dynamic_axes={ + "x": {0: "N"}, + "encoder_out": {0: "N"}, + **inputs, + **outputs, + }, + ) + + add_meta_data(filename=encoder_filename, meta_data=meta_data) + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + context_size = decoder_model.decoder.context_size + vocab_size = decoder_model.decoder.vocab_size + y = torch.zeros(10, context_size, dtype=torch.int64) + torch.onnx.export( + decoder_model, + y, + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + meta_data = { + "context_size": str(context_size), + "vocab_size": str(vocab_size), + } + add_meta_data(filename=decoder_filename, meta_data=meta_data) + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - encoder_out: a tensor of shape (N, joiner_dim) + - decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + """ + joiner_dim = joiner_model.output_linear.weight.shape[1] + logging.info(f"joiner dim: {joiner_dim}") + + projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) + + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "encoder_out", + "decoder_out", + ], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + meta_data = { + "joiner_dim": str(joiner_dim), + } + add_meta_data(filename=joiner_filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + encoder = OnnxEncoder( + encoder=model.encoder, + encoder_proj=model.joiner.encoder_proj, + ) + + decoder = OnnxDecoder( + decoder=model.decoder, + decoder_proj=model.joiner.decoder_proj, + ) + + joiner = OnnxJoiner(output_linear=model.joiner.output_linear) + + encoder_num_param = sum([p.numel() for p in encoder.parameters()]) + decoder_num_param = sum([p.numel() for p in decoder.parameters()]) + joiner_num_param = sum([p.numel() for p in joiner.parameters()]) + total_num_param = encoder_num_param + decoder_num_param + joiner_num_param + logging.info(f"encoder parameters: {encoder_num_param}") + logging.info(f"decoder parameters: {decoder_num_param}") + logging.info(f"joiner parameters: {joiner_num_param}") + logging.info(f"total parameters: {total_num_param}") + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + if params.use_averaged_model: + suffix += "-with-averaged-model" + + opset_version = 13 + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" + export_encoder_model_onnx( + encoder, + encoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported encoder to {encoder_filename}") + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" + export_decoder_model_onnx( + decoder, + decoder_filename, + opset_version=opset_version, + ) + logging.info(f"Exported decoder to {decoder_filename}") + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" + export_joiner_model_onnx( + joiner, + joiner_filename, + opset_version=opset_version, + ) + logging.info(f"Exported joiner to {joiner_filename}") + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py index d7092403e..e71bcaf29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -53,6 +53,7 @@ import sentencepiece as spm import torch import torch.nn as nn from decoder import Decoder +from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from torch import Tensor from train import add_model_arguments, get_params, get_transducer_model @@ -634,6 +635,35 @@ def main(): ) logging.info(f"Exported joiner to {joiner_filename}") + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + logging.info("Generate int8 quantization models") + + encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" + quantize_dynamic( + model_input=joiner_filename, + model_output=joiner_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + if __name__ == "__main__": main() From 12a222aa4b70f145d91eb7acf4fce201ad35bdc7 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 2 Apr 2023 16:32:43 +0800 Subject: [PATCH 08/17] Fix comments on the usage of train.py (#981) --- .../ASR/pruned_transducer_stateless2/train.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 9edc42b61..6a9f9f32f 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -19,26 +19,24 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3" +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" ./pruned_transducer_stateless2/train.py \ - --world-size 4 \ + --world-size 8 \ --num-epochs 30 \ --start-epoch 0 \ --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 300 + --max-duration 120 # For mix precision training: ./pruned_transducer_stateless2/train.py \ - --world-size 4 \ + --world-size 8 \ --num-epochs 30 \ --start-epoch 0 \ --use_fp16 1 \ --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 550 + --max-duration 200 """ From 180c7c2b7ae1e03319839d1782a89c9d012e0ddb Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 3 Apr 2023 12:39:34 +0800 Subject: [PATCH 09/17] Add UniqueLexicon for gigaspeech (#982) --- egs/gigaspeech/ASR/local/generate_unique_lexicon.py | 1 + 1 file changed, 1 insertion(+) create mode 120000 egs/gigaspeech/ASR/local/generate_unique_lexicon.py diff --git a/egs/gigaspeech/ASR/local/generate_unique_lexicon.py b/egs/gigaspeech/ASR/local/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/gigaspeech/ASR/local/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file From 46bf6df62fb3d78b4b9bf1b0592c889a04d7be9b Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Mon, 3 Apr 2023 14:55:45 +0800 Subject: [PATCH 10/17] Remove simulate streaming from stateless7 (#983) * Remove simulate streaming from stateless7 --- .../pruned_transducer_stateless7/decode.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 32b3134b9..576621e24 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -343,29 +343,6 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - parser.add_argument( "--use-shallow-fusion", type=str2bool, @@ -474,22 +451,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -782,10 +744,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -834,11 +792,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") From d337398d29e401600473ca11fd9d827eea86fadc Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Mon, 3 Apr 2023 16:20:29 +0800 Subject: [PATCH 11/17] Shallow fusion for Aishell (#954) * add shallow fusion and LODR for aishell * update RESULTS * add save by iterations --- egs/aishell/ASR/RESULTS.md | 74 ++++++++ .../local/prepare_char_lm_training_data.py | 0 .../ASR/local/sort_lm_training_data.py | 1 + egs/aishell/ASR/prepare.sh | 11 +- .../pruned_transducer_stateless3/decode.py | 166 ++++++++++++++++++ .../exp-context-size-1 | 1 - .../ASR/lstm_transducer_stateless2/decode.py | 2 - .../beam_search.py | 8 +- .../pruned_transducer_stateless3/decode.py | 2 - .../pruned_transducer_stateless5/decode.py | 2 - .../pruned_transducer_stateless7/decode.py | 2 - egs/wenetspeech/ASR/local/text2segments.py | 4 +- egs/wenetspeech/ASR/prepare.sh | 104 +++++++++++ .../pruned_transducer_stateless5/decode.py | 163 ++++++++++++++++- icefall/lm_wrapper.py | 2 +- icefall/rnn_lm/compute_perplexity.py | 50 +++++- icefall/rnn_lm/export.py | 43 ++++- icefall/rnn_lm/train.py | 48 ++++- 18 files changed, 647 insertions(+), 36 deletions(-) mode change 100644 => 100755 egs/aishell/ASR/local/prepare_char_lm_training_data.py create mode 120000 egs/aishell/ASR/local/sort_lm_training_data.py delete mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 9a06fbe9f..4c730c4ae 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -15,6 +15,8 @@ It uses pruned RNN-T. |------------------------|------|------|---------------------------------------| | greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 | | modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search + RNNLM shallow fusion | 4.73 | 4.53 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search + LODR | 4.57 | 4.37 | --epoch 29 --avg 5 --max-duration 600 | | fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 | Training command is: @@ -73,6 +75,78 @@ for epoch in 29; do done ``` +We provide the option of shallow fusion with a RNN language model. The pre-trained language model is +available at . To decode with the language model, +please use the following command: + +```bash +# download pre-trained model +git lfs install +git clone https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 + +aishell_exp=icefall-aishell-pruned-transducer-stateless3-2022-06-20/ + +pushd ${aishell_exp}/exp +ln -s pretrained-epoch-29-avg-5-torch-1.10.0.pt epoch-99.pt +popd + +# download RNN LM +git lfs install +git clone https://huggingface.co/marcoyang/icefall-aishell-rnn-lm +rnnlm_dir=icefall-aishell-rnn-lm + +# RNNLM shallow fusion +for lm_scale in $(seq 0.26 0.02 0.34); do + python ./pruned_transducer_stateless3/decode.py \ + --epoch 99 \ + --avg 1 \ + --lang-dir ${aishell_exp}/data/lang_char \ + --exp-dir ${aishell_exp}/exp \ + --use-averaged-model False \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir ${rnnlm_dir}/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 2 \ + --lm-vocab-size 4336 +done + +# RNNLM Low-order density ratio (LODR) with a 2-gram + +cp ${rnnlm_dir}/2gram.fst.txt ${aishell_exp}/data/lang_char/2gram.fst.txt + +for lm_scale in 0.48; do + for LODR_scale in -0.28; do + python ./pruned_transducer_stateless3/decode.py \ + --epoch 99 \ + --avg 1 \ + --lang-dir ${aishell_exp}/data/lang_char \ + --exp-dir ${aishell_exp}/exp \ + --use-averaged-model False \ + --decoding-method modified_beam_search_LODR \ + --use-shallow-fusion 1 \ + --lm-type rnn \ + --lm-exp-dir ${rnnlm_dir}/exp \ + --lm-epoch 99 \ + --lm-scale $lm_scale \ + --lm-avg 1 \ + --rnn-lm-embedding-dim 2048 \ + --rnn-lm-hidden-dim 2048 \ + --rnn-lm-num-layers 2 \ + --lm-vocab-size 4336 \ + --tokens-ngram 2 \ + --backoff-id 4336 \ + --ngram-lm-scale $LODR_scale + done +done + +``` + Pretrained models, training logs, decoding logs, and decoding results are available at diff --git a/egs/aishell/ASR/local/prepare_char_lm_training_data.py b/egs/aishell/ASR/local/prepare_char_lm_training_data.py old mode 100644 new mode 100755 diff --git a/egs/aishell/ASR/local/sort_lm_training_data.py b/egs/aishell/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..1d6ccbe33 --- /dev/null +++ b/egs/aishell/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index cf4ee7818..bd34c1f44 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -230,12 +230,14 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then if [ ! -f $dl_dir/lm/aishell-train-word.txt ]; then cp $lang_phone_dir/transcript_words.txt $dl_dir/lm/aishell-train-word.txt fi - + + # training words ./local/prepare_char_lm_training_data.py \ --lang-char data/lang_char \ --lm-data $dl_dir/lm/aishell-train-word.txt \ --lm-archive $out_dir/lm_data.pt + # valid words if [ ! -f $dl_dir/lm/aishell-valid-word.txt ]; then aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_valid_uid=$dl_dir/aishell/data_aishell/transcript/aishell_valid_uid @@ -249,6 +251,7 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then --lm-data $dl_dir/lm/aishell-valid-word.txt \ --lm-archive $out_dir/lm_data_valid.pt + # test words if [ ! -f $dl_dir/lm/aishell-test-word.txt ]; then aishell_text=$dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt aishell_test_uid=$dl_dir/aishell/data_aishell/transcript/aishell_test_uid @@ -303,9 +306,9 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then --hidden-dim 512 \ --num-layers 2 \ --batch-size 400 \ - --exp-dir rnnlm_char/exp \ - --lm-data data/lm_training_char/sorted_lm_data.pt \ - --lm-data-valid data/lm_training_char/sorted_lm_data-valid.pt \ + --exp-dir rnnlm_char/exp_aishell1_small \ + --lm-data data/lm_char/sorted_lm_data_aishell1.pt \ + --lm-data-valid data/lm_char/sorted_lm_data_valid.pt \ --vocab-size 4336 \ --master-port 12345 fi diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py index 954d9dc7e..27c64efaa 100755 --- a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -54,6 +54,40 @@ Usage: --beam 4 \ --max-contexts 4 \ --max-states 8 + +(5) modified beam search (with LM shallow fusion) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(6) modified beam search with LM shallow fusion + LODR +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --decoding-method modified_beam_search_LODR \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.48 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + --tokens-ngram 2 \ + --ngram-lm-scale -0.28 \ """ @@ -74,9 +108,12 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -212,6 +249,60 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -223,6 +314,9 @@ def decode_one_batch( token_table: k2.SymbolTable, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -287,6 +381,24 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) else: hyp_tokens = [] batch_size = encoder_out.size(0) @@ -334,6 +446,9 @@ def decode_dataset( model: nn.Module, token_table: k2.SymbolTable, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -379,6 +494,9 @@ def decode_dataset( token_table=token_table, decoding_graph=decoding_graph, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -445,6 +563,7 @@ def save_results( def main(): parser = get_parser() AsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -458,6 +577,8 @@ def main(): "beam_search", "fast_beam_search", "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -479,6 +600,19 @@ def main(): if params.use_averaged_model: params.suffix += "-use-averaged-model" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -588,6 +722,35 @@ def main(): else: decoding_graph = None + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = params.lang_dir / f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + lm_filename, + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + else: + LM = None + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -610,6 +773,9 @@ def main(): model=model, token_table=lexicon.token_table, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 deleted file mode 120000 index bcd4abc2f..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 +++ /dev/null @@ -1 +0,0 @@ -/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 6c58a57e1..73207017b 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -550,7 +550,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -561,7 +560,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 999d793a4..75edf0c54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -1863,7 +1863,6 @@ def modified_beam_search_LODR( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, LODR_lm: NgramLm, LODR_lm_scale: float, LM: LmScorer, @@ -1883,8 +1882,6 @@ def modified_beam_search_LODR( encoder_out_lens (torch.Tensor): A 1-D tensor of shape (N,), containing the number of valid frames in encoder_out before padding. - sp: - Sentence piece generator. LODR_lm: A low order n-gram LM, whose score will be subtracted during shallow fusion LODR_lm_scale: @@ -1912,7 +1909,7 @@ def modified_beam_search_LODR( ) blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") + sos_id = getattr(LM, "sos_id", 1) unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device @@ -2137,7 +2134,6 @@ def modified_beam_search_lm_shallow_fusion( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, LM: LmScorer, beam: int = 4, return_timestamps: bool = False, @@ -2176,7 +2172,7 @@ def modified_beam_search_lm_shallow_fusion( ) blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") + sos_id = getattr(LM, "sos_id", 1) unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size device = next(model.parameters()).device diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index b39007dfc..7c62bfa58 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -675,7 +675,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -686,7 +685,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index af0b2d9fc..7a3e63218 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -586,7 +586,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -597,7 +596,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index 576621e24..55a2493e9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -533,7 +533,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LM=LM, ) for hyp in sp.decode(hyp_tokens): @@ -544,7 +543,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - sp=sp, LODR_lm=ngram_lm, LODR_lm_scale=ngram_lm_scale, LM=LM, diff --git a/egs/wenetspeech/ASR/local/text2segments.py b/egs/wenetspeech/ASR/local/text2segments.py index df5b3c119..bdf5a3984 100644 --- a/egs/wenetspeech/ASR/local/text2segments.py +++ b/egs/wenetspeech/ASR/local/text2segments.py @@ -40,8 +40,8 @@ from tqdm import tqdm # and 'data()' is only supported in static graph mode. So if you # want to use this api, should call 'paddle.enable_static()' before # this api to enter static graph mode. -paddle.enable_static() -paddle.disable_signal_handler() +# paddle.enable_static() +# paddle.disable_signal_handler() jieba.enable_paddle() diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 50a00253d..f7b521794 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -261,3 +261,107 @@ if [ $stage -le 18 ] && [ $stop_stage -ge 18 ]; then log "Stage 18: Compile LG" python ./local/compile_lg.py --lang-dir $lang_char_dir fi + +# prepare RNNLM data +if [ $stage -le 19 ] && [ $stop_stage -ge 19 ]; then + log "Stage 19: Prepare LM training data" + + log "Processing char based data" + text_out_dir=data/lm_char + + mkdir -p $text_out_dir + + log "Genearating training text data" + + if [ ! -f $text_out_dir/lm_data.pt ]; then + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $lang_char_dir/text_words_segmentation \ + --lm-archive $text_out_dir/lm_data.pt + fi + + log "Generating DEV text data" + # prepare validation text data + if [ ! -f $text_out_dir/valid_text_words_segmentation ]; then + valid_text=${text_out_dir}/ + + gunzip -c data/manifests/wenetspeech_supervisions_DEV.jsonl.gz \ + | jq '.text' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $text_out_dir/valid_text + + python3 ./local/text2segments.py \ + --num-process $nj \ + --input-file $text_out_dir/valid_text \ + --output-file $text_out_dir/valid_text_words_segmentation + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $text_out_dir/valid_text_words_segmentation \ + --lm-archive $text_out_dir/lm_data_valid.pt + + # prepare TEST text data + if [ ! -f $text_out_dir/TEST_text_words_segmentation ]; then + log "Prepare text for test set." + for test_set in TEST_MEETING TEST_NET; do + gunzip -c data/manifests/wenetspeech_supervisions_${test_set}.jsonl.gz \ + | jq '.text' | sed 's/"//g' \ + | ./local/text2token.py -t "char" > $text_out_dir/${test_set}_text + + python3 ./local/text2segments.py \ + --num-process $nj \ + --input-file $text_out_dir/${test_set}_text \ + --output-file $text_out_dir/${test_set}_text_words_segmentation + done + + cat $text_out_dir/TEST_*_text_words_segmentation > $text_out_dir/test_text_words_segmentation + fi + + ./local/prepare_char_lm_training_data.py \ + --lang-char data/lang_char \ + --lm-data $text_out_dir/test_text_words_segmentation \ + --lm-archive $text_out_dir/lm_data_test.pt + +fi + +# sort RNNLM data +if [ $stage -le 20 ] && [ $stop_stage -ge 20 ]; then + text_out_dir=data/lm_char + + log "Sort lm data" + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data.pt \ + --out-lm-data $text_out_dir/sorted_lm_data.pt \ + --out-statistics $text_out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data_valid.pt \ + --out-lm-data $text_out_dir/sorted_lm_data-valid.pt \ + --out-statistics $text_out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $text_out_dir/lm_data_test.pt \ + --out-lm-data $text_out_dir/sorted_lm_data-test.pt \ + --out-statistics $text_out_dir/statistics-test.txt +fi + +export CUDA_VISIBLE_DEVICES="0,1" + +if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then + log "Stage 21: Train RNN LM model" + python ../../../icefall/rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 2 \ + --num-epochs 20 \ + --use-fp16 0 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 2 \ + --batch-size 400 \ + --exp-dir rnnlm_char/exp \ + --lm-data data/lm_char/sorted_lm_data.pt \ + --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ + --vocab-size 4336 \ + --master-port 12340 +fi \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index de12b2ff0..46ba6b005 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -2,6 +2,7 @@ # # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2022 Xiaomi Corporation (Author: Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -91,6 +92,22 @@ When training with the L subset, the streaming usage: --causal-convolution 1 \ --decode-chunk-size 16 \ --left-context 64 + +(4) modified beam search with RNNLM shallow fusion +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_lm_shallow_fusion \ + --beam-size 4 \ + --lm-type rnn \ + --lm-scale 0.3 \ + --lm-exp-dir /path/to/LM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 """ @@ -111,9 +128,12 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, ) from train import add_model_arguments, get_params, get_transducer_model +from icefall import LmScorer, NgramLm from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, @@ -224,6 +244,16 @@ def get_parser(): Used only when --decoding-method is fast_beam_search""", ) + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + parser.add_argument( "--max-contexts", type=int, @@ -277,6 +307,50 @@ def get_parser(): help="left context can be seen during decoding (in frames after subsampling)", ) + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring, or LODR + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) add_model_arguments(parser) return parser @@ -288,6 +362,9 @@ def decode_one_batch( lexicon: Lexicon, batch: dict, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -374,6 +451,28 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) else: batch_size = encoder_out.size(0) @@ -419,6 +518,9 @@ def decode_dataset( model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + LM: Optional[LmScorer] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. @@ -432,6 +534,8 @@ def decode_dataset( decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. + LM: + A neural network LM, used during shallow fusion Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -449,7 +553,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 100 else: - log_interval = 2 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -463,6 +567,9 @@ def decode_dataset( lexicon=lexicon, decoding_graph=decoding_graph, batch=batch, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) for name, hyps in hyps_dict.items(): @@ -524,6 +631,7 @@ def save_results( def main(): parser = get_parser() WenetSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -535,6 +643,8 @@ def main(): "beam_search", "fast_beam_search", "modified_beam_search", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", ) params.res_dir = params.exp_dir / params.decoding_method @@ -549,6 +659,22 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + if params.use_shallow_fusion: + if params.lm_type == "rnn": + params.suffix += f"-rnnlm-lm-scale-{params.lm_scale}" + elif params.lm_type == "transformer": + params.suffix += f"-transformer-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -558,6 +684,7 @@ def main(): logging.info(f"Device: {device}") + # import pdb; pdb.set_trace() lexicon = Lexicon(params.lang_dir) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 @@ -652,6 +779,37 @@ def main(): model.to(device) model.eval() model.device = device + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # import pdb; pdb.set_trace() + # only load the neural network LM if doing shallow fusion + if params.use_shallow_fusion: + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + + num_param = sum([p.numel() for p in LM.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + else: + LM = None if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) @@ -684,6 +842,9 @@ def main(): model=model, lexicon=lexicon, decoding_graph=decoding_graph, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + LM=LM, ) save_results( params=params, diff --git a/icefall/lm_wrapper.py b/icefall/lm_wrapper.py index 0468befd0..5e2783a47 100644 --- a/icefall/lm_wrapper.py +++ b/icefall/lm_wrapper.py @@ -50,7 +50,7 @@ class LmScorer(torch.nn.Module): def add_arguments(cls, parser): # LM general arguments parser.add_argument( - "--vocab-size", + "--lm-vocab-size", type=int, default=500, ) diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py index f75a89590..cc566bd92 100755 --- a/icefall/rnn_lm/compute_perplexity.py +++ b/icefall/rnn_lm/compute_perplexity.py @@ -33,7 +33,7 @@ import torch from dataset import get_dataloader from model import RnnLmModel -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, setup_logger, str2bool @@ -49,6 +49,7 @@ def get_parser(): help="It specifies the checkpoint to use for decoding." "Note: Epoch counts from 0.", ) + parser.add_argument( "--avg", type=int, @@ -58,6 +59,16 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--exp-dir", type=str, @@ -154,7 +165,14 @@ def main(): params = AttributeDict(vars(args)) - setup_logger(f"{params.exp_dir}/log-ppl/") + if params.iter > 0: + setup_logger( + f"{params.exp_dir}/log-ppl/log-ppl-iter-{params.iter}-avg-{params.avg}" + ) + else: + setup_logger( + f"{params.exp_dir}/log-ppl/log-ppl-epoch-{params.epoch}-avg-{params.avg}" + ) logging.info("Computing perplexity started") logging.info(params) @@ -173,19 +191,39 @@ def main(): tie_weights=params.tie_weights, ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): - if start >= 0: + if i >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) num_param_requires_grad = sum( diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index 2411cb1f0..a8598a1ce 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -25,7 +25,7 @@ from pathlib import Path import torch from model import RnnLmModel -from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.utils import AttributeDict, load_averaged_model, str2bool @@ -51,6 +51,16 @@ def get_parser(): "'--epoch'. ", ) + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + parser.add_argument( "--vocab-size", type=int, @@ -133,11 +143,36 @@ def main(): model.to(device) - if params.avg == 1: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: - model = load_averaged_model( - params.exp_dir, model, params.epoch, params.avg, device + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False ) model.to("cpu") diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index f43e66cd2..91df4f921 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -49,6 +49,7 @@ from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import save_checkpoint_with_global_batch_idx from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -178,6 +179,33 @@ def get_parser(): help="The seed for random generators intended for reproducibility", ) + parser.add_argument( + "--lr", + type=float, + default=1e-3, + ) + + parser.add_argument( + "--max-sent-len", + type=int, + default=200, + help="""Maximum number of tokens in a sentence. This is used + to adjust batch-size dynamically""", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + return parser @@ -190,16 +218,15 @@ def get_params() -> AttributeDict: "sos_id": 1, "eos_id": 1, "blank_id": 0, - "lr": 1e-3, "weight_decay": 1e-6, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 200, + "log_interval": 100, "reset_interval": 2000, - "valid_interval": 5000, + "valid_interval": 200, "env_info": get_env_info(), } ) @@ -382,6 +409,7 @@ def train_one_epoch( valid_dl: torch.utils.data.DataLoader, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, + rank: int = 0, ) -> None: """Train the model for one epoch. @@ -430,6 +458,19 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + params=params, + optimizer=optimizer, + rank=rank, + ) + if batch_idx % params.log_interval == 0: # Note: "frames" here means "num_tokens" this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) @@ -580,6 +621,7 @@ def run(rank, world_size, args): valid_dl=valid_dl, tb_writer=tb_writer, world_size=world_size, + rank=rank, ) save_checkpoint( From c90f57afdbc8116944771aab1c3a42217a53eeec Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 4 Apr 2023 11:04:00 +0800 Subject: [PATCH 12/17] Remove simulate streaming from stateless8 (#985) --- .../pruned_transducer_stateless8/decode.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 7b651a632..e07777c9f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -301,29 +301,6 @@ def get_parser(): fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="""Whether to simulate streaming in decoding, this is a good way to - test a streaming model. - """, - ) - - parser.add_argument( - "--decode-chunk-size", - type=int, - default=16, - help="The chunk size for decoding (in frames after subsampling)", - ) - - parser.add_argument( - "--left-context", - type=int, - default=64, - help="left context can be seen during decoding (in frames after subsampling)", - ) - add_model_arguments(parser) return parser @@ -378,22 +355,7 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - if params.simulate_streaming: - feature_lens += params.left_context - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.left_context), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( - x=feature, - x_lens=feature_lens, - chunk_size=params.decode_chunk_size, - left_context=params.left_context, - simulate_streaming=True, - ) - else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) hyps = [] @@ -651,10 +613,6 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.simulate_streaming: - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" - params.suffix += f"-left-context-{params.left_context}" - if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" @@ -690,11 +648,6 @@ def main(): params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() - if params.simulate_streaming: - assert ( - params.causal_convolution - ), "Decoding in streaming requires causal convolution" - logging.info(params) logging.info("About to create model") From 136aa94d5757ef02654612f0b19fb5d25d5eda39 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Thu, 6 Apr 2023 17:47:33 +0800 Subject: [PATCH 13/17] remove duplicated lines (#988) --- .../ASR/pruned_transducer_stateless3/onnx_pretrained.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 5adb6c16a..3947aa102 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -403,9 +403,8 @@ def main(): text += symbol_table[i] return text.replace("▁", " ").strip() - context_size = model.context_size for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp[context_size:]) + words = token_ids_to_words(hyp) s += f"{filename}:\n{words}\n" logging.info(s) From 6434c8eadc0d4326e1db69824cf0e40dc9a71c8a Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Sun, 9 Apr 2023 20:53:47 +0800 Subject: [PATCH 14/17] Add averaged model && change start from 0 to 1 && fix typo for gigaspeech (#990) * Add averaged model && change start from 0 to 1 && fix typo * Update train.py * Set use-averaged-model False for BC --------- Co-authored-by: yifanyang --- .../pruned_transducer_stateless2/decode.py | 181 ++++++++++++------ .../ASR/pruned_transducer_stateless2/train.py | 79 ++++++-- 2 files changed, 184 insertions(+), 76 deletions(-) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index ee694a9e0..72f74c968 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -19,40 +19,40 @@ Usage: (1) greedy search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method greedy_search (2) beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 (3) modified beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 (4) fast beam search ./pruned_transducer_stateless2/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 """ @@ -76,12 +76,17 @@ from beam_search import ( ) from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -94,9 +99,9 @@ def get_parser(): parser.add_argument( "--epoch", type=int, - default=29, + default=30, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -119,6 +124,17 @@ def get_parser(): "'--epoch' and '--iter'", ) + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + parser.add_argument( "--exp-dir", type=str, @@ -464,6 +480,9 @@ def main(): params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -476,7 +495,7 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -486,37 +505,85 @@ def main(): logging.info("About to create model") model = get_transducer_model(params) - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() - model.device = device if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 6a9f9f32f..578bd9218 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -42,6 +42,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" import argparse +import copy import logging import warnings from pathlib import Path @@ -70,7 +71,10 @@ from torch.utils.tensorboard import SummaryWriter from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool @@ -114,10 +118,10 @@ def get_parser(): parser.add_argument( "--start-epoch", type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - transducer_stateless2/exp/epoch-{start_epoch-1}.pt + default=1, + help="""Resume training from this epoch. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt """, ) @@ -240,7 +244,7 @@ def get_parser(): parser.add_argument( "--keep-last-k", type=int, - default=20, + default=30, help="""Only keep this number of checkpoints on disk. For instance, if it is 3, there are only 3 checkpoints in the exp-dir with filenames `checkpoint-xxx.pt`. @@ -248,6 +252,19 @@ def get_parser(): """, ) + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + parser.add_argument( "--use-fp16", type=str2bool, @@ -385,6 +402,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, + model_avg: nn.Module = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, ) -> Optional[Dict[str, Any]]: @@ -392,7 +410,7 @@ def load_checkpoint_if_available( If params.start_batch is positive, it will load the checkpoint from `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from + params.start_epoch is larger than 1, it will load the checkpoint from `params.start_epoch - 1`. Apart from loading state dict for `model` and `optimizer` it also updates @@ -404,6 +422,8 @@ def load_checkpoint_if_available( The return value of :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer that we are using. scheduler: @@ -413,7 +433,7 @@ def load_checkpoint_if_available( """ if params.start_batch > 0: filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: + elif params.start_epoch > 1: filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" else: return None @@ -423,6 +443,7 @@ def load_checkpoint_if_available( saved_params = load_checkpoint( filename, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, ) @@ -449,7 +470,8 @@ def load_checkpoint_if_available( def save_checkpoint( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[LRSchedulerType] = None, sampler: Optional[CutSampler] = None, @@ -463,6 +485,8 @@ def save_checkpoint( It is returned by :func:`get_params`. model: The training model. + model_avg: + The stored model averaged from the start of training. optimizer: The optimizer used in the training. sampler: @@ -476,6 +500,7 @@ def save_checkpoint( save_checkpoint_impl( filename=filename, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -495,14 +520,14 @@ def save_checkpoint( def compute_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, batch: dict, is_training: bool, warmup: float = 1.0, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute transducer loss given the model and its inputs. Args: params: @@ -568,7 +593,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], sp: spm.SentencePieceProcessor, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -602,13 +627,14 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, - model: nn.Module, + model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, sp: spm.SentencePieceProcessor, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, scaler: GradScaler, + model_avg: Optional[nn.Module] = None, tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, rank: int = 0, @@ -634,6 +660,8 @@ def train_one_epoch( Dataloader for the validation dataset. scaler: The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -660,6 +688,7 @@ def train_one_epoch( loss, loss_info = compute_loss( params=params, model=model, + model_avg=model_avg, sp=sp, batch=batch, is_training=True, @@ -688,6 +717,7 @@ def train_one_epoch( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, model=model, + model_avg=model_avg, params=params, optimizer=optimizer, scheduler=scheduler, @@ -791,7 +821,16 @@ def run(rank, world_size, args): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - checkpoints = load_checkpoint_if_available(params=params, model=model) + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) model.to(device) if world_size > 1: @@ -850,10 +889,10 @@ def run(rank, world_size, args): logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) if tb_writer is not None: tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) @@ -863,6 +902,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sp=sp, @@ -881,6 +921,7 @@ def run(rank, world_size, args): save_checkpoint( params=params, model=model, + model_avg=model_avg, optimizer=optimizer, scheduler=scheduler, sampler=train_dl.sampler, @@ -896,7 +937,7 @@ def run(rank, world_size, args): def scan_pessimistic_batches_for_oom( - model: nn.Module, + model: Union[nn.Module, DDP], train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, From 33578cca48613c04c227cd22a2d2fdd207a5e928 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:12:05 +0800 Subject: [PATCH 15/17] Fix filter_cuts in compute_fbank_librispeech.py (#993) --- egs/librispeech/ASR/local/compute_fbank_librispeech.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 745eaf1e8..554d7f109 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -121,10 +121,10 @@ def compute_fbank_librispeech( recordings=m["recordings"], supervisions=m["supervisions"], ) - if bpe_model: - cut_set = filter_cuts(cut_set, sp) if "train" in partition: + if bpe_model: + cut_set = filter_cuts(cut_set, sp) cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) From 3cb0a0121ba7ad934afb7b9e59045995e33e77b6 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 11 Apr 2023 20:56:40 +0800 Subject: [PATCH 16/17] Add Common Voice (#994) * Add commonvoice * Add data preparation recipe * Updata * update prepare.sh * Fix for black * Update prefix with cv- * 20 -> * Update compute_fbank_commonvoice_dev_test.py * Update prepare.sh * Update compute_fbank_commonvoice_dev_test.py --- .../compute_fbank_commonvoice_dev_test.py | 107 ++++++++++++ .../local/compute_fbank_commonvoice_splits.py | 157 ++++++++++++++++++ .../ASR/local/compute_fbank_musan.py | 1 + egs/commonvoice/ASR/local/filter_cuts.py | 1 + .../ASR/local/preprocess_commonvoice.py | 119 +++++++++++++ egs/commonvoice/ASR/prepare.sh | 156 +++++++++++++++++ egs/commonvoice/ASR/shared | 1 + 7 files changed, 542 insertions(+) create mode 100755 egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py create mode 100755 egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py create mode 120000 egs/commonvoice/ASR/local/compute_fbank_musan.py create mode 120000 egs/commonvoice/ASR/local/filter_cuts.py create mode 100755 egs/commonvoice/ASR/local/preprocess_commonvoice.py create mode 100755 egs/commonvoice/ASR/prepare.sh create mode 120000 egs/commonvoice/ASR/shared diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py new file mode 100755 index 000000000..c8f9b6ccb --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_dev_test.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the CommonVoice dataset. +It looks for manifests in the directory data/${lang}/manifests. + +The generated fbank features are saved in data/${lang}/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + return parser.parse_args() + + +def compute_fbank_commonvoice_dev_test(language: str): + src_dir = Path(f"data/{language}/manifests") + output_dir = Path(f"data/{language}/fbank") + num_workers = 42 + batch_duration = 600 + + subsets = ("dev", "test") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + + logging.info(f"device: {device}") + + for partition in subsets: + cuts_path = output_dir / f"cv-{language}_cuts_{partition}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + + raw_cuts_path = output_dir / f"cv-{language}_cuts_{partition}_raw.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/cv-{language}_feats_{partition}", + num_workers=num_workers, + batch_duration=batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_commonvoice_dev_test(language=args.language) diff --git a/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py new file mode 100755 index 000000000..f8c09ccf3 --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_commonvoice_splits.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from datetime import datetime +from pathlib import Path + +import torch +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomChunkyWriter, + set_audio_duration_mismatch_tolerance, + set_caching_enabled, +) + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + parser.add_argument( + "--num-workers", + type=int, + default=20, + help="Number of dataloading workers used for reading the audio.", + ) + + parser.add_argument( + "--batch-duration", + type=float, + default=600.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + + parser.add_argument( + "--num-splits", + type=int, + required=True, + help="The number of splits of the train subset", + ) + + parser.add_argument( + "--start", + type=int, + default=0, + help="Process pieces starting from this number (inclusive).", + ) + + parser.add_argument( + "--stop", + type=int, + default=-1, + help="Stop processing pieces until this number (exclusive).", + ) + + return parser.parse_args() + + +def compute_fbank_commonvoice_splits(args): + subset = "train" + num_splits = args.num_splits + language = args.language + output_dir = f"data/{language}/fbank/{subset}_split_{num_splits}" + output_dir = Path(output_dir) + assert output_dir.exists(), f"{output_dir} does not exist!" + + num_digits = len(str(num_splits)) + + start = args.start + stop = args.stop + if stop < start: + stop = num_splits + + stop = min(stop, num_splits) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device)) + logging.info(f"device: {device}") + + set_audio_duration_mismatch_tolerance(0.01) # 10ms tolerance + set_caching_enabled(False) + for i in range(start, stop): + idx = f"{i + 1}".zfill(num_digits) + logging.info(f"Processing {idx}/{num_splits}") + + cuts_path = output_dir / f"cv-{language}_cuts_{subset}.{idx}.jsonl.gz" + if cuts_path.is_file(): + logging.info(f"{cuts_path} exists - skipping") + continue + + raw_cuts_path = output_dir / f"cv-{language}_cuts_{subset}_raw.{idx}.jsonl.gz" + + logging.info(f"Loading {raw_cuts_path}") + cut_set = CutSet.from_file(raw_cuts_path) + + logging.info("Splitting cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + + logging.info("Computing features") + cut_set = cut_set.compute_and_store_features_batch( + extractor=extractor, + storage_path=f"{output_dir}/cv-{language}_feats_{subset}_{idx}", + num_workers=args.num_workers, + batch_duration=args.batch_duration, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + logging.info(f"Saving to {cuts_path}") + cut_set.to_file(cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_commonvoice_splits(args) + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/local/compute_fbank_musan.py b/egs/commonvoice/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/commonvoice/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/filter_cuts.py b/egs/commonvoice/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/commonvoice/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py new file mode 100755 index 000000000..22f7aa03c --- /dev/null +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import Optional + +from lhotse import CutSet, SupervisionSegment +from lhotse.recipes.utils import read_manifests_if_cached + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--language", + type=str, + help="""Language of Common Voice""", + ) + + return parser.parse_args() + + +def preprocess_commonvoice( + language: str, + dataset: Optional[str] = None, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + output_dir.mkdir(exist_ok=True) + + if dataset is None: + dataset_parts = ( + "dev", + "test", + "train", + ) + else: + dataset_parts = dataset.split(" ", -1) + + logging.info("Loading manifest") + prefix = f"cv-{language}" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, + prefix=prefix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + # Create long-recording cut manifests. + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ).resample(16000) + + # Run data augmentation that needs to be done in the + # time domain. + if "train" in partition: + logging.info( + f"Speed perturb for {partition} with factors 0.9 and 1.1 " + "(Perturbing may take 2 minutes and saving may take 7 minutes)" + ) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + preprocess_commonvoice( + language=args.language, + dataset=args.dataset, + ) + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/commonvoice/ASR/prepare.sh b/egs/commonvoice/ASR/prepare.sh new file mode 100755 index 000000000..9a28167b1 --- /dev/null +++ b/egs/commonvoice/ASR/prepare.sh @@ -0,0 +1,156 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=16 +stage=-1 +stop_stage=100 + +# Split data/${lang}set to this number of pieces +# This is to avoid OOM during feature extraction. +num_splits=1000 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/$release/$lang +# This directory contains the following files downloaded from +# https://mozilla-common-voice-datasets.s3.dualstack.us-west-2.amazonaws.com/${release}/${release}-${lang}.tar.gz +# +# - clips +# - dev.tsv +# - invalidated.tsv +# - other.tsv +# - reported.tsv +# - test.tsv +# - train.tsv +# - validated.tsv +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download +release=cv-corpus-13.0-2023-03-09 +lang=en + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/${lang}/lang_bpe_xxx, +# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data/${lang}". +# You can safely remove "data/${lang}" and rerun this script to regenerate it. +mkdir -p data/${lang} + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/$release, + # you can create a symlink + # + # ln -sfv /path/to/$release $dl_dir/$release + # + if [ ! -d $dl_dir/$release/$lang/clips ]; then + lhotse download commonvoice --languages $lang --release $release $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare CommonVoice manifest" + # We assume that you have downloaded the CommonVoice corpus + # to $dl_dir/$release + mkdir -p data/${lang}/manifests + if [ ! -e data/${lang}/manifests/.cv-${lang}.done ]; then + lhotse prepare commonvoice --language $lang -j $nj $dl_dir/$release data/${lang}/manifests + touch data/${lang}/manifests/.cv-${lang}.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Preprocess CommonVoice manifest" + if [ ! -e data/${lang}/fbank/.preprocess_complete ]; then + ./local/preprocess_commonvoice.py --language $lang + touch data/${lang}/fbank/.preprocess_complete + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for dev and test subsets of CommonVoice" + mkdir -p data/${lang}/fbank + if [ ! -e data/${lang}/fbank/.cv-${lang}_dev_test.done ]; then + ./local/compute_fbank_commonvoice_dev_test.py --language $lang + touch data/${lang}/fbank/.cv-${lang}_dev_test.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Split train subset into ${num_splits} pieces" + split_dir=data/${lang}/fbank/train_split_${num_splits} + if [ ! -e $split_dir/.cv-${lang}_train_split.done ]; then + lhotse split $num_splits ./data/${lang}/fbank/cv-${lang}_cuts_train_raw.jsonl.gz $split_dir + touch $split_dir/.cv-${lang}_train_split.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Compute features for train subset of CommonVoice" + if [ ! -e data/${lang}/fbank/.cv-${lang}_train.done ]; then + ./local/compute_fbank_commonvoice_splits.py \ + --num-workers $nj \ + --batch-duration 600 \ + --start 0 \ + --num-splits $num_splits \ + --language $lang + touch data/${lang}/fbank/.cv-${lang}_train.done + fi +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi diff --git a/egs/commonvoice/ASR/shared b/egs/commonvoice/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/commonvoice/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From dbf2aa3212a47650c85da574a4a2ed61773d6be3 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Tue, 11 Apr 2023 21:04:54 +0800 Subject: [PATCH 17/17] Create preprocess_commonvoice.py (#996) --- egs/commonvoice/ASR/local/preprocess_commonvoice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/commonvoice/ASR/local/preprocess_commonvoice.py b/egs/commonvoice/ASR/local/preprocess_commonvoice.py index 22f7aa03c..241e81963 100755 --- a/egs/commonvoice/ASR/local/preprocess_commonvoice.py +++ b/egs/commonvoice/ASR/local/preprocess_commonvoice.py @@ -46,8 +46,8 @@ def preprocess_commonvoice( language: str, dataset: Optional[str] = None, ): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") + src_dir = Path(f"data/{language}/manifests") + output_dir = Path(f"data/{language}/fbank") output_dir.mkdir(exist_ok=True) if dataset is None: