From cd3bf0a1fd25cad1f26dfce6fc4b3b3142e13423 Mon Sep 17 00:00:00 2001 From: luomingshuang <739314837@qq.com> Date: Mon, 16 May 2022 20:53:19 +0800 Subject: [PATCH] do some changes --- egs/aidatatang_200zh/ASR/README.md | 39 ++ egs/aidatatang_200zh/ASR/RESULTS.md | 72 ++++ .../ASR/local/process_aidatatang_200zh.py | 72 ---- egs/aidatatang_200zh/ASR/local/text2token.py | 2 - .../asr_datamodule.py | 2 +- .../beam_search.py | 299 ++++++++++++--- .../pruned_transducer_stateless2/decode.py | 74 ++-- .../pruned_transducer_stateless2/export.py | 6 +- .../pretrained.py | 347 ++++++++++++++++++ .../ASR/pruned_transducer_stateless2/train.py | 2 +- 10 files changed, 739 insertions(+), 176 deletions(-) create mode 100644 egs/aidatatang_200zh/ASR/README.md create mode 100644 egs/aidatatang_200zh/ASR/RESULTS.md delete mode 100755 egs/aidatatang_200zh/ASR/local/process_aidatatang_200zh.py create mode 100644 egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py diff --git a/egs/aidatatang_200zh/ASR/README.md b/egs/aidatatang_200zh/ASR/README.md new file mode 100644 index 000000000..fd72767c8 --- /dev/null +++ b/egs/aidatatang_200zh/ASR/README.md @@ -0,0 +1,39 @@ +Note: This recipe is trained with the codes from this PR https://github.com/k2-fsa/icefall/pull/355 +And the SpecAugment codes from this PR https://github.com/lhotse-speech/lhotse/pull/604. +# Pre-trained Transducer-Stateless2 models for the Aidatatang_200zh dataset with icefall. +The model was trained on full [Aidatatang_200zh](https://www.openslr.org/62) with the scripts in [icefall](https://github.com/k2-fsa/icefall) based on the latest version k2. +## Training procedure +The main repositories are list below, we will update the training and decoding scripts with the update of version. +k2: https://github.com/k2-fsa/k2 +icefall: https://github.com/k2-fsa/icefall +lhotse: https://github.com/lhotse-speech/lhotse +* Install k2 and lhotse, k2 installation guide refers to https://k2.readthedocs.io/en/latest/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall. +* Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above. +``` +git clone https://github.com/k2-fsa/icefall +cd icefall +``` +* Preparing data. +``` +cd egs/aidatatang_200zh/ASR +bash ./prepare.sh +``` +* Training +``` +export CUDA_VISIBLE_DEVICES="0,1" +./pruned_transducer_stateless2/train.py \ + --world-size 2 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --max-duration 250 +``` +## Evaluation results +The decoding results (WER%) on Aidatatang_200zh(dev and test) are listed below, we got this result by averaging models from epoch 11 to 29. +The WERs are +| | dev | test | comment | +|------------------------------------|------------|------------|------------------------------------------| +| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 | +| modified beam search (beam size 4) | 5.28 | 6.32 | --epoch 29, --avg 19, --max-duration 100 | +| fast beam search (set as default) | 5.29 | 6.33 | --epoch 29, --avg 19, --max-duration 1500| diff --git a/egs/aidatatang_200zh/ASR/RESULTS.md b/egs/aidatatang_200zh/ASR/RESULTS.md new file mode 100644 index 000000000..ab7b9cb92 --- /dev/null +++ b/egs/aidatatang_200zh/ASR/RESULTS.md @@ -0,0 +1,72 @@ +## Results + +### Aidatatang_200zh Char training results (Pruned Transducer Stateless2) + +#### 2022-05-16 + +Using the codes from this PR https://github.com/k2-fsa/icefall/pull/355. + +The WERs are + +| | dev | test | comment | +|------------------------------------|------------|------------|------------------------------------------| +| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 | +| modified beam search (beam size 4) | 5.28 | 6.32 | --epoch 29, --avg 19, --max-duration 100 | +| fast beam search (set as default) | 5.29 | 6.33 | --epoch 29, --avg 19, --max-duration 1500| + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0, 1" + +./pruned_transducer_stateless2/train.py \ + --world-size 2 \ + --num-epochs 30 \ + --start-epoch 0 \ + --exp-dir pruned_transducer_stateless2/exp \ + --lang-dir data/lang_char \ + --max-duration 250 \ + --save-every-n 1000 + +``` + +The tensorboard training log can be found at +https://tensorboard.dev/experiment/VpA8b7SZQ7CEjZs9WZ5HNA/#scalars + +The decoding command is: +``` +epoch=29 +avg=19 + +## greedy search +./pruned_transducer_stateless2/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless2/exp \ + --lang-dir ./data/lang_char \ + --max-duration 100 + +## modified beam search +./pruned_transducer_stateless2/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir pruned_transducer_stateless2/exp \ + --lang-dir ./data/lang_char \ + --max-duration 100 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +## fast beam search +./pruned_transducer_stateless2/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --lang-dir ./data/lang_char \ + --max-duration 1500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +``` + +A pre-trained model and decoding logs can be found at diff --git a/egs/aidatatang_200zh/ASR/local/process_aidatatang_200zh.py b/egs/aidatatang_200zh/ASR/local/process_aidatatang_200zh.py deleted file mode 100755 index 2c6951d42..000000000 --- a/egs/aidatatang_200zh/ASR/local/process_aidatatang_200zh.py +++ /dev/null @@ -1,72 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from pathlib import Path - -from lhotse import CutSet -from lhotse.recipes.utils import read_manifests_if_cached - - -def preprocess_aidatatang_200zh(): - src_dir = Path("data/manifests/aidatatang_200zh") - output_dir = Path("data/fbank/aidatatang_200zh") - output_dir.mkdir(exist_ok=True, parents=True) - - dataset_parts = ( - "train", - "test", - "dev", - ) - - logging.info("Loading manifest") - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, - output_dir=src_dir, - ) - assert len(manifests) > 0 - - for partition, m in manifests.items(): - logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" - if raw_cuts_path.is_file(): - logging.info(f"{partition} already exists - skipping") - continue - - for sup in m["supervisions"]: - sup.custom = {"origin": "aidatatang_200zh"} - - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - - logging.info(f"Saving to {raw_cuts_path}") - cut_set.to_file(raw_cuts_path) - - -def main(): - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - logging.basicConfig(format=formatter, level=logging.INFO) - - preprocess_aidatatang_200zh() - - -if __name__ == "__main__": - main() diff --git a/egs/aidatatang_200zh/ASR/local/text2token.py b/egs/aidatatang_200zh/ASR/local/text2token.py index 06e1188ec..71be2a613 100755 --- a/egs/aidatatang_200zh/ASR/local/text2token.py +++ b/egs/aidatatang_200zh/ASR/local/text2token.py @@ -186,8 +186,6 @@ def main(): for z in a: a_flat.append("".join(z)) - # a_chars = [z.replace(" ", args.space) for z in a_flat] - a_chars = [z for z in a_flat] a_chars = "".join(a_flat) print(a_chars) line = f.readline() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py index 5fb5b13db..de83cef5a 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -413,6 +413,6 @@ class Aidatatang_200zhAsrDataModule: return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz") @lru_cache() - def test_net_cuts(self) -> List[CutSet]: + def test_cuts(self) -> List[CutSet]: logging.info("About to get test cuts") return load_manifest(self.args.manifest_dir / "cuts_test.json.gz") diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py index 2e9bf3e0b..ba20911bd 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/beam_search.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Dict, List, Optional @@ -21,11 +22,11 @@ import k2 import torch from model import Transducer -from icefall.decode import one_best_decoding +from icefall.decode import Nbest, one_best_decoding from icefall.utils import get_texts -def fast_beam_search( +def fast_beam_search_one_best( model: Transducer, decoding_graph: k2.Fsa, encoder_out: torch.Tensor, @@ -35,7 +36,8 @@ def fast_beam_search( max_contexts: int, ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. Args: model: An instance of `Transducer`. @@ -55,6 +57,143 @@ def fast_beam_search( Returns: Return the decoded result. """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + return hyps + + +def fast_beam_search_nbest_oracle( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + ref_texts: List[List[int]], + use_double_scores: bool = True, + nbest_scale: float = 0.5, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + A lattice is first obtained using modified beam search, and then + we select `num_paths` linear paths from the lattice. The path + that has the minimum edit distance with the given reference transcript + is used as the output. + This is the best result we can achieve for any nbest based rescoring + methods. + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + ref_texts: + A list-of-list of integers containing the reference transcripts. + If the decoding_graph is a trivial_graph, the integer ID is the + BPE token ID. + use_double_scores: + True to use double precision for computation. False to use + single precision. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + hyps = nbest.build_levenshtein_graphs() + refs = k2.levenshtein_graph(ref_texts, device=hyps.device) + + levenshtein_alignment = k2.levenshtein_alignment( + refs=refs, + hyps=hyps, + hyp_to_ref_map=nbest.shape.row_ids(1), + sorted_match_ref=True, + ) + + tot_scores = levenshtein_alignment.get_tot_scores( + use_double_scores=False, log_semiring=False + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + + max_indexes = ragged_tot_scores.argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + hyps = get_texts(best_path) + return hyps + + +def fast_beam_search( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ assert encoder_out.ndim == 3 context_size = model.decoder.context_size @@ -103,9 +242,7 @@ def fast_beam_search( decoding_streams.terminate_and_flush_to_streams() lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - best_path = one_best_decoding(lattice) - hyps = get_texts(best_path) - return hyps + return lattice def greedy_search( @@ -130,8 +267,9 @@ def greedy_search( blank_id = model.decoder.blank_id context_size = model.decoder.context_size + unk_id = getattr(model, "unk_id", blank_id) - device = model.device + device = next(model.parameters()).device decoder_input = torch.tensor( [blank_id] * context_size, device=device, dtype=torch.int64 @@ -170,7 +308,7 @@ def greedy_search( # logits is (1, 1, 1, vocab_size) y = logits.argmax().item() - if y != blank_id: + if y not in (blank_id, unk_id): hyp.append(y) decoder_input = torch.tensor( [hyp[-context_size:]], device=device @@ -190,7 +328,9 @@ def greedy_search( def greedy_search_batch( - model: Transducer, encoder_out: torch.Tensor + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, ) -> List[List[int]]: """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. Args: @@ -198,6 +338,9 @@ def greedy_search_batch( The transducer model. encoder_out: Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. Returns: Return a list-of-list of token IDs containing the decoded results. len(ans) equals to encoder_out.size(0). @@ -205,30 +348,49 @@ def greedy_search_batch( assert encoder_out.ndim == 3 assert encoder_out.size(0) >= 1, encoder_out.size(0) - device = model.device + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) - batch_size = encoder_out.size(0) - T = encoder_out.size(1) + device = next(model.parameters()).device blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(batch_size)] + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] decoder_input = torch.tensor( hyps, device=device, dtype=torch.int64, - ) # (batch_size, context_size) + ) # (N, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) - encoder_out = model.joiner.encoder_proj(encoder_out) + # decoder_out: (N, 1, decoder_out_dim) - # decoder_out: (batch_size, 1, decoder_out_dim) - for t in range(T): - current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + logits = model.joiner( current_encoder_out, decoder_out.unsqueeze(1), project_input=False ) @@ -239,12 +401,12 @@ def greedy_search_batch( y = logits.argmax(dim=1).tolist() emitted = False for i, v in enumerate(y): - if v != blank_id: + if v not in (blank_id, unk_id): hyps[i].append(v) emitted = True if emitted: # update decoder output - decoder_input = [h[-context_size:] for h in hyps] + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] decoder_input = torch.tensor( decoder_input, device=device, @@ -253,7 +415,12 @@ def greedy_search_batch( decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) - ans = [h[context_size:] for h in hyps] + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + return ans @@ -291,10 +458,8 @@ class HypothesisList(object): def add(self, hyp: Hypothesis) -> None: """Add a Hypothesis to `self`. - If `hyp` already exists in `self`, its probability is updated using `log-sum-exp` with the existed one. - Args: hyp: The hypothesis to be added. @@ -311,7 +476,6 @@ class HypothesisList(object): def get_most_probable(self, length_norm: bool = False) -> Hypothesis: """Get the most probable hypothesis, i.e., the one with the largest `log_prob`. - Args: length_norm: If True, the `log_prob` of a hypothesis is normalized by the @@ -328,10 +492,8 @@ class HypothesisList(object): def remove(self, hyp: Hypothesis) -> None: """Remove a given hypothesis. - Caution: `self` is modified **in-place**. - Args: hyp: The hypothesis to be removed from `self`. @@ -344,10 +506,8 @@ class HypothesisList(object): def filter(self, threshold: torch.Tensor) -> "HypothesisList": """Remove all Hypotheses whose log_prob is less than threshold. - Caution: `self` is not modified. Instead, a new HypothesisList is returned. - Returns: Return a new HypothesisList containing all hypotheses from `self` with `log_prob` being greater than the given `threshold`. @@ -385,7 +545,6 @@ class HypothesisList(object): def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: """Return a ragged shape with axes [utt][num_hyps]. - Args: hyps: len(hyps) == batch_size. It contains the current hypothesis for @@ -411,15 +570,18 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: def modified_beam_search( model: Transducer, encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, beam: int = 4, ) -> List[List[int]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Args: model: The transducer model. encoder_out: Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. beam: Number of active paths during the beam search. Returns: @@ -427,15 +589,26 @@ def modified_beam_search( for the i-th utterance. """ assert encoder_out.ndim == 3, encoder_out.shape - - batch_size = encoder_out.size(0) - T = encoder_out.size(1) + assert encoder_out.size(0) >= 1, encoder_out.size(0) + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device - B = [HypothesisList() for _ in range(batch_size)] - for i in range(batch_size): + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): B[i].add( Hypothesis( ys=[blank_id] * context_size, @@ -443,11 +616,20 @@ def modified_beam_search( ) ) - encoder_out = model.joiner.encoder_proj(encoder_out) + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - for t in range(T): - current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] hyps_shape = _get_hyps_shape(B).to(device) @@ -503,8 +685,10 @@ def modified_beam_search( for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() for k in range(len(topk_hyp_indexes)): hyp_idx = topk_hyp_indexes[k] @@ -512,15 +696,21 @@ def modified_beam_search( new_ys = hyp.ys[:] new_token = topk_token_indexes[k] - if new_token != blank_id: + if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_log_prob = topk_log_probs[k] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) B[i].add(new_hyp) - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - ans = [h.ys[context_size:] for h in best_hyps] + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=False) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) return ans @@ -531,12 +721,9 @@ def _deprecated_modified_beam_search( beam: int = 4, ) -> List[int]: """It limits the maximum number of symbols per frame to 1. - It decodes only one utterance at a time. We keep it only for reference. The function :func:`modified_beam_search` should be preferred as it supports batch decoding. - - Args: model: An instance of `Transducer`. @@ -553,9 +740,10 @@ def _deprecated_modified_beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device T = encoder_out.size(1) @@ -614,14 +802,16 @@ def _deprecated_modified_beam_search( topk_hyp_indexes = topk_indexes // logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1) - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = topk_hyp_indexes.tolist() + topk_token_indexes = topk_token_indexes.tolist() for i in range(len(topk_hyp_indexes)): hyp = A[topk_hyp_indexes[i]] new_ys = hyp.ys[:] new_token = topk_token_indexes[i] - if new_token != blank_id: + if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_log_prob = topk_log_probs[i] new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) @@ -640,9 +830,7 @@ def beam_search( ) -> List[int]: """ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - Args: model: An instance of `Transducer`. @@ -658,9 +846,10 @@ def beam_search( # support only batch_size == 1 for now assert encoder_out.size(0) == 1, encoder_out.size(0) blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) context_size = model.decoder.context_size - device = model.device + device = next(model.parameters()).device decoder_input = torch.tensor( [blank_id] * context_size, @@ -743,7 +932,7 @@ def beam_search( # Second, process other non-blank labels values, indices = log_prob.topk(beam + 1) for i, v in zip(indices.tolist(), values.tolist()): - if i == blank_id: + if i in (blank_id, unk_id): continue new_ys = y_star.ys + [i] new_log_prob = y_star.log_prob + v diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py index 05f1ead69..83a442b90 100755 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/decode.py @@ -59,10 +59,10 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn -from asr_datamodule import WenetSpeechAsrDataModule +from asr_datamodule import Aidatatang_200zhAsrDataModule from beam_search import ( beam_search, - fast_beam_search, + fast_beam_search_one_best, greedy_search, greedy_search_batch, modified_beam_search, @@ -255,7 +255,7 @@ def decode_one_batch( hyps = [] if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search( + hyp_tokens = fast_beam_search_one_best( model=model, decoding_graph=decoding_graph, encoder_out=encoder_out, @@ -273,6 +273,7 @@ def decode_one_batch( hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) @@ -280,6 +281,7 @@ def decode_one_batch( hyp_tokens = modified_beam_search( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, beam=params.beam_size, ) for i in range(encoder_out.size(0)): @@ -359,12 +361,12 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 100 else: - log_interval = 2 + log_interval = 50 results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] - texts = [list(str(text)) for text in texts] + texts = [list(str(text).replace(" ", "")) for text in texts] hyps_dict = decode_one_batch( params=params, @@ -440,7 +442,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - WenetSpeechAsrDataModule.add_arguments(parser) + Aidatatang_200zhAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -506,6 +508,13 @@ def main(): model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) + average = average_checkpoints(filenames, device=device) + checkpoint = {"model": average} + torch.save( + checkpoint, + "pruned_transducer_stateless2/pretrained_average_11_to_29.pt", + ) + model.to(device) model.eval() model.device = device @@ -526,33 +535,26 @@ def main(): from lhotse import CutSet from lhotse.dataset.webdataset import export_to_webdataset - wenetspeech = WenetSpeechAsrDataModule(args) + aidatatang_200zh = Aidatatang_200zhAsrDataModule(args) dev = "dev" - test_net = "test_net" - test_meet = "test_meet" + test = "test" if not os.path.exists(f"{dev}/shared-0.tar"): - dev_cuts = wenetspeech.valid_cuts() + os.makedirs(dev) + dev_cuts = aidatatang_200zh.valid_cuts() export_to_webdataset( dev_cuts, output_path=f"{dev}/shared-%d.tar", shard_size=300, ) - if not os.path.exists(f"{test_net}/shared-0.tar"): - test_net_cuts = wenetspeech.test_net_cuts() + if not os.path.exists(f"{test}/shared-0.tar"): + os.makedirs(test) + test_cuts = aidatatang_200zh.test_cuts() export_to_webdataset( - test_net_cuts, - output_path=f"{test_net}/shared-%d.tar", - shard_size=300, - ) - - if not os.path.exists(f"{test_meet}/shared-0.tar"): - test_meeting_cuts = wenetspeech.test_meeting_cuts() - export_to_webdataset( - test_meeting_cuts, - output_path=f"{test_meet}/shared-%d.tar", + test_cuts, + output_path=f"{test}/shared-%d.tar", shard_size=300, ) @@ -567,34 +569,22 @@ def main(): shuffle_shards=True, ) - test_net_shards = [ + test_shards = [ str(path) - for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) + for path in sorted(glob.glob(os.path.join(test, "shared-*.tar"))) ] - cuts_test_net_webdataset = CutSet.from_webdataset( - test_net_shards, + cuts_test_webdataset = CutSet.from_webdataset( + test_shards, split_by_worker=True, split_by_node=True, shuffle_shards=True, ) - test_meet_shards = [ - str(path) - for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar"))) - ] - cuts_test_meet_webdataset = CutSet.from_webdataset( - test_meet_shards, - split_by_worker=True, - split_by_node=True, - shuffle_shards=True, - ) + dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset) + test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset) - dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) - test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) - test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset) - - test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] - test_dl = [dev_dl, test_net_dl, test_meeting_dl] + test_sets = ["dev", "test"] + test_dl = [dev_dl, test_dl] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py index c72aa3464..43033e517 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -21,8 +21,8 @@ Usage: ./pruned_transducer_stateless2/export.py \ --exp-dir ./pruned_transducer_stateless2/exp \ --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 + --epoch 29 \ + --avg 19 It will generate a file exp_dir/pretrained.pt @@ -32,7 +32,7 @@ you can do: cd /path/to/exp_dir ln -s pretrained.pt epoch-9999.pt - cd /path/to/egs/wenetspeech/ASR + cd /path/to/egs/aidatatang_200zh/ASR ./pruned_transducer_stateless2/decode.py \ --exp-dir ./pruned_transducer_stateless2/exp \ --epoch 9999 \ diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py new file mode 100644 index 000000000..eb5e6b0d4 --- /dev/null +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/pretrained.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 Xiaomi Crop. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --lang-dir ./data/lang_char \ + --method greedy_search \ + --max-sym-per-frame 1 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) modified beam search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ + --lang-dir ./data/lang_char \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) fast beam search +./pruned_transducer_stateless2/pretrained.py \ + --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ + --lang-dir ./data/lang_char \ + --method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by +./pruned_transducer_stateless2/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import get_params, get_transducer_model + +from icefall.lexicon import Lexicon + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--lang-dir", + type=str, + help="""Path to lang. + """, + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="Used only when --method is beam_search and modified_beam_search ", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + with torch.no_grad(): + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + hyps = [] + msg = f"Using {params.decoding_method}" + logging.info(msg) + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py index 50dfe8917..1fd709320 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/train.py @@ -103,7 +103,7 @@ def get_parser(): parser.add_argument( "--master-port", type=int, - default=12354, + default=12359, help="Master port to use for DDP training.", )