diff --git a/egs/audioset/AT/zipformer/export.py b/egs/audioset/AT/zipformer/export.py index 1b613b9d1..83034df95 100755 --- a/egs/audioset/AT/zipformer/export.py +++ b/egs/audioset/AT/zipformer/export.py @@ -68,65 +68,15 @@ you can do: ln -s pretrained.pt epoch-9999.pt cd /path/to/egs/librispeech/ASR - ./zipformer/decode.py \ + ./zipformer/evaluate.py \ --exp-dir ./zipformer/exp \ + --use-averaged-model False \ --epoch 9999 \ --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -- For streaming model: - -To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - - # simulated streaming decoding - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - - # chunk-wise streaming decoding - ./zipformer/streaming_decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --max-duration 600 Check ./pretrained.py for its usage. -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -- non-streaming model: -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - -- streaming model: -https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - # You will find the pre-trained models in exp dir """ import argparse @@ -219,13 +169,6 @@ def get_parser(): """, ) - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - add_model_arguments(parser) return parser @@ -258,107 +201,6 @@ class EncoderModel(nn.Module): return encoder_out, encoder_out_lens -class StreamingEncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - assert len(encoder.chunk_size) == 1, encoder.chunk_size - assert len(encoder.left_context_frames) == 1, encoder.left_context_frames - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """Streaming forward for encoder_embed and encoder. - - Args: - features: (N, T, C) - feature_lengths: (N,) - states: a list of Tensors - - Returns encoder outputs, output lengths, and updated states. - """ - chunk_size = self.chunk_size - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lengths, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - @torch.no_grad() def main(): args = get_parser().parse_args() @@ -368,15 +210,8 @@ def main(): params.update(vars(args)) device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - logging.info(params) logging.info("About to create model") @@ -467,15 +302,9 @@ def main(): # torch scriptabe. model.__class__.forward = torch.jit.ignore(model.__class__.forward) - # Wrap encoder and encoder_embed as a module - if params.causal: - model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) - chunk_size = model.encoder.chunk_size - left_context_len = model.encoder.left_context_len - filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" - else: - model.encoder = EncoderModel(model.encoder, model.encoder_embed) - filename = "jit_script.pt" + + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" logging.info("Using torch.jit.script") model = torch.jit.script(model) diff --git a/egs/audioset/AT/zipformer/scaling_converter.py b/egs/audioset/AT/zipformer/scaling_converter.py new file mode 120000 index 000000000..b0ecee05e --- /dev/null +++ b/egs/audioset/AT/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file