mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Support dynamic chunk streaming training in pruned_transcuder_stateless5 (#454)
* support dynamic chunk streaming training * Add simulate streaming decoding * Support streaming decoding * fix causal * Minor fixes * fix streaming decode; add results
This commit is contained in:
parent
1b478d3ac3
commit
2f75236c05
@ -618,6 +618,80 @@ done
|
||||
|
||||
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless4_20220625>
|
||||
|
||||
#### [pruned_transducer_stateless5](./pruned_transducer_stateless5)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/454> for more details.
|
||||
|
||||
##### Training on full librispeech
|
||||
The WERs are (the number in the table formatted as test-clean & test-other):
|
||||
|
||||
We only trained 25 epochs for saving time, if you want to get better results you can train more epochs.
|
||||
|
||||
| decoding method | left context | chunk size = 2 | chunk size = 4 | chunk size = 8 | chunk size = 16|
|
||||
|----------------------|--------------|----------------|----------------|----------------|----------------|
|
||||
| greedy search | 32 | 3.93 & 9.88 | 3.64 & 9.43 | 3.51 & 8.92 | 3.26 & 8.37 |
|
||||
| greedy search | 64 | 4.84 & 9.81 | 3.59 & 9.27 | 3.44 & 8.83 | 3.23 & 8.33 |
|
||||
| fast beam search | 32 | 3.86 & 9.77 | 3.67 & 9.3 | 3.5 & 8.83 | 3.27 & 8.33 |
|
||||
| fast beam search | 64 | 3.79 & 9.68 | 3.57 & 9.21 | 3.41 & 8.72 | 3.25 & 8.27 |
|
||||
| modified beam search | 32 | 3.84 & 9.71 | 3.66 & 9.38 | 3.47 & 8.86 | 3.26 & 8.42 |
|
||||
| modified beam search | 64 | 3.81 & 9.59 | 3.58 & 9.2 | 3.44 & 8.74 | 3.23 & 8.35 |
|
||||
|
||||
|
||||
**NOTE:** The WERs in table above were decoded with simulate streaming method (i.e. using masking strategy), see commands below. We also have [real streaming decoding](./pruned_transducer_stateless5/streaming_decode.py) script which should produce almost the same results. We tried adding right context in the real streaming decoding, but it seemed not to benefit the performance for all the models, the reasons might be the training and decoding mismatching.
|
||||
|
||||
The training command is:
|
||||
|
||||
```bash
|
||||
./pruned_transducer_stateless5/train.py \
|
||||
--exp-dir pruned_transducer_stateless5/exp \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--full-libri 1 \
|
||||
--dynamic-chunk-training 1 \
|
||||
--causal-convolution 1 \
|
||||
--short-chunk-size 20 \
|
||||
--num-left-chunks 4 \
|
||||
--max-duration 300 \
|
||||
--world-size 4 \
|
||||
--start-epoch 1 \
|
||||
--num-epochs 25
|
||||
```
|
||||
|
||||
You can find the tensorboard log here <https://tensorboard.dev/experiment/rO04h6vjTLyw0qSxjp4m4Q>
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
decoding_method="greedy_search" # "fast_beam_search", "modified_beam_search"
|
||||
|
||||
for chunk in 2 4 8 16; do
|
||||
for left in 32 64; do
|
||||
./pruned_transducer_stateless5/decode.py \
|
||||
--num-encoder-layers 18 \
|
||||
--dim-feedforward 2048 \
|
||||
--nhead 8 \
|
||||
--encoder-dim 512 \
|
||||
--decoder-dim 512 \
|
||||
--joiner-dim 512 \
|
||||
--simulate-streaming 1 \
|
||||
--decode-chunk-size ${chunk} \
|
||||
--left-context ${left} \
|
||||
--causal-convolution 1 \
|
||||
--epoch 25 \
|
||||
--avg 3 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--max-sym-per-frame 1 \
|
||||
--max-duration 1000 \
|
||||
--decoding-method ${decoding_method}
|
||||
done
|
||||
done
|
||||
```
|
||||
|
||||
Pre-trained models, training and decoding logs, and decoding results are available at <https://huggingface.co/pkufool/icefall_librispeech_streaming_pruned_transducer_stateless5_20220729>
|
||||
|
||||
|
||||
### LibriSpeech BPE training results (Pruned Stateless Conv-Emformer RNN-T)
|
||||
|
||||
|
@ -32,7 +32,7 @@ from scaling import (
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||
|
||||
|
||||
class Conformer(EncoderInterface):
|
||||
@ -46,8 +46,27 @@ class Conformer(EncoderInterface):
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
dropout (float): dropout rate
|
||||
layer_dropout (float): layer-dropout rate.
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
|
||||
you want to train a streaming model, this is expected to be True.
|
||||
When setting True, it will use a masking strategy to make the attention
|
||||
see only limited left and right context.
|
||||
short_chunk_threshold (float): a threshold to determinize the chunk size
|
||||
to be used in masking training, if the randomly generated chunk size
|
||||
is greater than ``max_len * short_chunk_threshold`` (max_len is the
|
||||
max sequence length of current batch) then it will use
|
||||
full context in training (i.e. with chunk size equals to max_len).
|
||||
This will be used only when dynamic_chunk_training is True.
|
||||
short_chunk_size (int): see docs above, if the randomly generated chunk
|
||||
size equals to or less than ``max_len * short_chunk_threshold``, the
|
||||
chunk size will be sampled uniformly from 1 to short_chunk_size.
|
||||
This also will be used only when dynamic_chunk_training is True.
|
||||
num_left_chunks (int): the left context (in chunks) attention can see, the
|
||||
chunk size is decided by short_chunk_threshold and short_chunk_size.
|
||||
A minus value means seeing full left context.
|
||||
This also will be used only when dynamic_chunk_training is True.
|
||||
causal (bool): Whether to use causal convolution in conformer encoder
|
||||
layer. This MUST be True when using dynamic_chunk_training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -62,6 +81,11 @@ class Conformer(EncoderInterface):
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
aux_layer_period: int = 3,
|
||||
dynamic_chunk_training: bool = False,
|
||||
short_chunk_threshold: float = 0.75,
|
||||
short_chunk_size: int = 25,
|
||||
num_left_chunks: int = -1,
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__()
|
||||
|
||||
@ -79,18 +103,28 @@ class Conformer(EncoderInterface):
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
self.encoder_layers = num_encoder_layers
|
||||
self.d_model = d_model
|
||||
self.cnn_module_kernel = cnn_module_kernel
|
||||
self.causal = causal
|
||||
self.dynamic_chunk_training = dynamic_chunk_training
|
||||
self.short_chunk_threshold = short_chunk_threshold
|
||||
self.short_chunk_size = short_chunk_size
|
||||
self.num_left_chunks = num_left_chunks
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
dropout,
|
||||
layer_dropout,
|
||||
cnn_module_kernel,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
layer_dropout=layer_dropout,
|
||||
cnn_module_kernel=cnn_module_kernel,
|
||||
causal=causal,
|
||||
)
|
||||
# aux_layers from 1/3
|
||||
self.encoder = ConformerEncoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers,
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=num_encoder_layers,
|
||||
aux_layers=list(
|
||||
range(
|
||||
num_encoder_layers // 3,
|
||||
@ -99,6 +133,7 @@ class Conformer(EncoderInterface):
|
||||
)
|
||||
),
|
||||
)
|
||||
self._init_state: List[torch.Tensor] = [torch.empty(0)]
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -126,16 +161,246 @@ class Conformer(EncoderInterface):
|
||||
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
if self.dynamic_chunk_training:
|
||||
assert (
|
||||
self.causal
|
||||
), "Causal convolution is required for streaming conformer."
|
||||
max_len = x.size(0)
|
||||
chunk_size = torch.randint(1, max_len, (1,)).item()
|
||||
if chunk_size > (max_len * self.short_chunk_threshold):
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % self.short_chunk_size + 1
|
||||
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0),
|
||||
chunk_size=chunk_size,
|
||||
num_left_chunks=self.num_left_chunks,
|
||||
device=x.device,
|
||||
)
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
else:
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths
|
||||
|
||||
@torch.jit.export
|
||||
def get_init_state(
|
||||
self, left_context: int, device: torch.device
|
||||
) -> List[torch.Tensor]:
|
||||
"""Return the initial cache state of the model.
|
||||
Args:
|
||||
left_context: The left context size (in frames after subsampling).
|
||||
Returns:
|
||||
Return the initial state of the model, it is a list containing two
|
||||
tensors, the first one is the cache for attentions which has a shape
|
||||
of (num_encoder_layers, left_context, encoder_dim), the second one
|
||||
is the cache of conv_modules which has a shape of
|
||||
(num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
|
||||
NOTE: the returned tensors are on the given device.
|
||||
"""
|
||||
if (
|
||||
len(self._init_state) == 2
|
||||
and self._init_state[0].size(1) == left_context
|
||||
):
|
||||
# Note: It is OK to share the init state as it is
|
||||
# not going to be modified by the model
|
||||
return self._init_state
|
||||
|
||||
init_states: List[torch.Tensor] = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.encoder_layers,
|
||||
left_context,
|
||||
self.d_model,
|
||||
),
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(
|
||||
self.encoder_layers,
|
||||
self.cnn_module_kernel - 1,
|
||||
self.d_model,
|
||||
),
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
|
||||
self._init_state = init_states
|
||||
|
||||
return init_states
|
||||
|
||||
@torch.jit.export
|
||||
def streaming_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: Optional[List[Tensor]] = None,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
left_context: int = 64,
|
||||
right_context: int = 4,
|
||||
chunk_size: int = 16,
|
||||
simulate_streaming: bool = False,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
processed_lens:
|
||||
How many frames (after subsampling) have been processed for each sequence.
|
||||
left_context:
|
||||
How many previous frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `left_context` frames
|
||||
of left context, some have more.
|
||||
right_context:
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
simulate_streaming:
|
||||
If setting True, it will use a masking strategy to simulate streaming
|
||||
fashion (i.e. every chunk data only see limited left context and
|
||||
right context). The whole sequence is supposed to be send at a time
|
||||
When using simulate_streaming.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
- decode_states, the updated states including the information
|
||||
of current chunk.
|
||||
"""
|
||||
|
||||
# x: [N, T, C]
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
|
||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
||||
#
|
||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
||||
|
||||
if not simulate_streaming:
|
||||
assert states is not None
|
||||
assert processed_lens is not None
|
||||
assert (
|
||||
len(states) == 2
|
||||
and states[0].shape
|
||||
== (self.encoder_layers, left_context, x.size(0), self.d_model)
|
||||
and states[1].shape
|
||||
== (
|
||||
self.encoder_layers,
|
||||
self.cnn_module_kernel - 1,
|
||||
x.size(0),
|
||||
self.d_model,
|
||||
)
|
||||
), f"""The length of states MUST be equal to 2, and the shape of
|
||||
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
|
||||
given {states[0].shape}. the shape of second element should be
|
||||
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
|
||||
given {states[1].shape}."""
|
||||
|
||||
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
|
||||
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
|
||||
processed_mask = torch.arange(left_context, device=x.device).expand(
|
||||
x.size(0), left_context
|
||||
)
|
||||
processed_lens = processed_lens.view(x.size(0), 1)
|
||||
processed_mask = (processed_lens <= processed_mask).flip(1)
|
||||
|
||||
src_key_padding_mask = torch.cat(
|
||||
[processed_mask, src_key_padding_mask], dim=1
|
||||
)
|
||||
|
||||
embed = self.encoder_embed(x)
|
||||
|
||||
# cut off 1 frame on each size of embed as they see the padding
|
||||
# value which causes a training and decoding mismatch.
|
||||
embed = embed[:, 1:-1, :]
|
||||
|
||||
embed, pos_enc = self.encoder_pos(embed, left_context)
|
||||
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
|
||||
x, states = self.encoder.chunk_forward(
|
||||
embed,
|
||||
pos_enc,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
states=states,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
) # (T, B, F)
|
||||
if right_context > 0:
|
||||
x = x[0:-right_context, ...]
|
||||
lengths -= right_context
|
||||
else:
|
||||
assert states is None
|
||||
states = [] # just to make torch.script.jit happy
|
||||
# this branch simulates streaming decoding using mask as we are
|
||||
# using in training time.
|
||||
src_key_padding_mask = make_pad_mask(lengths)
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
assert x.size(0) == lengths.max().item()
|
||||
|
||||
num_left_chunks = -1
|
||||
if left_context >= 0:
|
||||
assert left_context % chunk_size == 0
|
||||
num_left_chunks = left_context // chunk_size
|
||||
|
||||
mask = ~subsequent_chunk_mask(
|
||||
size=x.size(0),
|
||||
chunk_size=chunk_size,
|
||||
num_left_chunks=num_left_chunks,
|
||||
device=x.device,
|
||||
)
|
||||
x = self.encoder(
|
||||
x,
|
||||
pos_emb,
|
||||
mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths, states
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
@ -148,6 +413,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
causal (bool): Whether to use causal convolution in conformer encoder
|
||||
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
@ -164,6 +431,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
|
||||
@ -191,7 +459,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
self.conv_module = ConvolutionModule(
|
||||
d_model, cnn_module_kernel, causal=causal
|
||||
)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
@ -257,7 +527,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
src = src + self.dropout(self.conv_module(src))
|
||||
conv, _ = self.conv_module(src)
|
||||
src = src + self.dropout(conv)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
@ -269,6 +540,98 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
return src
|
||||
|
||||
@torch.jit.export
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
states: List[Tensor],
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
left_context:
|
||||
How many previous frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `left_context` frames
|
||||
of left context, some have more.
|
||||
right_context:
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
assert states[0].shape == (left_context, src.size(1), src.size(2))
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
||||
|
||||
# We put the attention cache this level (i.e. before linear transformation)
|
||||
# to save memory consumption, when decoding in streaming fashion, the
|
||||
# batch size would be thousands (for 32GB machine), if we cache key & val
|
||||
# separately, it needs extra several GB memory.
|
||||
# TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
|
||||
# separately) if needed.
|
||||
key = torch.cat([states[0], src], dim=0)
|
||||
val = key
|
||||
if right_context > 0:
|
||||
states[0] = key[
|
||||
-(left_context + right_context) : -right_context, ... # noqa
|
||||
]
|
||||
else:
|
||||
states[0] = key[-left_context:, ...]
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
key,
|
||||
val,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
left_context=left_context,
|
||||
)[0]
|
||||
|
||||
src = src + self.dropout(src_att)
|
||||
|
||||
# convolution module
|
||||
conv, conv_cache = self.conv_module(src, states[1], right_context)
|
||||
states[1] = conv_cache
|
||||
|
||||
src = src + self.dropout(conv)
|
||||
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
return src, states
|
||||
|
||||
|
||||
class ConformerEncoder(nn.Module):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
@ -352,6 +715,77 @@ class ConformerEncoder(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
@torch.jit.export
|
||||
def chunk_forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
states: List[Tensor],
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
left_context: int = 0,
|
||||
right_context: int = 0,
|
||||
) -> Tuple[Tensor, List[Tensor]]:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
states:
|
||||
The decode states for previous frames which contains the cached data.
|
||||
It has two elements, the first element is the attn_cache which has
|
||||
a shape of (encoder_layers, left_context, batch, attention_dim),
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
warmup: controls selective bypass of of layers; if < 1.0, we will
|
||||
bypass layers more frequently.
|
||||
left_context:
|
||||
How many previous frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `left_context` frames
|
||||
of left context, some have more.
|
||||
right_context:
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*(S+left_context)-1, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
assert not self.training
|
||||
assert len(states) == 2
|
||||
assert states[0].shape == (
|
||||
self.num_layers,
|
||||
left_context,
|
||||
src.size(1),
|
||||
src.size(2),
|
||||
)
|
||||
assert states[1].size(0) == self.num_layers
|
||||
|
||||
output = src
|
||||
|
||||
for layer_index, mod in enumerate(self.layers):
|
||||
cache = [states[0][layer_index], states[1][layer_index]]
|
||||
output, cache = mod.chunk_forward(
|
||||
output,
|
||||
pos_emb,
|
||||
states=cache,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
warmup=warmup,
|
||||
left_context=left_context,
|
||||
right_context=right_context,
|
||||
)
|
||||
states[0][layer_index] = cache[0]
|
||||
states[1][layer_index] = cache[1]
|
||||
|
||||
return output, states
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
@ -376,12 +810,13 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
x_size_1 = x.size(1) + left_context
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
if self.pe.size(1) >= x_size_1 * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
@ -391,9 +826,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# Suppose `i` means to the position of query vector and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
pe_positive = torch.zeros(x_size_1, self.d_model)
|
||||
pe_negative = torch.zeros(x_size_1, self.d_model)
|
||||
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
@ -411,22 +846,28 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(
|
||||
self, x: torch.Tensor, left_context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
self.extend_pe(x, left_context)
|
||||
x_size_1 = x.size(1) + left_context
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
- x_size_1
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
@ -498,6 +939,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -511,6 +953,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
@ -556,14 +1001,18 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
left_context=left_context,
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
@ -571,14 +1020,17 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
time2 = time1 + left_context
|
||||
assert (
|
||||
n == left_context + 2 * time1 - 1
|
||||
), f"{n} == {left_context} + 2 * {time1} - 1"
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_size, num_heads, time1, time2),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
@ -600,6 +1052,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
left_context: int = 0,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
@ -617,6 +1070,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
left_context (int): left context (in frames) used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
@ -780,7 +1236,8 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
|
||||
p = p.permute(0, 2, 3, 1)
|
||||
|
||||
q_with_bias_u = (q + self._pos_bias_u()).transpose(
|
||||
1, 2
|
||||
@ -800,9 +1257,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
q_with_bias_v, p
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
matrix_bd = self.rel_shift(matrix_bd, left_context)
|
||||
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
@ -837,6 +1294,39 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
|
||||
# If we are using dynamic_chunk_training and setting a limited
|
||||
# num_left_chunks, the attention may only see the padding values which
|
||||
# will also be masked out by `key_padding_mask`, at this circumstances,
|
||||
# the whole column of `attn_output_weights` will be `-inf`
|
||||
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
|
||||
# positions to avoid invalid loss value below.
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.dtype == torch.bool
|
||||
and key_padding_mask is not None
|
||||
):
|
||||
if attn_mask.size(0) != 1:
|
||||
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
|
||||
combined_mask = attn_mask | key_padding_mask.unsqueeze(
|
||||
1
|
||||
).unsqueeze(2)
|
||||
else:
|
||||
# attn_mask.shape == (1, tgt_len, src_len)
|
||||
combined_mask = attn_mask.unsqueeze(
|
||||
0
|
||||
) | key_padding_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
combined_mask, 0.0
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
@ -870,17 +1360,24 @@ class ConvolutionModule(nn.Module):
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
causal (bool): Whether to use causal convolution.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int,
|
||||
bias: bool = True,
|
||||
causal: bool = False,
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.causal = causal
|
||||
|
||||
self.pointwise_conv1 = ScaledConv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
@ -907,12 +1404,17 @@ class ConvolutionModule(nn.Module):
|
||||
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
|
||||
)
|
||||
|
||||
self.lorder = kernel_size - 1
|
||||
padding = (kernel_size - 1) // 2
|
||||
if self.causal:
|
||||
padding = 0
|
||||
|
||||
self.depthwise_conv = ScaledConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
@ -933,15 +1435,26 @@ class ConvolutionModule(nn.Module):
|
||||
initial_scale=0.25,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def forward(
|
||||
self, x: Tensor, cache: Optional[Tensor] = None, right_context: int = 0
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
cache: The cache of depthwise_conv, only used in real streaming
|
||||
decoding.
|
||||
right_context:
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
If cache is None return the output tensor (#time, batch, channels).
|
||||
If cache is not None, return a tuple of Tensor, the first one is
|
||||
the output tensor (#time, batch, channels), the second one is the
|
||||
new cache for next chunk (#kernel_size - 1, batch, channels).
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
@ -953,6 +1466,27 @@ class ConvolutionModule(nn.Module):
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
if self.causal and self.lorder > 0:
|
||||
if cache is None:
|
||||
# Make depthwise_conv causal by
|
||||
# manualy padding self.lorder zeros to the left
|
||||
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
|
||||
else:
|
||||
assert (
|
||||
not self.training
|
||||
), "Cache should be None in training time"
|
||||
assert cache.size(0) == self.lorder
|
||||
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
|
||||
if right_context > 0:
|
||||
cache = x.permute(2, 0, 1)[
|
||||
-(self.lorder + right_context) : ( # noqa
|
||||
-right_context
|
||||
),
|
||||
...,
|
||||
]
|
||||
else:
|
||||
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
|
||||
|
||||
x = self.depthwise_conv(x)
|
||||
|
||||
x = self.deriv_balancer2(x)
|
||||
@ -960,7 +1494,11 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
# torch.jit.script requires return types be the same as annotated above
|
||||
if cache is None:
|
||||
cache = torch.empty(0)
|
||||
|
||||
return x.permute(2, 0, 1), cache
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
|
@ -96,6 +96,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
@ -132,6 +133,8 @@ from icefall.utils import (
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -298,6 +301,29 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -352,9 +378,26 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
feature_lens += params.left_context
|
||||
feature = torch.nn.functional.pad(
|
||||
feature,
|
||||
pad=(0, 0, 0, params.left_context),
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -621,6 +664,10 @@ def main():
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
@ -658,6 +705,11 @@ def main():
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless/decode_stream.py
|
@ -97,7 +97,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
@ -137,6 +137,15 @@ def get_parser():
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--streaming-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to export a streaming model, if the models in exp-dir
|
||||
are streaming model, this should be True.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -162,6 +171,9 @@ def main():
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.streaming_model:
|
||||
assert params.causal_convolution
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/streaming_beam_search.py
|
660
egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
Executable file
660
egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py
Executable file
@ -0,0 +1,660 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./pruned_transducer_stateless5/streaming_decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--left-context 32 \
|
||||
--decode-chunk-size 8 \
|
||||
--right-context 0 \
|
||||
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||
--decoding_method greedy_search \
|
||||
--num-decode-streams 200
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from decode_stream import DecodeStream
|
||||
from kaldifeat import Fbank, FbankOptions
|
||||
from lhotse import CutSet
|
||||
from streaming_beam_search import (
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Supported decoding methods are:
|
||||
greedy_search
|
||||
modified_beam_search
|
||||
fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num_active_paths",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=32,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--right-context",
|
||||
type=int,
|
||||
default=0,
|
||||
help="right context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decode-streams",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="The number of streams that can be decoded parallel.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_chunk(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
decode_streams: List[DecodeStream],
|
||||
) -> List[int]:
|
||||
"""Decode one chunk frames of features for each decode_streams and
|
||||
return the indexes of finished streams in a List.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
decode_streams:
|
||||
A List of DecodeStream, each belonging to a utterance.
|
||||
Returns:
|
||||
Return a List containing which DecodeStreams are finished.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
features = []
|
||||
feature_lens = []
|
||||
states = []
|
||||
|
||||
processed_lens = []
|
||||
|
||||
for stream in decode_streams:
|
||||
feat, feat_len = stream.get_feature_frames(
|
||||
params.decode_chunk_size * params.subsampling_factor
|
||||
)
|
||||
features.append(feat)
|
||||
feature_lens.append(feat_len)
|
||||
states.append(stream.states)
|
||||
processed_lens.append(stream.done_frames)
|
||||
|
||||
feature_lens = torch.tensor(feature_lens, device=device)
|
||||
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
|
||||
|
||||
# if T is less than 7 there will be an error in time reduction layer,
|
||||
# because we subsample features with ((x_len - 1) // 2 - 1) // 2
|
||||
# we plus 2 here because we will cut off one frame on each size of
|
||||
# encoder_embed output as they see invalid paddings. so we need extra 2
|
||||
# frames.
|
||||
tail_length = 7 + (2 + params.right_context) * params.subsampling_factor
|
||||
if features.size(1) < tail_length:
|
||||
pad_length = tail_length - features.size(1)
|
||||
feature_lens += pad_length
|
||||
features = torch.nn.functional.pad(
|
||||
features,
|
||||
(0, 0, 0, pad_length),
|
||||
mode="constant",
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
states = [
|
||||
torch.stack([x[0] for x in states], dim=2),
|
||||
torch.stack([x[1] for x in states], dim=2),
|
||||
]
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens, states = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
states=states,
|
||||
left_context=params.left_context,
|
||||
right_context=params.right_context,
|
||||
processed_lens=processed_lens,
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(
|
||||
model=model, encoder_out=encoder_out, streams=decode_streams
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
fast_beam_search_one_best(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
processed_lens=processed_lens,
|
||||
streams=decode_streams,
|
||||
beam=params.beam,
|
||||
max_states=params.max_states,
|
||||
max_contexts=params.max_contexts,
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
modified_beam_search(
|
||||
model=model,
|
||||
streams=decode_streams,
|
||||
encoder_out=encoder_out,
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)]
|
||||
|
||||
finished_streams = []
|
||||
for i in range(len(decode_streams)):
|
||||
decode_streams[i].states = [states[0][i], states[1][i]]
|
||||
decode_streams[i].done_frames += encoder_out_lens[i]
|
||||
if decode_streams[i].done:
|
||||
finished_streams.append(i)
|
||||
|
||||
return finished_streams
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
cuts: CutSet,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
cuts:
|
||||
Lhotse Cutset containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
device = model.device
|
||||
|
||||
opts = FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
log_interval = 50
|
||||
|
||||
decode_results = []
|
||||
# Contain decode streams currently running.
|
||||
decode_streams = []
|
||||
initial_states = model.encoder.get_init_state(
|
||||
params.left_context, device=device
|
||||
)
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
initial_states=initial_states,
|
||||
decoding_graph=decoding_graph,
|
||||
device=device,
|
||||
)
|
||||
|
||||
audio: np.ndarray = cut.load_audio()
|
||||
# audio.shape: (1, num_samples)
|
||||
assert len(audio.shape) == 2
|
||||
assert audio.shape[0] == 1, "Should be single channel"
|
||||
assert audio.dtype == np.float32, audio.dtype
|
||||
|
||||
# The trained model is using normalized samples
|
||||
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
|
||||
|
||||
samples = torch.from_numpy(audio).squeeze(0)
|
||||
|
||||
fbank = Fbank(opts)
|
||||
feature = fbank(samples.to(device))
|
||||
decode_stream.set_features(feature)
|
||||
decode_stream.ground_truth = cut.supervisions[0].text
|
||||
|
||||
decode_streams.append(decode_stream)
|
||||
|
||||
while len(decode_streams) >= params.num_decode_streams:
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if num % log_interval == 0:
|
||||
logging.info(f"Cuts processed until now is {num}.")
|
||||
|
||||
# decode final chunks of last sequences
|
||||
while len(decode_streams):
|
||||
finished_streams = decode_one_chunk(
|
||||
params=params, model=model, decode_streams=decode_streams
|
||||
)
|
||||
for i in sorted(finished_streams, reverse=True):
|
||||
decode_results.append(
|
||||
(
|
||||
decode_streams[i].ground_truth.split(),
|
||||
sp.decode(decode_streams[i].decoding_result()).split(),
|
||||
)
|
||||
)
|
||||
del decode_streams[i]
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
key = "greedy_search"
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
key = (
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
# for streaming
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
params.suffix += f"-right-context-{params.right_context}"
|
||||
|
||||
# for fast_beam_search
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
# Decoding in streaming requires causal convolution
|
||||
params.causal_convolution = True
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
decoding_graph = None
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_cuts = [test_clean_cuts, test_other_cuts]
|
||||
|
||||
for test_set, test_cut in zip(test_sets, test_cuts):
|
||||
results_dict = decode_dataset(
|
||||
cuts=test_cut,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -134,6 +134,40 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="""Chunk length of dynamic training, the chunk size would be either
|
||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="How many left context can be seen in chunks when calculating attention.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -408,6 +442,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||
short_chunk_size=params.short_chunk_size,
|
||||
num_left_chunks=params.num_left_chunks,
|
||||
causal=params.causal_convolution,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -901,6 +939,11 @@ def run(rank, world_size, args):
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.dynamic_chunk_training:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "dynamic_chunk_training requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
Loading…
x
Reference in New Issue
Block a user