Support dynamic chunk streaming training in pruned_transcuder_stateless5 (#454)

* support dynamic chunk streaming training

* Add simulate streaming decoding

* Support streaming decoding

* fix causal

* Minor fixes

* fix streaming decode; add results
This commit is contained in:
Wei Kang 2022-07-29 16:40:06 +08:00 committed by GitHub
parent 1b478d3ac3
commit 2f75236c05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1420 additions and 39 deletions

View File

@ -618,6 +618,80 @@ done
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless4_20220625>
#### [pruned_transducer_stateless5](./pruned_transducer_stateless5)
See <https://github.com/k2-fsa/icefall/pull/454> for more details.
##### Training on full librispeech
The WERs are (the number in the table formatted as test-clean & test-other):
We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|----------------------|--------------|----------------|----------------|----------------|----------------|
| greedy search | 32 | 3.93 & 9.88 | 3.64 & 9.43 | 3.51 & 8.92 | 3.26 & 8.37 |
| greedy search | 64 | 4.84 & 9.81 | 3.59 & 9.27 | 3.44 & 8.83 | 3.23 & 8.33 |
| fast beam search | 32 | 3.86 & 9.77 | 3.67 & 9.3 | 3.5 & 8.83 | 3.27 & 8.33 |
| fast beam search | 64 | 3.79 & 9.68 | 3.57 & 9.21 | 3.41 & 8.72 | 3.25 & 8.27 |
| modified beam search | 32 | 3.84 & 9.71 | 3.66 & 9.38 | 3.47 & 8.86 | 3.26 & 8.42 |
| modified beam search | 64 | 3.81 & 9.59 | 3.58 & 9.2 | 3.44 & 8.74 | 3.23 & 8.35 |
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless5/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
The training command is:
```bash
./pruned_transducer_stateless5/train.py \
--exp-dir pruned_transducer_stateless5/exp \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512 \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 20 \
--num-left-chunks 4 \
--max-duration 300 \
--world-size 4 \
--start-epoch 1 \
--num-epochs 25
```
You can find the tensorboard log here <https://tensorboard.dev/experiment/rO04h6vjTLyw0qSxjp4m4Q>
The decoding command is:
```bash
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
for chunk in 2 4 8 16; do
for left in 32 64; do
./pruned_transducer_stateless5/decode.py \
--num-encoder-layers 18 \
--dim-feedforward 2048 \
--nhead 8 \
--encoder-dim 512 \
--decoder-dim 512 \
--joiner-dim 512 \
--simulate-streaming 1 \
--decode-chunk-size ${chunk} \
--left-context ${left} \
--causal-convolution 1 \
--epoch 25 \
--avg 3 \
--exp-dir ./pruned_transducer_stateless5/exp \
--max-sym-per-frame 1 \
--max-duration 1000 \
--decoding-method ${decoding_method}
done
done
```
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729>
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)

View File

