mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge 9b13eac94686fe4a31eb190b8deaf3ff967feacd into dd5d7e358bf6a49eb2a31dff2dd2b943f792c1fc
This commit is contained in:
commit
b1064e93ea
@ -122,6 +122,7 @@ from beam_search import (
|
|||||||
modified_beam_search_LODR,
|
modified_beam_search_LODR,
|
||||||
)
|
)
|
||||||
from lhotse import set_caching_enabled
|
from lhotse import set_caching_enabled
|
||||||
|
from tokenizer import Tokenizer
|
||||||
from train import add_model_arguments, get_model, get_params
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
from icefall import ContextGraph, LmScorer, NgramLm
|
from icefall import ContextGraph, LmScorer, NgramLm
|
||||||
@ -377,6 +378,17 @@ def get_parser():
|
|||||||
default=False,
|
default=False,
|
||||||
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
help="""Skip scoring, but still save the ASR output (for eval sets).""",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--blank-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""
|
||||||
|
The penalty applied on blank symbol during decoding.
|
||||||
|
Note: It is a positive value that would be applied to logits like
|
||||||
|
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||||
|
[batch_size, vocab] and blank id is 0).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
@ -601,6 +613,7 @@ def decode_one_batch(
|
|||||||
|
|
||||||
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
|
# prefix = ( "greedy_search" | "fast_beam_search_nbest" | "modified_beam_search" )
|
||||||
prefix = f"{params.decoding_method}"
|
prefix = f"{params.decoding_method}"
|
||||||
|
key = f"blank_penalty_{params.blank_penalty}"
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
return {"greedy_search": hyps}
|
return {"greedy_search": hyps}
|
||||||
elif "fast_beam_search" in params.decoding_method:
|
elif "fast_beam_search" in params.decoding_method:
|
||||||
|
@ -47,3 +47,41 @@ The decoding command is:
|
|||||||
--blank-penalty 0
|
--blank-penalty 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Streaming
|
||||||
|
|
||||||
|
We have not completed evaluation of our models yet and will add evaluation results here once it's completed.
|
||||||
|
|
||||||
|
The training command is:
|
||||||
|
```shell
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 8 \
|
||||||
|
--num-epochs 40 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--exp-dir zipformer/exp-large \
|
||||||
|
--causal 1 \
|
||||||
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
|
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||||
|
--encoder-dim 192,256,512,768,512,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--max-duration 1600
|
||||||
|
```
|
||||||
|
|
||||||
|
The decoding command is:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
./zipformer/streaming_decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--causal 1 \
|
||||||
|
--chunk-size 32 \
|
||||||
|
--left-context-frames 256 \
|
||||||
|
--exp-dir ./zipformer/exp-large \
|
||||||
|
--lang data/lang_char \
|
||||||
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
|
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||||
|
--encoder-dim 192,256,512,768,512,256 \
|
||||||
|
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||||
|
```
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@ class Tokenizer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_arguments(parser: argparse.ArgumentParser):
|
def add_arguments(parser: argparse.ArgumentParser):
|
||||||
group = parser.add_argument_group(title="Lang related options")
|
group = parser.add_argument_group(title="Lang related options")
|
||||||
|
|
||||||
group.add_argument("--lang", type=Path, help="Path to lang directory.")
|
group.add_argument("--lang", type=Path, help="Path to lang directory.")
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
|
||||||
#
|
# Fangjun Kuang,
|
||||||
|
# Zengwei Yao)
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
@ -17,28 +18,24 @@
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless7_streaming/streaming_decode.py \
|
./zipformer/streaming_decode.py--epoch 28 --avg 15 --causal 1 --chunk-size 32 --left-context-frames 256 --exp-dir ./zipformer/exp-large --lang data/lang_char --num-encoder-layers 2,2,4,5,4,2 --feedforward-dim 512,768,1536,2048,1536,768 --encoder-dim 192,256,512,768,512,256 --encoder-unmasked-dim 192,192,256,320,256,192
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--decode-chunk-len 32 \
|
|
||||||
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
|
|
||||||
--decoding_method greedy_search \
|
|
||||||
--lang data/lang_char \
|
|
||||||
--num-decode-streams 2000
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
# import subprocess as sp
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from asr_datamodule import ReazonSpeechAsrDataModule
|
from asr_datamodule import ReazonSpeechAsrDataModule
|
||||||
from decode import save_results
|
|
||||||
from decode_stream import DecodeStream
|
from decode_stream import DecodeStream
|
||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
@ -48,9 +45,9 @@ from streaming_beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from tokenizer import Tokenizer
|
from tokenizer import Tokenizer
|
||||||
|
from torch import Tensor, nn
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_model, get_params
|
||||||
from zipformer import stack_states, unstack_states
|
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -58,7 +55,14 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import AttributeDict, setup_logger, str2bool
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
make_pad_mask,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
LOG_EPS = math.log(1e-10)
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
@ -73,7 +77,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
Note: Epoch counts from 0.
|
Note: Epoch counts from 1.
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -87,12 +91,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--gpu",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
@ -116,7 +114,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="zipformer/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -127,6 +125,13 @@ def get_parser():
|
|||||||
help="Path to the BPE model",
|
help="Path to the BPE model",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--decoding-method",
|
"--decoding-method",
|
||||||
type=str,
|
type=str,
|
||||||
@ -138,14 +143,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoding-graph",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="""Used only when --decoding-method is
|
|
||||||
fast_beam_search""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num_active_paths",
|
"--num_active_paths",
|
||||||
type=int,
|
type=int,
|
||||||
@ -157,7 +154,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--beam",
|
"--beam",
|
||||||
type=float,
|
type=float,
|
||||||
default=4.0,
|
default=4,
|
||||||
help="""A floating point value to calculate the cutoff score during beam
|
help="""A floating point value to calculate the cutoff score during beam
|
||||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
`beam` in Kaldi.
|
`beam` in Kaldi.
|
||||||
@ -194,18 +191,235 @@ def get_parser():
|
|||||||
help="The number of streams that can be decoded parallel.",
|
help="The number of streams that can be decoded parallel.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--res-dir",
|
|
||||||
type=Path,
|
|
||||||
default=None,
|
|
||||||
help="The path to save results.",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_init_states(
|
||||||
|
model: nn.Module,
|
||||||
|
batch_size: int = 1,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
|
||||||
|
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
|
||||||
|
states[-2] is the cached left padding for ConvNeXt module,
|
||||||
|
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||||
|
states[-1] is processed_lens of shape (batch,), which records the number
|
||||||
|
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||||
|
"""
|
||||||
|
states = model.encoder.get_init_states(batch_size, device)
|
||||||
|
|
||||||
|
embed_states = model.encoder_embed.get_init_states(batch_size, device)
|
||||||
|
states.append(embed_states)
|
||||||
|
|
||||||
|
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||||
|
states.append(processed_lens)
|
||||||
|
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
|
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
||||||
|
"""Stack list of zipformer states that correspond to separate utterances
|
||||||
|
into a single emformer state, so that it can be used as an input for
|
||||||
|
zipformer when those utterances are formed into a batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_list:
|
||||||
|
Each element in state_list corresponding to the internal state
|
||||||
|
of the zipformer model for a single utterance. For element-n,
|
||||||
|
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
|
||||||
|
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
|
||||||
|
cached_val2, cached_conv1, cached_conv2).
|
||||||
|
state_list[n][-2] is the cached left padding for ConvNeXt module,
|
||||||
|
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||||
|
state_list[n][-1] is processed_lens of shape (batch,), which records the number
|
||||||
|
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It is the inverse of :func:`unstack_states`.
|
||||||
|
"""
|
||||||
|
batch_size = len(state_list)
|
||||||
|
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
|
||||||
|
tot_num_layers = (len(state_list[0]) - 2) // 6
|
||||||
|
|
||||||
|
batch_states = []
|
||||||
|
for layer in range(tot_num_layers):
|
||||||
|
layer_offset = layer * 6
|
||||||
|
# cached_key: (left_context_len, batch_size, key_dim)
|
||||||
|
cached_key = torch.cat(
|
||||||
|
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
|
||||||
|
)
|
||||||
|
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||||
|
cached_nonlin_attn = torch.cat(
|
||||||
|
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
|
||||||
|
)
|
||||||
|
# cached_val1: (left_context_len, batch_size, value_dim)
|
||||||
|
cached_val1 = torch.cat(
|
||||||
|
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
|
||||||
|
)
|
||||||
|
# cached_val2: (left_context_len, batch_size, value_dim)
|
||||||
|
cached_val2 = torch.cat(
|
||||||
|
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
|
||||||
|
)
|
||||||
|
# cached_conv1: (#batch, channels, left_pad)
|
||||||
|
cached_conv1 = torch.cat(
|
||||||
|
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
|
||||||
|
)
|
||||||
|
# cached_conv2: (#batch, channels, left_pad)
|
||||||
|
cached_conv2 = torch.cat(
|
||||||
|
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
|
||||||
|
)
|
||||||
|
batch_states += [
|
||||||
|
cached_key,
|
||||||
|
cached_nonlin_attn,
|
||||||
|
cached_val1,
|
||||||
|
cached_val2,
|
||||||
|
cached_conv1,
|
||||||
|
cached_conv2,
|
||||||
|
]
|
||||||
|
|
||||||
|
cached_embed_left_pad = torch.cat(
|
||||||
|
[state_list[i][-2] for i in range(batch_size)], dim=0
|
||||||
|
)
|
||||||
|
batch_states.append(cached_embed_left_pad)
|
||||||
|
|
||||||
|
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
||||||
|
batch_states.append(processed_lens)
|
||||||
|
|
||||||
|
return batch_states
|
||||||
|
|
||||||
|
|
||||||
|
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
||||||
|
"""Unstack the zipformer state corresponding to a batch of utterances
|
||||||
|
into a list of states, where the i-th entry is the state from the i-th
|
||||||
|
utterance in the batch.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It is the inverse of :func:`stack_states`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_states: A list of cached tensors of all encoder layers. For layer-i,
|
||||||
|
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
|
||||||
|
cached_conv1, cached_conv2).
|
||||||
|
state_list[-2] is the cached left padding for ConvNeXt module,
|
||||||
|
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||||
|
states[-1] is processed_lens of shape (batch,), which records the number
|
||||||
|
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
state_list: A list of list. Each element in state_list corresponding to the internal state
|
||||||
|
of the zipformer model for a single utterance.
|
||||||
|
"""
|
||||||
|
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
|
||||||
|
tot_num_layers = (len(batch_states) - 2) // 6
|
||||||
|
|
||||||
|
processed_lens = batch_states[-1]
|
||||||
|
batch_size = processed_lens.shape[0]
|
||||||
|
|
||||||
|
state_list = [[] for _ in range(batch_size)]
|
||||||
|
|
||||||
|
for layer in range(tot_num_layers):
|
||||||
|
layer_offset = layer * 6
|
||||||
|
# cached_key: (left_context_len, batch_size, key_dim)
|
||||||
|
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
||||||
|
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||||
|
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
||||||
|
chunks=batch_size, dim=1
|
||||||
|
)
|
||||||
|
# cached_val1: (left_context_len, batch_size, value_dim)
|
||||||
|
cached_val1_list = batch_states[layer_offset + 2].chunk(
|
||||||
|
chunks=batch_size, dim=1
|
||||||
|
)
|
||||||
|
# cached_val2: (left_context_len, batch_size, value_dim)
|
||||||
|
cached_val2_list = batch_states[layer_offset + 3].chunk(
|
||||||
|
chunks=batch_size, dim=1
|
||||||
|
)
|
||||||
|
# cached_conv1: (#batch, channels, left_pad)
|
||||||
|
cached_conv1_list = batch_states[layer_offset + 4].chunk(
|
||||||
|
chunks=batch_size, dim=0
|
||||||
|
)
|
||||||
|
# cached_conv2: (#batch, channels, left_pad)
|
||||||
|
cached_conv2_list = batch_states[layer_offset + 5].chunk(
|
||||||
|
chunks=batch_size, dim=0
|
||||||
|
)
|
||||||
|
for i in range(batch_size):
|
||||||
|
state_list[i] += [
|
||||||
|
cached_key_list[i],
|
||||||
|
cached_nonlin_attn_list[i],
|
||||||
|
cached_val1_list[i],
|
||||||
|
cached_val2_list[i],
|
||||||
|
cached_conv1_list[i],
|
||||||
|
cached_conv2_list[i],
|
||||||
|
]
|
||||||
|
|
||||||
|
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
||||||
|
for i in range(batch_size):
|
||||||
|
state_list[i].append(cached_embed_left_pad_list[i])
|
||||||
|
|
||||||
|
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
|
||||||
|
for i in range(batch_size):
|
||||||
|
state_list[i].append(processed_lens_list[i])
|
||||||
|
|
||||||
|
return state_list
|
||||||
|
|
||||||
|
|
||||||
|
def streaming_forward(
|
||||||
|
features: Tensor,
|
||||||
|
feature_lens: Tensor,
|
||||||
|
model: nn.Module,
|
||||||
|
states: List[Tensor],
|
||||||
|
chunk_size: int,
|
||||||
|
left_context_len: int,
|
||||||
|
) -> Tuple[Tensor, Tensor, List[Tensor]]:
|
||||||
|
"""
|
||||||
|
Returns encoder outputs, output lengths, and updated states.
|
||||||
|
"""
|
||||||
|
cached_embed_left_pad = states[-2]
|
||||||
|
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
||||||
|
x=features,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
cached_left_pad=cached_embed_left_pad,
|
||||||
|
)
|
||||||
|
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
|
||||||
|
# processed_mask is used to mask out initial states
|
||||||
|
processed_mask = torch.arange(left_context_len, device=x.device).expand(
|
||||||
|
x.size(0), left_context_len
|
||||||
|
)
|
||||||
|
processed_lens = states[-1] # (batch,)
|
||||||
|
# (batch, left_context_size)
|
||||||
|
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
|
||||||
|
# Update processed lengths
|
||||||
|
new_processed_lens = processed_lens + x_lens
|
||||||
|
|
||||||
|
# (batch, left_context_size + chunk_size)
|
||||||
|
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
encoder_states = states[:-2]
|
||||||
|
(
|
||||||
|
encoder_out,
|
||||||
|
encoder_out_lens,
|
||||||
|
new_encoder_states,
|
||||||
|
) = model.encoder.streaming_forward(
|
||||||
|
x=x,
|
||||||
|
x_lens=x_lens,
|
||||||
|
states=encoder_states,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
new_states = new_encoder_states + [
|
||||||
|
new_cached_embed_left_pad,
|
||||||
|
new_processed_lens,
|
||||||
|
]
|
||||||
|
return encoder_out, encoder_out_lens, new_states
|
||||||
|
|
||||||
|
|
||||||
def decode_one_chunk(
|
def decode_one_chunk(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -224,27 +438,32 @@ def decode_one_chunk(
|
|||||||
Returns:
|
Returns:
|
||||||
Return a List containing which DecodeStreams are finished.
|
Return a List containing which DecodeStreams are finished.
|
||||||
"""
|
"""
|
||||||
device = model.device
|
# pdb.set_trace()
|
||||||
|
# print(model)
|
||||||
|
# print(model.device)
|
||||||
|
# device = model.device
|
||||||
|
chunk_size = int(params.chunk_size)
|
||||||
|
left_context_len = int(params.left_context_frames)
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
feature_lens = []
|
feature_lens = []
|
||||||
states = []
|
states = []
|
||||||
processed_lens = []
|
processed_lens = [] # Used in fast-beam-search
|
||||||
|
|
||||||
for stream in decode_streams:
|
for stream in decode_streams:
|
||||||
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
|
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
|
||||||
features.append(feat)
|
features.append(feat)
|
||||||
feature_lens.append(feat_len)
|
feature_lens.append(feat_len)
|
||||||
states.append(stream.states)
|
states.append(stream.states)
|
||||||
processed_lens.append(stream.done_frames)
|
processed_lens.append(stream.done_frames)
|
||||||
|
|
||||||
feature_lens = torch.tensor(feature_lens, device=device)
|
feature_lens = torch.tensor(feature_lens, device=model.device)
|
||||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||||
|
|
||||||
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
|
# Make sure the length after encoder_embed is at least 1.
|
||||||
# factor in encoders is 8.
|
# The encoder_embed subsample features (T - 7) // 2
|
||||||
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
|
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
|
||||||
tail_length = 23
|
tail_length = chunk_size * 2 + 7 + 2 * 3
|
||||||
if features.size(1) < tail_length:
|
if features.size(1) < tail_length:
|
||||||
pad_length = tail_length - features.size(1)
|
pad_length = tail_length - features.size(1)
|
||||||
feature_lens += pad_length
|
feature_lens += pad_length
|
||||||
@ -256,12 +475,14 @@ def decode_one_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
states = stack_states(states)
|
states = stack_states(states)
|
||||||
processed_lens = torch.tensor(processed_lens, device=device)
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, new_states = streaming_forward(
|
||||||
x=features,
|
features=features,
|
||||||
x_lens=feature_lens,
|
feature_lens=feature_lens,
|
||||||
|
model=model,
|
||||||
states=states,
|
states=states,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
left_context_len=left_context_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
@ -269,6 +490,7 @@ def decode_one_chunk(
|
|||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||||
elif params.decoding_method == "fast_beam_search":
|
elif params.decoding_method == "fast_beam_search":
|
||||||
|
processed_lens = torch.tensor(processed_lens, device=model.device)
|
||||||
processed_lens = processed_lens + encoder_out_lens
|
processed_lens = processed_lens + encoder_out_lens
|
||||||
fast_beam_search_one_best(
|
fast_beam_search_one_best(
|
||||||
model=model,
|
model=model,
|
||||||
@ -295,8 +517,9 @@ def decode_one_chunk(
|
|||||||
for i in range(len(decode_streams)):
|
for i in range(len(decode_streams)):
|
||||||
decode_streams[i].states = states[i]
|
decode_streams[i].states = states[i]
|
||||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||||
if decode_streams[i].done:
|
# if decode_streams[i].done:
|
||||||
finished_streams.append(i)
|
# finished_streams.append(i)
|
||||||
|
finished_streams.append(i)
|
||||||
|
|
||||||
return finished_streams
|
return finished_streams
|
||||||
|
|
||||||
@ -338,14 +561,14 @@ def decode_dataset(
|
|||||||
opts.frame_opts.samp_freq = 16000
|
opts.frame_opts.samp_freq = 16000
|
||||||
opts.mel_opts.num_bins = 80
|
opts.mel_opts.num_bins = 80
|
||||||
|
|
||||||
log_interval = 50
|
log_interval = 100
|
||||||
|
|
||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
decode_streams = []
|
decode_streams = []
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
initial_states = model.encoder.get_init_state(device=device)
|
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
params=params,
|
params=params,
|
||||||
cut_id=cut.id,
|
cut_id=cut.id,
|
||||||
@ -361,15 +584,19 @@ def decode_dataset(
|
|||||||
assert audio.dtype == np.float32, audio.dtype
|
assert audio.dtype == np.float32, audio.dtype
|
||||||
|
|
||||||
# The trained model is using normalized samples
|
# The trained model is using normalized samples
|
||||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
# - this is to avoid sending [-32k,+32k] signal in...
|
||||||
|
# - some lhotse AudioTransform classes can make the signal
|
||||||
|
# be out of range [-1, 1], hence the tolerance 10
|
||||||
|
assert (
|
||||||
|
np.abs(audio).max() <= 10
|
||||||
|
), "Should be normalized to [-1, 1], 10 for tolerance..."
|
||||||
|
|
||||||
samples = torch.from_numpy(audio).squeeze(0)
|
samples = torch.from_numpy(audio).squeeze(0)
|
||||||
|
|
||||||
fbank = Fbank(opts)
|
fbank = Fbank(opts)
|
||||||
feature = fbank(samples.to(device))
|
feature = fbank(samples.to(device))
|
||||||
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
|
decode_stream.set_features(feature, tail_pad_len=30)
|
||||||
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
|
decode_stream.ground_truth = cut.supervisions[0].text
|
||||||
|
|
||||||
decode_streams.append(decode_stream)
|
decode_streams.append(decode_stream)
|
||||||
|
|
||||||
while len(decode_streams) >= params.num_decode_streams:
|
while len(decode_streams) >= params.num_decode_streams:
|
||||||
@ -380,8 +607,8 @@ def decode_dataset(
|
|||||||
decode_results.append(
|
decode_results.append(
|
||||||
(
|
(
|
||||||
decode_streams[i].id,
|
decode_streams[i].id,
|
||||||
sp.text2word(decode_streams[i].ground_truth),
|
decode_streams[i].ground_truth.split(),
|
||||||
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
del decode_streams[i]
|
del decode_streams[i]
|
||||||
@ -391,18 +618,37 @@ def decode_dataset(
|
|||||||
|
|
||||||
# decode final chunks of last sequences
|
# decode final chunks of last sequences
|
||||||
while len(decode_streams):
|
while len(decode_streams):
|
||||||
|
# print("INSIDE LEN DECODE STREAMS")
|
||||||
|
# pdb.set_trace()
|
||||||
|
# print(model.device)
|
||||||
|
# test_device = model.device
|
||||||
|
# print("done")
|
||||||
finished_streams = decode_one_chunk(
|
finished_streams = decode_one_chunk(
|
||||||
params=params, model=model, decode_streams=decode_streams
|
params=params, model=model, decode_streams=decode_streams
|
||||||
)
|
)
|
||||||
|
# print('INSIDE FOR LOOP ')
|
||||||
|
# print(finished_streams)
|
||||||
|
|
||||||
|
if not finished_streams:
|
||||||
|
print("No finished streams, breaking the loop")
|
||||||
|
break
|
||||||
|
|
||||||
for i in sorted(finished_streams, reverse=True):
|
for i in sorted(finished_streams, reverse=True):
|
||||||
decode_results.append(
|
try:
|
||||||
(
|
decode_results.append(
|
||||||
decode_streams[i].id,
|
(
|
||||||
sp.text2word(decode_streams[i].ground_truth),
|
decode_streams[i].id,
|
||||||
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
|
decode_streams[i].ground_truth.split(),
|
||||||
|
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
del decode_streams[i]
|
||||||
del decode_streams[i]
|
except IndexError as e:
|
||||||
|
print(f"IndexError: {e}")
|
||||||
|
print(f"decode_streams length: {len(decode_streams)}")
|
||||||
|
print(f"finished_streams: {finished_streams}")
|
||||||
|
print(f"i: {i}")
|
||||||
|
continue
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
key = "greedy_search"
|
key = "greedy_search"
|
||||||
@ -416,9 +662,54 @@ def decode_dataset(
|
|||||||
key = f"num_active_paths_{params.num_active_paths}"
|
key = f"num_active_paths_{params.num_active_paths}"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
|
torch.cuda.synchronize()
|
||||||
return {key: decode_results}
|
return {key: decode_results}
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = (
|
||||||
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = (
|
||||||
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
@ -430,16 +721,20 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
if not params.res_dir:
|
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
# for streaming
|
assert params.causal, params.causal
|
||||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
|
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||||
|
assert (
|
||||||
|
"," not in params.left_context_frames
|
||||||
|
), "left_context_frames should be one value in decoding."
|
||||||
|
params.suffix += f"-chunk-{params.chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
# for fast_beam_search
|
# for fast_beam_search
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -455,13 +750,13 @@ def main():
|
|||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", params.gpu)
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
sp = Tokenizer.load(params.lang, params.lang_type)
|
sp = Tokenizer.load(params.lang, params.lang_type)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/prepare_lang_char.py
|
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
@ -469,7 +764,7 @@ def main():
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_model(params)
|
||||||
|
|
||||||
if not params.use_averaged_model:
|
if not params.use_averaged_model:
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
@ -553,42 +848,51 @@ def main():
|
|||||||
model.device = device
|
model.device = device
|
||||||
|
|
||||||
decoding_graph = None
|
decoding_graph = None
|
||||||
if params.decoding_graph:
|
if params.decoding_method == "fast_beam_search":
|
||||||
decoding_graph = k2.Fsa.from_dict(
|
|
||||||
torch.load(params.decoding_graph, map_location=device)
|
|
||||||
)
|
|
||||||
elif params.decoding_method == "fast_beam_search":
|
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
args.return_cuts = True
|
args.return_cuts = True
|
||||||
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
|
||||||
|
|
||||||
for subdir in ["valid"]:
|
valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||||
|
test_cuts = reazonspeech_corpus.test_cuts()
|
||||||
|
|
||||||
|
test_sets = ["valid", "test"]
|
||||||
|
test_cuts = [valid_cuts, test_cuts]
|
||||||
|
|
||||||
|
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||||
results_dict = decode_dataset(
|
results_dict = decode_dataset(
|
||||||
cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(),
|
cuts=test_cut,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
)
|
)
|
||||||
tot_err = save_results(
|
save_results(
|
||||||
params=params, test_set_name=subdir, results_dict=results_dict
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
with (
|
# valid_cuts = reazonspeech_corpus.valid_cuts()
|
||||||
params.res_dir
|
|
||||||
/ (
|
# for valid_cut in valid_cuts:
|
||||||
f"{subdir}-{params.decode_chunk_len}"
|
# results_dict = decode_dataset(
|
||||||
f"_{params.avg}_{params.epoch}.cer"
|
# cuts=valid_cut,
|
||||||
)
|
# params=params,
|
||||||
).open("w") as fout:
|
# model=model,
|
||||||
if len(tot_err) == 1:
|
# sp=sp,
|
||||||
fout.write(f"{tot_err[0][1]}")
|
# decoding_graph=decoding_graph,
|
||||||
else:
|
# )
|
||||||
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
|
# save_results(
|
||||||
|
# params=params,
|
||||||
|
# test_set_name="valid",
|
||||||
|
# results_dict=results_dict,
|
||||||
|
# )
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
|
@ -644,7 +644,8 @@ def write_error_stats(
|
|||||||
results[i] = (cut_id, ref, hyp)
|
results[i] = (cut_id, ref, hyp)
|
||||||
|
|
||||||
for cut_id, ref, hyp in results:
|
for cut_id, ref, hyp in results:
|
||||||
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
# ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||||
|
ali = kaldialign.align(ref, hyp, ERR)
|
||||||
for ref_word, hyp_word in ali:
|
for ref_word, hyp_word in ali:
|
||||||
if ref_word == ERR:
|
if ref_word == ERR:
|
||||||
ins[hyp_word] += 1
|
ins[hyp_word] += 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user