mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 9b13eac94686fe4a31eb190b8deaf3ff967feacd into dd5d7e358bf6a49eb2a31dff2dd2b943f792c1fc
This commit is contained in:
commit
b1064e93ea
@ -122,6 +122,7 @@ from beam_search import (
|
||||
modified_beam_search_LODR,
|
||||
)
|
||||
from lhotse import set_caching_enabled
|
||||
from tokenizer import Tokenizer
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import ContextGraph, LmScorer, NgramLm
|
||||
@ -377,6 +378,17 @@ def get_parser():
|
||||
default=False,
|
||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
@ -601,6 +613,7 @@ def decode_one_batch(
|
||||
|
||||
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
|
||||
prefix = f"{params.decoding_method}"
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
|
@ -2459,4 +2459,4 @@ if __name__ == "__main__":
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_zipformer_main(False)
|
||||
_test_zipformer_main(True)
|
||||
_test_zipformer_main(True)
|
@ -47,3 +47,41 @@ The decoding command is:
|
||||
--blank-penalty 0
|
||||
```
|
||||
|
||||
#### Streaming
|
||||
|
||||
We have not completed evaluation of our models yet and will add evaluation results here once it's completed.
|
||||
|
||||
The training command is:
|
||||
```shell
|
||||
./zipformer/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 40 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
--causal 1 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--lang data/lang_char \
|
||||
--max-duration 1600
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
|
||||
```shell
|
||||
./zipformer/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--causal 1 \
|
||||
--chunk-size 32 \
|
||||
--left-context-frames 256 \
|
||||
--exp-dir ./zipformer/exp-large \
|
||||
--lang data/lang_char \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||
```
|
||||
|
||||
|
@ -12,7 +12,6 @@ class Tokenizer:
|
||||
@staticmethod
|
||||
def add_arguments(parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(title="Lang related options")
|
||||
|
||||
group.add_argument("--lang", type=Path, help="Path to lang directory.")
|
||||
|
||||
group.add_argument(
|
||||
|
@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||
#
|
||||
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
|
||||
# Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -17,28 +18,24 @@
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless7_streaming/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--decode-chunk-len 32 \
|
||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
||||
--decoding_method greedy_search \
|
||||
--lang data/lang_char \
|
||||
--num-decode-streams 2000
|
||||
./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import pdb
|
||||
|
||||
# import subprocess as sp
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
||||
from decode import save_results
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
@ -48,9 +45,9 @@ from streaming_beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from tokenizer import Tokenizer
|
||||
from torch import Tensor, nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from zipformer import stack_states, unstack_states
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -58,7 +55,14 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
make_pad_mask,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
@ -73,7 +77,7 @@ def get_parser():
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
@ -87,12 +91,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--gpu",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
@ -116,7 +114,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -127,6 +125,13 @@ def get_parser():
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_char",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@ -138,14 +143,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-graph",
|
||||
type=str,
|
||||
default="",
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
@ -157,7 +154,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4.0,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
@ -194,18 +191,235 @@ def get_parser():
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--res-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="The path to save results.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_init_states(
|
||||
model: nn.Module,
|
||||
batch_size: int = 1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||
states[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
"""
|
||||
states = model.encoder.get_init_states(batch_size, device)
|
||||
|
||||
embed_states = model.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
|
||||
|
||||
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
||||
"""Stack list of zipformer states that correspond to separate utterances
|
||||
into a single emformer state, so that it can be used as an input for
|
||||
zipformer when those utterances are formed into a batch.
|
||||
|
||||
Args:
|
||||
state_list:
|
||||
Each element in state_list corresponding to the internal state
|
||||
of the zipformer model for a single utterance. For element-n,
|
||||
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
|
||||
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
|
||||
cached_val2, cached_conv1, cached_conv2).
|
||||
state_list[n][-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
state_list[n][-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`unstack_states`.
|
||||
"""
|
||||
batch_size = len(state_list)
|
||||
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
|
||||
tot_num_layers = (len(state_list[0]) - 2) // 6
|
||||
|
||||
batch_states = []
|
||||
for layer in range(tot_num_layers):
|
||||
layer_offset = layer * 6
|
||||
# cached_key: (left_context_len, batch_size, key_dim)
|
||||
cached_key = torch.cat(
|
||||
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||
cached_nonlin_attn = torch.cat(
|
||||
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_val1: (left_context_len, batch_size, value_dim)
|
||||
cached_val1 = torch.cat(
|
||||
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_val2: (left_context_len, batch_size, value_dim)
|
||||
cached_val2 = torch.cat(
|
||||
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
|
||||
)
|
||||
# cached_conv1: (#batch, channels, left_pad)
|
||||
cached_conv1 = torch.cat(
|
||||
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
|
||||
)
|
||||
# cached_conv2: (#batch, channels, left_pad)
|
||||
cached_conv2 = torch.cat(
|
||||
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
|
||||
)
|
||||
batch_states += [
|
||||
cached_key,
|
||||
cached_nonlin_attn,
|
||||
cached_val1,
|
||||
cached_val2,
|
||||
cached_conv1,
|
||||
cached_conv2,
|
||||
]
|
||||
|
||||
cached_embed_left_pad = torch.cat(
|
||||
[state_list[i][-2] for i in range(batch_size)], dim=0
|
||||
)
|
||||
batch_states.append(cached_embed_left_pad)
|
||||
|
||||
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
||||
batch_states.append(processed_lens)
|
||||
|
||||
return batch_states
|
||||
|
||||
|
||||
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
||||
"""Unstack the zipformer state corresponding to a batch of utterances
|
||||
into a list of states, where the i-th entry is the state from the i-th
|
||||
utterance in the batch.
|
||||
|
||||
Note:
|
||||
It is the inverse of :func:`stack_states`.
|
||||
|
||||
Args:
|
||||
batch_states: A list of cached tensors of all encoder layers. For layer-i,
|
||||
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
|
||||
cached_conv1, cached_conv2).
|
||||
state_list[-2] is the cached left padding for ConvNeXt module,
|
||||
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||
states[-1] is processed_lens of shape (batch,), which records the number
|
||||
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||
|
||||
Returns:
|
||||
state_list: A list of list. Each element in state_list corresponding to the internal state
|
||||
of the zipformer model for a single utterance.
|
||||
"""
|
||||
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
|
||||
tot_num_layers = (len(batch_states) - 2) // 6
|
||||
|
||||
processed_lens = batch_states[-1]
|
||||
batch_size = processed_lens.shape[0]
|
||||
|
||||
state_list = [[] for _ in range(batch_size)]
|
||||
|
||||
for layer in range(tot_num_layers):
|
||||
layer_offset = layer * 6
|
||||
# cached_key: (left_context_len, batch_size, key_dim)
|
||||
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_val1: (left_context_len, batch_size, value_dim)
|
||||
cached_val1_list = batch_states[layer_offset + 2].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_val2: (left_context_len, batch_size, value_dim)
|
||||
cached_val2_list = batch_states[layer_offset + 3].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_conv1: (#batch, channels, left_pad)
|
||||
cached_conv1_list = batch_states[layer_offset + 4].chunk(
|
||||
chunks=batch_size, dim=0
|
||||
)
|
||||
# cached_conv2: (#batch, channels, left_pad)
|
||||
cached_conv2_list = batch_states[layer_offset + 5].chunk(
|
||||
chunks=batch_size, dim=0
|
||||
)
|
||||
for i in range(batch_size):
|
||||
state_list[i] += [
|
||||
cached_key_list[i],
|
||||
cached_nonlin_attn_list[i],
|
||||
cached_val1_list[i],
|
||||
cached_val2_list[i],
|
||||
cached_conv1_list[i],
|
||||
cached_conv2_list[i],
|
||||
]
|
||||
|
||||
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
||||
for i in range(batch_size):
|
||||
state_list[i].append(cached_embed_left_pad_list[i])
|
||||
|
||||
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
|
||||
for i in range(batch_size):
|
||||
state_list[i].append(processed_lens_list[i])
|
||||
|
||||
return state_list
|
||||
|
||||
|
||||
def streaming_forward(
|
||||
features: Tensor,
|
||||
feature_lens: Tensor,
|
||||
model: nn.Module,
|
||||
states: List[Tensor],
|
||||
chunk_size: int,
|
||||
left_context_len: int,
|
||||
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
||||
"""
|
||||
Returns encoder outputs, output lengths, and updated states.
|
||||
"""
|
||||
cached_embed_left_pad = states[-2]
|
||||
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
)
|
||||
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
|
||||
# processed_mask is used to mask out initial states
|
||||
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||
x.size(0), left_context_len
|
||||
)
|
||||
processed_lens = states[-1] # (batch,)
|
||||
# (batch, left_context_size)
|
||||
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||
# Update processed lengths
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
encoder_states = states[:-2]
|
||||
(
|
||||
encoder_out,
|
||||
encoder_out_lens,
|
||||
new_encoder_states,
|
||||
) = model.encoder.streaming_forward(
|
||||
x=x,
|
||||
x_lens=x_lens,
|
||||
states=encoder_states,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
new_states = new_encoder_states + [
|
||||
new_cached_embed_left_pad,
|
||||
new_processed_lens,
|
||||
]
|
||||
return encoder_out, encoder_out_lens, new_states
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -224,27 +438,32 @@ def decode_one_chunk(
|
||||
Returns:
|
||||
Return a List containing which DecodeStreams are finished.
|
||||
"""
|
||||
device = model.device
|
||||
# pdb.set_trace()
|
||||
# print(model)
|
||||
# print(model.device)
|
||||
# device = model.device
|
||||
chunk_size = int(params.chunk_size)
|
||||
left_context_len = int(params.left_context_frames)
|
||||
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
processed_lens = []
|
||||
processed_lens = [] # Used in fast-beam-search
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
|
||||
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
feature_lens = torch.tensor(feature_lens, device=model.device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
|
||||
# factor in encoders is 8.
|
||||
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
|
||||
tail_length = 23
|
||||
# Make sure the length after encoder_embed is at least 1.
|
||||
# The encoder_embed subsample features (T - 7) // 2
|
||||
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||
tail_length = chunk_size * 2 + 7 + 2 * 3
|
||||
if features.size(1) < tail_length:
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
@ -256,12 +475,14 @@ def decode_one_chunk(
|
||||
)
|
||||
|
||||
states = stack_states(states)
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
encoder_out, encoder_out_lens, new_states = streaming_forward(
|
||||
features=features,
|
||||
feature_lens=feature_lens,
|
||||
model=model,
|
||||
states=states,
|
||||
chunk_size=chunk_size,
|
||||
left_context_len=left_context_len,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
@ -269,6 +490,7 @@ def decode_one_chunk(
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = torch.tensor(processed_lens, device=model.device)
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
@ -295,8 +517,9 @@ def decode_one_chunk(
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = states[i]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
# if decode_streams[i].done:
|
||||
# finished_streams.append(i)
|
||||
finished_streams.append(i)
|
||||
|
||||
return finished_streams
|
||||
|
||||
@ -338,14 +561,14 @@ def decode_dataset(
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 50
|
||||
log_interval = 100
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
initial_states = model.encoder.get_init_state(device=device)
|
||||
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
@ -361,15 +584,19 @@ def decode_dataset(
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
# - this is to avoid sending [-32k,+32k] signal in...
|
||||
# - some lhotse AudioTransform classes can make the signal
|
||||
# be out of range [-1, 1], hence the tolerance 10
|
||||
assert (
|
||||
np.abs(audio).max() <= 10
|
||||
), "Should be normalized to [-1, 1], 10 for tolerance..."
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
|
||||
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
|
||||
|
||||
decode_stream.set_features(feature, tail_pad_len=30)
|
||||
decode_stream.ground_truth = cut.supervisions[0].text
|
||||
decode_streams.append(decode_stream)
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
@ -380,8 +607,8 @@ def decode_dataset(
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
sp.text2word(decode_streams[i].ground_truth),
|
||||
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
@ -391,18 +618,37 @@ def decode_dataset(
|
||||
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
# print("INSIDE LEN DECODE STREAMS")
|
||||
# pdb.set_trace()
|
||||
# print(model.device)
|
||||
# test_device = model.device
|
||||
# print("done")
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
# print('INSIDE FOR LOOP ')
|
||||
# print(finished_streams)
|
||||
|
||||
if not finished_streams:
|
||||
print("No finished streams, breaking the loop")
|
||||
break
|
||||
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
sp.text2word(decode_streams[i].ground_truth),
|
||||
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
||||
try:
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].id,
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
del decode_streams[i]
|
||||
except IndexError as e:
|
||||
print(f"IndexError: {e}")
|
||||
print(f"decode_streams length: {len(decode_streams)}")
|
||||
print(f"finished_streams: {finished_streams}")
|
||||
print(f"i: {i}")
|
||||
continue
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
@ -416,9 +662,54 @@ def decode_dataset(
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
torch.cuda.synchronize()
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
@ -430,16 +721,20 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if not params.res_dir:
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
||||
assert params.causal, params.causal
|
||||
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
params.suffix += f"-chunk-{params.chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -455,13 +750,13 @@ def main():
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", params.gpu)
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||
|
||||
# <blk> and <unk> is defined in local/prepare_lang_char.py
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
@ -469,7 +764,7 @@ def main():
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
model = get_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
@ -553,42 +848,51 @@ def main():
|
||||
model.device = device
|
||||
|
||||
decoding_graph = None
|
||||
if params.decoding_graph:
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(params.decoding_graph, map_location=device)
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
||||
|
||||
for subdir in ["valid"]:
|
||||
valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||
test_cuts = reazonspeech_corpus.test_cuts()
|
||||
|
||||
test_sets = ["valid", "test"]
|
||||
test_cuts = [valid_cuts, test_cuts]
|
||||
|
||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||
results_dict = decode_dataset(
|
||||
cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(),
|
||||
cuts=test_cut,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
tot_err = save_results(
|
||||
params=params, test_set_name=subdir, results_dict=results_dict
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
with (
|
||||
params.res_dir
|
||||
/ (
|
||||
f"{subdir}-{params.decode_chunk_len}"
|
||||
f"_{params.avg}_{params.epoch}.cer"
|
||||
)
|
||||
).open("w") as fout:
|
||||
if len(tot_err) == 1:
|
||||
fout.write(f"{tot_err[0][1]}")
|
||||
else:
|
||||
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
|
||||
# valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||
|
||||
# for valid_cut in valid_cuts:
|
||||
# results_dict = decode_dataset(
|
||||
# cuts=valid_cut,
|
||||
# params=params,
|
||||
# model=model,
|
||||
# sp=sp,
|
||||
# decoding_graph=decoding_graph,
|
||||
# )
|
||||
# save_results(
|
||||
# params=params,
|
||||
# test_set_name="valid",
|
||||
# results_dict=results_dict,
|
||||
# )
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
@ -644,7 +644,8 @@ def write_error_stats(
|
||||
results[i] = (cut_id, ref, hyp)
|
||||
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||
# ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user