copy pruned4, with streaming mode

This commit is contained in:
yaozengwei 2022-07-31 20:24:34 +08:00
parent b8abf38aca
commit b338471917
3 changed files with 949 additions and 69 deletions

View File

@ -18,7 +18,7 @@
import copy
import math
import warnings
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import torch
from encoder_interface import EncoderInterface
@ -32,7 +32,7 @@ from scaling import (
)
from torch import Tensor, nn
from icefall.utils import make_pad_mask
from icefall.utils import make_pad_mask, subsequent_chunk_mask
class Conformer(EncoderInterface):
@ -48,6 +48,26 @@ class Conformer(EncoderInterface):
layer_dropout (float): layer-dropout rate.
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
you want to train a streaming model, this is expected to be True.
When setting True, it will use a masking strategy to make the attention
see only limited left and right context.
short_chunk_threshold (float): a threshold to determinize the chunk size
to be used in masking training, if the randomly generated chunk size
is greater than ``max_len * short_chunk_threshold`` (max_len is the
max sequence length of current batch) then it will use
full context in training (i.e. with chunk size equals to max_len).
This will be used only when dynamic_chunk_training is True.
short_chunk_size (int): see docs above, if the randomly generated chunk
size equals to or less than ``max_len * short_chunk_threshold``, the
chunk size will be sampled uniformly from 1 to short_chunk_size.
This also will be used only when dynamic_chunk_training is True.
num_left_chunks (int): the left context (in chunks) attention can see, the
chunk size is decided by short_chunk_threshold and short_chunk_size.
A minus value means seeing full left context.
This also will be used only when dynamic_chunk_training is True.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training.
"""
def __init__(
@ -61,6 +81,11 @@ class Conformer(EncoderInterface):
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
dynamic_chunk_training: bool = False,
short_chunk_threshold: float = 0.75,
short_chunk_size: int = 25,
num_left_chunks: int = -1,
causal: bool = False,
) -> None:
super(Conformer, self).__init__()
@ -76,6 +101,15 @@ class Conformer(EncoderInterface):
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_layers = num_encoder_layers
self.d_model = d_model
self.cnn_module_kernel = cnn_module_kernel
self.causal = causal
self.dynamic_chunk_training = dynamic_chunk_training
self.short_chunk_threshold = short_chunk_threshold
self.short_chunk_size = short_chunk_size
self.num_left_chunks = num_left_chunks
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
@ -85,8 +119,10 @@ class Conformer(EncoderInterface):
dropout,
layer_dropout,
cnn_module_kernel,
causal,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self._init_state: List[torch.Tensor] = [torch.empty(0)]
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -120,15 +156,249 @@ class Conformer(EncoderInterface):
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
src_key_padding_mask = make_pad_mask(lengths)
if self.dynamic_chunk_training:
assert (
self.causal
), "Causal convolution is required for streaming conformer."
max_len = x.size(0)
chunk_size = torch.randint(1, max_len, (1,)).item()
if chunk_size > (max_len * self.short_chunk_threshold):
chunk_size = max_len
else:
chunk_size = chunk_size % self.short_chunk_size + 1
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=self.num_left_chunks,
device=x.device,
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
else:
x = self.encoder(
x,
pos_emb,
mask=None,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
@torch.jit.export
def get_init_state(
self, left_context: int, device: torch.device
) -> List[torch.Tensor]:
"""Return the initial cache state of the model.
Args:
left_context: The left context size (in frames after subsampling).
Returns:
Return the initial state of the model, it is a list containing two
tensors, the first one is the cache for attentions which has a shape
of (num_encoder_layers, left_context, encoder_dim), the second one
is the cache of conv_modules which has a shape of
(num_encoder_layers, cnn_module_kernel - 1, encoder_dim).
NOTE: the returned tensors are on the given device.
"""
if (
len(self._init_state) == 2
and self._init_state[0].size(1) == left_context
):
# Note: It is OK to share the init state as it is
# not going to be modified by the model
return self._init_state
init_states: List[torch.Tensor] = [
torch.zeros(
(
self.encoder_layers,
left_context,
self.d_model,
),
device=device,
),
torch.zeros(
(
self.encoder_layers,
self.cnn_module_kernel - 1,
self.d_model,
),
device=device,
),
]
self._init_state = init_states
return init_states
@torch.jit.export
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
states: Optional[List[Tensor]] = None,
processed_lens: Optional[Tensor] = None,
left_context: int = 64,
right_context: int = 4,
chunk_size: int = 16,
simulate_streaming: bool = False,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- decode_states, the updated states including the information
of current chunk.
"""
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4!
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
if not simulate_streaming:
assert states is not None
assert processed_lens is not None
assert (
len(states) == 2
and states[0].shape
== (self.encoder_layers, left_context, x.size(0), self.d_model)
and states[1].shape
== (
self.encoder_layers,
self.cnn_module_kernel - 1,
x.size(0),
self.d_model,
)
), f"""The length of states MUST be equal to 2, and the shape of
first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)},
given {states[0].shape}. the shape of second element should be
{(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)},
given {states[1].shape}."""
lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output
src_key_padding_mask = make_pad_mask(lengths)
processed_mask = torch.arange(left_context, device=x.device).expand(
x.size(0), left_context
)
processed_lens = processed_lens.view(x.size(0), 1)
processed_mask = (processed_lens <= processed_mask).flip(1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
embed = self.encoder_embed(x)
# cut off 1 frame on each size of embed as they see the padding
# value which causes a training and decoding mismatch.
embed = embed[:, 1:-1, :]
embed, pos_enc = self.encoder_pos(embed, left_context)
embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
x, states = self.encoder.chunk_forward(
embed,
pos_enc,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
states=states,
left_context=left_context,
right_context=right_context,
) # (T, B, F)
if right_context > 0:
x = x[0:-right_context, ...]
lengths -= right_context
else:
assert states is None
states = [] # just to make torch.script.jit happy
# this branch simulates streaming decoding using mask as we are
# using in training time.
src_key_padding_mask = make_pad_mask(lengths)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
assert x.size(0) == lengths.max().item()
num_left_chunks = -1
if left_context >= 0:
assert left_context % chunk_size == 0
num_left_chunks = left_context // chunk_size
mask = ~subsequent_chunk_mask(
size=x.size(0),
chunk_size=chunk_size,
num_left_chunks=num_left_chunks,
device=x.device,
)
x = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
return x, lengths, states
class ConformerEncoderLayer(nn.Module):
@ -142,6 +412,8 @@ class ConformerEncoderLayer(nn.Module):
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -158,6 +430,7 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
causal: bool = False,
) -> None:
super(ConformerEncoderLayer, self).__init__()
@ -185,7 +458,9 @@ class ConformerEncoderLayer(nn.Module):
ScaledLinear(dim_feedforward, d_model, initial_scale=0.25),
)
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module = ConvolutionModule(
d_model, cnn_module_kernel, causal=causal
)
self.norm_final = BasicNorm(d_model)
@ -214,7 +489,6 @@ class ConformerEncoderLayer(nn.Module):
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
@ -248,10 +522,12 @@ class ConformerEncoderLayer(nn.Module):
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = src + self.dropout(src_att)
# convolution module
src = src + self.dropout(self.conv_module(src))
conv, _ = self.conv_module(src)
src = src + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
@ -263,6 +539,100 @@ class ConformerEncoderLayer(nn.Module):
return src
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (left_context, src.size(1), src.size(2))
# macaron style feed forward module
src = src + self.dropout(self.feed_forward_macaron(src))
# We put the attention cache this level (i.e. before linear transformation)
# to save memory consumption, when decoding in streaming fashion, the
# batch size would be thousands (for 32GB machine), if we cache key & val
# separately, it needs extra several GB memory.
# TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val
# separately) if needed.
key = torch.cat([states[0], src], dim=0)
val = key
if right_context > 0:
states[0] = key[
-(left_context + right_context) : -right_context, ... # noqa
]
else:
states[0] = key[-left_context:, ...]
# multi-headed self-attention module
src_att = self.self_attn(
src,
key,
val,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
left_context=left_context,
)[0]
src = src + self.dropout(src_att)
# convolution module
conv, conv_cache = self.conv_module(src, states[1], right_context)
states[1] = conv_cache
src = src + self.dropout(conv)
# feed forward module
src = src + self.dropout(self.feed_forward(src))
src = self.norm_final(self.balancer(src))
return src, states
class ConformerEncoder(nn.Module):
r"""ConformerEncoder is a stack of N encoder layers
@ -301,6 +671,8 @@ class ConformerEncoder(nn.Module):
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
Shape:
src: (S, N, E).
@ -312,7 +684,7 @@ class ConformerEncoder(nn.Module):
"""
output = src
for i, mod in enumerate(self.layers):
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
pos_emb,
@ -323,6 +695,79 @@ class ConformerEncoder(nn.Module):
return output
@torch.jit.export
def chunk_forward(
self,
src: Tensor,
pos_emb: Tensor,
states: List[Tensor],
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
left_context: int = 0,
right_context: int = 0,
) -> Tuple[Tensor, List[Tensor]]:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
states:
The decode states for previous frames which contains the cached data.
It has two elements, the first element is the attn_cache which has
a shape of (encoder_layers, left_context, batch, attention_dim),
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
warmup: controls selective bypass of of layers; if < 1.0, we will
bypass layers more frequently.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
of left context, some have more.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Shape:
src: (S, N, E).
pos_emb: (N, 2*(S+left_context)-1, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
"""
assert not self.training
assert len(states) == 2
assert states[0].shape == (
self.num_layers,
left_context,
src.size(1),
src.size(2),
)
assert states[1].size(0) == self.num_layers
output = src
for layer_index, mod in enumerate(self.layers):
cache = [states[0][layer_index], states[1][layer_index]]
output, cache = mod.chunk_forward(
output,
pos_emb,
states=cache,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
left_context=left_context,
right_context=right_context,
)
states[0][layer_index] = cache[0]
states[1][layer_index] = cache[1]
return output, states
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
@ -347,24 +792,25 @@ class RelPositionalEncoding(torch.nn.Module):
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
def extend_pe(self, x: Tensor) -> None:
def extend_pe(self, x: Tensor, left_context: int = 0) -> None:
"""Reset the positional encodings."""
x_size_1 = x.size(1) + left_context
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.size(1) >= x_size_1 * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# Suppose `i` means to the position of query vector and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe_positive = torch.zeros(x_size_1, self.d_model)
pe_negative = torch.zeros(x_size_1, self.d_model)
position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
@ -382,22 +828,30 @@ class RelPositionalEncoding(torch.nn.Module):
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self,
x: torch.Tensor,
left_context: int = 0,
) -> Tuple[Tensor, Tensor]:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
self.extend_pe(x, left_context)
x_size_1 = x.size(1) + left_context
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
- x_size_1
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
]
@ -405,7 +859,7 @@ class RelPositionalEncoding(torch.nn.Module):
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with simplified relative position encoding
r"""Multi-Head Attention layer with relative position encoding
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
@ -441,7 +895,24 @@ class RelPositionMultiheadAttention(nn.Module):
)
# linear transformation for positional encoding.
self.linear_pos = ScaledLinear(embed_dim, num_heads, bias=True)
self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False)
# these two learnable bias are used in matrix c and matrix d
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach())
self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach())
self._reset_parameters()
def _pos_bias_u(self):
return self.pos_bias_u * self.pos_bias_u_scale.exp()
def _pos_bias_v(self):
return self.pos_bias_v * self.pos_bias_v_scale.exp()
def _reset_parameters(self) -> None:
nn.init.normal_(self.pos_bias_u, std=0.01)
nn.init.normal_(self.pos_bias_v, std=0.01)
def forward(
self,
@ -452,6 +923,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -465,6 +937,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
- Inputs:
@ -510,14 +985,18 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
left_context=left_context,
)
def rel_shift(self, x: Tensor) -> Tensor:
def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor:
"""Compute relative positional encoding.
Args:
x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns:
Tensor: tensor of shape (batch, head, time1, time2)
@ -525,14 +1004,19 @@ class RelPositionMultiheadAttention(nn.Module):
the key, while time1 is for the query).
"""
(batch_size, num_heads, time1, n) = x.shape
assert n == 2 * time1 - 1
time2 = time1 + left_context
assert (
n == left_context + 2 * time1 - 1
), f"{n} == {left_context} + 2 * {time1} - 1"
# Note: TorchScript requires explicit arg for stride()
batch_stride = x.stride(0)
head_stride = x.stride(1)
time1_stride = x.stride(2)
n_stride = x.stride(3)
return x.as_strided(
(batch_size, num_heads, time1, time1),
(batch_size, num_heads, time1, time2),
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
storage_offset=n_stride * (time1 - 1),
)
@ -554,6 +1038,7 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
left_context: int = 0,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
@ -571,6 +1056,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context (in frames) used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
Inputs:
@ -729,23 +1217,35 @@ class RelPositionMultiheadAttention(nn.Module):
key_padding_mask.size(1), src_len
)
q = q.permute(1, 2, 0, 3) # (batch, head, time1, d_k)
q = q.transpose(0, 1) # (batch, time1, head, d_k)
pos_emb_bsz = pos_emb.size(0)
assert pos_emb_bsz in (1, bsz) # actually it is 1
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
# (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1)
p = p.permute(0, 2, 3, 1)
q_with_bias_u = (q + self._pos_bias_u()).transpose(
1, 2
) # (batch, head, time1, d_k)
q_with_bias_v = (q + self._pos_bias_v()).transpose(
1, 2
) # (batch, head, time1, d_k)
# compute attention score
# first compute matrix a and matrix c
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
matrix_ac = torch.matmul(q, k) # (batch, head, time1, time2)
matrix_ac = torch.matmul(
q_with_bias_u, k
) # (batch, head, time1, time2)
# compute matrix b and matrix d
pos_emb = self.linear_pos(pos_emb) # (1, 2*time1-1, head)
matrix_bd = (
pos_emb.transpose(1, 2).unsqueeze(2).repeat(1, 1, tgt_len, 1)
) # (1, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd) # (1, head, time1, time2)
matrix_bd = torch.matmul(
q_with_bias_v, p
) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd, left_context)
attn_output_weights = (
matrix_ac + matrix_bd
@ -780,6 +1280,39 @@ class RelPositionMultiheadAttention(nn.Module):
)
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
# If we are using dynamic_chunk_training and setting a limited
# num_left_chunks, the attention may only see the padding values which
# will also be masked out by `key_padding_mask`, at this circumstances,
# the whole column of `attn_output_weights` will be `-inf`
# (i.e. be `nan` after softmax), so, we fill `0.0` at the masking
# positions to avoid invalid loss value below.
if (
attn_mask is not None
and attn_mask.dtype == torch.bool
and key_padding_mask is not None
):
if attn_mask.size(0) != 1:
attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len)
combined_mask = attn_mask | key_padding_mask.unsqueeze(
1
).unsqueeze(2)
else:
# attn_mask.shape == (1, tgt_len, src_len)
combined_mask = attn_mask.unsqueeze(
0
) | key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_output_weights = attn_output_weights.view(
bsz, num_heads, tgt_len, src_len
)
attn_output_weights = attn_output_weights.masked_fill(
combined_mask, 0.0
)
attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, src_len
)
attn_output_weights = nn.functional.dropout(
attn_output_weights, p=dropout_p, training=training
)
@ -813,16 +1346,21 @@ class ConvolutionModule(nn.Module):
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True).
causal (bool): Whether to use causal convolution.
"""
def __init__(
self, channels: int, kernel_size: int, bias: bool = True
self,
channels: int,
kernel_size: int,
bias: bool = True,
causal: bool = False,
) -> None:
"""Construct an ConvolutionModule object."""
super(ConvolutionModule, self).__init__()
# kernerl_size should be a odd number for 'SAME' padding
assert (kernel_size - 1) % 2 == 0
self.causal = causal
self.pointwise_conv1 = ScaledConv1d(
channels,
@ -850,12 +1388,17 @@ class ConvolutionModule(nn.Module):
channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0
)
self.lorder = kernel_size - 1
padding = (kernel_size - 1) // 2
if self.causal:
padding = 0
self.depthwise_conv = ScaledConv1d(
channels,
channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
padding=padding,
groups=channels,
bias=bias,
)
@ -876,14 +1419,28 @@ class ConvolutionModule(nn.Module):
initial_scale=0.25,
)
def forward(self, x: Tensor) -> Tensor:
def forward(
self,
x: Tensor,
cache: Optional[Tensor] = None,
right_context: int = 0,
) -> Tuple[Tensor, Tensor]:
"""Compute convolution module.
Args:
x: Input tensor (#time, batch, channels).
cache: The cache of depthwise_conv, only used in real streaming
decoding.
right_context:
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
Returns:
Tensor: Output tensor (#time, batch, channels).
If cache is None return the output tensor (#time, batch, channels).
If cache is not None, return a tuple of Tensor, the first one is
the output tensor (#time, batch, channels), the second one is the
new cache for next chunk (#kernel_size - 1, batch, channels).
"""
# exchange the temporal dimension and the feature dimension
@ -896,6 +1453,26 @@ class ConvolutionModule(nn.Module):
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
# 1D Depthwise Conv
if self.causal and self.lorder > 0:
if cache is None:
# Make depthwise_conv causal by
# manualy padding self.lorder zeros to the left
x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
else:
assert (
not self.training
), "Cache should be None in training time"
assert cache.size(0) == self.lorder
x = torch.cat([cache.permute(1, 2, 0), x], dim=2)
if right_context > 0:
cache = x.permute(2, 0, 1)[
-(self.lorder + right_context) : ( # noqa
-right_context
),
...,
]
else:
cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa
x = self.depthwise_conv(x)
x = self.deriv_balancer2(x)
@ -903,7 +1480,11 @@ class ConvolutionModule(nn.Module):
x = self.pointwise_conv2(x) # (batch, channel, time)
return x.permute(2, 0, 1)
# torch.jit.script requires return types be the same as annotated above
if cache is None:
cache = torch.empty(0)
return x.permute(2, 0, 1), cache
class Conv2dSubsampling(nn.Module):

View File

@ -44,21 +44,74 @@ Usage:
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
(4) fast beam search (one best)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless3/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless4/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) decode in streaming mode (take greedy search as an example)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method greedy_search
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
import argparse
import logging
import math
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -70,12 +123,15 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -83,6 +139,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
@ -91,6 +148,8 @@ from icefall.utils import (
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
@ -150,6 +209,13 @@ def get_parser():
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
@ -159,6 +225,11 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
)
@ -174,27 +245,42 @@ def get_parser():
parser.add_argument(
"--beam",
type=float,
default=4,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
Used only when --decoding-method is fast_beam_search,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
default=64,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
@ -212,6 +298,48 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
@ -220,6 +348,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -243,9 +372,12 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -260,9 +392,26 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
feature_lens += params.left_context
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -277,6 +426,49 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
@ -324,14 +516,17 @@ def decode_one_batch(
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -341,6 +536,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -354,9 +550,12 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -374,7 +573,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 10
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -385,6 +584,7 @@ def decode_dataset(
model=model,
sp=sp,
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
)
@ -466,6 +666,9 @@ def main():
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
@ -475,10 +678,19 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
@ -507,6 +719,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")
@ -592,10 +809,24 @@ def main():
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
word_table = lexicon.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -617,6 +848,7 @@ def main():
params=params,
model=model,
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
)

View File

@ -41,8 +41,20 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \
--max-duration 550
"""
# train a streaming model
./pruned_transducer_stateless4/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
"""
import argparse
import copy
@ -88,6 +100,42 @@ LRSchedulerType = Union[
]
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -281,6 +329,8 @@ def get_parser():
help="Whether to use half precision training.",
)
add_model_arguments(parser)
return parser
@ -367,6 +417,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
)
return encoder
@ -603,6 +657,15 @@ def compute_loss(
(feature_lens // params.subsampling_factor).sum().item()
)
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
info["utterances"] = feature.size(0)
# averaged input duration in frames over utterances
info["utt_duration"] = feature_lens.sum().item()
# averaged padding proportion over utterances
info["utt_pad_proportion"] = (
((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
@ -847,6 +910,11 @@ def run(rank, world_size, args):
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
logging.info(params)
logging.info("About to create model")
@ -932,6 +1000,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -992,6 +1061,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1002,9 +1072,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -1012,7 +1079,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()