mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
streaming conformer code
This commit is contained in:
parent
898efa7e8c
commit
1e35ea3260
@ -25,6 +25,42 @@ from torch import Tensor, nn
|
|||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42
|
||||||
|
def subsequent_chunk_mask(
|
||||||
|
size: int,
|
||||||
|
chunk_size: int,
|
||||||
|
num_left_chunks: int = -1,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||||
|
this is for streaming encoder
|
||||||
|
Args:
|
||||||
|
size (int): size of mask
|
||||||
|
chunk_size (int): size of chunk
|
||||||
|
num_left_chunks (int): number of left chunks
|
||||||
|
<0: use full chunk
|
||||||
|
>=0: use num_left_chunks
|
||||||
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: mask
|
||||||
|
Examples:
|
||||||
|
>>> subsequent_chunk_mask(4, 2)
|
||||||
|
[[1, 1, 0, 0],
|
||||||
|
[1, 1, 0, 0],
|
||||||
|
[1, 1, 1, 1],
|
||||||
|
[1, 1, 1, 1]]
|
||||||
|
"""
|
||||||
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||||
|
for i in range(size):
|
||||||
|
if num_left_chunks < 0:
|
||||||
|
start = 0
|
||||||
|
else:
|
||||||
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||||||
|
ending = min((i // chunk_size + 1) * chunk_size, size)
|
||||||
|
ret[i, start:ending] = True
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -57,6 +93,7 @@ class Conformer(Transformer):
|
|||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
use_feat_batchnorm: bool = False,
|
use_feat_batchnorm: bool = False,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -82,6 +119,7 @@ class Conformer(Transformer):
|
|||||||
dropout,
|
dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
normalize_before,
|
normalize_before,
|
||||||
|
causal,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
@ -93,7 +131,13 @@ class Conformer(Transformer):
|
|||||||
self.after_norm = identity
|
self.after_norm = identity
|
||||||
|
|
||||||
def run_encoder(
|
def run_encoder(
|
||||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
supervisions: Optional[Supervisions] = None,
|
||||||
|
dynamic_chunk_training: bool = False,
|
||||||
|
short_chunk_proportion: float = 0.5,
|
||||||
|
chunk_size: int = -1,
|
||||||
|
simulate_streaming: bool = False,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -107,23 +151,235 @@ class Conformer(Transformer):
|
|||||||
It is read directly from the batch, without any sorting. It is used
|
It is read directly from the batch, without any sorting. It is used
|
||||||
to compute encoder padding mask, which is used as memory key padding
|
to compute encoder padding mask, which is used as memory key padding
|
||||||
mask for the decoder.
|
mask for the decoder.
|
||||||
|
dynamic_chunk_training:
|
||||||
|
For training only.
|
||||||
|
IF True, train with dynamic right context for some batches
|
||||||
|
sampled with a distribution
|
||||||
|
if False, train with full right context all the time.
|
||||||
|
short_chunk_proportion:
|
||||||
|
For training only.
|
||||||
|
Proportion of samples that will be trained with dynamic chunk.
|
||||||
|
chunk_size:
|
||||||
|
For eval only.
|
||||||
|
right context when evaluating test utts.
|
||||||
|
-1 means all right context.
|
||||||
|
simulate_streaming=False,
|
||||||
|
For eval only.
|
||||||
|
If true, the feature will be feeded into the model chunk by chunk.
|
||||||
|
If false, the whole utts if feeded into the model together i.e. the
|
||||||
|
model only foward once.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||||
|
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||||
|
"""
|
||||||
|
if self.encoder.training:
|
||||||
|
return self.train_run_encoder(
|
||||||
|
x, supervisions, dynamic_chunk_training, short_chunk_proportion
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.eval_run_encoder(
|
||||||
|
x, supervisions, chunk_size, simulate_streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
def train_run_encoder(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
supervisions: Optional[Supervisions] = None,
|
||||||
|
dynamic_chunk_training: bool = False,
|
||||||
|
short_chunk_threshold: float = 0.5,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The model input. Its shape is (N, T, C).
|
||||||
|
supervisions:
|
||||||
|
Supervision in lhotse format.
|
||||||
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
|
CAUTION: It contains length information, i.e., start and number of
|
||||||
|
frames, before subsampling
|
||||||
|
It is read directly from the batch, without any sorting. It is used
|
||||||
|
to compute encoder padding mask, which is used as memory key padding
|
||||||
|
mask for the decoder.
|
||||||
|
dynamic_chunk_training:
|
||||||
|
IF True, train with dynamic right context for some batches
|
||||||
|
sampled with a distribution
|
||||||
|
if False, train with full right context all the time.
|
||||||
|
short_chunk_proportion:
|
||||||
|
Proportion of samples that will be trained with dynamic chunk.
|
||||||
|
"""
|
||||||
|
x = self.encoder_embed(x)
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
src_key_padding_mask = encoder_padding_mask(x.size(0), supervisions)
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
src_key_padding_mask = src_key_padding_mask.to(x.device)
|
||||||
|
|
||||||
|
if dynamic_chunk_training:
|
||||||
|
max_len = x.size(0)
|
||||||
|
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||||
|
if chunk_size > (max_len * short_chunk_threshold):
|
||||||
|
chunk_size = max_len
|
||||||
|
else:
|
||||||
|
chunk_size = chunk_size % 25 + 1
|
||||||
|
mask = ~subsequent_chunk_mask(
|
||||||
|
size=x.size(0), chunk_size=chunk_size, device=x.device
|
||||||
|
)
|
||||||
|
x = self.encoder(
|
||||||
|
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
|
||||||
|
) # (T, B, F)
|
||||||
|
else:
|
||||||
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
x = self.after_norm(x)
|
||||||
|
|
||||||
|
return x, src_key_padding_mask
|
||||||
|
|
||||||
|
def eval_run_encoder(
|
||||||
|
self,
|
||||||
|
feature: Tensor,
|
||||||
|
supervisions: Optional[Supervisions] = None,
|
||||||
|
chunk_size: int = -1,
|
||||||
|
simulate_streaming=False,
|
||||||
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
feature:
|
||||||
|
The model input. Its shape is (N, T, C).
|
||||||
|
supervisions:
|
||||||
|
Supervision in lhotse format.
|
||||||
|
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||||
|
CAUTION: It contains length information, i.e., start and number of
|
||||||
|
frames, before subsampling
|
||||||
|
It is read directly from the batch, without any sorting. It is used
|
||||||
|
to compute encoder padding mask, which is used as memory key padding
|
||||||
|
mask for the decoder.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
# feature.shape: N T C
|
||||||
x, pos_emb = self.encoder_pos(x)
|
num_frames = feature.size(1)
|
||||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
|
||||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
# As temporarily in icefall only subsampling_rate == 4 is supported,
|
||||||
if mask is not None:
|
# following parameters are hard-coded here.
|
||||||
mask = mask.to(x.device)
|
# Change it accordingly if other subsamling_rate are supported.
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
embed_left_context = 7
|
||||||
|
embed_conv_right_context = 3
|
||||||
|
subsampling_rate = 4
|
||||||
|
stride = chunk_size * subsampling_rate
|
||||||
|
decoding_window = embed_conv_right_context + stride
|
||||||
|
|
||||||
|
# This is also only compatible to sumsampling_rate == 4
|
||||||
|
length_after_subsampling = ((feature.size(1) - 1) // 2 - 1) // 2
|
||||||
|
src_key_padding_mask = encoder_padding_mask(
|
||||||
|
length_after_subsampling, supervisions
|
||||||
|
)
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
src_key_padding_mask = src_key_padding_mask.to(feature.device)
|
||||||
|
|
||||||
|
if chunk_size < 0:
|
||||||
|
# non-streaming decoding
|
||||||
|
x = self.encoder_embed(feature)
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
x = self.encoder(
|
||||||
|
x, pos_emb, src_key_padding_mask=src_key_padding_mask
|
||||||
|
) # (T, B, F)
|
||||||
|
else:
|
||||||
|
if simulate_streaming:
|
||||||
|
# simulate chunk_by_chunk streaming decoding
|
||||||
|
# Results of this branch should be identical to following
|
||||||
|
# "else" branch.
|
||||||
|
# But this branch is a little slower
|
||||||
|
# as the feature is feeded chunk by chunk
|
||||||
|
|
||||||
|
# store the result of chunk_by_chunk decoding
|
||||||
|
encoder_output = []
|
||||||
|
|
||||||
|
# caches
|
||||||
|
pos_emb_positive = []
|
||||||
|
pos_emb_negative = []
|
||||||
|
pos_emb_central = None
|
||||||
|
encoder_cache = [None for i in range(len(self.encoder.layers))]
|
||||||
|
conv_cache = [None for i in range(len(self.encoder.layers))]
|
||||||
|
|
||||||
|
# start chunk_by_chunk decoding
|
||||||
|
offset = 0
|
||||||
|
for cur in range(
|
||||||
|
0, num_frames - embed_left_context + 1, stride
|
||||||
|
):
|
||||||
|
end = min(cur + decoding_window, num_frames)
|
||||||
|
cur_feature = feature[:, cur:end, :]
|
||||||
|
cur_feature = self.encoder_embed(cur_feature)
|
||||||
|
cur_embed, cur_pos_emb = self.encoder_pos(
|
||||||
|
cur_feature, offset
|
||||||
|
)
|
||||||
|
cur_embed = cur_embed.permute(
|
||||||
|
1, 0, 2
|
||||||
|
) # (B, T, F) -> (T, B, F)
|
||||||
|
|
||||||
|
cur_T = cur_feature.size(1)
|
||||||
|
if cur == 0:
|
||||||
|
# for first chunk extract the central pos embedding
|
||||||
|
pos_emb_central = cur_pos_emb[
|
||||||
|
0, (chunk_size - 1), :
|
||||||
|
].view(1, 1, -1)
|
||||||
|
cur_T -= 1
|
||||||
|
pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0))
|
||||||
|
pos_emb_negative.append(cur_pos_emb[0, -cur_T:])
|
||||||
|
assert pos_emb_positive[-1].size(0) == cur_T
|
||||||
|
|
||||||
|
pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(
|
||||||
|
0
|
||||||
|
)
|
||||||
|
cur_pos_emb = torch.cat(
|
||||||
|
[pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.encoder.chunk_forward(
|
||||||
|
cur_embed,
|
||||||
|
cur_pos_emb,
|
||||||
|
src_key_padding_mask=src_key_padding_mask[
|
||||||
|
:, : offset + cur_embed.size(0)
|
||||||
|
],
|
||||||
|
encoder_cache=encoder_cache,
|
||||||
|
conv_cache=conv_cache,
|
||||||
|
offset=offset,
|
||||||
|
) # (T, B, F)
|
||||||
|
encoder_output.append(x)
|
||||||
|
offset += cur_embed.size(0)
|
||||||
|
|
||||||
|
x = torch.cat(encoder_output, dim=0)
|
||||||
|
else:
|
||||||
|
# NOT simulate chunk_by_chunk decoding
|
||||||
|
# Results of this branch should be identical to previous
|
||||||
|
# simulate chunk_by_chunk decoding branch.
|
||||||
|
# But this branch is faster.
|
||||||
|
x = self.encoder_embed(feature)
|
||||||
|
x, pos_emb = self.encoder_pos(x)
|
||||||
|
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||||
|
mask = ~subsequent_chunk_mask(
|
||||||
|
size=x.size(0), chunk_size=chunk_size, device=x.device
|
||||||
|
)
|
||||||
|
x = self.encoder(
|
||||||
|
x,
|
||||||
|
pos_emb,
|
||||||
|
mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
) # (T, B, F)
|
||||||
|
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
x = self.after_norm(x)
|
x = self.after_norm(x)
|
||||||
|
|
||||||
return x, mask
|
return x, src_key_padding_mask
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
@ -154,6 +410,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoderLayer, self).__init__()
|
super(ConformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = RelPositionMultiheadAttention(
|
self.self_attn = RelPositionMultiheadAttention(
|
||||||
@ -174,7 +431,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
nn.Linear(dim_feedforward, d_model),
|
nn.Linear(dim_feedforward, d_model),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
self.conv_module = ConvolutionModule(
|
||||||
|
d_model, cnn_module_kernel, causal=causal
|
||||||
|
)
|
||||||
|
|
||||||
self.norm_ff_macaron = nn.LayerNorm(
|
self.norm_ff_macaron = nn.LayerNorm(
|
||||||
d_model
|
d_model
|
||||||
@ -264,6 +523,97 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
encoder_cache: Optional[Tensor] = None,
|
||||||
|
conv_cache: Optional[Tensor] = None,
|
||||||
|
offset=0,
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Pass the input through the encoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder layer (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
src_mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*S-1, E)
|
||||||
|
src_mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
|
"""
|
||||||
|
|
||||||
|
# macaron style feed forward module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_ff_macaron(src)
|
||||||
|
src = residual + self.ff_scale * self.dropout(
|
||||||
|
self.feed_forward_macaron(src)
|
||||||
|
)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_ff_macaron(src)
|
||||||
|
|
||||||
|
# multi-headed self-attention module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
if encoder_cache is None:
|
||||||
|
# src: [chunk_size, N, F] e.g. [8, 41, 512]
|
||||||
|
key = src
|
||||||
|
val = key
|
||||||
|
encoder_cache = key
|
||||||
|
else:
|
||||||
|
key = torch.cat([encoder_cache, src], dim=0)
|
||||||
|
val = key
|
||||||
|
encoder_cache = key
|
||||||
|
src_att = self.self_attn(
|
||||||
|
src,
|
||||||
|
key,
|
||||||
|
val,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
attn_mask=src_mask,
|
||||||
|
key_padding_mask=src_key_padding_mask,
|
||||||
|
offset=offset,
|
||||||
|
)[0]
|
||||||
|
src = residual + self.dropout(src_att)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_mha(src)
|
||||||
|
|
||||||
|
# convolution module
|
||||||
|
residual = src # [chunk_size, N, F] e.g. [8, 41, 512]
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_conv(src)
|
||||||
|
if conv_cache is not None:
|
||||||
|
src = torch.cat([conv_cache, src], dim=0)
|
||||||
|
conv_cache = src
|
||||||
|
|
||||||
|
src = self.conv_module(src)
|
||||||
|
src = src[-residual.size(0) :, :, :] # noqa: E203
|
||||||
|
|
||||||
|
src = residual + self.dropout(src)
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_conv(src)
|
||||||
|
|
||||||
|
# feed forward module
|
||||||
|
residual = src
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_ff(src)
|
||||||
|
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||||
|
if not self.normalize_before:
|
||||||
|
src = self.norm_ff(src)
|
||||||
|
|
||||||
|
if self.normalize_before:
|
||||||
|
src = self.norm_final(src)
|
||||||
|
|
||||||
|
return src, encoder_cache, conv_cache
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoder(nn.TransformerEncoder):
|
class ConformerEncoder(nn.TransformerEncoder):
|
||||||
r"""ConformerEncoder is a stack of N encoder layers
|
r"""ConformerEncoder is a stack of N encoder layers
|
||||||
@ -326,6 +676,52 @@ class ConformerEncoder(nn.TransformerEncoder):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def chunk_forward(
|
||||||
|
self,
|
||||||
|
src: Tensor,
|
||||||
|
pos_emb: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
encoder_cache=None,
|
||||||
|
conv_cache=None,
|
||||||
|
offset=0,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder (required).
|
||||||
|
pos_emb: Positional embedding tensor (required).
|
||||||
|
mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
pos_emb: (N, 2*S-1, E)
|
||||||
|
mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
|
"""
|
||||||
|
output = src
|
||||||
|
|
||||||
|
for layer_index, mod in enumerate(self.layers):
|
||||||
|
output, e_cache, c_cache = mod.chunk_forward(
|
||||||
|
output,
|
||||||
|
pos_emb,
|
||||||
|
src_mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
encoder_cache=encoder_cache[layer_index],
|
||||||
|
conv_cache=conv_cache[layer_index],
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
encoder_cache[layer_index] = e_cache
|
||||||
|
conv_cache[layer_index] = c_cache
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
output = self.norm(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
"""Relative positional encoding module.
|
"""Relative positional encoding module.
|
||||||
@ -351,12 +747,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
self.pe = None
|
self.pe = None
|
||||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||||
|
|
||||||
def extend_pe(self, x: Tensor) -> None:
|
def extend_pe(self, x: Tensor, offset: int = 0) -> None:
|
||||||
"""Reset the positional encodings."""
|
"""Reset the positional encodings."""
|
||||||
|
x_size_1 = offset + x.size(1)
|
||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
if self.pe.size(1) >= x_size_1 * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
x.device
|
x.device
|
||||||
@ -366,9 +763,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||||
# position of key vector. We use position relative positions when keys
|
# position of key vector. We use position relative positions when keys
|
||||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
pe_positive = torch.zeros(x_size_1, self.d_model)
|
||||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
pe_negative = torch.zeros(x_size_1, self.d_model)
|
||||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
|
||||||
div_term = torch.exp(
|
div_term = torch.exp(
|
||||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
* -(math.log(10000.0) / self.d_model)
|
* -(math.log(10000.0) / self.d_model)
|
||||||
@ -386,7 +783,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
def forward(
|
||||||
|
self, x: torch.Tensor, offset: int = 0
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -397,15 +796,31 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x, offset)
|
||||||
x = x * self.xscale
|
x = x * self.xscale
|
||||||
|
x_size_1 = offset + x.size(1)
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
- x.size(1)
|
- x_size_1
|
||||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
+ x.size(1),
|
+ x_size_1,
|
||||||
]
|
]
|
||||||
|
x_T = x.size(1)
|
||||||
|
if offset > 0:
|
||||||
|
pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1)
|
||||||
|
else:
|
||||||
|
pos_emb = torch.cat(
|
||||||
|
[
|
||||||
|
pos_emb[:, : (x_T - 1)],
|
||||||
|
self.pe[0, self.pe.size(1) // 2].view(
|
||||||
|
1, 1, self.pe.size(-1)
|
||||||
|
),
|
||||||
|
pos_emb[:, -(x_T - 1) :], # noqa: E203
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
@ -469,6 +884,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
offset=0,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -527,9 +943,10 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask=key_padding_mask,
|
key_padding_mask=key_padding_mask,
|
||||||
need_weights=need_weights,
|
need_weights=need_weights,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
offset=offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
def rel_shift(self, x: Tensor) -> Tensor:
|
def rel_shift(self, x: Tensor, offset=0) -> Tensor:
|
||||||
"""Compute relative positional encoding.
|
"""Compute relative positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -538,18 +955,20 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: tensor of shape (batch, head, time1, time2)
|
Tensor: tensor of shape (batch, head, time1, time2)
|
||||||
(note: time2 has the same value as time1, but it is for
|
(note: time2 == time1 + offset, since it is for
|
||||||
the key, while time1 is for the query).
|
the key, while time1 is for the query).
|
||||||
"""
|
"""
|
||||||
(batch_size, num_heads, time1, n) = x.shape
|
(batch_size, num_heads, time1, n) = x.shape
|
||||||
assert n == 2 * time1 - 1
|
time2 = time1 + offset
|
||||||
|
assert n == 2 * time2 - 1
|
||||||
# Note: TorchScript requires explicit arg for stride()
|
# Note: TorchScript requires explicit arg for stride()
|
||||||
batch_stride = x.stride(0)
|
batch_stride = x.stride(0)
|
||||||
head_stride = x.stride(1)
|
head_stride = x.stride(1)
|
||||||
time1_stride = x.stride(2)
|
time1_stride = x.stride(2)
|
||||||
n_stride = x.stride(3)
|
n_stride = x.stride(3)
|
||||||
|
|
||||||
return x.as_strided(
|
return x.as_strided(
|
||||||
(batch_size, num_heads, time1, time1),
|
(batch_size, num_heads, time1, time2),
|
||||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||||
storage_offset=n_stride * (time1 - 1),
|
storage_offset=n_stride * (time1 - 1),
|
||||||
)
|
)
|
||||||
@ -571,6 +990,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
key_padding_mask: Optional[Tensor] = None,
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
need_weights: bool = True,
|
need_weights: bool = True,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
offset=0,
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -749,7 +1169,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
pos_emb_bsz = pos_emb.size(0)
|
pos_emb_bsz = pos_emb.size(0)
|
||||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
|
||||||
|
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||||
|
p = p.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||||
1, 2
|
1, 2
|
||||||
@ -769,10 +1191,11 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
# compute matrix b and matrix d
|
# compute matrix b and matrix d
|
||||||
matrix_bd = torch.matmul(
|
matrix_bd = torch.matmul(
|
||||||
q_with_bias_v, p.transpose(-2, -1)
|
q_with_bias_v, p
|
||||||
) # (batch, head, time1, 2*time1-1)
|
) # (batch, head, time1, 2*time1-1)
|
||||||
matrix_bd = self.rel_shift(matrix_bd)
|
matrix_bd = self.rel_shift(
|
||||||
|
matrix_bd, offset=offset
|
||||||
|
) # [B, head, time1, time2]
|
||||||
attn_output_weights = (
|
attn_output_weights = (
|
||||||
matrix_ac + matrix_bd
|
matrix_ac + matrix_bd
|
||||||
) * scaling # (batch, head, time1, time2)
|
) * scaling # (batch, head, time1, time2)
|
||||||
@ -843,7 +1266,11 @@ class ConvolutionModule(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, channels: int, kernel_size: int, bias: bool = True
|
self,
|
||||||
|
channels: int,
|
||||||
|
kernel_size: int,
|
||||||
|
bias: bool = True,
|
||||||
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an ConvolutionModule object."""
|
"""Construct an ConvolutionModule object."""
|
||||||
super(ConvolutionModule, self).__init__()
|
super(ConvolutionModule, self).__init__()
|
||||||
@ -858,12 +1285,20 @@ class ConvolutionModule(nn.Module):
|
|||||||
padding=0,
|
padding=0,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py#L41
|
||||||
|
if causal:
|
||||||
|
self.lorder = kernel_size - 1
|
||||||
|
padding = 0 # manualy padding self.lorder zeros to the left later
|
||||||
|
else:
|
||||||
|
assert (kernel_size - 1) % 2 == 0
|
||||||
|
self.lorder = 0
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
self.depthwise_conv = nn.Conv1d(
|
self.depthwise_conv = nn.Conv1d(
|
||||||
channels,
|
channels,
|
||||||
channels,
|
channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=(kernel_size - 1) // 2,
|
padding=padding,
|
||||||
groups=channels,
|
groups=channels,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
@ -896,6 +1331,10 @@ class ConvolutionModule(nn.Module):
|
|||||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||||
|
|
||||||
# 1D Depthwise Conv
|
# 1D Depthwise Conv
|
||||||
|
if self.lorder > 0:
|
||||||
|
# manualy padding self.lorder zeros to the left
|
||||||
|
# make depthwise_conv causal
|
||||||
|
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.activation(self.norm(x))
|
x = self.activation(self.norm(x))
|
||||||
|
|
||||||
|
506
egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
Executable file
506
egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py
Executable file
@ -0,0 +1,506 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
|
from conformer import Conformer
|
||||||
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py#L166
|
||||||
|
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||||
|
new_hyp: List[int] = []
|
||||||
|
cur = 0
|
||||||
|
while cur < len(hyp):
|
||||||
|
if hyp[cur] != 0:
|
||||||
|
new_hyp.append(hyp[cur])
|
||||||
|
prev = cur
|
||||||
|
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||||
|
cur += 1
|
||||||
|
return new_hyp
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=34,
|
||||||
|
help="It specifies the checkpoint to use for decoding."
|
||||||
|
"Note: Epoch counts from 0.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch'. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Frames of right context"
|
||||||
|
"-1 for whole right context, i.e. non-streaming decoding",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--tailing-num-frames",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="tailing dummy frames padded to the right,"
|
||||||
|
"only used during decoding",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="simulate chunk by chunk decoding",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="ctc-greedy-search",
|
||||||
|
help="Streaming Decoding method",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--export",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""When enabled, the averaged model is saved to
|
||||||
|
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||||
|
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||||
|
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=Path,
|
||||||
|
default="streaming_conformer_ctc/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_bpe",
|
||||||
|
help="The lang dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg-models",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Manually select models to average, seperated by comma;"
|
||||||
|
"e.g. 60,62,63,72",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"exp_dir": Path("conformer_ctc/exp"),
|
||||||
|
"lang_dir": Path("data/lang_bpe"),
|
||||||
|
"lm_dir": Path("data/lm"),
|
||||||
|
# parameters for conformer
|
||||||
|
"causal": True,
|
||||||
|
"subsampling_factor": 4,
|
||||||
|
"vgg_frontend": False,
|
||||||
|
"use_feat_batchnorm": True,
|
||||||
|
"feature_dim": 80,
|
||||||
|
"nhead": 8,
|
||||||
|
"attention_dim": 512,
|
||||||
|
"num_decoder_layers": 6,
|
||||||
|
# parameters for decoding
|
||||||
|
"search_beam": 20,
|
||||||
|
"output_beam": 8,
|
||||||
|
"min_active_states": 30,
|
||||||
|
"max_active_states": 10000,
|
||||||
|
"use_double_scores": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
|
batch: dict,
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
chunk_size: int = -1,
|
||||||
|
simulate_streaming=False,
|
||||||
|
) -> Dict[str, List[List[str]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if no rescoring is used, the key is the string `no_rescore`.
|
||||||
|
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||||
|
where `xxx` is the value of `lm_scale`. An example key is
|
||||||
|
`lm_scale_0.7`
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
bpe_model:
|
||||||
|
The BPE model. Used only when params.method is ctc-decoding.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
word_table:
|
||||||
|
The word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID of the SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID of the EOS.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
feature = batch["inputs"]
|
||||||
|
device = torch.device("cuda")
|
||||||
|
assert feature.ndim == 3
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
# Extra dummy tailing frames my reduce deletion error
|
||||||
|
# example WITHOUT padding:
|
||||||
|
# CHAPTER SEVEN ON THE RACES OF MAN
|
||||||
|
# example WITH padding:
|
||||||
|
# CHAPTER SEVEN ON THE RACES OF (MAN->*)
|
||||||
|
tailing_frames = (
|
||||||
|
torch.tensor([-23.0259])
|
||||||
|
.expand([feature.size(0), params.tailing_num_frames, 80])
|
||||||
|
.to(feature.device)
|
||||||
|
)
|
||||||
|
feature = torch.cat([feature, tailing_frames], dim=1)
|
||||||
|
supervisions["num_frames"] += params.tailing_num_frames
|
||||||
|
|
||||||
|
nnet_output, memory, memory_key_padding_mask = model(
|
||||||
|
feature,
|
||||||
|
supervisions,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
simulate_streaming=simulate_streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert params.method == "ctc-greedy-search"
|
||||||
|
key = "ctc-greedy-search"
|
||||||
|
batch_size = nnet_output.size(0)
|
||||||
|
maxlen = nnet_output.size(1)
|
||||||
|
topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1)
|
||||||
|
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
|
||||||
|
topk_index = topk_index.masked_fill_(
|
||||||
|
memory_key_padding_mask, 0
|
||||||
|
) # (B, maxlen)
|
||||||
|
token_ids = [token_id.tolist() for token_id in topk_index]
|
||||||
|
token_ids = [
|
||||||
|
remove_duplicates_and_blank(token_id) for token_id in token_ids
|
||||||
|
]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
return {key: hyps}
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||||
|
word_table: k2.SymbolTable,
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
chunk_size: int = -1,
|
||||||
|
simulate_streaming=False,
|
||||||
|
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
bpe_model:
|
||||||
|
The BPE model. Used only when params.method is ctc-decoding.
|
||||||
|
word_table:
|
||||||
|
It is the word symbol table.
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
chunk_size:
|
||||||
|
right context to simulate streaming decoding
|
||||||
|
-1 for whole right context, i.e. non-stream decoding
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||||
|
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains two elements:
|
||||||
|
The first is the reference transcript, and the second is the
|
||||||
|
predicted result.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
bpe_model=bpe_model,
|
||||||
|
batch=batch,
|
||||||
|
word_table=word_table,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
simulate_streaming=simulate_streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for hyp_words, ref_text in zip(hyps, texts):
|
||||||
|
ref_words = ref_text.split()
|
||||||
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
|
results[lm_scale].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
if batch_idx % 100 == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||||
|
):
|
||||||
|
if params.method == "attention-decoder":
|
||||||
|
# Set it to False since there are too many logs.
|
||||||
|
enable_log = False
|
||||||
|
else:
|
||||||
|
enable_log = True
|
||||||
|
test_set_wers = dict()
|
||||||
|
if params.avg_models is not None:
|
||||||
|
avg_models = params.avg_models.replace(",", "_")
|
||||||
|
result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
|
||||||
|
-{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
|
||||||
|
else:
|
||||||
|
result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
|
||||||
|
-{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = (
|
||||||
|
params.exp_dir
|
||||||
|
/ f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
|
||||||
|
)
|
||||||
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
|
if enable_log:
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = (
|
||||||
|
params.exp_dir
|
||||||
|
/ f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
if enable_log:
|
||||||
|
logging.info(
|
||||||
|
"Wrote detailed error stats to {}".format(errs_filename)
|
||||||
|
)
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
max_token_id = max(lexicon.tokens)
|
||||||
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||||
|
params.lang_dir,
|
||||||
|
device=device,
|
||||||
|
sos_token="<sos/eos>",
|
||||||
|
eos_token="<sos/eos>",
|
||||||
|
)
|
||||||
|
sos_id = graph_compiler.sos_id
|
||||||
|
eos_id = graph_compiler.eos_id
|
||||||
|
|
||||||
|
model = Conformer(
|
||||||
|
num_features=params.feature_dim,
|
||||||
|
nhead=params.nhead,
|
||||||
|
d_model=params.attention_dim,
|
||||||
|
num_classes=num_classes,
|
||||||
|
subsampling_factor=params.subsampling_factor,
|
||||||
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
|
vgg_frontend=params.vgg_frontend,
|
||||||
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
|
causal=params.causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.avg == 1 and params.avg_models is not None:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
filenames = []
|
||||||
|
if params.avg_models is not None:
|
||||||
|
model_ids = params.avg_models.split(",")
|
||||||
|
for i in model_ids:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if start >= 0:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.load_state_dict(average_checkpoints(filenames))
|
||||||
|
|
||||||
|
if params.export:
|
||||||
|
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||||
|
torch.save(
|
||||||
|
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
# CAUTION: `test_sets` is for displaying only.
|
||||||
|
# If you want to skip test-clean, you have to skip
|
||||||
|
# it inside the for loop. That is, use
|
||||||
|
#
|
||||||
|
# if test_set == 'test-clean': continue
|
||||||
|
#
|
||||||
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
|
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||||
|
test_sets = ["test-clean", "test-other"]
|
||||||
|
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
bpe_model=bpe_model,
|
||||||
|
word_table=lexicon.word_table,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
chunk_size=params.chunk_size,
|
||||||
|
simulate_streaming=params.simulate_streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params, test_set_name=test_set, results_dict=results_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -124,6 +124,20 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use dynamic right context during training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-proportion",
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help="Proportion of samples trained with short right context",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -340,7 +354,12 @@ def compute_loss(
|
|||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
nnet_output, encoder_memory, memory_mask = model(
|
||||||
|
feature,
|
||||||
|
supervisions,
|
||||||
|
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||||
|
short_chunk_proportion=params.short_chunk_proportion,
|
||||||
|
)
|
||||||
# nnet_output is (N, T, C)
|
# nnet_output is (N, T, C)
|
||||||
|
|
||||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||||
|
@ -158,7 +158,13 @@ class Transformer(nn.Module):
|
|||||||
self.decoder_criterion = None
|
self.decoder_criterion = None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
supervision: Optional[Supervisions] = None,
|
||||||
|
dynamic_chunk_training: bool = False,
|
||||||
|
short_chunk_proportion: float = 0.5,
|
||||||
|
chunk_size: int = -1,
|
||||||
|
simulate_streaming=False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -184,13 +190,21 @@ class Transformer(nn.Module):
|
|||||||
x = self.feat_batchnorm(x)
|
x = self.feat_batchnorm(x)
|
||||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||||
x, supervision
|
x,
|
||||||
|
supervision,
|
||||||
|
dynamic_chunk_training=dynamic_chunk_training,
|
||||||
|
short_chunk_proportion=short_chunk_proportion,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
simulate_streaming=simulate_streaming,
|
||||||
)
|
)
|
||||||
x = self.ctc_output(encoder_memory)
|
x = self.ctc_output(encoder_memory)
|
||||||
return x, encoder_memory, memory_key_padding_mask
|
return x, encoder_memory, memory_key_padding_mask
|
||||||
|
|
||||||
def run_encoder(
|
def run_encoder(
|
||||||
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
supervisions: Optional[Supervisions] = None,
|
||||||
|
chunk_size: int = -1,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
"""Run the transformer encoder.
|
"""Run the transformer encoder.
|
||||||
|
|
||||||
@ -205,6 +219,8 @@ class Transformer(nn.Module):
|
|||||||
It is read directly from the batch, without any sorting. It is used
|
It is read directly from the batch, without any sorting. It is used
|
||||||
to compute the encoder padding mask, which is used as memory key
|
to compute the encoder padding mask, which is used as memory key
|
||||||
padding mask for the decoder.
|
padding mask for the decoder.
|
||||||
|
chunk_size: right chunk_size to simulate streaming decoding
|
||||||
|
-1 for whole right context
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple with two tensors:
|
Return a tuple with two tensors:
|
||||||
- The encoder output, with shape (T, N, C)
|
- The encoder output, with shape (T, N, C)
|
||||||
@ -212,12 +228,16 @@ class Transformer(nn.Module):
|
|||||||
The mask is None if `supervisions` is None.
|
The mask is None if `supervisions` is None.
|
||||||
It is used as memory key padding mask in the decoder.
|
It is used as memory key padding mask in the decoder.
|
||||||
"""
|
"""
|
||||||
|
# streaming decoding(chunk_size >= 0) is only verified with Conformer
|
||||||
|
assert chunk_size == -1
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x = self.encoder_pos(x)
|
x = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||||
mask = mask.to(x.device) if mask is not None else None
|
mask = mask.to(x.device) if mask is not None else None
|
||||||
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
|
x = self.encoder(
|
||||||
|
x, src_key_padding_mask=mask, chunk_size=chunk_size
|
||||||
|
) # (T, N, C)
|
||||||
|
|
||||||
return x, mask
|
return x, mask
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user