@ -32,7 +32,7 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.utils import make_pad_mask, subsequent_chunk_mask
class Conformer(EncoderInterface):
@ -46,8 +46,27 @@ class Conformer(EncoderInterface):
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
layer_dropout (float): layer-dropout rate.
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
cnn_module_kernel (int): Kernel size of convolution module.
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
you want to train a streaming model, this is expected to be True.
When setting True, it will use a masking strategy to make the attention
see only limited left and right context.
short_chunk_threshold (float): a threshold to determinize the chunk size
to be used in masking training, if the randomly generated chunk size
is greater than ``max_len * short_chunk_threshold`` (max_len is the
max sequence length of current batch) then it will use
full context in training (i.e. with chunk size equals to max_len).
This will be used only when dynamic_chunk_training is True.
short_chunk_size (int): see docs above, if the randomly generated chunk
size equals to or less than ``max_len * short_chunk_threshold``, the
chunk size will be sampled uniformly from 1 to short_chunk_size.
This also will be used only when dynamic_chunk_training is True.
num_left_chunks (int): the left context (in chunks) attention can see, the
chunk size is decided by short_chunk_threshold and short_chunk_size.
A minus value means seeing full left context.
This also will be used only when dynamic_chunk_training is True.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training.
"""
def __init__(
@ -62,6 +81,11 @@ class Conformer(EncoderInterface):
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
aux_layer_period: int = 3,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
num_left_chunks: int = -1,
causal: bool = False,
) -> None:
super(Conformer, self).__init__()
@ -79,18 +103,28 @@ class Conformer(EncoderInterface):
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
self.encoder_layers = num_encoder_layers
self.d_model = d_model
self.cnn_module_kernel = cnn_module_kernel
self.causal = causal
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.num_left_chunks = num_left_chunks
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
dropout,
layer_dropout,
cnn_module_kernel,
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
layer_dropout=layer_dropout,
cnn_module_kernel=cnn_module_kernel,
causal=causal,
)
# aux_layers from 1/3
self.encoder = ConformerEncoder(
encoder_layer,
num_encoder_layers,
encoder_layer=encoder_layer,
num_layers=num_encoder_layers,
aux_layers=list(
range(
num_encoder_layers // 3,
@ -99,6 +133,7 @@ class Conformer(EncoderInterface):
)
),
)
self._init_state: List[torch.Tensor] = [torch.empty(0)]
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -126,16 +161,246 @@ class Conformer(EncoderInterface):
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
if self.dynamic_chunk_training:
assert (
self.causal
), "Causal convolution is required for streaming conformer."
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=self.num_left_chunks,
device=x.device,
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
else:
x = self.encoder(
x,
pos_emb,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
@torch.jit.export
def get_init_state(
self, left_context: int, device: torch.device
) -> List[torch.Tensor]:
"""Return the initial cache state of the model.
Args:
left_context: The left context size (in frames after subsampling).
Returns:
Return the initial state of the model, it is a list containing two
tensors, the first one is the cache for attentions which has a shape
of (num_encoder_layers, left_context, encoder_dim), the second one
is the cache of conv_modules which has a shape of
(num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
NOTE: the returned tensors are on the given device.
"""
if (
len(self._init_state) == 2
and self._init_state[0].size(1) == left_context
):
# Note: It is OK to share the init state as it is
# not going to be modified by the model
return self._init_state
init_states: List[torch.Tensor] = [
torch.zeros(
(
self.encoder_layers,
left_context,
self.d_model,
),
device=device,
),
torch.zeros(
(
self.encoder_layers,
self.cnn_module_kernel - 1,
self.d_model,
),
device=device,
),
]
self._init_state = init_states
return init_states
@torch.jit.export
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[Tensor]] = None,
processed_lens: Optional[Tensor] = None,
left_context: int = 64,
right_context: int = 4,
chunk_size: int = 16,
simulate_streaming: bool = False,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- decode_states, the updated states including the information
of current chunk.
"""
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4!
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming:
assert states is not None
assert processed_lens is not None
assert (
len(states) == 2
and states[0].shape
== (self.encoder_layers, left_context, x.size(0), self.d_model)
and states[1].shape
== (
self.encoder_layers,
self.cnn_module_kernel - 1,
x.size(0),
self.d_model,
)
), f"""The length of states MUST be equal to 2, and the shape of
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
given {states[0].shape}. the shape of second element should be
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
processed_lens = processed_lens.view(x.size(0), 1)
processed_mask = (processed_lens <= processed_mask).flip(1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
embed = self.encoder_embed(x)
# cut off 1 frame on each size of embed as they see the padding
# value which causes a training and decoding mismatch.
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x, states = self.encoder.chunk_forward(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
states=states,
left_context=left_context,
right_context=right_context,
) # (T, B, F)
if right_context > 0:
x = x[0:-right_context, ...]
lengths -= right_context
else:
assert states is None
states = [] # just to make torch.script.jit happy
# this branch simulates streaming decoding using mask as we are
# using in training time.
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item()
num_left_chunks = -1
if left_context >= 0:
assert left_context % chunk_size == 0
num_left_chunks = left_context // chunk_size
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=num_left_chunks,
device=x.device,
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths, states
class ConformerEncoderLayer(nn.Module):
"""
@ -148,6 +413,8 @@ class ConformerEncoderLayer(nn.Module):
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -164,6 +431,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
causal: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
@ -191,7 +459,9 @@ class ConformerEncoderLayer(nn.Module):
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, causal=causal
)
self.norm_final = BasicNorm(d_model)
@ -257,7 +527,8 @@ class ConformerEncoderLayer(nn.Module):
src = src + self.dropout(src_att)
# convolution module
src = src + self.dropout(self.conv_module(src))
conv, _ = self.conv_module(src)
src = src + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -269,6 +540,98 @@ class ConformerEncoderLayer(nn.Module):
return src
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (left_context, src.size(1), src.size(2))
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
# We put the attention cache this level (i.e. before linear transformation)
# to save memory consumption, when decoding in streaming fashion, the
# batch size would be thousands (for 32GB machine), if we cache key & val
# separately, it needs extra several GB memory.
# TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
# separately) if needed.
key = torch.cat([states[0], src], dim=0)
val = key
if right_context > 0:
states[0] = key[
-(left_context + right_context) : -right_context, ... # noqa
]
else:
states[0] = key[-left_context:, ...]
# multi-headed self-attention module
src_att = self.self_attn(
src,
key,
val,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
left_context=left_context,
)[0]
src = src + self.dropout(src_att)
# convolution module
conv, conv_cache = self.conv_module(src, states[1], right_context)
states[1] = conv_cache
src = src + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
return src, states
class ConformerEncoder(nn.Module):
r"""ConformerEncoder is a stack of N encoder layers
@ -352,6 +715,77 @@ class ConformerEncoder(nn.Module):
return output
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (
self.num_layers,
left_context,
src.size(1),
src.size(2),
)
assert states[1].size(0) == self.num_layers
output = src
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output, cache = mod.chunk_forward(
output,
pos_emb,
states=cache,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
left_context=left_context,
right_context=right_context,
)
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output, states
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
@ -376,12 +810,13 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = x.size(1) + left_context
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
@ -391,9 +826,9 @@ class RelPositionalEncoding(torch.nn.Module):
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
@ -411,22 +846,28 @@ class RelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self, x: torch.Tensor, left_context: int = 0
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
self.extend_pe(x, left_context)
x_size_1 = x.size(1) + left_context
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
@ -498,6 +939,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -511,6 +953,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
- Inputs:
@ -556,14 +1001,18 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
left_context=left_context,
)
def rel_shift(self, x: Tensor) -> Tensor:
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
@ -571,14 +1020,17 @@ class RelPositionMultiheadAttention(nn.Module):
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
time2 = time1 + left_context
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
@ -600,6 +1052,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -617,6 +1070,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
Inputs:
@ -780,7 +1236,8 @@ class RelPositionMultiheadAttention(nn.Module):
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2
@ -800,9 +1257,9 @@ class RelPositionMultiheadAttention(nn.Module):
# compute matrix b and matrix d
matrix_bd = torch.matmul(
q_with_bias_v, p.transpose(-2, -1)
q_with_bias_v, p
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd)
matrix_bd = self.rel_shift(matrix_bd, left_context)
attn_output_weights = (
matrix_ac + matrix_bd
@ -837,6 +1294,39 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
# If we are using dynamic_chunk_training and setting a limited
# num_left_chunks, the attention may only see the padding values which
# will also be masked out by `key_padding_mask`, at this circumstances,
# the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
# positions to avoid invalid loss value below.
if (
attn_mask is not None
and attn_mask.dtype == torch.bool
and key_padding_mask is not None
):
if attn_mask.size(0) != 1:
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
combined_mask = attn_mask | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
else:
# attn_mask.shape == (1, tgt_len, src_len)
combined_mask = attn_mask.unsqueeze(
0
) | key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
combined_mask, 0.0
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
@ -870,17 +1360,24 @@ class ConvolutionModule(nn.Module):
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
causal (bool): Whether to use causal convolution.
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = False,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.causal = causal
self.pointwise_conv1 = ScaledConv1d(
channels,
2 * channels,
@ -907,12 +1404,17 @@ class ConvolutionModule(nn.Module):
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
self.lorder = kernel_size - 1
padding = (kernel_size - 1) // 2
if self.causal:
padding = 0
self.depthwise_conv = ScaledConv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
padding=padding,
groups=channels,
bias=bias,
)
@ -933,15 +1435,26 @@ class ConvolutionModule(nn.Module):
initial_scale=0.25,
)
def forward(self, x: Tensor) -> Tensor:
def forward(
self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
cache: The cache of depthwise_conv, only used in real streaming
decoding.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Returns:
Tensor: Output tensor (#time, batch, channels).
If cache is None return the output tensor (#time, batch, channels).
If cache is not None, return a tuple of Tensor, the first one is
the output tensor (#time, batch, channels), the second one is the
new cache for next chunk (#kernel_size - 1, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
x = x.permute(1, 2, 0) # (#batch, channels, time).
@ -953,6 +1466,27 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if self.causal and self.lorder > 0:
if cache is None:
# Make depthwise_conv causal by
# manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
assert (
not self.training
), "Cache should be None in training time"
assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
if right_context > 0:
cache = x.permute(2, 0, 1)[
-(self.lorder + right_context) : ( # noqa
-right_context
),
...,
]
else:
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
@ -960,7 +1494,11 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
# torch.jit.script requires return types be the same as annotated above
if cache is None:
cache = torch.empty(0)
return x.permute(2, 0, 1), cache
class Conv2dSubsampling(nn.Module):

View File

@ -96,6 +96,7 @@ Usage:
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -132,6 +133,8 @@ from icefall.utils import (
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
@ -298,6 +301,29 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser
@ -352,9 +378,26 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -621,6 +664,10 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
@ -658,6 +705,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")

View File

@ -0,0 +1 @@
../pruned_transducer_stateless/decode_stream.py

View File

@ -97,7 +97,7 @@ def get_parser():
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
@ -137,6 +137,15 @@ def get_parser():
"2 means tri-gram",
)
parser.add_argument(
"--streaming-model",
type=str2bool,
default=False,
help="""Whether to export a streaming model, if the models in exp-dir
are streaming model, this should be True.
""",
)
add_model_arguments(parser)
return parser
@ -162,6 +171,9 @@ def main():
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if params.streaming_model:
assert params.causal_convolution
logging.info(params)
logging.info("About to create model")

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/streaming_beam_search.py

View File

@ -0,0 +1,660 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless5/streaming_decode.py \
--epoch 28 \
--avg 15 \
--left-context 32 \
--decode-chunk-size 8 \
--right-context 0 \
--exp-dir ./pruned_transducer_stateless5/exp \
--decoding_method greedy_search \
--num-decode-streams 200
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--right-context",
type=int,
default=0,
help="right context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
processed_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(
params.decode_chunk_size * params.subsampling_factor
)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# if T is less than 7 there will be an error in time reduction layer,
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
# we plus 2 here because we will cut off one frame on each size of
# encoder_embed output as they see invalid paddings. so we need extra 2
# frames.
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = [
torch.stack([x[0] for x in states], dim=2),
torch.stack([x[1] for x in states], dim=2),
]
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
left_context=params.left_context,
right_context=params.right_context,
processed_lens=processed_lens,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(
model=model, encoder_out=encoder_out, streams=decode_streams
)
elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = [states[0][i], states[1][i]]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
initial_states = model.encoder.get_init_state(
params.left_context, device=device
)
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
decode_stream = DecodeStream(
params=params,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
params.suffix += f"-right-context-{params.right_context}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
# Decoding in streaming requires causal convolution
params.causal_convolution = True
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -134,6 +134,40 @@ def add_model_arguments(parser: argparse.ArgumentParser):
""",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -408,6 +442,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
)
return encoder
@ -901,6 +939,11 @@ def run(rank, world_size, args):
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
logging.info(params)
logging.info("About to create model")