mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Support streaming decoding
This commit is contained in:
parent
1b14b13047
commit
1d036a6a5a
@ -132,6 +132,7 @@ class Conformer(EncoderInterface):
|
||||
)
|
||||
),
|
||||
)
|
||||
self._init_state: List[torch.Tensor] = [torch.empty(0)]
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -197,6 +198,52 @@ class Conformer(EncoderInterface):
|
||||
|
||||
return x, lengths
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_state(
|
||||
self, left_context: int, device: torch.device
|
||||
) -> List[torch.Tensor]:
|
||||
"""Return the initial cache state of the model.
|
||||
Args:
|
||||
left_context: The left context size (in frames after subsampling).
|
||||
Returns:
|
||||
Return the initial state of the model, it is a list containing two
|
||||
tensors, the first one is the cache for attentions which has a shape
|
||||
of (num_encoder_layers, left_context, encoder_dim), the second one
|
||||
is the cache of conv_modules which has a shape of
|
||||
(num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
|
||||
NOTE: the returned tensors are on the given device.
|
||||
"""
|
||||
if (
|
||||
len(self._init_state) == 2
|
||||
and self._init_state[0].size(1) == left_context
|
||||
):
|
||||
# Note: It is OK to share the init state as it is
|
||||
# not going to be modified by the model
|
||||
return self._init_state
|
||||
|
||||
init_states: List[torch.Tensor] = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.encoder_layers,
|
||||
left_context,
|
||||
self.d_model,
|
||||
),
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(
|
||||
self.encoder_layers,
|
||||
self.cnn_module_kernel - 1,
|
||||
self.d_model,
|
||||
),
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
|
||||
self._init_state = init_states
|
||||
|
||||
return init_states
|
||||
|
||||
@torch.jit.export
|
||||
def streaming_forward(
|
||||
self,
|
||||
@ -492,6 +539,98 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
return src
|
||||
|
||||
@torch.jit.export
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
states: List[Tensor],
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
left_context:
|
||||
How many previous frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `left_context` frames
|
||||
of left context, some have more.
|
||||
right_context:
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
assert states[0].shape == (left_context, src.size(1), src.size(2))
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# We put the attention cache this level (i.e. before linear transformation)
|
||||
# to save memory consumption, when decoding in streaming fashion, the
|
||||
# batch size would be thousands (for 32GB machine), if we cache key & val
|
||||
# separately, it needs extra several GB memory.
|
||||
# TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
|
||||
# separately) if needed.
|
||||
key = torch.cat([states[0], src], dim=0)
|
||||
val = key
|
||||
if right_context > 0:
|
||||
states[0] = key[
|
||||
-(left_context + right_context) : -right_context, ... # noqa
|
||||
]
|
||||
else:
|
||||
states[0] = key[-left_context:, ...]
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
key,
|
||||
val,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
left_context=left_context,
|
||||
)[0]
|
||||
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
conv, conv_cache = self.conv_module(src, states[1], right_context)
|
||||
states[1] = conv_cache
|
||||
|
||||
src = src + self.dropout(conv)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
return src, states
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
@ -575,6 +714,77 @@ class ConformerEncoder(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@torch.jit.export
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
states: List[Tensor],
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
left_context:
|
||||
How many previous frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `left_context` frames
|
||||
of left context, some have more.
|
||||
right_context:
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
assert states[0].shape == (
|
||||
self.num_layers,
|
||||
left_context,
|
||||
src.size(1),
|
||||
src.size(2),
|
||||
)
|
||||
assert states[1].size(0) == self.num_layers
|
||||
|
||||
output = src
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
cache = [states[0][layer_index], states[1][layer_index]]
|
||||
output, cache = mod.chunk_forward(
|
||||
output,
|
||||
pos_emb,
|
||||
states=cache,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
)
|
||||
states[0][layer_index] = cache[0]
|
||||
states[1][layer_index] = cache[1]
|
||||
|
||||
return output, states
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
@ -599,12 +809,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
x_size_1 = x.size(1) + left_context
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
if self.pe.size(1) >= x_size_1 * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
@ -614,9 +825,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# Suppose `i` means to the position of query vector and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
pe_positive = torch.zeros(x_size_1, self.d_model)
|
||||
pe_negative = torch.zeros(x_size_1, self.d_model)
|
||||
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
@ -634,22 +845,28 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, left_context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
self.extend_pe(x, left_context)
|
||||
x_size_1 = x.size(1) + left_context
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
- x_size_1
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
@ -721,6 +938,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -734,6 +952,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
@ -779,14 +1000,18 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
left_context=left_context,
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
@ -794,14 +1019,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
time2 = time1 + left_context
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_size, num_heads, time1, time2),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
@ -823,6 +1051,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -840,6 +1069,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
@ -1003,7 +1235,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
|
||||
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
||||
1, 2
|
||||
@ -1023,9 +1256,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
q_with_bias_v, p
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
matrix_bd = self.rel_shift(matrix_bd, left_context)
|
||||
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
@ -1201,13 +1434,19 @@ class ConvolutionModule(nn.Module):
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, cache: Optional[Tensor] = None) -> Tensor:
|
||||
def forward(
|
||||
self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
cache: The cache of depthwise_conv, only used in real streaming
|
||||
decoding.
|
||||
right_context:
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
@ -1237,6 +1476,14 @@ class ConvolutionModule(nn.Module):
|
||||
), "Cache should be None in training time"
|
||||
assert cache.size(0) == self.lorder
|
||||
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
||||
if right_context > 0:
|
||||
cache = x.permute(2, 0, 1)[
|
||||
-(self.lorder + right_context) : ( # noqa
|
||||
-right_context
|
||||
),
|
||||
...,
|
||||
]
|
||||
else:
|
||||
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
||||
|
||||
x = self.depthwise_conv(x)
|
||||
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless/decode_stream.py
|
@ -137,6 +137,15 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--streaming-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to export a streaming model, if the models in exp-dir
|
||||
are streaming model, this should be True.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -162,6 +171,9 @@ def main():
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.streaming_model:
|
||||
assert params.causal_convolution
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
750
egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
Executable file
750
egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
Executable file
@ -0,0 +1,750 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless5/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--left-context 32 \
|
||||
--decode-chunk-size 8 \
|
||||
--right-context 0 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 200
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import one_best_decoding
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Support only greedy_search and fast_beam_search now.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=32,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--right-context",
|
||||
type=int,
|
||||
default=0,
|
||||
help="right context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
streams: List[DecodeStream],
|
||||
) -> List[List[int]]:
|
||||
|
||||
assert len(streams) == encoder_out.size(0)
|
||||
assert encoder_out.ndim == 3
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = model.device
|
||||
T = encoder_out.size(1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
# decoder_out is of shape (N, decoder_out_dim)
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# logging.info(f"decoder_out shape : {decoder_out.shape}")
|
||||
|
||||
for t in range(T):
|
||||
# current_encoder_out's shape: (batch_size, 1, encoder_out_dim)
|
||||
current_encoder_out = encoder_out[:, t : t + 1, :] # noqa
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
streams[i].hyp.append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = torch.tensor(
|
||||
[stream.hyp[-context_size:] for stream in streams],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=False,
|
||||
)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
|
||||
hyp_tokens = []
|
||||
for stream in streams:
|
||||
hyp_tokens.append(stream.hyp)
|
||||
return hyp_tokens
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
processed_lens: torch.Tensor,
|
||||
decoding_streams: k2.RnntDecodingStreams,
|
||||
) -> List[List[int]]:
|
||||
|
||||
B, T, C = encoder_out.shape
|
||||
for t in range(T):
|
||||
# shape is a RaggedShape of shape (B, context)
|
||||
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||
shape, contexts = decoding_streams.get_contexts()
|
||||
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||
contexts = contexts.to(torch.int64)
|
||||
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||
decoder_out = model.decoder(contexts, need_pad=False)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# current_encoder_out is of shape
|
||||
# (shape.NumElements(), 1, joiner_dim)
|
||||
# fmt: off
|
||||
current_encoder_out = torch.index_select(
|
||||
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||
)
|
||||
# fmt: on
|
||||
logits = model.joiner(
|
||||
current_encoder_out.unsqueeze(2),
|
||||
decoder_out.unsqueeze(1),
|
||||
project_input=False,
|
||||
)
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
log_probs = logits.log_softmax(dim=-1)
|
||||
decoding_streams.advance(log_probs)
|
||||
|
||||
decoding_streams.terminate_and_flush_to_streams()
|
||||
|
||||
lattice = decoding_streams.format_output(processed_lens.tolist())
|
||||
best_path = one_best_decoding(lattice)
|
||||
hyp_tokens = get_texts(best_path)
|
||||
return hyp_tokens
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
decode_streams: List[DecodeStream],
|
||||
) -> List[int]:
|
||||
"""Decode one chunk frames of features for each decode_streams and
|
||||
return the indexes of finished streams in a List.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decode_streams:
|
||||
A List of DecodeStream, each belonging to a utterance.
|
||||
Returns:
|
||||
Return a List containing which DecodeStreams are finished.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
|
||||
rnnt_stream_list = []
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(
|
||||
params.decode_chunk_size * params.subsampling_factor
|
||||
)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
rnnt_stream_list.append(stream.rnnt_decoding_stream)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
# if T is less than 7 there will be an error in time reduction layer,
|
||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||
# we plus 2 here because we will cut off one frame on each size of
|
||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||
# frames.
|
||||
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||
if features.size(1) < tail_length:
|
||||
feature_lens += tail_length - features.size(1)
|
||||
features = torch.cat(
|
||||
[
|
||||
features,
|
||||
torch.tensor(
|
||||
LOG_EPS, dtype=features.dtype, device=device
|
||||
).expand(
|
||||
features.size(0),
|
||||
tail_length - features.size(1),
|
||||
features.size(2),
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
states = [
|
||||
torch.stack([x[0] for x in states], dim=2),
|
||||
torch.stack([x[1] for x in states], dim=2),
|
||||
]
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
left_context=params.left_context,
|
||||
right_context=params.right_context,
|
||||
processed_lens=processed_lens,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp_tokens = greedy_search(model, encoder_out, decode_streams)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
config = k2.RnntDecodingConfig(
|
||||
vocab_size=params.vocab_size,
|
||||
decoder_history_len=params.context_size,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
decoding_streams = k2.RnntDecodingStreams(rnnt_stream_list, config)
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
hyp_tokens = fast_beam_search(
|
||||
model, encoder_out, processed_lens, decoding_streams
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
|
||||
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decode_streams[i].hyp = hyp_tokens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
|
||||
return finished_streams
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
cuts:
|
||||
Lhotse Cutset containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
opts = FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 50
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
initial_states = model.encoder.get_init_state(
|
||||
params.left_context, device=device
|
||||
)
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
initial_states=initial_states,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
)
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature)
|
||||
decode_stream.ground_truth = cut.supervisions[0].text
|
||||
|
||||
decode_streams.append(decode_stream)
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = decode_streams[i].hyp
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(hyp).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
hyp = decode_streams[i].hyp
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = hyp[params.context_size :] # noqa
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(hyp).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
key = "greedy_search"
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
params.suffix += f"-right-context-{params.right_context}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
# Decoding in streaming requires causal convolution
|
||||
params.causal_convolution = True
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
decoding_graph = None
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_cuts = [test_clean_cuts, test_other_cuts]
|
||||
|
||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||
results_dict = decode_dataset(
|
||||
cuts=test_cut,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user