Merge 9b13eac94686fe4a31eb190b8deaf3ff967feacd into dd5d7e358bf6a49eb2a31dff2dd2b943f792c1fc

This commit is contained in:
Machiko Bailey 2025-02-01 09:56:33 +00:00 committed by GitHub
commit b1064e93ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 453 additions and 98 deletions

View File

@ -122,6 +122,7 @@ from beam_search import (
modified_beam_search_LODR,
)
from lhotse import set_caching_enabled
from tokenizer import Tokenizer
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm
@ -377,6 +378,17 @@ def get_parser():
default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""",
)
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser)
@ -601,6 +613,7 @@ def decode_one_batch(
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
prefix = f"{params.decoding_method}"
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method:

View File

@ -47,3 +47,41 @@ The decoding command is:
--blank-penalty 0
```
#### Streaming
We have not completed evaluation of our models yet and will add evaluation results here once it's completed.
The training command is:
```shell
./zipformer/train.py \
--world-size 8 \
--num-epochs 40 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large \
--causal 1 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--lang data/lang_char \
--max-duration 1600
```
The decoding command is:
```shell
./zipformer/streaming_decode.py \
--epoch 28 \
--avg 15 \
--causal 1 \
--chunk-size 32 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp-large \
--lang data/lang_char \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
```

View File

@ -12,7 +12,6 @@ class Tokenizer:
@staticmethod
def add_arguments(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Lang related options")
group.add_argument("--lang", type=Path, help="Path to lang directory.")
group.add_argument(

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
# Fangjun Kuang,
# Zengwei Yao)
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -17,28 +18,24 @@
"""
Usage:
./pruned_transducer_stateless7_streaming/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-len 32 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--decoding_method greedy_search \
--lang data/lang_char \
--num-decode-streams 2000
./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192
"""
import argparse
import logging
import math
import os
import pdb
# import subprocess as sp
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import torch
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from decode import save_results
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
@ -48,9 +45,9 @@ from streaming_beam_search import (
modified_beam_search,
)
from tokenizer import Tokenizer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states, unstack_states
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
@ -58,7 +55,14 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
@ -73,7 +77,7 @@ def get_parser():
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
@ -87,12 +91,6 @@ def get_parser():
""",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
)
parser.add_argument(
"--avg",
type=int,
@ -116,7 +114,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="zipformer/exp",
help="The experiment dir",
)
@ -127,6 +125,13 @@ def get_parser():
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
@ -138,14 +143,6 @@ def get_parser():
""",
)
parser.add_argument(
"--decoding-graph",
type=str,
default="",
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--num_active_paths",
type=int,
@ -157,7 +154,7 @@ def get_parser():
parser.add_argument(
"--beam",
type=float,
default=4.0,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
@ -194,18 +191,235 @@ def get_parser():
help="The number of streams that can be decoded parallel.",
)
parser.add_argument(
"--res-dir",
type=Path,
default=None,
help="The path to save results.",
)
add_model_arguments(parser)
return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
@ -224,27 +438,32 @@ def decode_one_chunk(
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
# pdb.set_trace()
# print(model)
# print(model.device)
# device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = []
feature_lens = []
states = []
processed_lens = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
# factor in encoders is 8.
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
tail_length = 23
# Make sure the length after encoder_embed is at least 1.
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
@ -256,12 +475,14 @@ def decode_one_chunk(
)
states = stack_states(states)
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
encoder_out, encoder_out_lens, new_states = streaming_forward(
features=features,
feature_lens=feature_lens,
model=model,
states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
@ -269,6 +490,7 @@ def decode_one_chunk(
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=model.device)
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
@ -295,8 +517,9 @@ def decode_one_chunk(
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
# if decode_streams[i].done:
# finished_streams.append(i)
finished_streams.append(i)
return finished_streams
@ -338,14 +561,14 @@ def decode_dataset(
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = model.encoder.get_init_state(device=device)
initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
@ -361,15 +584,19 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
# - this is to avoid sending [-32k,+32k] signal in...
# - some lhotse AudioTransform classes can make the signal
# be out of range [-1, 1], hence the tolerance 10
assert (
np.abs(audio).max() <= 10
), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
@ -380,8 +607,8 @@ def decode_dataset(
decode_results.append(
(
decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth),
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
@ -391,18 +618,37 @@ def decode_dataset(
# decode final chunks of last sequences
while len(decode_streams):
# print("INSIDE LEN DECODE STREAMS")
# pdb.set_trace()
# print(model.device)
# test_device = model.device
# print("done")
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
# print('INSIDE FOR LOOP ')
# print(finished_streams)
if not finished_streams:
print("No finished streams, breaking the loop")
break
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth),
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
try:
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
)
del decode_streams[i]
del decode_streams[i]
except IndexError as e:
print(f"IndexError: {e}")
print(f"decode_streams length: {len(decode_streams)}")
print(f"finished_streams: {finished_streams}")
print(f"i: {i}")
continue
if params.decoding_method == "greedy_search":
key = "greedy_search"
@ -416,9 +662,54 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize()
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
@ -430,16 +721,20 @@ def main():
params = get_params()
params.update(vars(args))
if not params.res_dir:
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
assert params.causal, params.causal
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
@ -455,13 +750,13 @@ def main():
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", params.gpu)
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/prepare_lang_char.py
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
@ -469,7 +764,7 @@ def main():
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
@ -553,42 +848,51 @@ def main():
model.device = device
decoding_graph = None
if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device)
)
elif params.decoding_method == "fast_beam_search":
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
for subdir in ["valid"]:
valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()
test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(),
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
tot_err = save_results(
params=params, test_set_name=subdir, results_dict=results_dict
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
with (
params.res_dir
/ (
f"{subdir}-{params.decode_chunk_len}"
f"_{params.avg}_{params.epoch}.cer"
)
).open("w") as fout:
if len(tot_err) == 1:
fout.write(f"{tot_err[0][1]}")
else:
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
# valid_cuts = reazonspeech_corpus.valid_cuts()
# for valid_cut in valid_cuts:
# results_dict = decode_dataset(
# cuts=valid_cut,
# params=params,
# model=model,
# sp=sp,
# decoding_graph=decoding_graph,
# )
# save_results(
# params=params,
# test_set_name="valid",
# results_dict=results_dict,
# )
logging.info("Done!")

View File

@ -644,7 +644,8 @@ def write_error_stats(
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
# ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1