add streaming support to reazonresearch

This commit is contained in:
root 2024-08-01 15:08:31 +09:00 committed by root
parent 1730fce688
commit 250ff30875
4 changed files with 482 additions and 135 deletions

View File

@ -103,10 +103,9 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from asr_datamodule import ReazonSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
@ -122,6 +121,7 @@ from beam_search import (
modified_beam_search_LODR,
)
from lhotse import set_caching_enabled
from tokenizer import Tokenizer
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm
@ -134,6 +134,7 @@ from icefall.checkpoint import (
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
@ -204,7 +205,7 @@ def get_parser():
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
@ -377,6 +378,17 @@ def get_parser():
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
@ -386,7 +398,7 @@ def get_parser():
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
sp: Tokenizer,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
@ -465,9 +477,10 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
@ -479,6 +492,7 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
@ -493,9 +507,10 @@ def decode_one_batch(
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
@ -508,17 +523,19 @@ def decode_one_batch(
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -526,9 +543,10 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
blank_penalty=params.blank_penalty,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_lm_shallow_fusion":
hyp_tokens = modified_beam_search_lm_shallow_fusion(
model=model,
@ -538,7 +556,7 @@ def decode_one_batch(
LM=LM,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_LODR":
hyp_tokens = modified_beam_search_LODR(
model=model,
@ -551,7 +569,7 @@ def decode_one_batch(
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
hyps.append(sp.text2word(hyp))
elif params.decoding_method == "modified_beam_search_lm_rescore":
lm_scale_list = [0.01 * i for i in range(10, 50)]
ans_dict = modified_beam_search_lm_rescore(
@ -597,10 +615,11 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
hyps.append(sp.text2word(sp.decode(hyp)))
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
prefix = f"{params.decoding_method}"
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:
@ -639,7 +658,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
sp: Tokenizer,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
@ -705,7 +724,7 @@ def decode_dataset(
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
ref_words = sp.text2word(ref_text)
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
@ -778,8 +797,8 @@ def save_wer_results(
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
LmScorer.add_arguments(parser)
ReazonSpeechAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -853,6 +872,8 @@ def main():
f"_LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}"
)
params.suffix += f"-blank-penalty-{params.blank_penalty}"
if params.use_averaged_model:
params.suffix += "_use-averaged-model"
@ -865,10 +886,9 @@ def main():
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> are defined in local/train_bpe_model.py
# <blk> and <unk> are defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
@ -1040,20 +1060,13 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
librispeech = LibriSpeechAsrDataModule(args)
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
test_sets = ["test-clean", "test-other"]
test_dl = [test_clean_dl, test_other_dl]
for test_set, test_dl in zip(test_sets, test_dl):
for subdir in ["valid"]:
results_dict = decode_dataset(
dl=test_dl,
dl=reazonspeech_corpus.test_dataloaders(
getattr(reazonspeech_corpus, f"{subdir}_cuts")()
),
params=params,
model=model,
sp=sp,
@ -1067,9 +1080,20 @@ def main():
save_asr_output(
params=params,
test_set_name=test_set,
test_set_name=subdir,
results_dict=results_dict,
)
# with (
# params.res_dir
# / (
# f"{subdir}-{params.decode_chunk_len}_{params.beam_size}"
# f"_{params.avg}_{params.epoch}.cer"
# )
# ).open("w") as fout:
# if len(tot_err) == 1:
# fout.write(f"{tot_err[0][1]}")
# else:
# fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
if not params.skip_scoring:
save_wer_results(

View File

@ -2434,4 +2434,4 @@ if __name__ == "__main__":
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_zipformer_main(False)
_test_zipformer_main(True)
_test_zipformer_main(True)

View File

@ -12,7 +12,6 @@ class Tokenizer:
@staticmethod
def add_arguments(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Lang related options")
group.add_argument("--lang", type=Path, help="Path to lang directory.")
group.add_argument(

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
# Fangjun Kuang,
# Zengwei Yao)
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -17,28 +18,29 @@
"""
Usage:
./pruned_transducer_stateless7_streaming/streaming_decode.py \
./zipformer/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-len 32 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--decoding_method greedy_search \
--lang data/lang_char \
--causal 1 \
--chunk-size 32 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp \
--decoding-method greedy_search \
--num-decode-streams 2000
"""
import pdb
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from tokenizer import Tokenizer
import k2
import numpy as np
import torch
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from decode import save_results
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
@ -47,10 +49,9 @@ from streaming_beam_search import (
greedy_search,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states, unstack_states
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
@ -58,7 +59,17 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
import subprocess as sp
import os
LOG_EPS = math.log(1e-10)
@ -73,7 +84,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -87,12 +98,6 @@ def get_parser():
""",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
)
parser.add_argument(
"--avg",
type=int,
@ -116,7 +121,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="zipformer/exp",
help="The experiment dir",
)
@ -126,6 +131,13 @@ def get_parser():
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
@ -138,14 +150,6 @@ def get_parser():
""",
)
parser.add_argument(
"--decoding-graph",
type=str,
default="",
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--num_active_paths",
type=int,
@ -157,7 +161,7 @@ def get_parser():
parser.add_argument(
"--beam",
type=float,
default=4.0,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
@ -194,18 +198,235 @@ def get_parser():
help="The number of streams that can be decoded parallel.",
)
parser.add_argument(
"--res-dir",
type=Path,
default=None,
help="The path to save results.",
)
add_model_arguments(parser)
return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
@ -224,27 +445,34 @@ def decode_one_chunk(
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
# pdb.set_trace()
# print(model)
# print(model.device)
# device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = []
feature_lens = []
states = []
processed_lens = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
print(feature_lens)
feature_lens = torch.tensor(feature_lens, device=model.device)
print(feature_lens)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
# factor in encoders is 8.
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
tail_length = 23
# Make sure the length after encoder_embed is at least 1.
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
@ -256,12 +484,14 @@ def decode_one_chunk(
)
states = stack_states(states)
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
encoder_out, encoder_out_lens, new_states = streaming_forward(
features=features,
feature_lens=feature_lens,
model=model,
states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
@ -269,6 +499,7 @@ def decode_one_chunk(
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
@ -295,9 +526,11 @@ def decode_one_chunk(
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
# if decode_streams[i].done:
# finished_streams.append(i)
finished_streams.append(i)
print(finished_streams)
return finished_streams
@ -338,14 +571,14 @@ def decode_dataset(
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = model.encoder.get_init_state(device=device)
initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
@ -361,15 +594,19 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
# - this is to avoid sending [-32k,+32k] signal in...
# - some lhotse AudioTransform classes can make the signal
# be out of range [-1, 1], hence the tolerance 10
assert (
np.abs(audio).max() <= 10
), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
@ -380,8 +617,8 @@ def decode_dataset(
decode_results.append(
(
decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth),
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
@ -389,21 +626,43 @@ def decode_dataset(
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
print("cuts processed finished")
print(len(decode_streams))
# decode final chunks of last sequences
while len(decode_streams):
# print("INSIDE LEN DECODE STREAMS")
# pdb.set_trace()
# print(model.device)
# test_device = model.device
# print("done")
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
# print('INSIDE FOR LOOP ')
# print(finished_streams)
if not finished_streams:
print("No finished streams, breaking the loop")
break
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth),
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
)
)
del decode_streams[i]
try:
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
except IndexError as e:
print(f"IndexError: {e}")
print(f"decode_streams length: {len(decode_streams)}")
print(f"finished_streams: {finished_streams}")
print(f"i: {i}")
continue
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
@ -416,9 +675,57 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize()
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
print("error stats")
print("results")
print(results)
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
@ -430,16 +737,20 @@ def main():
params = get_params()
params.update(vars(args))
if not params.res_dir:
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
assert params.causal, params.causal
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
@ -455,13 +766,13 @@ def main():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", params.gpu)
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/prepare_lang_char.py
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
@ -469,7 +780,7 @@ def main():
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
@ -553,43 +864,56 @@ def main():
model.device = device
decoding_graph = None
if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device)
)
elif params.decoding_method == "fast_beam_search":
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()
for subdir in ["valid"]:
test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts]
print('test cuts')
print(test_cuts)
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(),
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
tot_err = save_results(
params=params, test_set_name=subdir, results_dict=results_dict
print(r"esults_dict")
print(results_dict)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
with (
params.res_dir
/ (
f"{subdir}-{params.decode_chunk_len}"
f"_{params.avg}_{params.epoch}.cer"
)
).open("w") as fout:
if len(tot_err) == 1:
fout.write(f"{tot_err[0][1]}")
else:
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
# valid_cuts = reazonspeech_corpus.valid_cuts()
# for valid_cut in valid_cuts:
# results_dict = decode_dataset(
# cuts=valid_cut,
# params=params,
# model=model,
# sp=sp,
# decoding_graph=decoding_graph,
# )
# save_results(
# params=params,
# test_set_name="valid",
# results_dict=results_dict,
# )
logging.info("Done!")