diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py b/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py deleted file mode 100755 index 3345d20d3..000000000 --- a/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py +++ /dev/null @@ -1,436 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) - -""" -This script exports a CTC model from PyTorch to ONNX. - -Note that the model is trained using both transducer and CTC loss. This script -exports only the CTC head. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx-ctc.py \ - --use-transducer 0 \ - --use-ctc 1 \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal False \ - --chunk-size 16 \ - --left-context-frames 128 - -It will generate the following 2 files inside $repo/exp: - - - model.onnx - - model.int8.onnx - -See ./onnx_pretrained_ctc.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_model, get_params -from zipformer import Zipformer2 - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxModel(nn.Module): - """A wrapper for encoder_embed, Zipformer, and ctc_output layer""" - - def __init__( - self, - encoder: Zipformer2, - encoder_embed: nn.Module, - ctc_output: nn.Module, - ): - """ - Args: - encoder: - A Zipformer encoder. - encoder_embed: - The first downsampling layer for zipformer. - """ - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.ctc_output = ctc_output - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Zipformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - log_probs, a 3-D tensor of shape (N, T', vocab_size) - - log_probs_len, a 1-D int64 tensor of shape (N,) - """ - x, x_lens = self.encoder_embed(x, x_lens) - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) - encoder_out, log_probs_len = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) - log_probs = self.ctc_output(encoder_out) - - return log_probs, log_probs_len - - -def export_ctc_model_onnx( - model: OnnxModel, - filename: str, - opset_version: int = 11, -) -> None: - """Export the given model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - log_probs, a tensor of shape (N, T', joiner_dim) - - log_probs_len, a tensor of shape (N,) - - Args: - model: - The input model - filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - model = torch.jit.trace(model, (x, x_lens)) - - torch.onnx.export( - model, - (x, x_lens), - filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["log_probs", "log_probs_len"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "log_probs": {0: "N", 1: "T"}, - "log_probs_len": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "zipformer2_ctc", - "version": "1", - "model_author": "k2-fsa", - "comment": "non-streaming zipformer2 CTC", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint( - f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False - ) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ), - strict=False, - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) - - model = OnnxModel( - encoder=model.encoder, - encoder_embed=model.encoder_embed, - ctc_output=model.ctc_output, - ) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"num parameters: {num_param}") - - opset_version = 13 - - logging.info("Exporting ctc model") - filename = params.exp_dir / f"model.onnx" - export_ctc_model_onnx( - model, - filename, - opset_version=opset_version, - ) - logging.info(f"Exported to {filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - filename_int8 = params.exp_dir / f"model.int8.onnx" - quantize_dynamic( - model_input=filename, - model_output=filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py b/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py new file mode 120000 index 000000000..f9d756352 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export-onnx-ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py b/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py deleted file mode 100755 index e2c7d7d95..000000000 --- a/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py +++ /dev/null @@ -1,775 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx-streaming.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal True \ - --chunk-size 16 \ - --left-context-frames 64 - -The --chunk-size in training is "16,32,64,-1", so we select one of them -(excluding -1) during streaming export. The same applies to `--left-context`, -whose value is "64,128,256,-1". - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1-chunk-16-left-64.onnx - - decoder-epoch-99-avg-1-chunk-16-left-64.onnx - - joiner-epoch-99-avg-1-chunk-16-left-64.onnx - -See ./onnx_pretrained-streaming.py for how to use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_model, get_params -from zipformer import Zipformer2 - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__( - self, encoder: Zipformer2, encoder_embed: nn.Module, encoder_proj: nn.Linear - ): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - self.encoder_proj = encoder_proj - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - self.pad_length = 7 + 2 * 3 - - def forward( - self, - x: torch.Tensor, - states: List[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: - N = x.size(0) - T = self.chunk_size * 2 + self.pad_length - x_lens = torch.tensor([T] * N, device=x.device) - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=x, - x_lens=x_lens, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == self.chunk_size, (x.size(1), self.chunk_size) - - src_key_padding_mask = torch.zeros(N, self.chunk_size, dtype=torch.bool) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) - encoder_states = states[:-2] - logging.info(f"len_encoder_states={len(encoder_states)}") - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - - return encoder_out, new_states - - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int64, device=device) - states.append(processed_lens) - - return states - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - encoder_model.encoder.__class__.forward = ( - encoder_model.encoder.__class__.streaming_forward - ) - - decode_chunk_len = encoder_model.chunk_size * 2 - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - T = decode_chunk_len + encoder_model.pad_length - - x = torch.rand(1, T, 80, dtype=torch.float32) - init_state = encoder_model.get_init_states() - num_encoders = len(encoder_model.encoder.encoder_dim) - logging.info(f"num_encoders: {num_encoders}") - logging.info(f"len(init_state): {len(init_state)}") - - inputs = {} - input_names = ["x"] - - outputs = {} - output_names = ["encoder_out"] - - def build_inputs_outputs(tensors, i): - assert len(tensors) == 6, len(tensors) - - # (downsample_left, batch_size, key_dim) - name = f"cached_key_{i}" - logging.info(f"{name}.shape: {tensors[0].shape}") - inputs[name] = {1: "N"} - outputs[f"new_{name}"] = {1: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (1, batch_size, downsample_left, nonlin_attn_head_dim) - name = f"cached_nonlin_attn_{i}" - logging.info(f"{name}.shape: {tensors[1].shape}") - inputs[name] = {1: "N"} - outputs[f"new_{name}"] = {1: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (downsample_left, batch_size, value_dim) - name = f"cached_val1_{i}" - logging.info(f"{name}.shape: {tensors[2].shape}") - inputs[name] = {1: "N"} - outputs[f"new_{name}"] = {1: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (downsample_left, batch_size, value_dim) - name = f"cached_val2_{i}" - logging.info(f"{name}.shape: {tensors[3].shape}") - inputs[name] = {1: "N"} - outputs[f"new_{name}"] = {1: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv1_{i}" - logging.info(f"{name}.shape: {tensors[4].shape}") - inputs[name] = {0: "N"} - outputs[f"new_{name}"] = {0: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv2_{i}" - logging.info(f"{name}.shape: {tensors[5].shape}") - inputs[name] = {0: "N"} - outputs[f"new_{name}"] = {0: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) - encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dim)) - cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernel)) - ds = encoder_model.encoder.downsampling_factor - left_context_len = encoder_model.left_context_len - left_context_len = [left_context_len // k for k in ds] - left_context_len = ",".join(map(str, left_context_len)) - query_head_dims = ",".join(map(str, encoder_model.encoder.query_head_dim)) - value_head_dims = ",".join(map(str, encoder_model.encoder.value_head_dim)) - num_heads = ",".join(map(str, encoder_model.encoder.num_heads)) - - meta_data = { - "model_type": "zipformer2", - "version": "1", - "model_author": "k2-fsa", - "comment": "streaming zipformer2", - "decode_chunk_len": str(decode_chunk_len), # 32 - "T": str(T), # 32+7+2*3=45 - "num_encoder_layers": num_encoder_layers, - "encoder_dims": encoder_dims, - "cnn_module_kernels": cnn_module_kernels, - "left_context_len": left_context_len, - "query_head_dims": query_head_dims, - "value_head_dims": value_head_dims, - "num_heads": num_heads, - } - logging.info(f"meta_data: {meta_data}") - - for i in range(len(init_state[:-2]) // 6): - build_inputs_outputs(init_state[i * 6 : (i + 1) * 6], i) - - # (batch_size, channels, left_pad, freq) - embed_states = init_state[-2] - name = "embed_states" - logging.info(f"{name}.shape: {embed_states.shape}") - inputs[name] = {0: "N"} - outputs[f"new_{name}"] = {0: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - # (batch_size,) - processed_lens = init_state[-1] - name = "processed_lens" - logging.info(f"{name}.shape: {processed_lens.shape}") - inputs[name] = {0: "N"} - outputs[f"new_{name}"] = {0: "N"} - input_names.append(name) - output_names.append(f"new_{name}") - - logging.info(inputs) - logging.info(outputs) - logging.info(input_names) - logging.info(output_names) - - torch.onnx.export( - encoder_model, - (x, init_state), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "x": {0: "N"}, - "encoder_out": {0: "N"}, - **inputs, - **outputs, - }, - ) - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_embed=model.encoder_embed, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - suffix += f"-chunk-{params.chunk_size}" - suffix += f"-left-{params.left_context_frames}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py b/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py new file mode 120000 index 000000000..2962eb784 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained.py deleted file mode 100755 index a41fbc1c9..000000000 --- a/egs/gigaspeech/ASR/zipformer/jit_pretrained.py +++ /dev/null @@ -1,280 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -./zipformer/jit_pretrained.py \ - --nn-model-filename ./zipformer/exp/cpu_jit.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model cpu_jit.pt", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def greedy_search( - model: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = model.decoder.blank_id - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.decoder.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - - model.eval() - - model.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder( - features=features, - feature_lengths=feature_lengths, - ) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - - s = "\n" - - token_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp) - s += f"{filename}:\n{words}\n" - - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained.py new file mode 120000 index 000000000..25108391f --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py deleted file mode 100755 index 660a4bfc6..000000000 --- a/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py +++ /dev/null @@ -1,436 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -(1) ctc-decoding -./zipformer/jit_pretrained_ctc.py \ - --model-filename ./zipformer/exp/jit_script.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method ctc-decoding \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) 1best -./zipformer/jit_pretrained_ctc.py \ - --model-filename ./zipformer/exp/jit_script.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --method 1best \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) nbest-rescoring -./zipformer/jit_pretrained_ctc.py \ - --model-filename ./zipformer/exp/jit_script.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method nbest-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) whole-lattice-rescoring -./zipformer/jit_pretrained_ctc.py \ - --model-filename ./zipformer/exp/jit_script.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method whole-lattice-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from ctc_decode import get_decoding_params -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import get_params - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.utils import get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--model-filename", - type=str, - required=True, - help="Path to the torchscript model.", - ) - - parser.add_argument( - "--words-file", - type=str, - help="""Path to words.txt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.pt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a token table, - i.e., lang_dir/token.txt, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an LM, the path with - the highest score is the decoding result. - We call it HLG decoding + nbest n-gram LM rescoring. - (3) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + whole-lattice n-gram LM rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or nbest-rescoring. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and nbest-rescoring. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help=""" - Used only when method is nbest-rescoring. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.model_filename) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(features, feature_lengths) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - batch_size = ctc_output.shape[0] - supervision_segments = torch.tensor( - [ - [i, 0, feature_lengths[i].item() // params.subsampling_factor] - for i in range(batch_size) - ], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - max_token_id = params.vocab_size - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = [[token_table[i] for i in ids] for ids in token_ids] - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - G = G.to(device) - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - if params.method == "nbest-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=[params.ngram_lm_scale], - nbest_scale=params.nbest_scale, - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - if params.method == "ctc-decoding": - for filename, hyp in zip(params.sound_files, hyps): - words = "".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py new file mode 120000 index 000000000..9a8da5844 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py deleted file mode 100755 index d4ceacefd..000000000 --- a/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py +++ /dev/null @@ -1,273 +0,0 @@ -#!/usr/bin/env python3 -# flake8: noqa -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models exported by `torch.jit.script()` -and uses them to decode waves. -You can use the following command to get the exported models: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -Usage of this script: - -./zipformer/jit_pretrained_streaming.py \ - --nn-model-filename ./zipformer/exp-causal/jit_script_chunk_16_left_128.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - /path/to/foo.wav \ -""" - -import argparse -import logging -import math -from typing import List, Optional - -import k2 -import kaldifeat -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model-filename", - type=str, - required=True, - help="Path to the torchscript model jit_script.pt", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, - device: torch.device = torch.device("cpu"), -): - assert encoder_out.ndim == 2 - context_size = decoder.context_size - blank_id = decoder.blank_id - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0) - # decoder_input.shape (1,, 1 context_size) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - else: - assert decoder_out.ndim == 2 - assert hyp is not None, hyp - - T = encoder_out.size(0) - for i in range(T): - cur_encoder_out = encoder_out[i : i + 1] - joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - - decoder_input = torch.tensor( - decoder_input, dtype=torch.int32, device=device - ).unsqueeze(0) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) - - return hyp, decoder_out - - -def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = sample_rate - opts.mel_opts.num_bins = 80 - return OnlineFbank(opts) - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - model = torch.jit.load(args.nn_model_filename) - model.eval() - model.to(device) - - encoder = model.encoder - decoder = model.decoder - joiner = model.joiner - - token_table = k2.SymbolTable.from_file(args.tokens) - context_size = decoder.context_size - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor(args.sample_rate) - - logging.info(f"Reading sound files: {args.sound_file}") - wave_samples = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=args.sample_rate, - )[0] - logging.info(wave_samples.shape) - - logging.info("Decoding started") - - chunk_length = encoder.chunk_size * 2 - T = chunk_length + encoder.pad_length - - logging.info(f"chunk_length: {chunk_length}") - logging.info(f"T: {T}") - - states = encoder.get_init_states(device=device) - - tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) - - wave_samples = torch.cat([wave_samples, tail_padding]) - - chunk = int(0.25 * args.sample_rate) # 0.2 second - num_processed_frames = 0 - - hyp = None - decoder_out = None - - start = 0 - while start < wave_samples.numel(): - logging.info(f"{start}/{wave_samples.numel()}") - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - online_fbank.accept_waveform( - sampling_rate=args.sample_rate, - waveform=samples, - ) - while online_fbank.num_frames_ready - num_processed_frames >= T: - frames = [] - for i in range(T): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - frames = torch.cat(frames, dim=0).to(device).unsqueeze(0) - x_lens = torch.tensor([T], dtype=torch.int32, device=device) - encoder_out, out_lens, states = encoder( - features=frames, - feature_lengths=x_lens, - states=states, - ) - num_processed_frames += chunk_length - - hyp, decoder_out = greedy_search( - decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device - ) - - text = "" - for i in hyp[context_size:]: - text += token_table[i] - text = text.replace("▁", " ").strip() - - logging.info(args.sound_file) - logging.info(text) - - logging.info("Decoding Done") - - -torch.set_num_threads(4) -torch.set_num_interop_threads(1) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py new file mode 120000 index 000000000..1962351e9 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/jit_pretrained_streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/jit_pretrained_streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_check.py b/egs/gigaspeech/ASR/zipformer/onnx_check.py deleted file mode 100755 index 93bd3a211..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_check.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model via torchscript (torch.jit.script()) - -./zipformer/export.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp/ \ - --jit 1 - -It will generate the following file in $repo/exp: - - jit_script.pt - -3. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -4. Run this file - -./zipformer/onnx_check.py \ - --jit-filename $repo/exp/jit_script.pt \ - --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx -""" - -import argparse -import logging - -import torch -from onnx_pretrained import OnnxModel - -from icefall import is_module_available - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the onnx encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the onnx decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the onnx joiner model", - ) - - return parser - - -def test_encoder( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - C = 80 - for i in range(3): - N = torch.randint(low=1, high=20, size=(1,)).item() - T = torch.randint(low=30, high=50, size=(1,)).item() - logging.info(f"test_encoder: iter {i}, N={N}, T={T}") - - x = torch.rand(N, T, C) - x_lens = torch.randint(low=30, high=T + 1, size=(N,)) - x_lens[0] = T - - torch_encoder_out, torch_encoder_out_lens = torch_model.encoder(x, x_lens) - torch_encoder_out = torch_model.joiner.encoder_proj(torch_encoder_out) - - onnx_encoder_out, onnx_encoder_out_lens = onnx_model.run_encoder(x, x_lens) - - assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-05), ( - (torch_encoder_out - onnx_encoder_out).abs().max() - ) - - -def test_decoder( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - context_size = onnx_model.context_size - vocab_size = onnx_model.vocab_size - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_decoder: iter {i}, N={N}") - x = torch.randint( - low=1, - high=vocab_size, - size=(N, context_size), - dtype=torch.int64, - ) - torch_decoder_out = torch_model.decoder(x, need_pad=torch.tensor([False])) - torch_decoder_out = torch_model.joiner.decoder_proj(torch_decoder_out) - torch_decoder_out = torch_decoder_out.squeeze(1) - - onnx_decoder_out = onnx_model.run_decoder(x) - assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( - (torch_decoder_out - onnx_decoder_out).abs().max() - ) - - -def test_joiner( - torch_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - encoder_dim = torch_model.joiner.encoder_proj.weight.shape[1] - decoder_dim = torch_model.joiner.decoder_proj.weight.shape[1] - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_joiner: iter {i}, N={N}") - encoder_out = torch.rand(N, encoder_dim) - decoder_out = torch.rand(N, decoder_dim) - - projected_encoder_out = torch_model.joiner.encoder_proj(encoder_out) - projected_decoder_out = torch_model.joiner.decoder_proj(decoder_out) - - torch_joiner_out = torch_model.joiner(encoder_out, decoder_out) - onnx_joiner_out = onnx_model.run_joiner( - projected_encoder_out, projected_decoder_out - ) - - assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( - (torch_joiner_out - onnx_joiner_out).abs().max() - ) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - torch_model = torch.jit.load(args.jit_filename) - - onnx_model = OnnxModel( - encoder_model_filename=args.onnx_encoder_filename, - decoder_model_filename=args.onnx_decoder_filename, - joiner_model_filename=args.onnx_joiner_filename, - ) - - logging.info("Test encoder") - test_encoder(torch_model, onnx_model) - - logging.info("Test decoder") - test_decoder(torch_model, onnx_model) - - logging.info("Test joiner") - test_joiner(torch_model, onnx_model) - logging.info("Finished checking ONNX models") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -# See https://github.com/pytorch/pytorch/issues/38342 -# and https://github.com/pytorch/pytorch/issues/33354 -# -# If we don't do this, the delay increases whenever there is -# a new request that changes the actual batch size. -# If you use `py-spy dump --pid --native`, you will -# see a lot of time is spent in re-compiling the torch script model. -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/onnx_check.py b/egs/gigaspeech/ASR/zipformer/onnx_check.py new file mode 120000 index 000000000..f3dd42004 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_check.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_decode.py b/egs/gigaspeech/ASR/zipformer/onnx_decode.py deleted file mode 100755 index 356c2a830..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_decode.py +++ /dev/null @@ -1,325 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX exported models and uses them to decode the test sets. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_bpe_500/bpe.model" -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -2. Run this file - -./zipformer/onnx_decode.py \ - --exp-dir $repo/exp \ - --max-duration 600 \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ -""" - - -import argparse -import logging -import time -from pathlib import Path -from typing import List, Tuple - -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule - -from onnx_pretrained import greedy_search, OnnxModel - -from icefall.utils import setup_logger, store_transcripts, write_error_stats -from k2 import SymbolTable - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="Valid values are greedy_search and modified_beam_search", - ) - - return parser - - -def decode_one_batch( - model: OnnxModel, token_table: SymbolTable, batch: dict -) -> List[List[str]]: - """Decode one batch and return the result. - Currently it only greedy_search is supported. - - Args: - model: - The neural model. - token_table: - The token table. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - - Returns: - Return the decoded results for each utterance. - """ - feature = batch["inputs"] - assert feature.ndim == 3 - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(dtype=torch.int64) - - encoder_out, encoder_out_lens = model.run_encoder(x=feature, x_lens=feature_lens) - - hyps = greedy_search( - model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens - ) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - hyps = [token_ids_to_words(h).split() for h in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - model: nn.Module, - token_table: SymbolTable, -) -> Tuple[List[Tuple[str, List[str], List[str]]], float]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - model: - The neural model. - token_table: - The token table. - - Returns: - - A list of tuples. Each tuple contains three elements: - - cut_id, - - reference transcript, - - predicted result. - - The total duration (in seconds) of the dataset. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - log_interval = 10 - total_duration = 0 - - results = [] - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - total_duration += sum([cut.duration for cut in batch["supervisions"]["cut"]]) - - hyps = decode_one_batch(model=model, token_table=token_table, batch=batch) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - - return results, total_duration - - -def save_results( - res_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -): - recog_path = res_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = res_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - errs_info = res_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("WER", file=f) - print(wer, file=f) - - s = "\nFor {}, WER is {}:\n".format(test_set_name, wer) - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - assert ( - args.decoding_method == "greedy_search" - ), "Only supports greedy_search currently." - res_dir = Path(args.exp_dir) / f"onnx-{args.decoding_method}" - - setup_logger(f"{res_dir}/log-decode") - logging.info("Decoding started") - - device = torch.device("cpu") - logging.info(f"Device: {device}") - - token_table = SymbolTable.from_file(args.tokens) - - logging.info(vars(args)) - - logging.info("About to create model") - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - # we need cut ids to display recognition results. - args.return_cuts = True - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - start_time = time.time() - results, total_duration = decode_dataset( - dl=test_dl, model=model, token_table=token_table - ) - end_time = time.time() - elapsed_seconds = end_time - start_time - rtf = elapsed_seconds / total_duration - - logging.info(f"Elapsed time: {elapsed_seconds:.3f} s") - logging.info(f"Wave duration: {total_duration:.3f} s") - logging.info( - f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}" - ) - - save_results(res_dir=res_dir, test_set_name=test_set, results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/gigaspeech/ASR/zipformer/onnx_decode.py b/egs/gigaspeech/ASR/zipformer/onnx_decode.py new file mode 120000 index 000000000..0573b88c5 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_decode.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py deleted file mode 100755 index e62491444..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py +++ /dev/null @@ -1,546 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) -# Copyright 2023 Danqing Fu (danqing.fu@gmail.com) - -""" -This script loads ONNX models exported by ./export-onnx-streaming.py -and uses them to decode waves. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx-streaming.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --num-encoder-layers "2,2,3,4,3,2" \ - --downsampling-factor "1,2,4,8,4,2" \ - --feedforward-dim "512,768,1024,1536,1024,768" \ - --num-heads "4,4,4,8,4,4" \ - --encoder-dim "192,256,384,512,384,256" \ - --query-head-dim 32 \ - --value-head-dim 12 \ - --pos-head-dim 4 \ - --pos-dim 48 \ - --encoder-unmasked-dim "192,192,256,256,256,192" \ - --cnn-module-kernel "31,31,15,15,15,31" \ - --decoder-dim 512 \ - --joiner-dim 512 \ - --causal True \ - --chunk-size 16 \ - --left-context-frames 64 - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file with the exported ONNX models - -./zipformer/onnx_pretrained-streaming.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav - -Note: Even though this script only supports decoding a single file, -the exported ONNX models do support batch processing. -""" - -import argparse -import logging -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import onnxruntime as ort -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - self.init_encoder_states() - - def init_encoder_states(self, batch_size: int = 1): - encoder_meta = self.encoder.get_modelmeta().custom_metadata_map - logging.info(f"encoder_meta={encoder_meta}") - - model_type = encoder_meta["model_type"] - assert model_type == "zipformer2", model_type - - decode_chunk_len = int(encoder_meta["decode_chunk_len"]) - T = int(encoder_meta["T"]) - - num_encoder_layers = encoder_meta["num_encoder_layers"] - encoder_dims = encoder_meta["encoder_dims"] - cnn_module_kernels = encoder_meta["cnn_module_kernels"] - left_context_len = encoder_meta["left_context_len"] - query_head_dims = encoder_meta["query_head_dims"] - value_head_dims = encoder_meta["value_head_dims"] - num_heads = encoder_meta["num_heads"] - - def to_int_list(s): - return list(map(int, s.split(","))) - - num_encoder_layers = to_int_list(num_encoder_layers) - encoder_dims = to_int_list(encoder_dims) - cnn_module_kernels = to_int_list(cnn_module_kernels) - left_context_len = to_int_list(left_context_len) - query_head_dims = to_int_list(query_head_dims) - value_head_dims = to_int_list(value_head_dims) - num_heads = to_int_list(num_heads) - - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"T: {T}") - logging.info(f"num_encoder_layers: {num_encoder_layers}") - logging.info(f"encoder_dims: {encoder_dims}") - logging.info(f"cnn_module_kernels: {cnn_module_kernels}") - logging.info(f"left_context_len: {left_context_len}") - logging.info(f"query_head_dims: {query_head_dims}") - logging.info(f"value_head_dims: {value_head_dims}") - logging.info(f"num_heads: {num_heads}") - - num_encoders = len(num_encoder_layers) - - self.states = [] - for i in range(num_encoders): - num_layers = num_encoder_layers[i] - key_dim = query_head_dims[i] * num_heads[i] - embed_dim = encoder_dims[i] - nonlin_attn_head_dim = 3 * embed_dim // 4 - value_dim = value_head_dims[i] * num_heads[i] - conv_left_pad = cnn_module_kernels[i] // 2 - - for layer in range(num_layers): - cached_key = torch.zeros( - left_context_len[i], batch_size, key_dim - ).numpy() - cached_nonlin_attn = torch.zeros( - 1, batch_size, left_context_len[i], nonlin_attn_head_dim - ).numpy() - cached_val1 = torch.zeros( - left_context_len[i], batch_size, value_dim - ).numpy() - cached_val2 = torch.zeros( - left_context_len[i], batch_size, value_dim - ).numpy() - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).numpy() - self.states += [ - cached_key, - cached_nonlin_attn, - cached_val1, - cached_val2, - cached_conv1, - cached_conv2, - ] - embed_states = torch.zeros(batch_size, 128, 3, 19).numpy() - self.states.append(embed_states) - processed_lens = torch.zeros(batch_size, dtype=torch.int64).numpy() - self.states.append(processed_lens) - - self.num_encoders = num_encoders - - self.segment = T - self.offset = decode_chunk_len - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def _build_encoder_input_output( - self, - x: torch.Tensor, - ) -> Tuple[Dict[str, np.ndarray], List[str]]: - encoder_input = {"x": x.numpy()} - encoder_output = ["encoder_out"] - - def build_inputs_outputs(tensors, i): - assert len(tensors) == 6, len(tensors) - - # (downsample_left, batch_size, key_dim) - name = f"cached_key_{i}" - encoder_input[name] = tensors[0] - encoder_output.append(f"new_{name}") - - # (1, batch_size, downsample_left, nonlin_attn_head_dim) - name = f"cached_nonlin_attn_{i}" - encoder_input[name] = tensors[1] - encoder_output.append(f"new_{name}") - - # (downsample_left, batch_size, value_dim) - name = f"cached_val1_{i}" - encoder_input[name] = tensors[2] - encoder_output.append(f"new_{name}") - - # (downsample_left, batch_size, value_dim) - name = f"cached_val2_{i}" - encoder_input[name] = tensors[3] - encoder_output.append(f"new_{name}") - - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv1_{i}" - encoder_input[name] = tensors[4] - encoder_output.append(f"new_{name}") - - # (batch_size, embed_dim, conv_left_pad) - name = f"cached_conv2_{i}" - encoder_input[name] = tensors[5] - encoder_output.append(f"new_{name}") - - for i in range(len(self.states[:-2]) // 6): - build_inputs_outputs(self.states[i * 6 : (i + 1) * 6], i) - - # (batch_size, channels, left_pad, freq) - name = "embed_states" - embed_states = self.states[-2] - encoder_input[name] = embed_states - encoder_output.append(f"new_{name}") - - # (batch_size,) - name = "processed_lens" - processed_lens = self.states[-1] - encoder_input[name] = processed_lens - encoder_output.append(f"new_{name}") - - return encoder_input, encoder_output - - def _update_states(self, states: List[np.ndarray]): - self.states = states - - def run_encoder(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - Returns: - Return a 3-D tensor of shape (N, T', joiner_dim) where - T' is usually equal to ((T-7)//2+1)//2 - """ - encoder_input, encoder_output_names = self._build_encoder_input_output(x) - - out = self.encoder.run(encoder_output_names, encoder_input) - - self._update_states(out[1:]) - - return torch.from_numpy(out[0]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def create_streaming_feature_extractor() -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - return OnlineFbank(opts) - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - context_size: int, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, -) -> List[int]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (1, T, joiner_dim) - context_size: - The context size of the decoder model. - decoder_out: - Optional. Decoder output of the previous chunk. - hyp: - Decoding results for previous chunks. - Returns: - Return the decoded results so far. - """ - - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - else: - assert hyp is not None, hyp - - encoder_out = encoder_out.squeeze(0) - T = encoder_out.size(0) - for t in range(T): - cur_encoder_out = encoder_out[t : t + 1] - joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - decoder_input = torch.tensor([decoder_input], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - - return hyp, decoder_out - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - sample_rate = 16000 - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor() - - logging.info(f"Reading sound files: {args.sound_file}") - waves = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=sample_rate, - )[0] - - tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) - wave_samples = torch.cat([waves, tail_padding]) - - num_processed_frames = 0 - segment = model.segment - offset = model.offset - - context_size = model.context_size - hyp = None - decoder_out = None - - chunk = int(1 * sample_rate) # 1 second - start = 0 - while start < wave_samples.numel(): - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - - online_fbank.accept_waveform( - sampling_rate=sample_rate, - waveform=samples, - ) - - while online_fbank.num_frames_ready - num_processed_frames >= segment: - frames = [] - for i in range(segment): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - num_processed_frames += offset - frames = torch.cat(frames, dim=0) - frames = frames.unsqueeze(0) - encoder_out = model.run_encoder(frames) - hyp, decoder_out = greedy_search( - model, - encoder_out, - context_size, - decoder_out, - hyp, - ) - - token_table = k2.SymbolTable.from_file(args.tokens) - - text = "" - for i in hyp[context_size:]: - text += token_table[i] - text = text.replace("▁", " ").strip() - - logging.info(args.sound_file) - logging.info(text) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py new file mode 120000 index 000000000..cfea104c2 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py deleted file mode 100755 index 334376093..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py +++ /dev/null @@ -1,421 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/librispeech/ASR - -repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "exp/pretrained.pt" - -cd exp -ln -s pretrained.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./zipformer/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp \ - --causal False - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file - -./zipformer/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav \ - $repo/test_wavs/1221-135766-0001.wav \ - $repo/test_wavs/1221-135766-0002.wav -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def run_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, its shape is (N, T', joiner_dim) - - encoder_out_lens, its shape is (N,) - """ - out = self.encoder.run( - [ - self.encoder.get_outputs()[0].name, - self.encoder.get_outputs()[1].name, - ], - { - self.encoder.get_inputs()[0].name: x.numpy(), - self.encoder.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, joiner_dim) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.run_decoder(decoder_input) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - # current_encoder_out's shape: (batch_size, joiner_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - logits = model.run_joiner(current_encoder_out, decoder_out) - - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = model.run_decoder(decoder_input) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - - token_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp) - s += f"{filename}:\n{words}\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py new file mode 120000 index 000000000..8f32f4ee7 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py deleted file mode 100755 index eb5cee9cd..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -""" -This script loads ONNX models and uses them to decode waves. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 -as an example to show how to use this file. - -1. Please follow ./export-onnx-ctc.py to get the onnx model. - -2. Run this file - -./zipformer/onnx_pretrained_ctc.py \ - --nn-model /path/to/model.onnx \ - --tokens /path/to/data/lang_bpe_500/tokens.txt \ - 1089-134686-0001.wav \ - 1221-135766-0001.wav \ - 1221-135766-0002.wav -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model", - type=str, - required=True, - help="Path to the onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - nn_model: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_model(nn_model) - - def init_model(self, nn_model: str): - self.model = ort.InferenceSession( - nn_model, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - meta = self.model.get_modelmeta().custom_metadata_map - print(meta) - - def __call__( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D float tensor of shape (N, T, C) - x_lens: - A 1-D int64 tensor of shape (N,) - Returns: - Return a tuple containing: - - A float tensor containing log_probs of shape (N, T, C) - - A int64 tensor containing log_probs_len of shape (N) - """ - out = self.model.run( - [ - self.model.get_outputs()[0].name, - self.model.get_outputs()[1].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - nn_model=args.nn_model, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - log_probs, log_probs_len = model(features, feature_lengths) - - token_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - blank_id = 0 - s = "\n" - for i in range(log_probs.size(0)): - # greedy search - indexes = log_probs[i, : log_probs_len[i]].argmax(dim=-1) - token_ids = torch.unique_consecutive(indexes) - - token_ids = token_ids[token_ids != blank_id] - words = token_ids_to_words(token_ids.tolist()) - s += f"{args.sound_files[i]}:\n{words}\n\n" - - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py new file mode 120000 index 000000000..a3183ebf6 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py deleted file mode 100755 index 683a7dc20..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py +++ /dev/null @@ -1,277 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -""" -This script loads ONNX models and uses them to decode waves. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 -as an example to show how to use this file. - -1. Please follow ./export-onnx-ctc.py to get the onnx model. - -2. Run this file - -./zipformer/onnx_pretrained_ctc_H.py \ - --nn-model /path/to/model.onnx \ - --tokens /path/to/data/lang_bpe_500/tokens.txt \ - --H /path/to/H.fst \ - 1089-134686-0001.wav \ - 1221-135766-0001.wav \ - 1221-135766-0002.wav - -You can find exported ONNX models at -https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -from typing import Dict -import kaldifst -import onnxruntime as ort -import torch -import torchaudio -from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model", - type=str, - required=True, - help="Path to the onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--H", - type=str, - help="""Path to H.fst.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - nn_model: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_model(nn_model) - - def init_model(self, nn_model: str): - self.model = ort.InferenceSession( - nn_model, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - meta = self.model.get_modelmeta().custom_metadata_map - print(meta) - - def __call__( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D float tensor of shape (N, T, C) - x_lens: - A 1-D int64 tensor of shape (N,) - Returns: - Return a tuple containing: - - A float tensor containing log_probs of shape (N, T, C) - - A int64 tensor containing log_probs_len of shape (N) - """ - out = self.model.run( - [ - self.model.get_outputs()[0].name, - self.model.get_outputs()[1].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def decode( - filename: str, - log_probs: torch.Tensor, - H: kaldifst, - id2token: Dict[int, str], -) -> List[str]: - """ - Args: - filename: - Path to the filename for decoding. Used for debugging. - log_probs: - A 2-D float32 tensor of shape (num_frames, vocab_size). It - contains output from log_softmax. - H: - The H graph. - id2word: - A map mapping token ID to word string. - Returns: - Return a list of decoded words. - """ - logging.info(f"{filename}, {log_probs.shape}") - decodable = DecodableCtc(log_probs.cpu()) - - decoder_opts = FasterDecoderOptions(max_active=3000) - decoder = FasterDecoder(H, decoder_opts) - decoder.decode(decodable) - - if not decoder.reached_final(): - logging.info(f"failed to decode {filename}") - return [""] - - ok, best_path = decoder.get_best_path() - - ( - ok, - isymbols_out, - osymbols_out, - total_weight, - ) = kaldifst.get_linear_symbol_sequence(best_path) - if not ok: - logging.info(f"failed to get linear symbol sequence for {filename}") - return [""] - - # tokens are incremented during graph construction - # are shifted by 1 during graph construction - hyps = [id2token[i - 1] for i in osymbols_out if i != 1] - hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ - - return hyps - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - nn_model=args.nn_model, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - logging.info(f"Loading H from {args.H}") - H = kaldifst.StdVectorFst.read(args.H) - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - log_probs, log_probs_len = model(features, feature_lengths) - - token_table = k2.SymbolTable.from_file(args.tokens) - - hyps = [] - for i in range(log_probs.shape[0]): - hyp = decode( - filename=args.sound_files[i], - log_probs=log_probs[i, : log_probs_len[i]], - H=H, - id2token=token_table, - ) - hyps.append(hyp) - - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py new file mode 120000 index 000000000..a4fd76ac2 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_H.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_H.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py deleted file mode 100755 index 0b94bfa65..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py +++ /dev/null @@ -1,275 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -""" -This script loads ONNX models and uses them to decode waves. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 -as an example to show how to use this file. - -1. Please follow ./export-onnx-ctc.py to get the onnx model. - -2. Run this file - -./zipformer/onnx_pretrained_ctc_HL.py \ - --nn-model /path/to/model.onnx \ - --words /path/to/data/lang_bpe_500/words.txt \ - --HL /path/to/HL.fst \ - 1089-134686-0001.wav \ - 1221-135766-0001.wav \ - 1221-135766-0002.wav - -You can find exported ONNX models at -https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -from typing import Dict -import kaldifst -import onnxruntime as ort -import torch -import torchaudio -from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model", - type=str, - required=True, - help="Path to the onnx model. ", - ) - - parser.add_argument( - "--words", - type=str, - help="""Path to words.txt.""", - ) - - parser.add_argument( - "--HL", - type=str, - help="""Path to HL.fst.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - nn_model: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_model(nn_model) - - def init_model(self, nn_model: str): - self.model = ort.InferenceSession( - nn_model, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - meta = self.model.get_modelmeta().custom_metadata_map - print(meta) - - def __call__( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D float tensor of shape (N, T, C) - x_lens: - A 1-D int64 tensor of shape (N,) - Returns: - Return a tuple containing: - - A float tensor containing log_probs of shape (N, T, C) - - A int64 tensor containing log_probs_len of shape (N) - """ - out = self.model.run( - [ - self.model.get_outputs()[0].name, - self.model.get_outputs()[1].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def decode( - filename: str, - log_probs: torch.Tensor, - HL: kaldifst, - id2word: Dict[int, str], -) -> List[str]: - """ - Args: - filename: - Path to the filename for decoding. Used for debugging. - log_probs: - A 2-D float32 tensor of shape (num_frames, vocab_size). It - contains output from log_softmax. - HL: - The HL graph. - id2word: - A map mapping word ID to word string. - Returns: - Return a list of decoded words. - """ - logging.info(f"{filename}, {log_probs.shape}") - decodable = DecodableCtc(log_probs.cpu()) - - decoder_opts = FasterDecoderOptions(max_active=3000) - decoder = FasterDecoder(HL, decoder_opts) - decoder.decode(decodable) - - if not decoder.reached_final(): - logging.info(f"failed to decode {filename}") - return [""] - - ok, best_path = decoder.get_best_path() - - ( - ok, - isymbols_out, - osymbols_out, - total_weight, - ) = kaldifst.get_linear_symbol_sequence(best_path) - if not ok: - logging.info(f"failed to get linear symbol sequence for {filename}") - return [""] - - # are shifted by 1 during graph construction - hyps = [id2word[i] for i in osymbols_out] - - return hyps - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - nn_model=args.nn_model, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - logging.info(f"Loading HL from {args.HL}") - HL = kaldifst.StdVectorFst.read(args.HL) - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - log_probs, log_probs_len = model(features, feature_lengths) - - word_table = k2.SymbolTable.from_file(args.words) - - hyps = [] - for i in range(log_probs.shape[0]): - hyp = decode( - filename=args.sound_files[i], - log_probs=log_probs[i, : log_probs_len[i]], - HL=HL, - id2word=word_table, - ) - hyps.append(hyp) - - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py new file mode 120000 index 000000000..f805e3761 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HL.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HL.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py deleted file mode 100755 index 93569142a..000000000 --- a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py +++ /dev/null @@ -1,275 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -""" -This script loads ONNX models and uses them to decode waves. - -We use the pre-trained model from -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-transducer-ctc-2023-06-13 -as an example to show how to use this file. - -1. Please follow ./export-onnx-ctc.py to get the onnx model. - -2. Run this file - -./zipformer/onnx_pretrained_ctc_HLG.py \ - --nn-model /path/to/model.onnx \ - --words /path/to/data/lang_bpe_500/words.txt \ - --HLG /path/to/HLG.fst \ - 1089-134686-0001.wav \ - 1221-135766-0001.wav \ - 1221-135766-0002.wav - -You can find exported ONNX models at -https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -from typing import Dict -import kaldifst -import onnxruntime as ort -import torch -import torchaudio -from kaldi_decoder import DecodableCtc, FasterDecoder, FasterDecoderOptions -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model", - type=str, - required=True, - help="Path to the onnx model. ", - ) - - parser.add_argument( - "--words", - type=str, - help="""Path to words.txt.""", - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.fst.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - nn_model: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_model(nn_model) - - def init_model(self, nn_model: str): - self.model = ort.InferenceSession( - nn_model, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - meta = self.model.get_modelmeta().custom_metadata_map - print(meta) - - def __call__( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D float tensor of shape (N, T, C) - x_lens: - A 1-D int64 tensor of shape (N,) - Returns: - Return a tuple containing: - - A float tensor containing log_probs of shape (N, T, C) - - A int64 tensor containing log_probs_len of shape (N) - """ - out = self.model.run( - [ - self.model.get_outputs()[0].name, - self.model.get_outputs()[1].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - self.model.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def decode( - filename: str, - log_probs: torch.Tensor, - HLG: kaldifst, - id2word: Dict[int, str], -) -> List[str]: - """ - Args: - filename: - Path to the filename for decoding. Used for debugging. - log_probs: - A 2-D float32 tensor of shape (num_frames, vocab_size). It - contains output from log_softmax. - HLG: - The HLG graph. - id2word: - A map mapping word ID to word string. - Returns: - Return a list of decoded words. - """ - logging.info(f"{filename}, {log_probs.shape}") - decodable = DecodableCtc(log_probs.cpu()) - - decoder_opts = FasterDecoderOptions(max_active=3000) - decoder = FasterDecoder(HLG, decoder_opts) - decoder.decode(decodable) - - if not decoder.reached_final(): - logging.info(f"failed to decode {filename}") - return [""] - - ok, best_path = decoder.get_best_path() - - ( - ok, - isymbols_out, - osymbols_out, - total_weight, - ) = kaldifst.get_linear_symbol_sequence(best_path) - if not ok: - logging.info(f"failed to get linear symbol sequence for {filename}") - return [""] - - # are shifted by 1 during graph construction - hyps = [id2word[i] for i in osymbols_out] - - return hyps - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - nn_model=args.nn_model, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - logging.info(f"Loading HLG from {args.HLG}") - HLG = kaldifst.StdVectorFst.read(args.HLG) - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - log_probs, log_probs_len = model(features, feature_lengths) - - word_table = k2.SymbolTable.from_file(args.words) - - hyps = [] - for i in range(log_probs.shape[0]): - hyp = decode( - filename=args.sound_files[i], - log_probs=log_probs[i, : log_probs_len[i]], - HLG=HLG, - id2word=word_table, - ) - hyps.append(hyp) - - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py new file mode 120000 index 000000000..8343d5079 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/onnx_pretrained_ctc_HLG.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/pretrained.py b/egs/gigaspeech/ASR/zipformer/pretrained.py deleted file mode 100755 index 3104b6084..000000000 --- a/egs/gigaspeech/ASR/zipformer/pretrained.py +++ /dev/null @@ -1,381 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -Usage of this script: - -- For non-streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -- For streaming model: - -(1) greedy search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) modified beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) fast beam search -./zipformer/pretrained.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - /path/to/foo.wav \ - /path/to/bar.wav - - -You can also use `./zipformer/exp/epoch-xx.pt`. - -Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - fast_beam_search_one_best, - greedy_search_batch, - modified_beam_search, -) -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.utils import make_pad_mask - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - if params.causal: - assert ( - "," not in params.chunk_size - ), "chunk_size should be one value in decoding." - assert ( - "," not in params.left_context_frames - ), "left_context_frames should be one value in decoding." - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - # model forward - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - - hyps = [] - msg = f"Using {params.method}" - logging.info(msg) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - else: - raise ValueError(f"Unsupported method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/pretrained.py b/egs/gigaspeech/ASR/zipformer/pretrained.py new file mode 120000 index 000000000..0bd71dde4 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained.py \ No newline at end of file diff --git a/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py deleted file mode 100755 index 9dff2e6fc..000000000 --- a/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py +++ /dev/null @@ -1,455 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --use-ctc 1 \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -Usage of this script: - -(1) ctc-decoding -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method ctc-decoding \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) 1best -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --method 1best \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) nbest-rescoring -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method nbest-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav - - -(4) whole-lattice-rescoring -./zipformer/pretrained_ctc.py \ - --checkpoint ./zipformer/exp/pretrained.pt \ - --HLG data/lang_bpe_500/HLG.pt \ - --words-file data/lang_bpe_500/words.txt \ - --G data/lm/G_4_gram.pt \ - --method whole-lattice-rescoring \ - --sample-rate 16000 \ - /path/to/foo.wav \ - /path/to/bar.wav -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from ctc_decode import get_decoding_params -from export import num_tokens -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_model, get_params - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.utils import get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--words-file", - type=str, - help="""Path to words.txt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.pt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a token table, - i.e., lang_dir/tokens.txt, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an LM, the path with - the highest score is the decoding result. - We call it HLG decoding + nbest n-gram LM rescoring. - (3) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + whole-lattice n-gram LM rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or nbest-rescoring. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and nbest-rescoring. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=1.0, - help=""" - Used only when method is nbest-rescoring. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float = 16000 -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - # add decoding params - params.update(get_decoding_params()) - params.update(vars(args)) - - token_table = k2.SymbolTable.from_file(params.tokens) - params.vocab_size = num_tokens(token_table) + 1 # +1 for blank - params.blank_id = token_table[""] - assert params.blank_id == 0 - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) - ctc_output = model.ctc_output(encoder_out) # (N, T, C) - - batch_size = ctc_output.shape[0] - supervision_segments = torch.tensor( - [ - [i, 0, feature_lengths[i].item() // params.subsampling_factor] - for i in range(batch_size) - ], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - max_token_id = params.vocab_size - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = [[token_table[i] for i in ids] for ids in token_ids] - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - G = G.to(device) - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=ctc_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - if params.method == "nbest-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=[params.ngram_lm_scale], - nbest_scale=params.nbest_scale, - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - if params.method == "ctc-decoding": - for filename, hyp in zip(params.sound_files, hyps): - words = "".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - elif params.method in [ - "1best", - "nbest-rescoring", - "whole-lattice-rescoring", - ]: - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - words = words.replace("▁", " ").strip() - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py b/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py new file mode 120000 index 000000000..c2f6f6fc3 --- /dev/null +++ b/egs/gigaspeech/ASR/zipformer/pretrained_ctc.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/pretrained_ctc.py \ No newline at end of file