From 250ff30875c0a545a5abc75682e1a1faec8344b6 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Aug 2024 15:08:31 +0900 Subject: [PATCH 01/10] add streaming support to reazonresearch --- egs/librispeech/ASR/zipformer/decode.py | 88 +-- egs/librispeech/ASR/zipformer/zipformer.py | 2 +- egs/reazonspeech/ASR/local/utils/tokenizer.py | 1 - .../ASR/zipformer/streaming_decode.py | 526 ++++++++++++++---- 4 files changed, 482 insertions(+), 135 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index cbfb3728e..f14ea847a 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -103,10 +103,9 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import ReazonSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -122,6 +121,7 @@ from beam_search import ( modified_beam_search_LODR, ) from lhotse import set_caching_enabled +from tokenizer import Tokenizer from train import add_model_arguments, get_model, get_params from icefall import ContextGraph, LmScorer, NgramLm @@ -134,6 +134,7 @@ from icefall.checkpoint import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + make_pad_mask, setup_logger, store_transcripts, str2bool, @@ -204,7 +205,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=Path, - default="data/lang_bpe_500", + default="data/lang_char", help="The lang dir containing word table and LG graph", ) @@ -377,6 +378,17 @@ def get_parser(): default=False, help="""Skip scoring, but still save the ASR output (for eval sets).""", ) + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) add_model_arguments(parser) @@ -386,7 +398,7 @@ def get_parser(): def decode_one_batch( params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + sp: Tokenizer, batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, @@ -465,9 +477,10 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -479,6 +492,7 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -493,9 +507,10 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -508,17 +523,19 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -526,9 +543,10 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, context_graph=context_graph, + blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, @@ -538,7 +556,7 @@ def decode_one_batch( LM=LM, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -551,7 +569,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyps.append(sp.text2word(hyp)) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -597,10 +615,11 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + hyps.append(sp.text2word(sp.decode(hyp))) # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) prefix = f"{params.decoding_method}" + key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": return {"greedy_search": hyps} elif "fast_beam_search" in params.decoding_method: @@ -639,7 +658,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - sp: spm.SentencePieceProcessor, + sp: Tokenizer, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, context_graph: Optional[ContextGraph] = None, @@ -705,7 +724,7 @@ def decode_dataset( this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() + ref_words = sp.text2word(ref_text) this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -778,8 +797,8 @@ def save_wer_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - LmScorer.add_arguments(parser) + ReazonSpeechAsrDataModule.add_arguments(parser) + Tokenizer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -853,6 +872,8 @@ def main(): f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" ) + params.suffix += f"-blank-penalty-{params.blank_penalty}" + if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -865,10 +886,9 @@ def main(): logging.info(f"Device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + sp = Tokenizer.load(params.lang, params.lang_type) - # and are defined in local/train_bpe_model.py + # and are defined in local/prepare_lang_char.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -1040,20 +1060,13 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) + reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): + for subdir in ["valid"]: results_dict = decode_dataset( - dl=test_dl, + dl=reazonspeech_corpus.test_dataloaders( + getattr(reazonspeech_corpus, f"{subdir}_cuts")() + ), params=params, model=model, sp=sp, @@ -1067,9 +1080,20 @@ def main(): save_asr_output( params=params, - test_set_name=test_set, + test_set_name=subdir, results_dict=results_dict, ) + # with ( + # params.res_dir + # / ( + # f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" + # f"_{params.avg}_{params.epoch}.cer" + # ) + # ).open("w") as fout: + # if len(tot_err) == 1: + # fout.write(f"{tot_err[0][1]}") + # else: + # fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) if not params.skip_scoring: save_wer_results( diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 69059287b..d42a5b145 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -2434,4 +2434,4 @@ if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) _test_zipformer_main(False) - _test_zipformer_main(True) + _test_zipformer_main(True) \ No newline at end of file diff --git a/egs/reazonspeech/ASR/local/utils/tokenizer.py b/egs/reazonspeech/ASR/local/utils/tokenizer.py index c9be72be1..ba71cff89 100644 --- a/egs/reazonspeech/ASR/local/utils/tokenizer.py +++ b/egs/reazonspeech/ASR/local/utils/tokenizer.py @@ -12,7 +12,6 @@ class Tokenizer: @staticmethod def add_arguments(parser: argparse.ArgumentParser): group = parser.add_argument_group(title="Lang related options") - group.add_argument("--lang", type=Path, help="Path to lang directory.") group.add_argument( diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py index 4c18c7563..1a726724d 100755 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,28 +18,29 @@ """ Usage: -./pruned_transducer_stateless7_streaming/streaming_decode.py \ +./zipformer/streaming_decode.py \ --epoch 28 \ --avg 15 \ - --decode-chunk-len 32 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --decoding_method greedy_search \ - --lang data/lang_char \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ --num-decode-streams 2000 """ +import pdb import argparse import logging import math from pathlib import Path from typing import Dict, List, Optional, Tuple +from tokenizer import Tokenizer import k2 import numpy as np import torch -import torch.nn as nn from asr_datamodule import ReazonSpeechAsrDataModule -from decode import save_results from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet @@ -47,10 +49,9 @@ from streaming_beam_search import ( greedy_search, modified_beam_search, ) -from tokenizer import Tokenizer +from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states, unstack_states +from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, @@ -58,7 +59,17 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import AttributeDict, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +import subprocess as sp +import os LOG_EPS = math.log(1e-10) @@ -73,7 +84,7 @@ def get_parser(): type=int, default=28, help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. + Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) @@ -87,12 +98,6 @@ def get_parser(): """, ) - parser.add_argument( - "--gpu", - type=int, - default=0, - ) - parser.add_argument( "--avg", type=int, @@ -116,7 +121,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless2/exp", + default="zipformer/exp", help="The experiment dir", ) @@ -126,6 +131,13 @@ def get_parser(): default="data/lang_bpe_500/bpe.model", help="Path to the BPE model", ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) parser.add_argument( "--decoding-method", @@ -138,14 +150,6 @@ def get_parser(): """, ) - parser.add_argument( - "--decoding-graph", - type=str, - default="", - help="""Used only when --decoding-method is - fast_beam_search""", - ) - parser.add_argument( "--num_active_paths", type=int, @@ -157,7 +161,7 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4.0, + default=4, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. @@ -194,18 +198,235 @@ def get_parser(): help="The number of streams that can be decoded parallel.", ) - parser.add_argument( - "--res-dir", - type=Path, - default=None, - help="The path to save results.", - ) - add_model_arguments(parser) return parser +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -224,27 +445,34 @@ def decode_one_chunk( Returns: Return a List containing which DecodeStreams are finished. """ - device = model.device + # pdb.set_trace() + # print(model) + # print(model.device) + # device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) features = [] feature_lens = [] states = [] - processed_lens = [] - + processed_lens = [] # Used in fast-beam-search + for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) + feat, feat_len = stream.get_feature_frames(chunk_size * 2) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) + + print(feature_lens) + feature_lens = torch.tensor(feature_lens, device=model.device) + print(feature_lens) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling - # factor in encoders is 8. - # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. - tail_length = 23 + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -256,12 +484,14 @@ def decode_one_chunk( ) states = stack_states(states) - processed_lens = torch.tensor(processed_lens, device=device) - encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, ) encoder_out = model.joiner.encoder_proj(encoder_out) @@ -269,6 +499,7 @@ def decode_one_chunk( if params.decoding_method == "greedy_search": greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( model=model, @@ -295,9 +526,11 @@ def decode_one_chunk( for i in range(len(decode_streams)): decode_streams[i].states = states[i] decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - + # if decode_streams[i].done: + # finished_streams.append(i) + finished_streams.append(i) + + print(finished_streams) return finished_streams @@ -338,14 +571,14 @@ def decode_dataset( opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 - log_interval = 50 + log_interval = 100 decode_results = [] # Contain decode streams currently running. decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = model.encoder.get_init_state(device=device) + initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -361,15 +594,19 @@ def decode_dataset( assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." samples = torch.from_numpy(audio).squeeze(0) fbank = Fbank(opts) feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) - decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] - + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text decode_streams.append(decode_stream) while len(decode_streams) >= params.num_decode_streams: @@ -380,8 +617,8 @@ def decode_dataset( decode_results.append( ( decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), ) ) del decode_streams[i] @@ -389,21 +626,43 @@ def decode_dataset( if num % log_interval == 0: logging.info(f"Cuts processed until now is {num}.") + print("cuts processed finished") + print(len(decode_streams)) # decode final chunks of last sequences while len(decode_streams): + # print("INSIDE LEN DECODE STREAMS") + # pdb.set_trace() + # print(model.device) + # test_device = model.device + # print("done") finished_streams = decode_one_chunk( params=params, model=model, decode_streams=decode_streams ) + # print('INSIDE FOR LOOP ') + # print(finished_streams) + + if not finished_streams: + print("No finished streams, breaking the loop") + break + + for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), - ) - ) - del decode_streams[i] - + try: + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + except IndexError as e: + print(f"IndexError: {e}") + print(f"decode_streams length: {len(decode_streams)}") + print(f"finished_streams: {finished_streams}") + print(f"i: {i}") + continue + if params.decoding_method == "greedy_search": key = "greedy_search" elif params.decoding_method == "fast_beam_search": @@ -416,9 +675,57 @@ def decode_dataset( key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + torch.cuda.synchronize() return {key: decode_results} +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + print("error stats") + print("results") + print(results) + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + @torch.no_grad() def main(): parser = get_parser() @@ -430,16 +737,20 @@ def main(): params = get_params() params.update(vars(args)) - if not params.res_dir: - params.res_dir = params.exp_dir / "streaming" / params.decoding_method + params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" # for fast_beam_search if params.decoding_method == "fast_beam_search": @@ -455,13 +766,13 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): - device = torch.device("cuda", params.gpu) + device = torch.device("cuda", 0) logging.info(f"Device: {device}") sp = Tokenizer.load(params.lang, params.lang_type) - # and is defined in local/prepare_lang_char.py + # and is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -469,7 +780,7 @@ def main(): logging.info(params) logging.info("About to create model") - model = get_transducer_model(params) + model = get_model(params) if not params.use_averaged_model: if params.iter > 0: @@ -553,43 +864,56 @@ def main(): model.device = device decoding_graph = None - if params.decoding_graph: - decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) - ) - elif params.decoding_method == "fast_beam_search": + if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - + + # we need cut ids to display recognition results. args.return_cuts = True reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + + valid_cuts = reazonspeech_corpus.valid_cuts() + test_cuts = reazonspeech_corpus.test_cuts() - for subdir in ["valid"]: + test_sets = ["valid", "test"] + test_cuts = [valid_cuts, test_cuts] + print('test cuts') + print(test_cuts) + + for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset( - cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(), + cuts=test_cut, params=params, model=model, sp=sp, decoding_graph=decoding_graph, ) - tot_err = save_results( - params=params, test_set_name=subdir, results_dict=results_dict + print(r"esults_dict") + print(results_dict) + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, ) - - with ( - params.res_dir - / ( - f"{subdir}-{params.decode_chunk_len}" - f"_{params.avg}_{params.epoch}.cer" - ) - ).open("w") as fout: - if len(tot_err) == 1: - fout.write(f"{tot_err[0][1]}") - else: - fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) - + + # valid_cuts = reazonspeech_corpus.valid_cuts() + + # for valid_cut in valid_cuts: + # results_dict = decode_dataset( + # cuts=valid_cut, + # params=params, + # model=model, + # sp=sp, + # decoding_graph=decoding_graph, + # ) + # save_results( + # params=params, + # test_set_name="valid", + # results_dict=results_dict, + # ) + logging.info("Done!") From d24f1e1baeba0f556c5cc4bdf761646a3f6716d3 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Aug 2024 15:31:20 +0900 Subject: [PATCH 02/10] update README for streaming --- egs/reazonspeech/ASR/RESULTS.md | 38 +++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/egs/reazonspeech/ASR/RESULTS.md b/egs/reazonspeech/ASR/RESULTS.md index c0b4fe54a..92610d75b 100644 --- a/egs/reazonspeech/ASR/RESULTS.md +++ b/egs/reazonspeech/ASR/RESULTS.md @@ -47,3 +47,41 @@ The decoding command is: --blank-penalty 0 ``` +#### Streaming + +We have not completed evaluation of our models yet and will add evaluation results here once it's completed. + +The training command is: +```shell +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 40 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp-large \ + --causal 1 \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 \ + --lang data/lang_char \ + --max-duration 1600 +``` + +The decoding command is: + +```shell +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp-large \ + --lang data/lang_char \ + --num-encoder-layers 2,2,4,5,4,2 \ + --feedforward-dim 512,768,1536,2048,1536,768 \ + --encoder-dim 192,256,512,768,512,256 \ + --encoder-unmasked-dim 192,192,256,320,256,192 +``` + From ed5fa02c5c759f99c30afc183f3c5e51642173ae Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Aug 2024 15:43:44 +0900 Subject: [PATCH 03/10] update streaming decoding file --- ...2-left-context-256-use-averaged-model-2024-07-29-17-52-11 | 2 ++ ...2-left-context-256-use-averaged-model-2024-07-29-17-54-22 | 2 ++ ...2-left-context-256-use-averaged-model-2024-07-29-17-55-15 | 2 ++ ...2-left-context-256-use-averaged-model-2024-07-29-17-59-02 | 2 ++ ...2-left-context-256-use-averaged-model-2024-07-29-18-01-06 | 5 +++++ 5 files changed, 13 insertions(+) create mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 create mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 create mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 create mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 create mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 new file mode 100644 index 000000000..e20ff996d --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 @@ -0,0 +1,2 @@ +2024-07-29 17:52:11,668 INFO [streaming_decode.py:736] Decoding started +2024-07-29 17:52:11,669 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 new file mode 100644 index 000000000..6ad2cb344 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 @@ -0,0 +1,2 @@ +2024-07-29 17:54:22,556 INFO [streaming_decode.py:736] Decoding started +2024-07-29 17:54:22,556 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 new file mode 100644 index 000000000..40a190b75 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 @@ -0,0 +1,2 @@ +2024-07-29 17:55:15,276 INFO [streaming_decode.py:736] Decoding started +2024-07-29 17:55:15,277 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 new file mode 100644 index 000000000..4c721e7d2 --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 @@ -0,0 +1,2 @@ +2024-07-29 17:59:02,028 INFO [streaming_decode.py:736] Decoding started +2024-07-29 17:59:02,029 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 new file mode 100644 index 000000000..6fe40297f --- /dev/null +++ b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 @@ -0,0 +1,5 @@ +2024-07-29 18:01:06,736 INFO [streaming_decode.py:736] Decoding started +2024-07-29 18:01:06,736 INFO [streaming_decode.py:742] Device: cuda:0 +2024-07-29 18:01:06,740 INFO [streaming_decode.py:753] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.24.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '8f976a1e1407e330e2a233d68f81b1eb5269fdaa', 'k2-git-date': 'Thu Jun 6 02:13:08 2024', 'lhotse-version': '1.26.0.dev+git.bd12d5d.clean', 'torch-version': '2.3.1+cu121', 'torch-cuda-available': True, 'torch-cuda-version': '12.1', 'python-version': '3.10', 'icefall-git-branch': 'jp-streaming', 'icefall-git-sha1': '4af81af-dirty', 'icefall-git-date': 'Thu Jul 18 22:05:59 2024', 'icefall-path': '/root/tmp/icefall', 'k2-path': '/root/miniconda3/envs/myenv/lib/python3.10/site-packages/k2/__init__.py', 'lhotse-path': '/root/miniconda3/envs/myenv/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'KDA00', 'IP address': '192.168.0.1'}, 'epoch': 28, 'iter': 0, 'avg': 15, 'use_averaged_model': True, 'exp_dir': PosixPath('zipformer'), 'bpe_model': 'data/lang_bpe_500/bpe.model', 'lang_dir': PosixPath('data/lang_char'), 'decoding_method': 'greedy_search', 'num_active_paths': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 32, 'context_size': 2, 'num_decode_streams': 2000, 'num_encoder_layers': '2,2,3,4,3,2', 'downsampling_factor': '1,2,4,8,4,2', 'feedforward_dim': '512,768,1024,1536,1024,768', 'num_heads': '4,4,4,8,4,4', 'encoder_dim': '192,256,384,512,384,256', 'query_head_dim': '32', 'value_head_dim': '12', 'pos_head_dim': '4', 'pos_dim': 48, 'encoder_unmasked_dim': '192,192,256,256,256,192', 'cnn_module_kernel': '31,31,15,15,15,31', 'decoder_dim': 512, 'joiner_dim': 512, 'causal': True, 'chunk_size': '32', 'left_context_frames': '256', 'use_transducer': True, 'use_ctc': False, 'manifest_dir': PosixPath('data/manifests'), 'max_duration': 200.0, 'bucketing_sampler': True, 'num_buckets': 30, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'drop_last': True, 'return_cuts': False, 'num_workers': 2, 'enable_spec_aug': True, 'spec_aug_time_warp_factor': 80, 'enable_musan': False, 'lang': PosixPath('data/lang_char'), 'lang_type': None, 'res_dir': PosixPath('zipformer/streaming/greedy_search'), 'suffix': 'epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model', 'blank_id': 0, 'unk_id': 2990, 'vocab_size': 2992} +2024-07-29 18:01:06,740 INFO [streaming_decode.py:755] About to create model +2024-07-29 18:01:07,118 INFO [streaming_decode.py:822] Calculating the averaged model over epoch range from 13 (excluded) to 28 From eebe6add4c19bc49d8f03c686291b8312572d4b0 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Aug 2024 15:44:50 +0900 Subject: [PATCH 04/10] update streaming decode --- egs/reazonspeech/ASR/zipformer/streaming_decode.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py index 1a726724d..0b58853a3 100755 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -18,15 +18,7 @@ """ Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 +./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192 """ import pdb From cfef53dcbee1d372cd81d09d6e0e2daabb7e8cd8 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Aug 2024 16:00:27 +0900 Subject: [PATCH 05/10] remove streaming/greedy_search results folder --- egs/librispeech/ASR/zipformer/decode.py | 306 +++++++++--------- ...256-use-averaged-model-2024-07-29-17-52-11 | 2 - ...256-use-averaged-model-2024-07-29-17-54-22 | 2 - ...256-use-averaged-model-2024-07-29-17-55-15 | 2 - ...256-use-averaged-model-2024-07-29-17-59-02 | 2 - ...256-use-averaged-model-2024-07-29-18-01-06 | 5 - 6 files changed, 148 insertions(+), 171 deletions(-) delete mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 delete mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 delete mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 delete mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 delete mode 100644 egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index f14ea847a..cfe5638b7 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -103,9 +103,10 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule +from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, fast_beam_search_nbest, @@ -134,7 +135,6 @@ from icefall.checkpoint import ( from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - make_pad_mask, setup_logger, store_transcripts, str2bool, @@ -205,7 +205,7 @@ def get_parser(): parser.add_argument( "--lang-dir", type=Path, - default="data/lang_char", + default="data/lang_bpe_500", help="The lang dir containing word table and LG graph", ) @@ -371,6 +371,7 @@ def get_parser(): modified_beam_search_LODR. """, ) +<<<<<<< HEAD parser.add_argument( "--skip-scoring", @@ -398,7 +399,7 @@ def get_parser(): def decode_one_batch( params: AttributeDict, model: nn.Module, - sp: Tokenizer, + sp: spm.SentencePieceProcessor, batch: dict, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, @@ -477,10 +478,9 @@ def decode_one_batch( beam=params.beam, max_contexts=params.max_contexts, max_states=params.max_states, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -492,7 +492,6 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, ) for hyp in hyp_tokens: hyps.append([word_table[i] for i in hyp]) @@ -507,10 +506,9 @@ def decode_one_batch( max_states=params.max_states, num_paths=params.num_paths, nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -523,19 +521,17 @@ def decode_one_batch( num_paths=params.num_paths, ref_texts=sp.encode(supervisions["text"]), nbest_scale=params.nbest_scale, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -543,10 +539,9 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, beam=params.beam_size, context_graph=context_graph, - blank_penalty=params.blank_penalty, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": hyp_tokens = modified_beam_search_lm_shallow_fusion( model=model, @@ -556,7 +551,7 @@ def decode_one_batch( LM=LM, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_LODR": hyp_tokens = modified_beam_search_LODR( model=model, @@ -569,7 +564,7 @@ def decode_one_batch( context_graph=context_graph, ) for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) + hyps.append(hyp.split()) elif params.decoding_method == "modified_beam_search_lm_rescore": lm_scale_list = [0.01 * i for i in range(10, 50)] ans_dict = modified_beam_search_lm_rescore( @@ -615,7 +610,7 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.text2word(sp.decode(hyp))) + hyps.append(sp.decode(hyp).split()) # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) prefix = f"{params.decoding_method}" @@ -636,9 +631,9 @@ def decode_one_batch( elif "modified_beam_search" in params.decoding_method: prefix += f"_beam-size-{params.beam_size}" if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): @@ -655,17 +650,17 @@ def decode_one_batch( def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, + ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -708,23 +703,23 @@ def decode_dataset( cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) for name, hyps in hyps_dict.items(): this_batch = [] assert len(hyps) == len(texts) for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = sp.text2word(ref_text) + ref_words = ref_text.split() this_batch.append((cut_id, ref_words, hyp_words)) results[name].extend(this_batch) @@ -739,10 +734,10 @@ def decode_dataset( def save_asr_output( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], + ): """ Save text produced by ASR. """ @@ -757,10 +752,10 @@ def save_asr_output( def save_wer_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], -): + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], + ): """ Save WER and per-utterance word alignments. """ @@ -771,8 +766,8 @@ def save_wer_results( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - fd, f"{test_set_name}-{key}", results, enable_log=True - ) + fd, f"{test_set_name}-{key}", results, enable_log=True + ) test_set_wers[key] = wer logging.info(f"Wrote detailed error stats to {errs_filename}") @@ -797,8 +792,8 @@ def save_wer_results( @torch.no_grad() def main(): parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) + LibriSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) @@ -809,18 +804,18 @@ def main(): set_caching_enabled(True) # lhotse assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) params.res_dir = params.exp_dir / params.decoding_method if os.path.exists(params.context_file): @@ -835,11 +830,11 @@ def main(): if params.causal: assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" @@ -855,9 +850,9 @@ def main(): elif "beam_search" in params.decoding_method: params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): + "modified_beam_search", + "modified_beam_search_LODR", + ): if params.has_contexts: params.suffix += f"-context-score-{params.context_score}" else: @@ -869,10 +864,8 @@ def main(): if "LODR" in params.decoding_method: params.suffix += ( - f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) - - params.suffix += f"-blank-penalty-{params.blank_penalty}" + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -886,9 +879,10 @@ def main(): logging.info(f"Device: {device}") - sp = Tokenizer.load(params.lang, params.lang_type) + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) - # and are defined in local/prepare_lang_char.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -901,18 +895,18 @@ def main(): if not params.use_averaged_model: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) elif len(filenames) < params.avg: raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) @@ -930,32 +924,32 @@ def main(): else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) elif len(filenames) < params.avg + 1: raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) filename_start = filenames[-1] filename_end = filenames[0] logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) model.to(device) model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) else: assert params.avg > 0, params.avg start = params.epoch - params.avg @@ -963,34 +957,34 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) model.to(device) model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() # only load the neural network LM if required if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) LM.to(device) LM.eval() else: @@ -1016,10 +1010,10 @@ def main(): lm_filename = f"{params.tokens_ngram}gram.fst.txt" logging.info(f"Loading token level lm: {lm_filename}") ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) logging.info(f"num states: {ngram_lm.lm.num_states}") ngram_lm_scale = params.ngram_lm_scale else: @@ -1033,8 +1027,8 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) + torch.load(lg_filename, map_location=device) + ) decoding_graph.scores *= params.ngram_lm_scale else: word_table = None @@ -1060,40 +1054,36 @@ def main(): # we need cut ids to display recognition results. args.return_cuts = True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) + librispeech = LibriSpeechAsrDataModule(args) - for subdir in ["valid"]: + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( - dl=reazonspeech_corpus.test_dataloaders( - getattr(reazonspeech_corpus, f"{subdir}_cuts")() - ), - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) save_asr_output( params=params, - test_set_name=subdir, + test_set_name=test_set, results_dict=results_dict, ) - # with ( - # params.res_dir - # / ( - # f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" - # f"_{params.avg}_{params.epoch}.cer" - # ) - # ).open("w") as fout: - # if len(tot_err) == 1: - # fout.write(f"{tot_err[0][1]}") - # else: - # fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) if not params.skip_scoring: save_wer_results( diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 deleted file mode 100644 index e20ff996d..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-52-11 +++ /dev/null @@ -1,2 +0,0 @@ -2024-07-29 17:52:11,668 INFO [streaming_decode.py:736] Decoding started -2024-07-29 17:52:11,669 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 deleted file mode 100644 index 6ad2cb344..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-54-22 +++ /dev/null @@ -1,2 +0,0 @@ -2024-07-29 17:54:22,556 INFO [streaming_decode.py:736] Decoding started -2024-07-29 17:54:22,556 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 deleted file mode 100644 index 40a190b75..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-55-15 +++ /dev/null @@ -1,2 +0,0 @@ -2024-07-29 17:55:15,276 INFO [streaming_decode.py:736] Decoding started -2024-07-29 17:55:15,277 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 deleted file mode 100644 index 4c721e7d2..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-17-59-02 +++ /dev/null @@ -1,2 +0,0 @@ -2024-07-29 17:59:02,028 INFO [streaming_decode.py:736] Decoding started -2024-07-29 17:59:02,029 INFO [streaming_decode.py:742] Device: cuda:0 diff --git a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 b/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 deleted file mode 100644 index 6fe40297f..000000000 --- a/egs/reazonspeech/ASR/zipformer/streaming/greedy_search/log-decode-epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model-2024-07-29-18-01-06 +++ /dev/null @@ -1,5 +0,0 @@ -2024-07-29 18:01:06,736 INFO [streaming_decode.py:736] Decoding started -2024-07-29 18:01:06,736 INFO [streaming_decode.py:742] Device: cuda:0 -2024-07-29 18:01:06,740 INFO [streaming_decode.py:753] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_idx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'warm_step': 2000, 'env_info': {'k2-version': '1.24.4', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '8f976a1e1407e330e2a233d68f81b1eb5269fdaa', 'k2-git-date': 'Thu Jun 6 02:13:08 2024', 'lhotse-version': '1.26.0.dev+git.bd12d5d.clean', 'torch-version': '2.3.1+cu121', 'torch-cuda-available': True, 'torch-cuda-version': '12.1', 'python-version': '3.10', 'icefall-git-branch': 'jp-streaming', 'icefall-git-sha1': '4af81af-dirty', 'icefall-git-date': 'Thu Jul 18 22:05:59 2024', 'icefall-path': '/root/tmp/icefall', 'k2-path': '/root/miniconda3/envs/myenv/lib/python3.10/site-packages/k2/__init__.py', 'lhotse-path': '/root/miniconda3/envs/myenv/lib/python3.10/site-packages/lhotse/__init__.py', 'hostname': 'KDA00', 'IP address': '192.168.0.1'}, 'epoch': 28, 'iter': 0, 'avg': 15, 'use_averaged_model': True, 'exp_dir': PosixPath('zipformer'), 'bpe_model': 'data/lang_bpe_500/bpe.model', 'lang_dir': PosixPath('data/lang_char'), 'decoding_method': 'greedy_search', 'num_active_paths': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 32, 'context_size': 2, 'num_decode_streams': 2000, 'num_encoder_layers': '2,2,3,4,3,2', 'downsampling_factor': '1,2,4,8,4,2', 'feedforward_dim': '512,768,1024,1536,1024,768', 'num_heads': '4,4,4,8,4,4', 'encoder_dim': '192,256,384,512,384,256', 'query_head_dim': '32', 'value_head_dim': '12', 'pos_head_dim': '4', 'pos_dim': 48, 'encoder_unmasked_dim': '192,192,256,256,256,192', 'cnn_module_kernel': '31,31,15,15,15,31', 'decoder_dim': 512, 'joiner_dim': 512, 'causal': True, 'chunk_size': '32', 'left_context_frames': '256', 'use_transducer': True, 'use_ctc': False, 'manifest_dir': PosixPath('data/manifests'), 'max_duration': 200.0, 'bucketing_sampler': True, 'num_buckets': 30, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': True, 'drop_last': True, 'return_cuts': False, 'num_workers': 2, 'enable_spec_aug': True, 'spec_aug_time_warp_factor': 80, 'enable_musan': False, 'lang': PosixPath('data/lang_char'), 'lang_type': None, 'res_dir': PosixPath('zipformer/streaming/greedy_search'), 'suffix': 'epoch-28-avg-15-chunk-32-left-context-256-use-averaged-model', 'blank_id': 0, 'unk_id': 2990, 'vocab_size': 2992} -2024-07-29 18:01:06,740 INFO [streaming_decode.py:755] About to create model -2024-07-29 18:01:07,118 INFO [streaming_decode.py:822] Calculating the averaged model over epoch range from 13 (excluded) to 28 From 887366370d31e99c8d6648869797f26acd632fd4 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Aug 2024 16:13:59 +0900 Subject: [PATCH 06/10] update for streaming --- icefall/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/icefall/utils.py b/icefall/utils.py index 1dbb954de..9a25784cb 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -631,7 +631,8 @@ def write_error_stats( results[i] = (cut_id, ref, hyp) for cut_id, ref, hyp in results: - ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) + # ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) + ali = kaldialign.align(ref, hyp, ERR) for ref_word, hyp_word in ali: if ref_word == ERR: ins[hyp_word] += 1 From 8d0107a9591fdc43587be957c4acee9c3f506ad1 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Aug 2024 16:17:59 +0900 Subject: [PATCH 07/10] remove prints --- egs/reazonspeech/ASR/zipformer/streaming_decode.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py index 0b58853a3..693c8db1c 100755 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -19,6 +19,7 @@ """ Usage: ./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192 + """ import pdb @@ -456,9 +457,7 @@ def decode_one_chunk( states.append(stream.states) processed_lens.append(stream.done_frames) - print(feature_lens) feature_lens = torch.tensor(feature_lens, device=model.device) - print(feature_lens) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) # Make sure the length after encoder_embed is at least 1. @@ -522,7 +521,6 @@ def decode_one_chunk( # finished_streams.append(i) finished_streams.append(i) - print(finished_streams) return finished_streams @@ -618,8 +616,6 @@ def decode_dataset( if num % log_interval == 0: logging.info(f"Cuts processed until now is {num}.") - print("cuts processed finished") - print(len(decode_streams)) # decode final chunks of last sequences while len(decode_streams): # print("INSIDE LEN DECODE STREAMS") @@ -691,9 +687,6 @@ def save_results( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: - print("error stats") - print("results") - print(results) wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True ) @@ -871,8 +864,6 @@ def main(): test_sets = ["valid", "test"] test_cuts = [valid_cuts, test_cuts] - print('test cuts') - print(test_cuts) for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset( @@ -882,8 +873,6 @@ def main(): sp=sp, decoding_graph=decoding_graph, ) - print(r"esults_dict") - print(results_dict) save_results( params=params, test_set_name=test_set, From 563292599b286f00cfcaebe8022e9a5f1b829dcf Mon Sep 17 00:00:00 2001 From: root Date: Thu, 1 Aug 2024 18:27:34 +0900 Subject: [PATCH 08/10] resolve PR issue --- egs/reazonspeech/ASR/zipformer/streaming_decode.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py index 693c8db1c..81bdc4845 100755 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -490,7 +490,7 @@ def decode_one_chunk( if params.decoding_method == "greedy_search": greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) elif params.decoding_method == "fast_beam_search": - processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = torch.tensor(processed_lens, device=model.device) processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( model=model, From 814d3ac702fd8bf995660658ba241c9ad41ce2ab Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 Aug 2024 15:29:29 +0900 Subject: [PATCH 09/10] fix formatting issues --- .../ASR/zipformer/streaming_decode.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py index 81bdc4845..9274f4dc4 100755 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -22,13 +22,15 @@ Usage: """ -import pdb import argparse import logging import math +import os +import pdb + +# import subprocess as sp from pathlib import Path from typing import Dict, List, Optional, Tuple -from tokenizer import Tokenizer import k2 import numpy as np @@ -42,6 +44,7 @@ from streaming_beam_search import ( greedy_search, modified_beam_search, ) +from tokenizer import Tokenizer from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params @@ -61,9 +64,6 @@ from icefall.utils import ( write_error_stats, ) -import subprocess as sp -import os - LOG_EPS = math.log(1e-10) @@ -124,7 +124,7 @@ def get_parser(): default="data/lang_bpe_500/bpe.model", help="Path to the BPE model", ) - + parser.add_argument( "--lang-dir", type=Path, @@ -449,14 +449,14 @@ def decode_one_chunk( feature_lens = [] states = [] processed_lens = [] # Used in fast-beam-search - + for stream in decode_streams: feat, feat_len = stream.get_feature_frames(chunk_size * 2) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - + feature_lens = torch.tensor(feature_lens, device=model.device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) @@ -518,9 +518,9 @@ def decode_one_chunk( decode_streams[i].states = states[i] decode_streams[i].done_frames += encoder_out_lens[i] # if decode_streams[i].done: - # finished_streams.append(i) + # finished_streams.append(i) finished_streams.append(i) - + return finished_streams @@ -628,21 +628,20 @@ def decode_dataset( ) # print('INSIDE FOR LOOP ') # print(finished_streams) - + if not finished_streams: print("No finished streams, breaking the loop") break - - + for i in sorted(finished_streams, reverse=True): - try: + try: decode_results.append( ( decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) - ) + ) del decode_streams[i] except IndexError as e: print(f"IndexError: {e}") @@ -650,7 +649,7 @@ def decode_dataset( print(f"finished_streams: {finished_streams}") print(f"i: {i}") continue - + if params.decoding_method == "greedy_search": key = "greedy_search" elif params.decoding_method == "fast_beam_search": @@ -663,7 +662,7 @@ def decode_dataset( key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - torch.cuda.synchronize() + torch.cuda.synchronize() return {key: decode_results} @@ -854,11 +853,11 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - + # we need cut ids to display recognition results. args.return_cuts = True reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - + valid_cuts = reazonspeech_corpus.valid_cuts() test_cuts = reazonspeech_corpus.test_cuts() @@ -878,9 +877,9 @@ def main(): test_set_name=test_set, results_dict=results_dict, ) - + # valid_cuts = reazonspeech_corpus.valid_cuts() - + # for valid_cut in valid_cuts: # results_dict = decode_dataset( # cuts=valid_cut, @@ -894,7 +893,7 @@ def main(): # test_set_name="valid", # results_dict=results_dict, # ) - + logging.info("Done!") From 9b13eac94686fe4a31eb190b8deaf3ff967feacd Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 Aug 2024 15:47:28 +0900 Subject: [PATCH 10/10] fix misc line --- egs/librispeech/ASR/zipformer/decode.py | 235 ++++++++++++------------ 1 file changed, 117 insertions(+), 118 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index cfe5638b7..52a489eb3 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -371,7 +371,6 @@ def get_parser(): modified_beam_search_LODR. """, ) -<<<<<<< HEAD parser.add_argument( "--skip-scoring", @@ -631,9 +630,9 @@ def decode_one_batch( elif "modified_beam_search" in params.decoding_method: prefix += f"_beam-size-{params.beam_size}" if params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ): + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): ans = dict() assert ans_dict is not None for key, hyps in ans_dict.items(): @@ -650,17 +649,17 @@ def decode_one_batch( def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, - context_graph: Optional[ContextGraph] = None, - LM: Optional[LmScorer] = None, - ngram_lm=None, - ngram_lm_scale: float = 0.0, - ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: """Decode dataset. Args: @@ -703,17 +702,17 @@ def decode_dataset( cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - context_graph=context_graph, - word_table=word_table, - batch=batch, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) for name, hyps in hyps_dict.items(): this_batch = [] @@ -734,10 +733,10 @@ def decode_dataset( def save_asr_output( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], - ): + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): """ Save text produced by ASR. """ @@ -752,10 +751,10 @@ def save_asr_output( def save_wer_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], - ): + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str], Tuple]]], +): """ Save WER and per-utterance word alignments. """ @@ -766,8 +765,8 @@ def save_wer_results( errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" with open(errs_filename, "w", encoding="utf8") as fd: wer = write_error_stats( - fd, f"{test_set_name}-{key}", results, enable_log=True - ) + fd, f"{test_set_name}-{key}", results, enable_log=True + ) test_set_wers[key] = wer logging.info(f"Wrote detailed error stats to {errs_filename}") @@ -804,18 +803,18 @@ def main(): set_caching_enabled(True) # lhotse assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - "modified_beam_search_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - ) + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) params.res_dir = params.exp_dir / params.decoding_method if os.path.exists(params.context_file): @@ -830,11 +829,11 @@ def main(): if params.causal: assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." params.suffix += f"_chunk-{params.chunk_size}" params.suffix += f"_left-context-{params.left_context_frames}" @@ -850,9 +849,9 @@ def main(): elif "beam_search" in params.decoding_method: params.suffix += f"__{params.decoding_method}__beam-size-{params.beam_size}" if params.decoding_method in ( - "modified_beam_search", - "modified_beam_search_LODR", - ): + "modified_beam_search", + "modified_beam_search_LODR", + ): if params.has_contexts: params.suffix += f"-context-score-{params.context_score}" else: @@ -864,8 +863,8 @@ def main(): if "LODR" in params.decoding_method: params.suffix += ( - f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" - ) + f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) if params.use_averaged_model: params.suffix += "_use-averaged-model" @@ -895,18 +894,18 @@ def main(): if not params.use_averaged_model: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + : params.avg + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) elif len(filenames) < params.avg: raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) @@ -924,32 +923,32 @@ def main(): else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + : params.avg + 1 + ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) elif len(filenames) < params.avg + 1: raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) filename_start = filenames[-1] filename_end = filenames[0] logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) model.to(device) model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) else: assert params.avg > 0, params.avg start = params.epoch - params.avg @@ -957,34 +956,34 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) model.to(device) model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) model.to(device) model.eval() # only load the neural network LM if required if params.use_shallow_fusion or params.decoding_method in ( - "modified_beam_search_lm_rescore", - "modified_beam_search_lm_rescore_LODR", - "modified_beam_search_lm_shallow_fusion", - "modified_beam_search_LODR", - ): + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): LM = LmScorer( - lm_type=params.lm_type, - params=params, - device=device, - lm_scale=params.lm_scale, - ) + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) LM.to(device) LM.eval() else: @@ -1010,10 +1009,10 @@ def main(): lm_filename = f"{params.tokens_ngram}gram.fst.txt" logging.info(f"Loading token level lm: {lm_filename}") ngram_lm = NgramLm( - str(params.lang_dir / lm_filename), - backoff_id=params.backoff_id, - is_binary=False, - ) + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) logging.info(f"num states: {ngram_lm.lm.num_states}") ngram_lm_scale = params.ngram_lm_scale else: @@ -1027,8 +1026,8 @@ def main(): lg_filename = params.lang_dir / "LG.pt" logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) + torch.load(lg_filename, map_location=device) + ) decoding_graph.scores *= params.ngram_lm_scale else: word_table = None @@ -1067,17 +1066,17 @@ def main(): for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - context_graph=context_graph, - LM=LM, - ngram_lm=ngram_lm, - ngram_lm_scale=ngram_lm_scale, - ) + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) save_asr_output( params=params,