mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
add streaming support to reazonresearch
This commit is contained in:
parent
1730fce688
commit
250ff30875
@ -103,10 +103,9 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
@ -122,6 +121,7 @@ from beam_search import (
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from lhotse import set_caching_enabled
|
||||
from tokenizer import Tokenizer
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
@ -134,6 +134,7 @@ from icefall.checkpoint import (
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
@ -204,7 +205,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
default="data/lang_char",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
@ -377,6 +378,17 @@ def get_parser():
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
@ -386,7 +398,7 @@ def get_parser():
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
sp: Tokenizer,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
@ -465,9 +477,10 @@ def decode_one_batch(
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
@ -479,6 +492,7 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
@ -493,9 +507,10 @@ def decode_one_batch(
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
@ -508,17 +523,19 @@ def decode_one_batch(
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
@ -526,9 +543,10 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
context_graph=context_graph,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
|
||||
hyp_tokens = modified_beam_search_lm_shallow_fusion(
|
||||
model=model,
|
||||
@ -538,7 +556,7 @@ def decode_one_batch(
|
||||
LM=LM,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "modified_beam_search_LODR":
|
||||
hyp_tokens = modified_beam_search_LODR(
|
||||
model=model,
|
||||
@ -551,7 +569,7 @@ def decode_one_batch(
|
||||
context_graph=context_graph,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
hyps.append(sp.text2word(hyp))
|
||||
elif params.decoding_method == "modified_beam_search_lm_rescore":
|
||||
lm_scale_list = [0.01 * i for i in range(10, 50)]
|
||||
ans_dict = modified_beam_search_lm_rescore(
|
||||
@ -597,10 +615,11 @@ def decode_one_batch(
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
hyps.append(sp.text2word(sp.decode(hyp)))
|
||||
|
||||
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
|
||||
prefix = f"{params.decoding_method}"
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
@ -639,7 +658,7 @@ def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
sp: Tokenizer,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
context_graph: Optional[ContextGraph] = None,
|
||||
@ -705,7 +724,7 @@ def decode_dataset(
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
ref_words = sp.text2word(ref_text)
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
@ -778,8 +797,8 @@ def save_wer_results(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
LmScorer.add_arguments(parser)
|
||||
ReazonSpeechAsrDataModule.add_arguments(parser)
|
||||
Tokenizer.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
@ -853,6 +872,8 @@ def main():
|
||||
f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
|
||||
)
|
||||
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "_use-averaged-model"
|
||||
|
||||
@ -865,10 +886,9 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
# <blk> and <unk> are defined in local/prepare_lang_char.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
@ -1040,20 +1060,13 @@ def main():
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
for subdir in ["valid"]:
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
dl=reazonspeech_corpus.test_dataloaders(
|
||||
getattr(reazonspeech_corpus, f"{subdir}_cuts")()
|
||||
),
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
@ -1067,9 +1080,20 @@ def main():
|
||||
|
||||
save_asr_output(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
test_set_name=subdir,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
# with (
|
||||
# params.res_dir
|
||||
# / (
|
||||
# f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
|
||||
# f"_{params.avg}_{params.epoch}.cer"
|
||||
# )
|
||||
# ).open("w") as fout:
|
||||
# if len(tot_err) == 1:
|
||||
# fout.write(f"{tot_err[0][1]}")
|
||||
# else:
|
||||
# fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
|
||||
|
||||
if not params.skip_scoring:
|
||||
save_wer_results(
|
||||
|
@ -2434,4 +2434,4 @@ if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_zipformer_main(False)
|
||||
_test_zipformer_main(True)
|
||||
_test_zipformer_main(True)
|
@ -12,7 +12,6 @@ class Tokenizer:
|
||||
@staticmethod
|
||||
def add_arguments(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(title="Lang related options")
|
||||
|
||||
group.add_argument("--lang", type=Path, help="Path to lang directory.")
|
||||
|
||||
group.add_argument(
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||
#
|
||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
|
||||
# Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -17,28 +18,29 @@
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||
./zipformer/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--decoding_method greedy_search \
|
||||
--lang data/lang_char \
|
||||
--causal 1 \
|
||||
--chunk-size 32 \
|
||||
--left-context-frames 256 \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--decoding-method greedy_search \
|
||||
--num-decode-streams 2000
|
||||
"""
|
||||
|
||||
import pdb
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from tokenizer import Tokenizer
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
||||
from decode import save_results
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
@ -47,10 +49,9 @@ from streaming_beam_search import (
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from tokenizer import Tokenizer
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import stack_states, unstack_states
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -58,7 +59,17 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
import subprocess as sp
|
||||
import os
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
@ -73,7 +84,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -87,12 +98,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
@ -116,7 +121,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -126,6 +131,13 @@ def get_parser():
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_char",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
@ -138,14 +150,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-graph",
|
||||
type=str,
|
||||
default="",
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
@ -157,7 +161,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4.0,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
@ -194,18 +198,235 @@ def get_parser():
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--res-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="The path to save results.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_init_states(
|
||||
model: nn.Module,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||
states[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
"""
|
||||
states = model.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = model.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
||||
"""Stack list of zipformer states that correspond to separate utterances
|
||||
into a single emformer state, so that it can be used as an input for
|
||||
zipformer when those utterances are formed into a batch.
|
||||
|
||||
Args:
|
||||
state_list:
|
||||
Each element in state_list corresponding to the internal state
|
||||
of the zipformer model for a single utterance. For element-n,
|
||||
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
|
||||
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
|
||||
cached_val2, cached_conv1, cached_conv2).
|
||||
state_list[n][-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
state_list[n][-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`unstack_states`.
|
||||
"""
|
||||
batch_size = len(state_list)
|
||||
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
|
||||
tot_num_layers = (len(state_list[0]) - 2) // 6
|
||||
|
||||
batch_states = []
|
||||
for layer in range(tot_num_layers):
|
||||
layer_offset = layer * 6
|
||||
# cached_key: (left_context_len, batch_size, key_dim)
|
||||
cached_key = torch.cat(
|
||||
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||
cached_nonlin_attn = torch.cat(
|
||||
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_val1: (left_context_len, batch_size, value_dim)
|
||||
cached_val1 = torch.cat(
|
||||
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_val2: (left_context_len, batch_size, value_dim)
|
||||
cached_val2 = torch.cat(
|
||||
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_conv1: (#batch, channels, left_pad)
|
||||
cached_conv1 = torch.cat(
|
||||
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
|
||||
)
|
||||
# cached_conv2: (#batch, channels, left_pad)
|
||||
cached_conv2 = torch.cat(
|
||||
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
|
||||
)
|
||||
batch_states += [
|
||||
cached_key,
|
||||
cached_nonlin_attn,
|
||||
cached_val1,
|
||||
cached_val2,
|
||||
cached_conv1,
|
||||
cached_conv2,
|
||||
]
|
||||
|
||||
cached_embed_left_pad = torch.cat(
|
||||
[state_list[i][-2] for i in range(batch_size)], dim=0
|
||||
)
|
||||
batch_states.append(cached_embed_left_pad)
|
||||
|
||||
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
||||
batch_states.append(processed_lens)
|
||||
|
||||
return batch_states
|
||||
|
||||
|
||||
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
||||
"""Unstack the zipformer state corresponding to a batch of utterances
|
||||
into a list of states, where the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`stack_states`.
|
||||
|
||||
Args:
|
||||
batch_states: A list of cached tensors of all encoder layers. For layer-i,
|
||||
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
|
||||
cached_conv1, cached_conv2).
|
||||
state_list[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
|
||||
Returns:
|
||||
state_list: A list of list. Each element in state_list corresponding to the internal state
|
||||
of the zipformer model for a single utterance.
|
||||
"""
|
||||
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
|
||||
tot_num_layers = (len(batch_states) - 2) // 6
|
||||
|
||||
processed_lens = batch_states[-1]
|
||||
batch_size = processed_lens.shape[0]
|
||||
|
||||
state_list = [[] for _ in range(batch_size)]
|
||||
|
||||
for layer in range(tot_num_layers):
|
||||
layer_offset = layer * 6
|
||||
# cached_key: (left_context_len, batch_size, key_dim)
|
||||
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_val1: (left_context_len, batch_size, value_dim)
|
||||
cached_val1_list = batch_states[layer_offset + 2].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_val2: (left_context_len, batch_size, value_dim)
|
||||
cached_val2_list = batch_states[layer_offset + 3].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_conv1: (#batch, channels, left_pad)
|
||||
cached_conv1_list = batch_states[layer_offset + 4].chunk(
|
||||
chunks=batch_size, dim=0
|
||||
)
|
||||
# cached_conv2: (#batch, channels, left_pad)
|
||||
cached_conv2_list = batch_states[layer_offset + 5].chunk(
|
||||
chunks=batch_size, dim=0
|
||||
)
|
||||
for i in range(batch_size):
|
||||
state_list[i] += [
|
||||
cached_key_list[i],
|
||||
cached_nonlin_attn_list[i],
|
||||
cached_val1_list[i],
|
||||
cached_val2_list[i],
|
||||
cached_conv1_list[i],
|
||||
cached_conv2_list[i],
|
||||
]
|
||||
|
||||
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
||||
for i in range(batch_size):
|
||||
state_list[i].append(cached_embed_left_pad_list[i])
|
||||
|
||||
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
|
||||
for i in range(batch_size):
|
||||
state_list[i].append(processed_lens_list[i])
|
||||
|
||||
return state_list
|
||||
|
||||
|
||||
def streaming_forward(
|
||||
features: Tensor,
|
||||
feature_lens: Tensor,
|
||||
model: nn.Module,
|
||||
states: List[Tensor],
|
||||
chunk_size: int,
|
||||
left_context_len: int,
|
||||
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
||||
"""
|
||||
Returns encoder outputs, output lengths, and updated states.
|
||||
"""
|
||||
cached_embed_left_pad = states[-2]
|
||||
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
)
|
||||
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
encoder_states = states[:-2]
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_encoder_states,
|
||||
) = model.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=encoder_states,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
new_states = new_encoder_states + [
|
||||
new_cached_embed_left_pad,
|
||||
new_processed_lens,
|
||||
]
|
||||
return encoder_out, encoder_out_lens, new_states
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -224,27 +445,34 @@ def decode_one_chunk(
|
||||
Returns:
|
||||
Return a List containing which DecodeStreams are finished.
|
||||
"""
|
||||
device = model.device
|
||||
# pdb.set_trace()
|
||||
# print(model)
|
||||
# print(model.device)
|
||||
# device = model.device
|
||||
chunk_size = int(params.chunk_size)
|
||||
left_context_len = int(params.left_context_frames)
|
||||
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
processed_lens = []
|
||||
|
||||
processed_lens = [] # Used in fast-beam-search
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
|
||||
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
|
||||
print(feature_lens)
|
||||
feature_lens = torch.tensor(feature_lens, device=model.device)
|
||||
print(feature_lens)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
|
||||
# factor in encoders is 8.
|
||||
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
|
||||
tail_length = 23
|
||||
# Make sure the length after encoder_embed is at least 1.
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
tail_length = chunk_size * 2 + 7 + 2 * 3
|
||||
if features.size(1) < tail_length:
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
@ -256,12 +484,14 @@ def decode_one_chunk(
|
||||
)
|
||||
|
||||
states = stack_states(states)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
encoder_out, encoder_out_lens, new_states = streaming_forward(
|
||||
features=features,
|
||||
feature_lens=feature_lens,
|
||||
model=model,
|
||||
states=states,
|
||||
chunk_size=chunk_size,
|
||||
left_context_len=left_context_len,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -269,6 +499,7 @@ def decode_one_chunk(
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
@ -295,9 +526,11 @@ def decode_one_chunk(
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = states[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
|
||||
# if decode_streams[i].done:
|
||||
# finished_streams.append(i)
|
||||
finished_streams.append(i)
|
||||
|
||||
print(finished_streams)
|
||||
return finished_streams
|
||||
|
||||
|
||||
@ -338,14 +571,14 @@ def decode_dataset(
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 50
|
||||
log_interval = 100
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
initial_states = model.encoder.get_init_state(device=device)
|
||||
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
@ -361,15 +594,19 @@ def decode_dataset(
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
# - this is to avoid sending [-32k,+32k] signal in...
|
||||
# - some lhotse AudioTransform classes can make the signal
|
||||
# be out of range [-1, 1], hence the tolerance 10
|
||||
assert (
|
||||
np.abs(audio).max() <= 10
|
||||
), "Should be normalized to [-1, 1], 10 for tolerance..."
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
|
||||
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
|
||||
|
||||
decode_stream.set_features(feature, tail_pad_len=30)
|
||||
decode_stream.ground_truth = cut.supervisions[0].text
|
||||
decode_streams.append(decode_stream)
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
@ -380,8 +617,8 @@ def decode_dataset(
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
sp.text2word(decode_streams[i].ground_truth),
|
||||
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
@ -389,21 +626,43 @@ def decode_dataset(
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
print("cuts processed finished")
|
||||
print(len(decode_streams))
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
# print("INSIDE LEN DECODE STREAMS")
|
||||
# pdb.set_trace()
|
||||
# print(model.device)
|
||||
# test_device = model.device
|
||||
# print("done")
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
# print('INSIDE FOR LOOP ')
|
||||
# print(finished_streams)
|
||||
|
||||
if not finished_streams:
|
||||
print("No finished streams, breaking the loop")
|
||||
break
|
||||
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
sp.text2word(decode_streams[i].ground_truth),
|
||||
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
try:
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
except IndexError as e:
|
||||
print(f"IndexError: {e}")
|
||||
print(f"decode_streams length: {len(decode_streams)}")
|
||||
print(f"finished_streams: {finished_streams}")
|
||||
print(f"i: {i}")
|
||||
continue
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
@ -416,9 +675,57 @@ def decode_dataset(
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
torch.cuda.synchronize()
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
print("error stats")
|
||||
print("results")
|
||||
print(results)
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
@ -430,16 +737,20 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if not params.res_dir:
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||
assert params.causal, params.causal
|
||||
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -455,13 +766,13 @@ def main():
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", params.gpu)
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> and <unk> is defined in local/prepare_lang_char.py
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
@ -469,7 +780,7 @@ def main():
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
@ -553,43 +864,56 @@ def main():
|
||||
model.device = device
|
||||
|
||||
decoding_graph = None
|
||||
if params.decoding_graph:
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(params.decoding_graph, map_location=device)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
||||
|
||||
valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||
test_cuts = reazonspeech_corpus.test_cuts()
|
||||
|
||||
for subdir in ["valid"]:
|
||||
test_sets = ["valid", "test"]
|
||||
test_cuts = [valid_cuts, test_cuts]
|
||||
print('test cuts')
|
||||
print(test_cuts)
|
||||
|
||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||
results_dict = decode_dataset(
|
||||
cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(),
|
||||
cuts=test_cut,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
tot_err = save_results(
|
||||
params=params, test_set_name=subdir, results_dict=results_dict
|
||||
print(r"esults_dict")
|
||||
print(results_dict)
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
with (
|
||||
params.res_dir
|
||||
/ (
|
||||
f"{subdir}-{params.decode_chunk_len}"
|
||||
f"_{params.avg}_{params.epoch}.cer"
|
||||
)
|
||||
).open("w") as fout:
|
||||
if len(tot_err) == 1:
|
||||
fout.write(f"{tot_err[0][1]}")
|
||||
else:
|
||||
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
|
||||
|
||||
|
||||
# valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||
|
||||
# for valid_cut in valid_cuts:
|
||||
# results_dict = decode_dataset(
|
||||
# cuts=valid_cut,
|
||||
# params=params,
|
||||
# model=model,
|
||||
# sp=sp,
|
||||
# decoding_graph=decoding_graph,
|
||||
# )
|
||||
# save_results(
|
||||
# params=params,
|
||||
# test_set_name="valid",
|
||||
# results_dict=results_dict,
|
||||
# )
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user