Merge 9b13eac94686fe4a31eb190b8deaf3ff967feacd into dd5d7e358bf6a49eb2a31dff2dd2b943f792c1fc

This commit is contained in:
Machiko Bailey 2025-02-01 09:56:33 +00:00 committed by GitHub
commit b1064e93ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 453 additions and 98 deletions

View File

@ -122,6 +122,7 @@ from beam_search import (
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from lhotse import set_caching_enabled from lhotse import set_caching_enabled
from tokenizer import Tokenizer
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
@ -377,6 +378,17 @@ def get_parser():
default=False, default=False,
help="""Skip scoring, but still save the ASR output (for eval sets).""", help="""Skip scoring, but still save the ASR output (for eval sets).""",
) )
parser.add_argument(
"--blank-penalty",
type=float,
default=0.0,
help="""
The penalty applied on blank symbol during decoding.
Note: It is a positive value that would be applied to logits like
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
[batch_size, vocab] and blank id is 0).
""",
)
add_model_arguments(parser) add_model_arguments(parser)
@ -601,6 +613,7 @@ def decode_one_batch(
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" ) # prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
prefix = f"{params.decoding_method}" prefix = f"{params.decoding_method}"
key = f"blank_penalty_{params.blank_penalty}"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif "fast_beam_search" in params.decoding_method: elif "fast_beam_search" in params.decoding_method:

View File

@ -47,3 +47,41 @@ The decoding command is:
--blank-penalty 0 --blank-penalty 0
``` ```
#### Streaming
We have not completed evaluation of our models yet and will add evaluation results here once it's completed.
The training command is:
```shell
./zipformer/train.py \
--world-size 8 \
--num-epochs 40 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large \
--causal 1 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--lang data/lang_char \
--max-duration 1600
```
The decoding command is:
```shell
./zipformer/streaming_decode.py \
--epoch 28 \
--avg 15 \
--causal 1 \
--chunk-size 32 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp-large \
--lang data/lang_char \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
```

View File

@ -12,7 +12,6 @@ class Tokenizer:
@staticmethod @staticmethod
def add_arguments(parser: argparse.ArgumentParser): def add_arguments(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Lang related options") group = parser.add_argument_group(title="Lang related options")
group.add_argument("--lang", type=Path, help="Path to lang directory.") group.add_argument("--lang", type=Path, help="Path to lang directory.")
group.add_argument( group.add_argument(

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) # Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
# # Fangjun Kuang,
# Zengwei Yao)
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -17,28 +18,24 @@
""" """
Usage: Usage:
./pruned_transducer_stateless7_streaming/streaming_decode.py \ ./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192
--epoch 28 \
--avg 15 \
--decode-chunk-len 32 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--decoding_method greedy_search \
--lang data/lang_char \
--num-decode-streams 2000
""" """
import argparse import argparse
import logging import logging
import math import math
import os
import pdb
# import subprocess as sp
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule from asr_datamodule import ReazonSpeechAsrDataModule
from decode import save_results
from decode_stream import DecodeStream from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet from lhotse import CutSet
@ -48,9 +45,9 @@ from streaming_beam_search import (
modified_beam_search, modified_beam_search,
) )
from tokenizer import Tokenizer from tokenizer import Tokenizer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_model, get_params
from zipformer import stack_states, unstack_states
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -58,7 +55,14 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import AttributeDict, setup_logger, str2bool from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10) LOG_EPS = math.log(1e-10)
@ -73,7 +77,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for decoding. help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -87,12 +91,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--gpu",
type=int,
default=0,
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
@ -116,7 +114,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="zipformer/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -127,6 +125,13 @@ def get_parser():
help="Path to the BPE model", help="Path to the BPE model",
) )
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument( parser.add_argument(
"--decoding-method", "--decoding-method",
type=str, type=str,
@ -138,14 +143,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--decoding-graph",
type=str,
default="",
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument( parser.add_argument(
"--num_active_paths", "--num_active_paths",
type=int, type=int,
@ -157,7 +154,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--beam", "--beam",
type=float, type=float,
default=4.0, default=4,
help="""A floating point value to calculate the cutoff score during beam help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi. `beam` in Kaldi.
@ -194,18 +191,235 @@ def get_parser():
help="The number of streams that can be decoded parallel.", help="The number of streams that can be decoded parallel.",
) )
parser.add_argument(
"--res-dir",
type=Path,
default=None,
help="The path to save results.",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk( def decode_one_chunk(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -224,27 +438,32 @@ def decode_one_chunk(
Returns: Returns:
Return a List containing which DecodeStreams are finished. Return a List containing which DecodeStreams are finished.
""" """
device = model.device # pdb.set_trace()
# print(model)
# print(model.device)
# device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = [] features = []
feature_lens = [] feature_lens = []
states = [] states = []
processed_lens = [] processed_lens = [] # Used in fast-beam-search
for stream in decode_streams: for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat) features.append(feat)
feature_lens.append(feat_len) feature_lens.append(feat_len)
states.append(stream.states) states.append(stream.states)
processed_lens.append(stream.done_frames) processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device) feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling # Make sure the length after encoder_embed is at least 1.
# factor in encoders is 8. # The encoder_embed subsample features (T - 7) // 2
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = 23 tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length: if features.size(1) < tail_length:
pad_length = tail_length - features.size(1) pad_length = tail_length - features.size(1)
feature_lens += pad_length feature_lens += pad_length
@ -256,12 +475,14 @@ def decode_one_chunk(
) )
states = stack_states(states) states = stack_states(states)
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( encoder_out, encoder_out_lens, new_states = streaming_forward(
x=features, features=features,
x_lens=feature_lens, feature_lens=feature_lens,
model=model,
states=states, states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(encoder_out)
@ -269,6 +490,7 @@ def decode_one_chunk(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search": elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=model.device)
processed_lens = processed_lens + encoder_out_lens processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best( fast_beam_search_one_best(
model=model, model=model,
@ -295,8 +517,9 @@ def decode_one_chunk(
for i in range(len(decode_streams)): for i in range(len(decode_streams)):
decode_streams[i].states = states[i] decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i] decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done: # if decode_streams[i].done:
finished_streams.append(i) # finished_streams.append(i)
finished_streams.append(i)
return finished_streams return finished_streams
@ -338,14 +561,14 @@ def decode_dataset(
opts.frame_opts.samp_freq = 16000 opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80 opts.mel_opts.num_bins = 80
log_interval = 50 log_interval = 100
decode_results = [] decode_results = []
# Contain decode streams currently running. # Contain decode streams currently running.
decode_streams = [] decode_streams = []
for num, cut in enumerate(cuts): for num, cut in enumerate(cuts):
# each utterance has a DecodeStream. # each utterance has a DecodeStream.
initial_states = model.encoder.get_init_state(device=device) initial_states = get_init_states(model=model, batch_size=1, device=device)
decode_stream = DecodeStream( decode_stream = DecodeStream(
params=params, params=params,
cut_id=cut.id, cut_id=cut.id,
@ -361,15 +584,19 @@ def decode_dataset(
assert audio.dtype == np.float32, audio.dtype assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples # The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])" # - this is to avoid sending [-32k,+32k] signal in...
# - some lhotse AudioTransform classes can make the signal
# be out of range [-1, 1], hence the tolerance 10
assert (
np.abs(audio).max() <= 10
), "Should be normalized to [-1, 1], 10 for tolerance..."
samples = torch.from_numpy(audio).squeeze(0) samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts) fbank = Fbank(opts)
feature = fbank(samples.to(device)) feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream) decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams: while len(decode_streams) >= params.num_decode_streams:
@ -380,8 +607,8 @@ def decode_dataset(
decode_results.append( decode_results.append(
( (
decode_streams[i].id, decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth), decode_streams[i].ground_truth.split(),
sp.text2word(sp.decode(decode_streams[i].decoding_result())), sp.decode(decode_streams[i].decoding_result()).split(),
) )
) )
del decode_streams[i] del decode_streams[i]
@ -391,18 +618,37 @@ def decode_dataset(
# decode final chunks of last sequences # decode final chunks of last sequences
while len(decode_streams): while len(decode_streams):
# print("INSIDE LEN DECODE STREAMS")
# pdb.set_trace()
# print(model.device)
# test_device = model.device
# print("done")
finished_streams = decode_one_chunk( finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams params=params, model=model, decode_streams=decode_streams
) )
# print('INSIDE FOR LOOP ')
# print(finished_streams)
if not finished_streams:
print("No finished streams, breaking the loop")
break
for i in sorted(finished_streams, reverse=True): for i in sorted(finished_streams, reverse=True):
decode_results.append( try:
( decode_results.append(
decode_streams[i].id, (
sp.text2word(decode_streams[i].ground_truth), decode_streams[i].id,
sp.text2word(sp.decode(decode_streams[i].decoding_result())), decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
) )
) del decode_streams[i]
del decode_streams[i] except IndexError as e:
print(f"IndexError: {e}")
print(f"decode_streams length: {len(decode_streams)}")
print(f"finished_streams: {finished_streams}")
print(f"i: {i}")
continue
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
key = "greedy_search" key = "greedy_search"
@ -416,9 +662,54 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}" key = f"num_active_paths_{params.num_active_paths}"
else: else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}") raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize()
return {key: decode_results} return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
@ -430,16 +721,20 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if not params.res_dir: params.res_dir = params.exp_dir / "streaming" / params.decoding_method
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0: if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}" params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming assert params.causal, params.causal
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
# for fast_beam_search # for fast_beam_search
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -455,13 +750,13 @@ def main():
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", params.gpu) device = torch.device("cuda", 0)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type) sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/prepare_lang_char.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
@ -469,7 +764,7 @@ def main():
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_model(params)
if not params.use_averaged_model: if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
@ -553,42 +848,51 @@ def main():
model.device = device model.device = device
decoding_graph = None decoding_graph = None
if params.decoding_graph: if params.decoding_method == "fast_beam_search":
decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device)
)
elif params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args) reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
for subdir in ["valid"]: valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()
test_sets = ["valid", "test"]
test_cuts = [valid_cuts, test_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset( results_dict = decode_dataset(
cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(), cuts=test_cut,
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
) )
tot_err = save_results( save_results(
params=params, test_set_name=subdir, results_dict=results_dict params=params,
test_set_name=test_set,
results_dict=results_dict,
) )
with ( # valid_cuts = reazonspeech_corpus.valid_cuts()
params.res_dir
/ ( # for valid_cut in valid_cuts:
f"{subdir}-{params.decode_chunk_len}" # results_dict = decode_dataset(
f"_{params.avg}_{params.epoch}.cer" # cuts=valid_cut,
) # params=params,
).open("w") as fout: # model=model,
if len(tot_err) == 1: # sp=sp,
fout.write(f"{tot_err[0][1]}") # decoding_graph=decoding_graph,
else: # )
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) # save_results(
# params=params,
# test_set_name="valid",
# results_dict=results_dict,
# )
logging.info("Done!") logging.info("Done!")

View File

@ -644,7 +644,8 @@ def write_error_stats(
results[i] = (cut_id, ref, hyp) results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) # ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
ali = kaldialign.align(ref, hyp, ERR)
for ref_word, hyp_word in ali: for ref_word, hyp_word in ali:
if ref_word == ERR: if ref_word == ERR:
ins[hyp_word] += 1 ins[hyp_word] += 1