#!/usr/bin/env python3 # Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, # Fangjun Kuang, # Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Usage: ./zipformer/streaming_decode.py \ --epoch 28 \ --avg 15 \ --causal 1 \ --chunk-size 16 \ --left-context-frames 256 \ --exp-dir ./zipformer/exp \ --decoding-method greedy_search \ --num-decode-streams 2000 """ import argparse import logging import math from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 import numpy as np import torch from asr_datamodule import AishellAsrDataModule from decode_stream import DecodeStream from kaldifeat import Fbank, FbankOptions from lhotse import CutSet from streaming_beam_search import ( fast_beam_search_one_best, greedy_search, modified_beam_search, ) from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params from icefall.checkpoint import ( average_checkpoints, average_checkpoints_with_averaged_model, find_checkpoints, load_checkpoint, ) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, make_pad_mask, setup_logger, store_transcripts, str2bool, write_error_stats, ) LOG_EPS = math.log(1e-10) def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--epoch", type=int, default=28, help="""It specifies the checkpoint to use for decoding. Note: Epoch counts from 1. You can specify --avg to use more checkpoints for model averaging.""", ) parser.add_argument( "--iter", type=int, default=0, help="""If positive, --epoch is ignored and it will use the checkpoint exp_dir/checkpoint-iter.pt. You can specify --avg to use more checkpoints for model averaging. """, ) parser.add_argument( "--avg", type=int, default=15, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, help="Whether to load averaged model. Currently it only supports " "using --epoch. If True, it would decode with the averaged model " "over the epoch range from `epoch-avg` (excluded) to `epoch`." "Actually only the models with epoch number of `epoch-avg` and " "`epoch` are loaded for averaging. ", ) parser.add_argument( "--exp-dir", type=str, default="zipformer/exp", help="The experiment dir", ) parser.add_argument( "--lang-dir", type=str, default="data/lang_char", help="Path to the lang dir(containing lexicon, tokens, etc.)", ) parser.add_argument( "--decoding-method", type=str, default="greedy_search", help="""Supported decoding methods are: greedy_search modified_beam_search fast_beam_search """, ) parser.add_argument( "--num_active_paths", type=int, default=4, help="""An interger indicating how many candidates we will keep for each frame. Used only when --decoding-method is modified_beam_search.""", ) parser.add_argument( "--beam", type=float, default=4, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. Used only when --decoding-method is fast_beam_search""", ) parser.add_argument( "--max-contexts", type=int, default=4, help="""Used only when --decoding-method is fast_beam_search""", ) parser.add_argument( "--max-states", type=int, default=32, help="""Used only when --decoding-method is fast_beam_search""", ) parser.add_argument( "--context-size", type=int, default=2, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) parser.add_argument( "--blank-penalty", type=float, default=0.0, help=""" The penalty applied on blank symbol during decoding. Note: It is a positive value that would be applied to logits like this `logits[:, 0] -= blank_penalty` (suppose logits.shape is [batch_size, vocab] and blank id is 0). """, ) parser.add_argument( "--num-decode-streams", type=int, default=2000, help="The number of streams that can be decoded parallel.", ) add_model_arguments(parser) return parser def get_init_states( model: nn.Module, batch_size: int = 1, device: torch.device = torch.device("cpu"), ) -> List[torch.Tensor]: """ Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). states[-2] is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) states[-1] is processed_lens of shape (batch,), which records the number of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. """ states = model.encoder.get_init_states(batch_size, device) embed_states = model.encoder_embed.get_init_states(batch_size, device) states.append(embed_states) processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) states.append(processed_lens) return states def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: """Stack list of zipformer states that correspond to separate utterances into a single emformer state, so that it can be used as an input for zipformer when those utterances are formed into a batch. Args: state_list: Each element in state_list corresponding to the internal state of the zipformer model for a single utterance. For element-n, state_list[n] is a list of cached tensors of all encoder layers. For layer-i, state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). state_list[n][-2] is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) state_list[n][-1] is processed_lens of shape (batch,), which records the number of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. Note: It is the inverse of :func:`unstack_states`. """ batch_size = len(state_list) assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0]) tot_num_layers = (len(state_list[0]) - 2) // 6 batch_states = [] for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) cached_key = torch.cat( [state_list[i][layer_offset] for i in range(batch_size)], dim=1 ) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn = torch.cat( [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 ) # cached_val1: (left_context_len, batch_size, value_dim) cached_val1 = torch.cat( [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 ) # cached_val2: (left_context_len, batch_size, value_dim) cached_val2 = torch.cat( [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 ) # cached_conv1: (#batch, channels, left_pad) cached_conv1 = torch.cat( [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 ) # cached_conv2: (#batch, channels, left_pad) cached_conv2 = torch.cat( [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 ) batch_states += [ cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2, ] cached_embed_left_pad = torch.cat( [state_list[i][-2] for i in range(batch_size)], dim=0 ) batch_states.append(cached_embed_left_pad) processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) batch_states.append(processed_lens) return batch_states def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: """Unstack the zipformer state corresponding to a batch of utterances into a list of states, where the i-th entry is the state from the i-th utterance in the batch. Note: It is the inverse of :func:`stack_states`. Args: batch_states: A list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). state_list[-2] is the cached left padding for ConvNeXt module, of shape (batch_size, num_channels, left_pad, num_freqs) states[-1] is processed_lens of shape (batch,), which records the number of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. Returns: state_list: A list of list. Each element in state_list corresponding to the internal state of the zipformer model for a single utterance. """ assert (len(batch_states) - 2) % 6 == 0, len(batch_states) tot_num_layers = (len(batch_states) - 2) // 6 processed_lens = batch_states[-1] batch_size = processed_lens.shape[0] state_list = [[] for _ in range(batch_size)] for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 ) # cached_val1: (left_context_len, batch_size, value_dim) cached_val1_list = batch_states[layer_offset + 2].chunk( chunks=batch_size, dim=1 ) # cached_val2: (left_context_len, batch_size, value_dim) cached_val2_list = batch_states[layer_offset + 3].chunk( chunks=batch_size, dim=1 ) # cached_conv1: (#batch, channels, left_pad) cached_conv1_list = batch_states[layer_offset + 4].chunk( chunks=batch_size, dim=0 ) # cached_conv2: (#batch, channels, left_pad) cached_conv2_list = batch_states[layer_offset + 5].chunk( chunks=batch_size, dim=0 ) for i in range(batch_size): state_list[i] += [ cached_key_list[i], cached_nonlin_attn_list[i], cached_val1_list[i], cached_val2_list[i], cached_conv1_list[i], cached_conv2_list[i], ] cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) for i in range(batch_size): state_list[i].append(processed_lens_list[i]) return state_list def streaming_forward( features: Tensor, feature_lens: Tensor, model: nn.Module, states: List[Tensor], chunk_size: int, left_context_len: int, ) -> Tuple[Tensor, Tensor, List[Tensor]]: """ Returns encoder outputs, output lengths, and updated states. """ cached_embed_left_pad = states[-2] (x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward( x=features, x_lens=feature_lens, cached_left_pad=cached_embed_left_pad, ) assert x.size(1) == chunk_size, (x.size(1), chunk_size) src_key_padding_mask = make_pad_mask(x_lens) # processed_mask is used to mask out initial states processed_mask = torch.arange(left_context_len, device=x.device).expand( x.size(0), left_context_len ) processed_lens = states[-1] # (batch,) # (batch, left_context_size) processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) # Update processed lengths new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] ( encoder_out, encoder_out_lens, new_encoder_states, ) = model.encoder.streaming_forward( x=x, x_lens=x_lens, states=encoder_states, src_key_padding_mask=src_key_padding_mask, ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) new_states = new_encoder_states + [ new_cached_embed_left_pad, new_processed_lens, ] return encoder_out, encoder_out_lens, new_states def decode_one_chunk( params: AttributeDict, model: nn.Module, decode_streams: List[DecodeStream], ) -> List[int]: """Decode one chunk frames of features for each decode_streams and return the indexes of finished streams in a List. Args: params: It's the return value of :func:`get_params`. model: The neural model. decode_streams: A List of DecodeStream, each belonging to a utterance. Returns: Return a List containing which DecodeStreams are finished. """ device = model.device chunk_size = int(params.chunk_size) left_context_len = int(params.left_context_frames) features = [] feature_lens = [] states = [] processed_lens = [] # Used in fast-beam-search for stream in decode_streams: feat, feat_len = stream.get_feature_frames(chunk_size * 2) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) feature_lens = torch.tensor(feature_lens, device=device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) # Make sure the length after encoder_embed is at least 1. # The encoder_embed subsample features (T - 7) // 2 # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling tail_length = chunk_size * 2 + 7 + 2 * 3 if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length features = torch.nn.functional.pad( features, (0, 0, 0, pad_length), mode="constant", value=LOG_EPS, ) states = stack_states(states) encoder_out, encoder_out_lens, new_states = streaming_forward( features=features, feature_lens=feature_lens, model=model, states=states, chunk_size=chunk_size, left_context_len=left_context_len, ) encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": greedy_search( model=model, encoder_out=encoder_out, streams=decode_streams, blank_penalty=params.blank_penalty, ) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( model=model, encoder_out=encoder_out, processed_lens=processed_lens, streams=decode_streams, beam=params.beam, max_states=params.max_states, max_contexts=params.max_contexts, blank_penalty=params.blank_penalty, ) elif params.decoding_method == "modified_beam_search": modified_beam_search( model=model, streams=decode_streams, encoder_out=encoder_out, num_active_paths=params.num_active_paths, blank_penalty=params.blank_penalty, ) else: raise ValueError(f"Unsupported decoding method: {params.decoding_method}") states = unstack_states(new_states) finished_streams = [] for i in range(len(decode_streams)): decode_streams[i].states = states[i] decode_streams[i].done_frames += encoder_out_lens[i] if decode_streams[i].done: finished_streams.append(i) return finished_streams def decode_dataset( cuts: CutSet, params: AttributeDict, model: nn.Module, lexicon: Lexicon, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: cuts: Lhotse Cutset containing the dataset to decode. params: It is returned by :func:`get_params`. model: The neural model. lexicon: The Lexicon. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used only when --decoding_method is fast_beam_search. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. Its value is a list of tuples. Each tuple contains two elements: The first is the reference transcript, and the second is the predicted result. """ device = model.device opts = FbankOptions() opts.device = device opts.frame_opts.dither = 0 opts.frame_opts.snip_edges = False opts.frame_opts.samp_freq = 16000 opts.mel_opts.num_bins = 80 log_interval = 100 decode_results = [] # Contain decode streams currently running. decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. initial_states = get_init_states(model=model, batch_size=1, device=device) decode_stream = DecodeStream( params=params, cut_id=cut.id, initial_states=initial_states, decoding_graph=decoding_graph, device=device, ) audio: np.ndarray = cut.load_audio() # audio.shape: (1, num_samples) assert len(audio.shape) == 2 assert audio.shape[0] == 1, "Should be single channel" assert audio.dtype == np.float32, audio.dtype # The trained model is using normalized samples if audio.max() > 1: logging.warning( f"The audio should be normalized to [-1, 1], audio.max : {audio.max()}." f"Clipping to [-1, 1]." ) audio = np.clip(audio, -1, 1) samples = torch.from_numpy(audio).squeeze(0) fbank = Fbank(opts) feature = fbank(samples.to(device)) decode_stream.set_features(feature, tail_pad_len=30) decode_stream.ground_truth = cut.supervisions[0].text decode_streams.append(decode_stream) while len(decode_streams) >= params.num_decode_streams: finished_streams = decode_one_chunk( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): decode_results.append( ( decode_streams[i].id, list(decode_streams[i].ground_truth.strip()), [ lexicon.token_table[idx] for idx in decode_streams[i].decoding_result() ], ) ) del decode_streams[i] if num % log_interval == 0: logging.info(f"Cuts processed until now is {num}.") # decode final chunks of last sequences while len(decode_streams): finished_streams = decode_one_chunk( params=params, model=model, decode_streams=decode_streams ) for i in sorted(finished_streams, reverse=True): decode_results.append( ( decode_streams[i].id, decode_streams[i].ground_truth.split(), [ lexicon.token_table[idx] for idx in decode_streams[i].decoding_result() ], ) ) del decode_streams[i] key = f"blank_penalty_{params.blank_penalty}" if params.decoding_method == "greedy_search": key = f"greedy_search_{key}" elif params.decoding_method == "fast_beam_search": key = ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" f"max_states_{params.max_states}_{key}" ) elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}_{key}" else: raise ValueError(f"Unsupported decoding method: {params.decoding_method}") return {key: decode_results} def save_results( params: AttributeDict, test_set_name: str, results_dict: Dict[str, List[Tuple[List[str], List[str]]]], ): test_set_wers = dict() for key, results in results_dict.items(): recog_path = ( params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( f, f"{test_set_name}-{key}", results, enable_log=True ) test_set_wers[key] = wer logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) s = "\nFor {}, WER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key, val in test_set_wers: s += "{}\t{}{}\n".format(key, val, note) note = "" logging.info(s) @torch.no_grad() def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) params = get_params() params.update(vars(args)) params.res_dir = params.exp_dir / "streaming" / params.decoding_method if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal assert "," not in params.chunk_size, "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." params.suffix += f"-chunk-{params.chunk_size}" params.suffix += f"-left-context-{params.left_context_frames}" params.suffix += f"-blank-penalty-{params.blank_penalty}" # for fast_beam_search if params.decoding_method == "fast_beam_search": params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" if params.use_averaged_model: params.suffix += "-use-averaged-model" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) logging.info(f"Device: {device}") lexicon = Lexicon(params.lang_dir) params.blank_id = lexicon.token_table[""] params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) logging.info("About to create model") model = get_model(params) if not params.use_averaged_model: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ : params.avg ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( f"Not enough checkpoints ({len(filenames)}) found for" f" --iter {params.iter}, --avg {params.avg}" ) logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) elif params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: start = params.epoch - params.avg + 1 filenames = [] for i in range(start, params.epoch + 1): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") model.to(device) model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ : params.avg + 1 ] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( f"Not enough checkpoints ({len(filenames)}) found for" f" --iter {params.iter}, --avg {params.avg}" ) filename_start = filenames[-1] filename_end = filenames[0] logging.info( "Calculating the averaged model over iteration checkpoints" f" from {filename_start} (excluded) to {filename_end}" ) model.to(device) model.load_state_dict( average_checkpoints_with_averaged_model( filename_start=filename_start, filename_end=filename_end, device=device, ) ) else: assert params.avg > 0, params.avg start = params.epoch - params.avg assert start >= 1, start filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) model.load_state_dict( average_checkpoints_with_averaged_model( filename_start=filename_start, filename_end=filename_end, device=device, ) ) model.to(device) model.eval() model.device = device decoding_graph = None if params.decoding_method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") aishell = AishellAsrDataModule(args) dev_cuts = aishell.valid_cuts() test_cuts = aishell.test_cuts() test_sets = ["dev", "test"] test_cuts = [dev_cuts, test_cuts] for test_set, test_cut in zip(test_sets, test_cuts): results_dict = decode_dataset( cuts=test_cut, params=params, model=model, lexicon=lexicon, decoding_graph=decoding_graph, ) save_results( params=params, test_set_name=test_set, results_dict=results_dict, ) logging.info("Done!") if __name__ == "__main__": main()