From 7b6e0de5bbc7b7e3a8b36e4fe1969a046d41eb23 Mon Sep 17 00:00:00 2001 From: jinzr Date: Fri, 3 Nov 2023 11:38:37 +0800 Subject: [PATCH 01/15] Delete streaming_decode.py --- egs/multi_zh-hans/ASR/zipformer/streaming_decode.py | 1 - 1 file changed, 1 deletion(-) delete mode 120000 egs/multi_zh-hans/ASR/zipformer/streaming_decode.py diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py deleted file mode 120000 index 13fd02a78..000000000 --- a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/zipformer/streaming_decode.py \ No newline at end of file From 8fe390518c696bf6b1ae174d3c05a40be8120185 Mon Sep 17 00:00:00 2001 From: jinzr Date: Fri, 3 Nov 2023 11:45:59 +0800 Subject: [PATCH 02/15] refined `streaming_decode.py` --- .../ASR/zipformer/decode_stream.py | 1 + .../ASR/zipformer/streaming_decode.py | 858 ++++++++++++++++++ 2 files changed, 859 insertions(+) create mode 120000 egs/multi_zh-hans/ASR/zipformer/decode_stream.py create mode 100755 egs/multi_zh-hans/ASR/zipformer/streaming_decode.py diff --git a/egs/multi_zh-hans/ASR/zipformer/decode_stream.py b/egs/multi_zh-hans/ASR/zipformer/decode_stream.py new file mode 120000 index 000000000..b8d8ddfc4 --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decode_stream.py \ No newline at end of file diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py new file mode 100755 index 000000000..dfc046fdc --- /dev/null +++ b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py @@ -0,0 +1,858 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import AsrDataModule +from decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from multi_dataset import MultiDataset +from streaming_beam_search import ( + fast_beam_search_one_best, + greedy_search, + modified_beam_search, +) +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + if params.decoding_method == "greedy_search": + greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + elif params.decoding_method == "fast_beam_search": + processed_lens = torch.tensor(processed_lens, device=device) + processed_lens = processed_lens + encoder_out_lens + fast_beam_search_one_best( + model=model, + encoder_out=encoder_out, + processed_lens=processed_lens, + streams=decode_streams, + beam=params.beam, + max_states=params.max_states, + max_contexts=params.max_contexts, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=decode_streams, + encoder_out=encoder_out, + num_active_paths=params.num_active_paths, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + elif params.decoding_method == "modified_beam_search": + key = f"num_active_paths_{params.num_active_paths}" + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + decoding_graph = None + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + data_module = AsrDataModule(args) + multi_dataset = MultiDataset(args.manifest_dir) + + test_sets_cuts = multi_dataset.test_cuts() + test_sets = test_sets_cuts.keys() + test_cuts = [test_sets_cuts[cuts_name] for cuts_name in test_sets] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() From 72fd8e2c7c2a9cce5097521251d42aba7314b2d3 Mon Sep 17 00:00:00 2001 From: jinzr Date: Sun, 5 Nov 2023 16:42:46 +0800 Subject: [PATCH 03/15] black formatted --- egs/multi_zh-hans/ASR/zipformer/streaming_decode.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py index dfc046fdc..34ff7ad45 100755 --- a/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py +++ b/egs/multi_zh-hans/ASR/zipformer/streaming_decode.py @@ -376,11 +376,7 @@ def streaming_forward( Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( + (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, From 2aab91e561566f703155a740fe5dea054765ce73 Mon Sep 17 00:00:00 2001 From: jinzr Date: Mon, 6 Nov 2023 10:45:01 +0800 Subject: [PATCH 04/15] Update run-multi-zh_hans-zipformer.sh --- .../scripts/run-multi-zh_hans-zipformer.sh | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-zh_hans-zipformer.sh index dd32a94f8..7cc02de87 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-zh_hans-zipformer.sh @@ -92,4 +92,53 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav +done + +log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 ====" +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05/ + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-20.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + + +./zipformer/pretrained.py \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-ctc 1 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --method greedy_search \ +$repo/test_wavs/DEV_T0000000000.wav \ +$repo/test_wavs/DEV_T0000000001.wav \ +$repo/test_wavs/DEV_T0000000002.wav + +for method in modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --use-ctc 1 \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav done \ No newline at end of file From dfe263a7fb6d4b01320521d67dba7f835b0b1710 Mon Sep 17 00:00:00 2001 From: jinzr Date: Mon, 6 Nov 2023 14:23:25 +0800 Subject: [PATCH 05/15] Update run-multi-zh_hans-zipformer.sh --- .github/scripts/run-multi-zh_hans-zipformer.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/scripts/run-multi-zh_hans-zipformer.sh b/.github/scripts/run-multi-zh_hans-zipformer.sh index 7cc02de87..02d262747 100755 --- a/.github/scripts/run-multi-zh_hans-zipformer.sh +++ b/.github/scripts/run-multi-zh_hans-zipformer.sh @@ -51,6 +51,8 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000002.wav done +rm -rf $repo + log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 ====" repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ @@ -94,6 +96,8 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000002.wav done +rm -rf $repo + log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 ====" repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05/ @@ -141,4 +145,6 @@ for method in modified_beam_search fast_beam_search; do $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ $repo/test_wavs/DEV_T0000000002.wav -done \ No newline at end of file +done + +rm -rf $repo \ No newline at end of file From 4833bc0defa74bc82b63d3056cca1100d88438b4 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 9 Nov 2023 10:21:08 +0800 Subject: [PATCH 06/15] Update RESULTS.md --- egs/multi_zh-hans/ASR/RESULTS.md | 41 ++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index 15e789604..8c54806c4 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -4,6 +4,47 @@ This is the [pull request #1238](https://github.com/k2-fsa/icefall/pull/1238) in icefall. +#### Streaming (with CTC head) + +Best results (num of params : ~69M): + +The training command: + +``` +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 20 \ + --use-fp16 1 \ + --max-duration 600 \ + --num-workers 8 \ + --use-ctc 1 \ + --causal 1 \ + --exp-dir ./zipformer/exp-streaming +``` + +The decoding command: + +``` +./zipformer/decode.py \ + --causal 1 \ + --avg 1 \ + --epoch 20 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --use-ctc 1 \ + --exp-dir ./zipformer/exp-streaming +``` + +Character Error Rates (CERs) listed below are produced by the checkpoint of the 20th epoch using BPE model ( # tokens is 2000, byte fallback enabled). + +| Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | +|--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| +| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| CTC Decoding | 6.17 | 6.75 | 30.0 | 31.44 | 3.8 | 4.48 | 4.69 | 5.18 | 19.59 | 10.85 | 9.83 | 14.45 | 8.53 | 16.48 | 10.24 | 9.73 | 12.03 | +| Greedy Search | 4.71 | 5.18 | 33.78 | 35.2 | 3.28 | 3.63 | 4.34 | 4.76 | 21.99 | 4.65 | 3.98 | 10.42 | 3.94 | 12.54 | 8.16 | 9.73 | 9.66 | + +Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05/ + #### Non-streaming (with CTC head) Best results (num of params : ~69M): From cac22f4075263b46d68d37c1a3f5705eaaca1977 Mon Sep 17 00:00:00 2001 From: jinzr Date: Fri, 10 Nov 2023 11:18:19 +0800 Subject: [PATCH 07/15] Update RESULTS.md --- egs/multi_zh-hans/ASR/RESULTS.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/egs/multi_zh-hans/ASR/RESULTS.md b/egs/multi_zh-hans/ASR/RESULTS.md index 8c54806c4..b89c42997 100644 --- a/egs/multi_zh-hans/ASR/RESULTS.md +++ b/egs/multi_zh-hans/ASR/RESULTS.md @@ -39,9 +39,9 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the | Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | |--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| CTC Decoding | 6.17 | 6.75 | 30.0 | 31.44 | 3.8 | 4.48 | 4.69 | 5.18 | 19.59 | 10.85 | 9.83 | 14.45 | 8.53 | 16.48 | 10.24 | 9.73 | 12.03 | -| Greedy Search | 4.71 | 5.18 | 33.78 | 35.2 | 3.28 | 3.63 | 4.34 | 4.76 | 21.99 | 4.65 | 3.98 | 10.42 | 3.94 | 12.54 | 8.16 | 9.73 | 9.66 | +| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| Encoder CTC Decoding (Non-streaming) | 6.17 | 6.75 | 30.0 | 31.44 | 3.8 | 4.48 | 4.69 | 5.18 | 19.59 | 10.85 | 9.83 | 14.45 | 8.53 | 16.48 | 10.24 | 9.73 | 12.03 | +| Transducer Greedy Search | 4.71 | 5.18 | 33.78 | 35.2 | 3.28 | 3.63 | 4.34 | 4.76 | 21.99 | 4.65 | 3.98 | 10.42 | 3.94 | 12.54 | 8.16 | 9.73 | 9.66 | Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05/ @@ -74,9 +74,9 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the | Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | |--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | -| CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 | -| Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 | +| Zipformer CER (%) | dev | test | eval | test | dev | test | dev | test | test | dev | test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| Encoder CTC Decoding | 2.86 | 3.36 | 22.93 | 24.28 | 2.05 | 2.27 | 3.33 | 3.82 | 15.45 | 3.49 | 2.77 | 6.90 | 2.85 | 8.29 | 9.41 | 6.92 | 8.57 | +| Transducer Greedy Search | 3.36 | 3.83 | 23.90 | 25.18 | 2.77 | 3.08 | 3.70 | 4.04 | 16.13 | 3.77 | 3.15 | 6.88 | 3.14 | 8.08 | 9.04 | 7.19 | 8.17 | Pre-trained model can be found here : https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/ @@ -107,7 +107,7 @@ Character Error Rates (CERs) listed below are produced by the checkpoint of the | Datasets | aidatatang _200zh | aidatatang _200zh | alimeeting | alimeeting | aishell-1 | aishell-1 | aishell-2 | aishell-2 | aishell-4 | magicdata | magicdata | kespeech-asr | kespeech-asr | kespeech-asr | WenetSpeech | WenetSpeech | WenetSpeech | |--------------------------------|------------------------------|-------------|-------------------|--------------|----------------|-------------|------------------|-------------|------------------|------------------|-------------|-----------------------|-----------------------|-------------|--------------------|-------------------------|---------------------| -| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | +| Zipformer CER (%) | dev | test | eval| test | dev | test | dev| test | test | dev| test | dev phase1 | dev phase2 | test | dev | test meeting | test net | | Greedy Search | 3.2 | 3.67 | 23.15 | 24.78 | 2.91 | 3.04 | 3.59 | 4.03 | 15.68 | 3.68 | 3.12 | 6.69 | 3.19 | 8.01 | 9.32 | 7.05 | 8.78 | From 20f96ac6a82f55312a5008e2114135ee646aeed3 Mon Sep 17 00:00:00 2001 From: jinzr Date: Mon, 4 Dec 2023 15:31:46 +0800 Subject: [PATCH 08/15] init commit --- .../ASR/zipformer/ctc_decode_stream.py | 127 +++ .../ASR/zipformer/streaming_ctc_decode.py | 885 ++++++++++++++++++ 2 files changed, 1012 insertions(+) create mode 100644 egs/librispeech/ASR/zipformer/ctc_decode_stream.py create mode 100755 egs/librispeech/ASR/zipformer/streaming_ctc_decode.py diff --git a/egs/librispeech/ASR/zipformer/ctc_decode_stream.py b/egs/librispeech/ASR/zipformer/ctc_decode_stream.py new file mode 100644 index 000000000..ed83d02a7 --- /dev/null +++ b/egs/librispeech/ASR/zipformer/ctc_decode_stream.py @@ -0,0 +1,127 @@ +# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + initial_states: List[torch.Tensor], + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + Used only when decoding_method is fast_beam_search. + device: + The device to run this stream. + """ + + self.params = params + self.cut_id = cut_id + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # The decoding result (partial or final) of current utterance. + self.hyp: List = [] + + # how many frames have been processed, at encoder output + self.done_frames: int = 0 + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut_id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + return self.hyp diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py new file mode 100755 index 000000000..9d86b507a --- /dev/null +++ b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py @@ -0,0 +1,885 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: +./zipformer/streaming_decode.py \ + --epoch 28 \ + --avg 15 \ + --causal 1 \ + --chunk-size 32 \ + --left-context-frames 256 \ + --exp-dir ./zipformer/exp \ + --decoding-method greedy_search \ + --num-decode-streams 2000 +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from ctc_decode_stream import DecodeStream +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import ( + AttributeDict, + get_texts, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=20, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Supported decoding methods are: + greedy_search + modified_beam_search + fast_beam_search + """, + ) + + parser.add_argument( + "--num_active_paths", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=32, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + add_model_arguments(parser) + + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + embed_states = model.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + H: Optional[k2.Fsa], + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = model.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = streaming_forward( + features=features, + feature_lens=feature_lens, + model=model, + states=states, + chunk_size=chunk_size, + left_context_len=left_context_len, + ) + ctc_output = model.ctc_output(encoder_out) # (N, T, C) + + supervision_segments = torch.stack( + ( + # supervisions["sequence_idx"], + list(map(lambda x: x.cut_id, decode_streams)), + torch.div( + 0, + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + feature_lens, + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + decoding_graph = H + + lattice = get_lattice( + nnet_output=ctc_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].hyp += token_ids[i] + decode_streams[i].states = states[i] + decode_streams[i].done_frames += encoder_out_lens[i] + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = model.device + + opts = FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + log_interval = 100 + + decode_results = [] + # Contain decode streams currently running. + decode_streams = [] + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = get_init_states(model=model, batch_size=1, device=device) + decode_stream = DecodeStream( + params=params, + cut_id=cut.id, + initial_states=initial_states, + decoding_graph=decoding_graph, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank(samples.to(device)) + decode_stream.set_features(feature, tail_pad_len=30) + decode_stream.ground_truth = cut.supervisions[0].text + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, H=decoding_graph, decode_streams=decode_streams + ) + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + decode_streams[i].id, + decode_streams[i].ground_truth.split(), + sp.decode(decode_streams[i].decoding_result()).split(), + ) + ) + del decode_streams[i] + + key = "ctc-decoding" + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params = get_decoding_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + assert params.causal, params.causal + assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + # for fast_beam_search + if params.decoding_method == "fast_beam_search": + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + max_token_id = sp.get_piece_size() - 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + H = k2.ctc_topo( + max_token=max_token_id, + modified=True, + device=device, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_cuts = librispeech.test_cuts() + test_sets = {"test": test_cuts} + + for test_set, test_cut in test_sets.items(): + results_dict = decode_dataset( + cuts=test_cut, + params=params, + model=model, + sp=sp, + decoding_graph=H, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() From 7fb3b130668bb2709c6932af788652e7078993eb Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 7 Dec 2023 10:49:18 +0800 Subject: [PATCH 09/15] Update streaming_ctc_decode.py --- .../ASR/zipformer/streaming_ctc_decode.py | 248 +----------------- 1 file changed, 9 insertions(+), 239 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py index 9d86b507a..aa031d930 100755 --- a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py @@ -45,7 +45,13 @@ from asr_datamodule import LibriSpeechAsrDataModule from ctc_decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet -from torch import Tensor, nn +from streaming_decode import ( + get_init_states, + stack_states, + streaming_forward, + unstack_states, +) +from torch import nn from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params @@ -55,18 +61,10 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) +from icefall.decode import get_lattice, one_best_decoding from icefall.utils import ( AttributeDict, get_texts, - make_pad_mask, setup_logger, store_transcripts, str2bool, @@ -198,234 +196,6 @@ def get_parser(): return parser -def get_init_states( - model: nn.Module, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), -) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = model.encoder.get_init_states(batch_size, device) - - embed_states = model.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. For element-n, - state_list[n] is a list of cached tensors of all encoder layers. For layer-i, - state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, - cached_val2, cached_conv1, cached_conv2). - state_list[n][-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - state_list[n][-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Note: - It is the inverse of :func:`unstack_states`. - """ - batch_size = len(state_list) - assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) - tot_num_layers = (len(state_list[0]) - 2) // 6 - - batch_states = [] - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key = torch.cat( - [state_list[i][layer_offset] for i in range(batch_size)], dim=1 - ) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn = torch.cat( - [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1 = torch.cat( - [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2 = torch.cat( - [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1 = torch.cat( - [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2 = torch.cat( - [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 - ) - batch_states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - - cached_embed_left_pad = torch.cat( - [state_list[i][-2] for i in range(batch_size)], dim=0 - ) - batch_states.append(cached_embed_left_pad) - - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) - batch_states.append(processed_lens) - - return batch_states - - -def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - batch_states: A list of cached tensors of all encoder layers. For layer-i, - states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, - cached_conv1, cached_conv2). - state_list[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - - Returns: - state_list: A list of list. Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - """ - assert (len(batch_states) - 2) % 6 == 0, len(batch_states) - tot_num_layers = (len(batch_states) - 2) // 6 - - processed_lens = batch_states[-1] - batch_size = processed_lens.shape[0] - - state_list = [[] for _ in range(batch_size)] - - for layer in range(tot_num_layers): - layer_offset = layer * 6 - # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) - # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) - cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( - chunks=batch_size, dim=1 - ) - # cached_val1: (left_context_len, batch_size, value_dim) - cached_val1_list = batch_states[layer_offset + 2].chunk( - chunks=batch_size, dim=1 - ) - # cached_val2: (left_context_len, batch_size, value_dim) - cached_val2_list = batch_states[layer_offset + 3].chunk( - chunks=batch_size, dim=1 - ) - # cached_conv1: (#batch, channels, left_pad) - cached_conv1_list = batch_states[layer_offset + 4].chunk( - chunks=batch_size, dim=0 - ) - # cached_conv2: (#batch, channels, left_pad) - cached_conv2_list = batch_states[layer_offset + 5].chunk( - chunks=batch_size, dim=0 - ) - for i in range(batch_size): - state_list[i] += [ - cached_key_list[i], - cached_nonlin_attn_list[i], - cached_val1_list[i], - cached_val2_list[i], - cached_conv1_list[i], - cached_conv2_list[i], - ] - - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(cached_embed_left_pad_list[i]) - - processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) - for i in range(batch_size): - state_list[i].append(processed_lens_list[i]) - - return state_list - - -def streaming_forward( - features: Tensor, - feature_lens: Tensor, - model: nn.Module, - states: List[Tensor], - chunk_size: int, - left_context_len: int, -) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Returns encoder outputs, output lengths, and updated states. - """ - cached_embed_left_pad = states[-2] - ( - x, - x_lens, - new_cached_embed_left_pad, - ) = model.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = model.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -493,7 +263,7 @@ def decode_one_chunk( supervision_segments = torch.stack( ( # supervisions["sequence_idx"], - list(map(lambda x: x.cut_id, decode_streams)), + torch.tensor([index for index, _ in enumerate(decode_streams)]), torch.div( 0, params.subsampling_factor, From 206f2143e8b86f81f4879e6b587677deae136415 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 7 Dec 2023 11:12:47 +0800 Subject: [PATCH 10/15] Update streaming_ctc_decode.py --- .../ASR/zipformer/streaming_ctc_decode.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py index aa031d930..525c0f405 100755 --- a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py @@ -128,7 +128,7 @@ def get_parser(): parser.add_argument( "--bpe-model", type=str, - default="data/lang_bpe_2000/bpe.model", + default="data/lang_bpe_500/bpe.model", help="Path to the BPE model", ) @@ -196,6 +196,22 @@ def get_parser(): return parser +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "feature_dim": 80, + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + def decode_one_chunk( params: AttributeDict, model: nn.Module, @@ -463,21 +479,6 @@ def save_results( logging.info(s) -def get_decoding_params() -> AttributeDict: - """Parameters for decoding.""" - params = AttributeDict( - { - "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - @torch.no_grad() def main(): parser = get_parser() From 44f0c8bd8abd738e2e22c36e717b6f6b313e20b7 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 7 Dec 2023 11:15:45 +0800 Subject: [PATCH 11/15] Update streaming_ctc_decode.py --- egs/librispeech/ASR/zipformer/streaming_ctc_decode.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py index 525c0f405..3e7b5e1e8 100755 --- a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py @@ -631,8 +631,12 @@ def main(): librispeech = LibriSpeechAsrDataModule(args) - test_cuts = librispeech.test_cuts() - test_sets = {"test": test_cuts} + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + test_sets = { + "test-clean": test_clean_cuts, + "test-other": test_other_cuts, + } for test_set, test_cut in test_sets.items(): results_dict = decode_dataset( From 975f6e03fdd37251d9778c543cc4aef730341564 Mon Sep 17 00:00:00 2001 From: jinzr Date: Thu, 7 Dec 2023 11:17:25 +0800 Subject: [PATCH 12/15] Update streaming_ctc_decode.py --- egs/librispeech/ASR/zipformer/streaming_ctc_decode.py | 1 - 1 file changed, 1 deletion(-) diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py index 3e7b5e1e8..ab550290c 100755 --- a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py @@ -377,7 +377,6 @@ def decode_dataset( params=params, cut_id=cut.id, initial_states=initial_states, - decoding_graph=decoding_graph, device=device, ) From ad7aeaccb0b445b21acdc1d60ac8a5355882ea22 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Thu, 7 Dec 2023 18:34:25 +0800 Subject: [PATCH 13/15] minor updates --- .../ASR/zipformer/ctc_decode_stream.py | 2 + .../ASR/zipformer/streaming_ctc_decode.py | 104 ++++++++++++------ 2 files changed, 74 insertions(+), 32 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/ctc_decode_stream.py b/egs/librispeech/ASR/zipformer/ctc_decode_stream.py index ed83d02a7..066f8afc3 100644 --- a/egs/librispeech/ASR/zipformer/ctc_decode_stream.py +++ b/egs/librispeech/ASR/zipformer/ctc_decode_stream.py @@ -31,6 +31,7 @@ class DecodeStream(object): params: AttributeDict, cut_id: str, initial_states: List[torch.Tensor], + decode_state: k2.DecodeStateInfo, device: torch.device = torch.device("cpu"), ) -> None: """ @@ -50,6 +51,7 @@ class DecodeStream(object): self.LOG_EPS = math.log(1e-10) self.states = initial_states + self.decode_state = decode_state # It contains a 2-D tensors representing the feature frames. self.features: torch.Tensor = None diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py index ab550290c..52a5162fb 100755 --- a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py @@ -201,6 +201,7 @@ def get_decoding_params() -> AttributeDict: params = AttributeDict( { "feature_dim": 80, + "subsampling_factor": 4, "frame_shift_ms": 10, "search_beam": 20, "output_beam": 8, @@ -216,7 +217,9 @@ def decode_one_chunk( params: AttributeDict, model: nn.Module, H: Optional[k2.Fsa], + intersector: k2.OnlineDenseIntersecter, decode_streams: List[DecodeStream], + streams_to_pad: int = None, ) -> List[int]: """Decode one chunk frames of features for each decode_streams and return the indexes of finished streams in a List. @@ -276,39 +279,49 @@ def decode_one_chunk( ) ctc_output = model.ctc_output(encoder_out) # (N, T, C) - supervision_segments = torch.stack( - ( - # supervisions["sequence_idx"], - torch.tensor([index for index, _ in enumerate(decode_streams)]), - torch.div( - 0, - params.subsampling_factor, - rounding_mode="floor", - ), - torch.div( - feature_lens, - params.subsampling_factor, - rounding_mode="floor", - ), - ), - 1, - ).to(torch.int32) + if streams_to_pad: + ctc_output = torch.cat( + [ + ctc_output, + torch.zeros( + (streams_to_pad, ctc_output.size(-2), ctc_output.size(-1)), + device=device, + ), + ] + ) - decoding_graph = H - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=decoding_graph, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, + supervision_segments = torch.tensor( + [[i, 0, 8] for i in range(params.num_decode_streams)], + dtype=torch.int32, ) + + # decoding_graph = H + + # lattice = get_lattice( + # nnet_output=ctc_output, + # decoding_graph=decoding_graph, + # supervision_segments=supervision_segments, + # search_beam=params.search_beam, + # output_beam=params.output_beam, + # min_active_states=params.min_active_states, + # max_active_states=params.max_active_states, + # subsampling_factor=params.subsampling_factor, + # ) + dense_fsa_vec = k2.DenseFsaVec(ctc_output, supervision_segments) + + current_decode_states = [ + decode_stream.decode_state for decode_stream in decode_streams + ] + if streams_to_pad: + current_decode_states += [k2.DecodeStateInfo()] * streams_to_pad + lattice, current_decode_states = intersector.decode( + dense_fsa_vec, current_decode_states + ) + best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs # since we are using H, not HLG here. # @@ -317,10 +330,15 @@ def decode_one_chunk( states = unstack_states(new_states) + num_streams = ( + len(decode_streams) - streams_to_pad if streams_to_pad else len(decode_streams) + ) + finished_streams = [] - for i in range(len(decode_streams)): + for i in range(num_streams): decode_streams[i].hyp += token_ids[i] decode_streams[i].states = states[i] + decode_streams[i].decode_state = current_decode_states[i] decode_streams[i].done_frames += encoder_out_lens[i] if decode_streams[i].done: finished_streams.append(i) @@ -367,6 +385,15 @@ def decode_dataset( log_interval = 100 + intersector = k2.OnlineDenseIntersecter( + decoding_graph=decoding_graph, + num_streams=params.num_decode_streams, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + decode_results = [] # Contain decode streams currently running. decode_streams = [] @@ -377,6 +404,7 @@ def decode_dataset( params=params, cut_id=cut.id, initial_states=initial_states, + decode_state=k2.DecodeStateInfo(), device=device, ) @@ -400,7 +428,11 @@ def decode_dataset( while len(decode_streams) >= params.num_decode_streams: finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams + params=params, + model=model, + H=decoding_graph, + intersector=intersector, + decode_streams=decode_streams, ) for i in sorted(finished_streams, reverse=True): decode_results.append( @@ -415,10 +447,16 @@ def decode_dataset( if num % log_interval == 0: logging.info(f"Cuts processed until now is {num}.") + num_remained_decode_streams = len(decode_streams) # decode final chunks of last sequences - while len(decode_streams): + while num_remained_decode_streams: finished_streams = decode_one_chunk( - params=params, model=model, H=decoding_graph, decode_streams=decode_streams + params=params, + model=model, + H=decoding_graph, + intersector=intersector, + decode_streams=decode_streams, + streams_to_pad=params.num_decode_streams - num_remained_decode_streams, ) for i in sorted(finished_streams, reverse=True): decode_results.append( @@ -429,6 +467,7 @@ def decode_dataset( ) ) del decode_streams[i] + num_remained_decode_streams -= 1 key = "ctc-decoding" return {key: decode_results} @@ -624,6 +663,7 @@ def main(): modified=True, device=device, ) + H = k2.Fsa.from_fsas([H]) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") From a1177c979a0c19a7bbaf16c1b1fcd8bd1b194bdf Mon Sep 17 00:00:00 2001 From: jinzr Date: Tue, 30 Jan 2024 09:48:57 +0800 Subject: [PATCH 14/15] minor updates --- .../scripts/run-multi-corpora-zipformer.sh | 47 ++ .../ASR/zipformer/ctc_decode_stream.py | 129 ---- .../ASR/zipformer/streaming_ctc_decode.py | 699 ------------------ 3 files changed, 47 insertions(+), 828 deletions(-) delete mode 100644 egs/librispeech/ASR/zipformer/ctc_decode_stream.py delete mode 100755 egs/librispeech/ASR/zipformer/streaming_ctc_decode.py diff --git a/.github/scripts/run-multi-corpora-zipformer.sh b/.github/scripts/run-multi-corpora-zipformer.sh index 90f859f43..7b79da122 100755 --- a/.github/scripts/run-multi-corpora-zipformer.sh +++ b/.github/scripts/run-multi-corpora-zipformer.sh @@ -98,6 +98,53 @@ done rm -rf $repo +log "==== Test icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05 ====" +repo_url=https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-streaming-2023-11-05/ + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-20.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + + +./zipformer/pretrained.py \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + --use-ctc 1 \ + --causal 1 \ + --method greedy_search \ +$repo/test_wavs/DEV_T0000000000.wav \ +$repo/test_wavs/DEV_T0000000001.wav \ +$repo/test_wavs/DEV_T0000000002.wav + +for method in modified_beam_search fast_beam_search; do + log "$method" + + ./zipformer/pretrained.py \ + --method $method \ + --beam-size 4 \ + --use-ctc 1 \ + --causal 1 \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_2000/tokens.txt \ + $repo/test_wavs/DEV_T0000000000.wav \ + $repo/test_wavs/DEV_T0000000001.wav \ + $repo/test_wavs/DEV_T0000000002.wav +done + +rm -rf $repo + cd ../../../egs/multi_zh_en/ASR log "==== Test icefall-asr-zipformer-multi-zh-en-2023-11-22 ====" repo_url=https://huggingface.co/zrjin/icefall-asr-zipformer-multi-zh-en-2023-11-22/ diff --git a/egs/librispeech/ASR/zipformer/ctc_decode_stream.py b/egs/librispeech/ASR/zipformer/ctc_decode_stream.py deleted file mode 100644 index 066f8afc3..000000000 --- a/egs/librispeech/ASR/zipformer/ctc_decode_stream.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class DecodeStream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - initial_states: List[torch.Tensor], - decode_state: k2.DecodeStateInfo, - device: torch.device = torch.device("cpu"), - ) -> None: - """ - Args: - initial_states: - Initial decode states of the model, e.g. the return value of - `get_init_state` in conformer.py - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - Used only when decoding_method is fast_beam_search. - device: - The device to run this stream. - """ - - self.params = params - self.cut_id = cut_id - self.LOG_EPS = math.log(1e-10) - - self.states = initial_states - self.decode_state = decode_state - - # It contains a 2-D tensors representing the feature frames. - self.features: torch.Tensor = None - - self.num_frames: int = 0 - # how many frames have been processed. (before subsampling). - # we only modify this value in `func:get_feature_frames`. - self.num_processed_frames: int = 0 - - self._done: bool = False - - # The transcript of current utterance. - self.ground_truth: str = "" - - # The decoding result (partial or final) of current utterance. - self.hyp: List = [] - - # how many frames have been processed, at encoder output - self.done_frames: int = 0 - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - - @property - def done(self) -> bool: - """Return True if all the features are processed.""" - return self._done - - @property - def id(self) -> str: - return self.cut_id - - def set_features( - self, - features: torch.Tensor, - tail_pad_len: int = 0, - ) -> None: - """Set features tensor of current utterance.""" - assert features.dim() == 2, features.dim() - self.features = torch.nn.functional.pad( - features, - (0, 0, 0, self.pad_length + tail_pad_len), - mode="constant", - value=self.LOG_EPS, - ) - self.num_frames = self.features.size(0) - - def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: - """Consume chunk_size frames of features""" - chunk_length = chunk_size + self.pad_length - - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) - - ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa - ] - - self.num_processed_frames += chunk_size - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_features, ret_length - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - return self.hyp diff --git a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py b/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py deleted file mode 100755 index 52a5162fb..000000000 --- a/egs/librispeech/ASR/zipformer/streaming_ctc_decode.py +++ /dev/null @@ -1,699 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, -# Fangjun Kuang, -# Zengwei Yao, -# Zengrui Jin,) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./zipformer/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --causal 1 \ - --chunk-size 32 \ - --left-context-frames 256 \ - --exp-dir ./zipformer/exp \ - --decoding-method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -from asr_datamodule import LibriSpeechAsrDataModule -from ctc_decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_decode import ( - get_init_states, - stack_states, - streaming_forward, - unstack_states, -) -from torch import nn -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=20, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def get_decoding_params() -> AttributeDict: - """Parameters for decoding.""" - params = AttributeDict( - { - "feature_dim": 80, - "subsampling_factor": 4, - "frame_shift_ms": 10, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - H: Optional[k2.Fsa], - intersector: k2.OnlineDenseIntersecter, - decode_streams: List[DecodeStream], - streams_to_pad: int = None, -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - chunk_size = int(params.chunk_size) - left_context_len = int(params.left_context_frames) - - features = [] - feature_lens = [] - states = [] - processed_lens = [] # Used in fast-beam-search - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(chunk_size * 2) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # Make sure the length after encoder_embed is at least 1. - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - tail_length = chunk_size * 2 + 7 + 2 * 3 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - - encoder_out, encoder_out_lens, new_states = streaming_forward( - features=features, - feature_lens=feature_lens, - model=model, - states=states, - chunk_size=chunk_size, - left_context_len=left_context_len, - ) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - if streams_to_pad: - ctc_output = torch.cat( - [ - ctc_output, - torch.zeros( - (streams_to_pad, ctc_output.size(-2), ctc_output.size(-1)), - device=device, - ), - ] - ) - - supervision_segments = torch.tensor( - [[i, 0, 8] for i in range(params.num_decode_streams)], - dtype=torch.int32, - ) - - # decoding_graph = H - - # lattice = get_lattice( - # nnet_output=ctc_output, - # decoding_graph=decoding_graph, - # supervision_segments=supervision_segments, - # search_beam=params.search_beam, - # output_beam=params.output_beam, - # min_active_states=params.min_active_states, - # max_active_states=params.max_active_states, - # subsampling_factor=params.subsampling_factor, - # ) - dense_fsa_vec = k2.DenseFsaVec(ctc_output, supervision_segments) - - current_decode_states = [ - decode_stream.decode_state for decode_stream in decode_streams - ] - if streams_to_pad: - current_decode_states += [k2.DecodeStateInfo()] * streams_to_pad - lattice, current_decode_states = intersector.decode( - dense_fsa_vec, current_decode_states - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - token_ids = get_texts(best_path) - - states = unstack_states(new_states) - - num_streams = ( - len(decode_streams) - streams_to_pad if streams_to_pad else len(decode_streams) - ) - - finished_streams = [] - for i in range(num_streams): - decode_streams[i].hyp += token_ids[i] - decode_streams[i].states = states[i] - decode_streams[i].decode_state = current_decode_states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 100 - - intersector = k2.OnlineDenseIntersecter( - decoding_graph=decoding_graph, - num_streams=params.num_decode_streams, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decode_state=k2.DecodeStateInfo(), - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=30) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, - model=model, - H=decoding_graph, - intersector=intersector, - decode_streams=decode_streams, - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - num_remained_decode_streams = len(decode_streams) - # decode final chunks of last sequences - while num_remained_decode_streams: - finished_streams = decode_one_chunk( - params=params, - model=model, - H=decoding_graph, - intersector=intersector, - decode_streams=decode_streams, - streams_to_pad=params.num_decode_streams - num_remained_decode_streams, - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - num_remained_decode_streams -= 1 - - key = "ctc-decoding" - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params = get_decoding_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - params.suffix += f"-chunk-{params.chunk_size}" - params.suffix += f"-left-context-{params.left_context_frames}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - max_token_id = sp.get_piece_size() - 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - H = k2.ctc_topo( - max_token=max_token_id, - modified=True, - device=device, - ) - H = k2.Fsa.from_fsas([H]) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - test_sets = { - "test-clean": test_clean_cuts, - "test-other": test_other_cuts, - } - - for test_set, test_cut in test_sets.items(): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=H, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() From 76002685398ce4e203cebeec50691c70939ba59e Mon Sep 17 00:00:00 2001 From: jinzr Date: Tue, 30 Jan 2024 09:59:19 +0800 Subject: [PATCH 15/15] fixing the CI test --- .github/scripts/run-multi-corpora-zipformer.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/scripts/run-multi-corpora-zipformer.sh b/.github/scripts/run-multi-corpora-zipformer.sh index 7b79da122..6d9fbb973 100755 --- a/.github/scripts/run-multi-corpora-zipformer.sh +++ b/.github/scripts/run-multi-corpora-zipformer.sh @@ -123,6 +123,8 @@ ls -lh $repo/exp/*.pt --tokens $repo/data/lang_bpe_2000/tokens.txt \ --use-ctc 1 \ --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ --method greedy_search \ $repo/test_wavs/DEV_T0000000000.wav \ $repo/test_wavs/DEV_T0000000001.wav \ @@ -136,6 +138,8 @@ for method in modified_beam_search fast_beam_search; do --beam-size 4 \ --use-ctc 1 \ --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ --checkpoint $repo/exp/epoch-99.pt \ --tokens $repo/data/lang_bpe_2000/tokens.txt \ $repo/test_wavs/DEV_T0000000000.wav \