streaming conformer code

This commit is contained in:
Guo Liyong 2021-11-22 18:52:12 +08:00
parent 898efa7e8c
commit 1e35ea3260
4 changed files with 1018 additions and 34 deletions

View File

@ -25,6 +25,42 @@ from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
class Conformer(Transformer):
"""
Args:
@ -57,6 +93,7 @@ class Conformer(Transformer):
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
causal: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -82,6 +119,7 @@ class Conformer(Transformer):
dropout,
cnn_module_kernel,
normalize_before,
causal,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before
@ -93,7 +131,13 @@ class Conformer(Transformer):
self.after_norm = identity
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
self,
x: Tensor,
supervisions: Optional[Supervisions] = None,
dynamic_chunk_training: bool = False,
short_chunk_proportion: float = 0.5,
chunk_size: int = -1,
simulate_streaming: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
@ -107,23 +151,235 @@ class Conformer(Transformer):
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
dynamic_chunk_training:
For training only.
IF True, train with dynamic right context for some batches
sampled with a distribution
if False, train with full right context all the time.
short_chunk_proportion:
For training only.
Proportion of samples that will be trained with dynamic chunk.
chunk_size:
For eval only.
right context when evaluating test utts.
-1 means all right context.
simulate_streaming=False,
For eval only.
If true, the feature will be feeded into the model chunk by chunk.
If false, the whole utts if feeded into the model together i.e. the
model only foward once.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
if self.encoder.training:
return self.train_run_encoder(
x, supervisions, dynamic_chunk_training, short_chunk_proportion
)
else:
return self.eval_run_encoder(
x, supervisions, chunk_size, simulate_streaming
)
def train_run_encoder(
self,
x: Tensor,
supervisions: Optional[Supervisions] = None,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.5,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
x:
The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
dynamic_chunk_training:
IF True, train with dynamic right context for some batches
sampled with a distribution
if False, train with full right context all the time.
short_chunk_proportion:
Proportion of samples that will be trained with dynamic chunk.
"""
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
src_key_padding_mask = encoder_padding_mask(x.size(0), supervisions)
if src_key_padding_mask is not None:
src_key_padding_mask = src_key_padding_mask.to(x.device)
if dynamic_chunk_training:
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size, device=x.device
)
x = self.encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before:
x = self.after_norm(x)
return x, src_key_padding_mask
def eval_run_encoder(
self,
feature: Tensor,
supervisions: Optional[Supervisions] = None,
chunk_size: int = -1,
simulate_streaming=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Args:
feature:
The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute encoder padding mask, which is used as memory key padding
mask for the decoder.
Returns:
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
Tensor: Mask tensor of dimension (batch_size, input_length)
"""
x = self.encoder_embed(x)
# feature.shape: N T C
num_frames = feature.size(1)
# As temporarily in icefall only subsampling_rate == 4 is supported,
# following parameters are hard-coded here.
# Change it accordingly if other subsamling_rate are supported.
embed_left_context = 7
embed_conv_right_context = 3
subsampling_rate = 4
stride = chunk_size * subsampling_rate
decoding_window = embed_conv_right_context + stride
# This is also only compatible to sumsampling_rate == 4
length_after_subsampling = ((feature.size(1) - 1) // 2 - 1) // 2
src_key_padding_mask = encoder_padding_mask(
length_after_subsampling, supervisions
)
if src_key_padding_mask is not None:
src_key_padding_mask = src_key_padding_mask.to(feature.device)
if chunk_size < 0:
# non-streaming decoding
x = self.encoder_embed(feature)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
mask = encoder_padding_mask(x.size(0), supervisions)
if mask is not None:
mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
x = self.encoder(
x, pos_emb, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
if simulate_streaming:
# simulate chunk_by_chunk streaming decoding
# Results of this branch should be identical to following
# "else" branch.
# But this branch is a little slower
# as the feature is feeded chunk by chunk
# store the result of chunk_by_chunk decoding
encoder_output = []
# caches
pos_emb_positive = []
pos_emb_negative = []
pos_emb_central = None
encoder_cache = [None for i in range(len(self.encoder.layers))]
conv_cache = [None for i in range(len(self.encoder.layers))]
# start chunk_by_chunk decoding
offset = 0
for cur in range(
0, num_frames - embed_left_context + 1, stride
):
end = min(cur + decoding_window, num_frames)
cur_feature = feature[:, cur:end, :]
cur_feature = self.encoder_embed(cur_feature)
cur_embed, cur_pos_emb = self.encoder_pos(
cur_feature, offset
)
cur_embed = cur_embed.permute(
1, 0, 2
) # (B, T, F) -> (T, B, F)
cur_T = cur_feature.size(1)
if cur == 0:
# for first chunk extract the central pos embedding
pos_emb_central = cur_pos_emb[
0, (chunk_size - 1), :
].view(1, 1, -1)
cur_T -= 1
pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
assert pos_emb_positive[-1].size(0) == cur_T
pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
0
)
pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
0
)
cur_pos_emb = torch.cat(
[pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
dim=1,
)
x = self.encoder.chunk_forward(
cur_embed,
cur_pos_emb,
src_key_padding_mask=src_key_padding_mask[
:, : offset + cur_embed.size(0)
],
encoder_cache=encoder_cache,
conv_cache=conv_cache,
offset=offset,
) # (T, B, F)
encoder_output.append(x)
offset += cur_embed.size(0)
x = torch.cat(encoder_output, dim=0)
else:
# NOT simulate chunk_by_chunk decoding
# Results of this branch should be identical to previous
# simulate chunk_by_chunk decoding branch.
# But this branch is faster.
x = self.encoder_embed(feature)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
mask = ~subsequent_chunk_mask(
size=x.size(0), chunk_size=chunk_size, device=x.device
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
) # (T, B, F)
if self.normalize_before:
x = self.after_norm(x)
return x, mask
return x, src_key_padding_mask
class ConformerEncoderLayer(nn.Module):
@ -154,6 +410,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
cnn_module_kernel: int = 31,
normalize_before: bool = True,
causal: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention(
@ -174,7 +431,9 @@ class ConformerEncoderLayer(nn.Module):
nn.Linear(dim_feedforward, d_model),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, causal=causal
)
self.norm_ff_macaron = nn.LayerNorm(
d_model
@ -264,6 +523,97 @@ class ConformerEncoderLayer(nn.Module):
return src
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
encoder_cache: Optional[Tensor] = None,
conv_cache: Optional[Tensor] = None,
offset=0,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
# macaron style feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff_macaron(src)
src = residual + self.ff_scale * self.dropout(
self.feed_forward_macaron(src)
)
if not self.normalize_before:
src = self.norm_ff_macaron(src)
# multi-headed self-attention module
residual = src
if self.normalize_before:
src = self.norm_mha(src)
if encoder_cache is None:
# src: [chunk_size, N, F] e.g. [8, 41, 512]
key = src
val = key
encoder_cache = key
else:
key = torch.cat([encoder_cache, src], dim=0)
val = key
encoder_cache = key
src_att = self.self_attn(
src,
key,
val,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
offset=offset,
)[0]
src = residual + self.dropout(src_att)
if not self.normalize_before:
src = self.norm_mha(src)
# convolution module
residual = src # [chunk_size, N, F] e.g. [8, 41, 512]
if self.normalize_before:
src = self.norm_conv(src)
if conv_cache is not None:
src = torch.cat([conv_cache, src], dim=0)
conv_cache = src
src = self.conv_module(src)
src = src[-residual.size(0) :, :, :] # noqa: E203
src = residual + self.dropout(src)
if not self.normalize_before:
src = self.norm_conv(src)
# feed forward module
residual = src
if self.normalize_before:
src = self.norm_ff(src)
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
if not self.normalize_before:
src = self.norm_ff(src)
if self.normalize_before:
src = self.norm_final(src)
return src, encoder_cache, conv_cache
class ConformerEncoder(nn.TransformerEncoder):
r"""ConformerEncoder is a stack of N encoder layers
@ -326,6 +676,52 @@ class ConformerEncoder(nn.TransformerEncoder):
return output
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
encoder_cache=None,
conv_cache=None,
offset=0,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
output = src
for layer_index, mod in enumerate(self.layers):
output, e_cache, c_cache = mod.chunk_forward(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
encoder_cache=encoder_cache[layer_index],
conv_cache=conv_cache[layer_index],
offset=offset,
)
encoder_cache[layer_index] = e_cache
conv_cache[layer_index] = c_cache
if self.norm is not None:
output = self.norm(output)
return output
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
@ -351,12 +747,13 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
def extend_pe(self, x: Tensor, offset: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = offset + x.size(1)
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
@ -366,9 +763,9 @@ class RelPositionalEncoding(torch.nn.Module):
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
@ -386,7 +783,9 @@ class RelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self, x: torch.Tensor, offset: int = 0
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
@ -397,15 +796,31 @@ class RelPositionalEncoding(torch.nn.Module):
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
self.extend_pe(x, offset)
x = x * self.xscale
x_size_1 = offset + x.size(1)
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
+ x_size_1,
]
x_T = x.size(1)
if offset > 0:
pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1)
else:
pos_emb = torch.cat(
[
pos_emb[:, : (x_T - 1)],
self.pe[0, self.pe.size(1) // 2].view(
1, 1, self.pe.size(-1)
),
pos_emb[:, -(x_T - 1) :], # noqa: E203
],
dim=1,
)
return self.dropout(x), self.dropout(pos_emb)
@ -469,6 +884,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
offset=0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -527,9 +943,10 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
offset=offset,
)
def rel_shift(self, x: Tensor) -> Tensor:
def rel_shift(self, x: Tensor, offset=0) -> Tensor:
"""Compute relative positional encoding.
Args:
@ -538,18 +955,20 @@ class RelPositionMultiheadAttention(nn.Module):
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
(note: time2 has the same value as time1, but it is for
(note: time2 == time1 + offset, since it is for
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
time2 = time1 + offset
assert n == 2 * time2 - 1
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
@ -571,6 +990,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
offset=0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -749,7 +1169,9 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self.pos_bias_u).transpose(
1, 2
@ -769,10 +1191,11 @@ class RelPositionMultiheadAttention(nn.Module):
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
q_with_bias_v, p
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
matrix_bd = self.rel_shift(
matrix_bd, offset=offset
) # [B, head, time1, time2]
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
@ -843,7 +1266,11 @@ class ConvolutionModule(nn.Module):
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = False,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
@ -858,12 +1285,20 @@ class ConvolutionModule(nn.Module):
padding=0,
bias=bias,
)
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py#L41
if causal:
self.lorder = kernel_size - 1
padding = 0 # manualy padding self.lorder zeros to the left later
else:
assert (kernel_size - 1) % 2 == 0
self.lorder = 0
padding = (kernel_size - 1) // 2
self.depthwise_conv = nn.Conv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
padding=padding,
groups=channels,
bias=bias,
)
@ -896,6 +1331,10 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if self.lorder > 0:
# manualy padding self.lorder zeros to the left
# make depthwise_conv causal
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))

View File

@ -0,0 +1,506 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py#L166
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--chunk-size",
type=int,
default=8,
help="Frames of right context"
"-1 for whole right context, i.e. non-streaming decoding",
)
parser.add_argument(
"--tailing-num-frames",
type=int,
default=20,
help="tailing dummy frames padded to the right,"
"only used during decoding",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="simulate chunk by chunk decoding",
)
parser.add_argument(
"--method",
type=str,
default="ctc-greedy-search",
help="Streaming Decoding method",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="streaming_conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe",
help="The lang dir",
)
parser.add_argument(
"--avg-models",
type=str,
default=None,
help="Manually select models to average, seperated by comma;"
"e.g. 60,62,63,72",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
# parameters for conformer
"causal": True,
"subsampling_factor": 4,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"num_decoder_layers": 6,
# parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
chunk_size: int = -1,
simulate_streaming=False,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
feature = batch["inputs"]
device = torch.device("cuda")
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
# Extra dummy tailing frames my reduce deletion error
# example WITHOUT padding:
# CHAPTER SEVEN ON THE RACES OF MAN
# example WITH padding:
# CHAPTER SEVEN ON THE RACES OF (MAN->*)
tailing_frames = (
torch.tensor([-23.0259])
.expand([feature.size(0), params.tailing_num_frames, 80])
.to(feature.device)
)
feature = torch.cat([feature, tailing_frames], dim=1)
supervisions["num_frames"] += params.tailing_num_frames
nnet_output, memory, memory_key_padding_mask = model(
feature,
supervisions,
chunk_size=chunk_size,
simulate_streaming=simulate_streaming,
)
assert params.method == "ctc-greedy-search"
key = "ctc-greedy-search"
batch_size = nnet_output.size(0)
maxlen = nnet_output.size(1)
topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
topk_index = topk_index.masked_fill_(
memory_key_padding_mask, 0
) # (B, maxlen)
token_ids = [token_id.tolist() for token_id in topk_index]
token_ids = [
remove_duplicates_and_blank(token_id) for token_id in token_ids
]
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
return {key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
chunk_size: int = -1,
simulate_streaming=False,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
chunk_size:
right context to simulate streaming decoding
-1 for whole right context, i.e. non-stream decoding
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
sos_id=sos_id,
eos_id=eos_id,
chunk_size=chunk_size,
simulate_streaming=simulate_streaming,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
if params.avg_models is not None:
avg_models = params.avg_models.replace(",", "_")
result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
-{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
else:
result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
-{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
for key, results in results_dict.items():
recog_path = (
params.exp_dir
/ f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.exp_dir
/ f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
causal=params.causal,
)
if params.avg == 1 and params.avg_models is not None:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
filenames = []
if params.avg_models is not None:
model_ids = params.avg_models.split(",")
for i in model_ids:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
else:
start = params.epoch - params.avg + 1
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
bpe_model=bpe_model,
word_table=lexicon.word_table,
sos_id=sos_id,
eos_id=eos_id,
chunk_size=params.chunk_size,
simulate_streaming=params.simulate_streaming,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -124,6 +124,20 @@ def get_parser():
""",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="Whether to use dynamic right context during training.",
)
parser.add_argument(
"--short-chunk-proportion",
type=float,
default=0.7,
help="Proportion of samples trained with short right context",
)
return parser
@ -340,7 +354,12 @@ def compute_loss(
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
nnet_output, encoder_memory, memory_mask = model(
feature,
supervisions,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_proportion=params.short_chunk_proportion,
)
# nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with

View File

@ -158,7 +158,13 @@ class Transformer(nn.Module):
self.decoder_criterion = None
def forward(
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
self,
x: torch.Tensor,
supervision: Optional[Supervisions] = None,
dynamic_chunk_training: bool = False,
short_chunk_proportion: float = 0.5,
chunk_size: int = -1,
simulate_streaming=False,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
@ -184,13 +190,21 @@ class Transformer(nn.Module):
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x, supervision
x,
supervision,
dynamic_chunk_training=dynamic_chunk_training,
short_chunk_proportion=short_chunk_proportion,
chunk_size=chunk_size,
simulate_streaming=simulate_streaming,
)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
def run_encoder(
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
self,
x: torch.Tensor,
supervisions: Optional[Supervisions] = None,
chunk_size: int = -1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Run the transformer encoder.
@ -205,6 +219,8 @@ class Transformer(nn.Module):
It is read directly from the batch, without any sorting. It is used
to compute the encoder padding mask, which is used as memory key
padding mask for the decoder.
chunk_size: right chunk_size to simulate streaming decoding
-1 for whole right context
Returns:
Return a tuple with two tensors:
- The encoder output, with shape (T, N, C)
@ -212,12 +228,16 @@ class Transformer(nn.Module):
The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder.
"""
# streaming decoding(chunk_size >= 0) is only verified with Conformer
assert chunk_size == -1
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervisions)
mask = mask.to(x.device) if mask is not None else None
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
x = self.encoder(
x, src_key_padding_mask=mask, chunk_size=chunk_size
) # (T, N, C)
return x, mask