mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Fix bugs
This commit is contained in:
parent
c6c3750cab
commit
a75f75bbad
@ -16,19 +16,23 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import random
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List
|
||||
import torch_flow_sampling
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from transformer import Supervisions, TransformerEncoderLayer, TransformerDecoderLayer, encoder_padding_mask, \
|
||||
LabelSmoothingLoss, PositionalEncoding, pad_sequence, add_sos, add_eos, decoder_padding_mask
|
||||
LabelSmoothingLoss, PositionalEncoding, pad_sequence, add_sos, add_eos, decoder_padding_mask, \
|
||||
generate_square_subsequent_mask
|
||||
|
||||
|
||||
class ConformerTrunk(nn.Module):
|
||||
def __init__(num_features: int,
|
||||
def __init__(self,
|
||||
num_features: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
@ -37,6 +41,7 @@ class ConformerTrunk(nn.Module):
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
use_feat_batchnorm: bool = True) -> None:
|
||||
super(ConformerTrunk, self).__init__()
|
||||
if use_feat_batchnorm:
|
||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
||||
|
||||
@ -62,12 +67,11 @@ class ConformerTrunk(nn.Module):
|
||||
cnn_module_kernel,
|
||||
)
|
||||
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_layers)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -78,12 +82,10 @@ class ConformerTrunk(nn.Module):
|
||||
(CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling)
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- Encoder output with shape [T, N, C]. It can be used as key and
|
||||
value for the decoder.
|
||||
- Encoder output padding mask. It can be used as
|
||||
memory_key_padding_mask for the decoder. Its shape is [N, T].
|
||||
Return (output, pos_emb, mask), where:
|
||||
output: The output embedding, of shape (T, N, C).
|
||||
pos_emb: The positional embedding (this will be used by ctc_encoder forward).
|
||||
mask: The output padding mask, a Tensor of bool, of shape [N, T].
|
||||
It is None if `supervision` is None.
|
||||
"""
|
||||
if hasattr(self, 'feat_batchnorm'):
|
||||
@ -92,13 +94,15 @@ class ConformerTrunk(nn.Module):
|
||||
x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C]
|
||||
|
||||
x = self.feat_embed(x)
|
||||
x = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||
mask = mask.to(x.device) if mask is not None else None
|
||||
x = self.encoder(x, key_padding_mask=mask) # (T, N, C)
|
||||
|
||||
return x, mask
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
mask = encoder_padding_mask(x.size(0), supervision)
|
||||
|
||||
mask = mask.to(x.device) if mask is not None else None
|
||||
x = self.encoder(x, pos_emb=pos_emb, key_padding_mask=mask) # (T, N, C)
|
||||
|
||||
return x, pos_emb, mask
|
||||
|
||||
|
||||
class BidirectionalConformer(nn.Module):
|
||||
@ -178,25 +182,26 @@ class BidirectionalConformer(nn.Module):
|
||||
in the discrete bottleneck
|
||||
"""
|
||||
def __init__(
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_trunk_encoder_layers: int = 12,
|
||||
num_ctc_encoder_layers: int = 4,
|
||||
num_decoder_layers: int = 6,
|
||||
num_reverse_encoder_layers: int = 4,
|
||||
num_reverse_decoder_layers: int = 4,
|
||||
num_self_predictor_layers: int = 3,
|
||||
bypass_bottleneck: bool = True,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
is_bpe: bool = False,
|
||||
use_feat_batchnorm: bool = True,
|
||||
discrete_bottleneck_tot_classes: int = 512,
|
||||
discrete_bottleneck_num_groups: int = 4
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_trunk_encoder_layers: int = 12,
|
||||
num_ctc_encoder_layers: int = 4,
|
||||
num_decoder_layers: int = 6,
|
||||
num_reverse_encoder_layers: int = 4,
|
||||
num_reverse_decoder_layers: int = 4,
|
||||
num_self_predictor_layers: int = 3,
|
||||
bypass_bottleneck: bool = True,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
is_bpe: bool = False,
|
||||
use_feat_batchnorm: bool = True,
|
||||
discrete_bottleneck_tot_classes: int = 512,
|
||||
discrete_bottleneck_num_groups: int = 4
|
||||
) -> None:
|
||||
super(BidirectionalConformer, self).__init__()
|
||||
|
||||
@ -236,7 +241,7 @@ class BidirectionalConformer(nn.Module):
|
||||
self.token_embed_scale = d_model ** 0.5
|
||||
self.token_embed = nn.Embedding(
|
||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model,
|
||||
_weight=torch.randn(self.num_classes, d_model) * (1 / self.embed_scale)
|
||||
_weight=torch.randn(self.decoder_num_class, d_model) * (1 / self.token_embed_scale)
|
||||
)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
@ -315,8 +320,7 @@ class BidirectionalConformer(nn.Module):
|
||||
if num_self_predictor_layers > 0:
|
||||
encoder_layer = SimpleCausalEncoderLayer(d_model,
|
||||
dropout=dropout)
|
||||
self.self_predictor_encoder = simple_causal_encoder(encoder_layer,
|
||||
num_self_predictor_layers)
|
||||
self.self_predictor_encoder = encoder_layer
|
||||
|
||||
|
||||
self.discrete_bottleneck = DiscreteBottleneck(
|
||||
@ -326,8 +330,8 @@ class BidirectionalConformer(nn.Module):
|
||||
|
||||
|
||||
|
||||
def forward(self, x: Tensor, supervision: Optional[Supervisions],
|
||||
need_softmax: bool = True) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
|
||||
def forward(self, x: Tensor, supervision: Optional[Supervisions] = None,
|
||||
need_softmax: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Forward function that "encodes" the features.
|
||||
|
||||
@ -343,26 +347,29 @@ class BidirectionalConformer(nn.Module):
|
||||
If true, the last output ("softmax") will be computed. This can be useful
|
||||
in the reverse model, but only necessary if straight_through_scale != 1.0.
|
||||
|
||||
Returns: (memory, bn_memory, sampled, softmax, key_padding_mask), where:
|
||||
Returns: (memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask), where:
|
||||
|
||||
`memory` is a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T
|
||||
is actually a subsampled form of the num_frames of the input `x`.
|
||||
If self.bypass_bottleneck, it will be taken before the discrete
|
||||
bottleneck; otherwise, from after.
|
||||
`bn_memory` is the same shape as `memory`, but comes after the discrete bottleneck
|
||||
regardless of the value of self.bypass_bottleneck.
|
||||
`sampled` is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
|
||||
as given to the constructor. This will be needed for the 'reverse' model.
|
||||
`softmax` is a "soft" version of `sampled`. Will only be returned if need_softmax == True;
|
||||
memory: a Tensor of shape [T, N, E] i.e. [T, batch_size, embedding_dim] where T
|
||||
is actually a subsampled form of the num_frames of the input `x`.
|
||||
If self.bypass_bottleneck, it will be taken before the discrete
|
||||
bottleneck; otherwise, from after.
|
||||
bn_memory: The same shape as `memory`, but comes after the discrete bottleneck
|
||||
regardless of the value of self.bypass_bottleneck.
|
||||
pos_emb: The relative positional embedding; will be given to ctc_encoder_forward()
|
||||
sampled: a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
|
||||
as given to the constructor. This will be needed for the 'reverse' model.
|
||||
softmax: a "soft" version of `sampled`. Will only be returned if need_softmax == True;
|
||||
else will be None.
|
||||
key_padding_mask: The padding mask for the "memory" output, a Tensor of bool of
|
||||
shape [N, T] (only if supervision was supplied, else None).
|
||||
"""
|
||||
encoder_output, memory_key_padding_mask = self.trunk(x, supervision)
|
||||
encoder_output, pos_emb, memory_key_padding_mask = self.trunk(x, supervision)
|
||||
|
||||
bn_memory, sampled, softmax = self.discrete_bottleneck(encoder_output)
|
||||
|
||||
memory = encoder_output if self.bypass_bottleneck else bn_memory
|
||||
|
||||
return (memory, bn_memory, sampled, softmax, memory_key_padding_mask)
|
||||
return (memory, bn_memory, pos_emb, sampled, softmax, memory_key_padding_mask)
|
||||
|
||||
def decoder_forward(
|
||||
self,
|
||||
@ -373,11 +380,13 @@ class BidirectionalConformer(nn.Module):
|
||||
eos_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the decoder loss function (given a particular list of hypotheses).
|
||||
|
||||
Args:
|
||||
memory:
|
||||
It's the first output of forward(), with shape [T, N, E]
|
||||
The first output of forward(), with shape [T, N, E]
|
||||
memory_key_padding_mask:
|
||||
The padding mask from forward()
|
||||
The padding mask from forward(), a tensor of bool with shape [N, T]
|
||||
token_ids:
|
||||
A list-of-list IDs. Each sublist contains IDs for an utterance.
|
||||
The IDs can be either phone IDs or word piece IDs.
|
||||
@ -390,6 +399,7 @@ class BidirectionalConformer(nn.Module):
|
||||
A scalar, the **sum** of label smoothing loss over utterances
|
||||
in the batch without any normalization.
|
||||
"""
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
@ -426,6 +436,7 @@ class BidirectionalConformer(nn.Module):
|
||||
tgt_mask=tgt_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
) # (T, N, C)
|
||||
|
||||
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
|
||||
|
||||
@ -436,6 +447,7 @@ class BidirectionalConformer(nn.Module):
|
||||
def ctc_encoder_forward(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@ -444,17 +456,19 @@ class BidirectionalConformer(nn.Module):
|
||||
|
||||
Args:
|
||||
memory:
|
||||
It's the output of forward(), with shape [T, N, E]
|
||||
It's the output of forward(), with shape (T, N, E)
|
||||
pos_emb:
|
||||
Relative positional embedding tensor: (N, 2*T-1, E)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from forward()
|
||||
The padding mask from forward(), a tensor of bool of shape (N, T)
|
||||
|
||||
Returns:
|
||||
A Tensor with shape [N, T, C] where C is the number of classes
|
||||
(e.g. number of phones or word pieces). Contains normalized
|
||||
log-probabilities.
|
||||
"""
|
||||
|
||||
x = self.ctc_encoder(memory,
|
||||
pos_emb,
|
||||
key_padding_mask=memory_key_padding_mask)
|
||||
x = self.ctc_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
@ -470,10 +484,10 @@ class BidirectionalConformer(nn.Module):
|
||||
softmax: Optional[torch.Tensor],
|
||||
reverse_gradient: bool = True) -> Tensor:
|
||||
"""
|
||||
Returns the total log-prob of the the
|
||||
labels sampled in the discrete bottleneck layer, as predicted using a relatively
|
||||
simple model from previous frames sampled from the bottleneck layer.
|
||||
[Appears on the denominator of an expressin for mutual information].
|
||||
Returns the total log-prob of the the labels sampled in the discrete
|
||||
bottleneck layer, as predicted using a relatively simple model that
|
||||
predicts from previous frames sampled from the bottleneck layer.
|
||||
[Appears on the denominator of an expression for mutual information].
|
||||
|
||||
Args:
|
||||
memory_shifted:
|
||||
@ -482,21 +496,25 @@ class BidirectionalConformer(nn.Module):
|
||||
(T, N, E) = memory.shape
|
||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
sampled: is a Tensor of shape [T, N, C] where C corresponds to `discrete_bottleneck_tot_classes`
|
||||
as given to the constructor. This will be needed for the 'reverse' model.
|
||||
The padding mask from the encoder, of shape [N, T], boolean, True
|
||||
for masked locations.
|
||||
sampled: sampled and interpolated one-hot values, as a Tensor of shape [T, N, C]
|
||||
where C corresponds to `discrete_bottleneck_tot_classes`
|
||||
as given to the constructor. This will be needed for the 'reverse'
|
||||
model.
|
||||
softmax: is a "soft" version of `sampled`; if None, will default to `sampled`.
|
||||
reverse_gradient: will likely be true. If true, the gradient is reversed twice
|
||||
in this computation, so that we train predictors with the correct
|
||||
gradient, i.e. to predict, not anti-predict (since the return value
|
||||
of this function will appear with positive, not negative, sign in the
|
||||
loss function, so will be minimized).
|
||||
The gradient w.r.t. the non-self inputs to this function, though (i.e.
|
||||
memory_shifted, sampled, softmax) will not be reversed, though.
|
||||
Returns:
|
||||
A scalar tensor, the **sum** of label smoothing loss over utterances
|
||||
in the batch without any normalization.
|
||||
"""
|
||||
|
||||
|
||||
if reverse_gradient:
|
||||
# Reversing gradient for memory_shifted puts the gradient back into
|
||||
# the correct sign; we reversed it in
|
||||
@ -506,6 +524,8 @@ class BidirectionalConformer(nn.Module):
|
||||
# what happens to the gradients).
|
||||
memory_shifted = ReverseGrad.apply(memory_shifted)
|
||||
|
||||
# no mask is needed for self_predictor_encoder; its CNN
|
||||
# layer uses left-padding only, making it causal.
|
||||
predictor = self.self_predictor_encoder(memory_shifted)
|
||||
|
||||
prob = self.discrete_bottleneck.compute_prob(predictor,
|
||||
@ -566,15 +586,21 @@ class BidirectionalConformer(nn.Module):
|
||||
tokens_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=padding_id)
|
||||
|
||||
|
||||
# Let S be the length of the longest sentence (padded)
|
||||
token_embedding = self.token_embed(tokens_padded) * self.token_embed_scale # (N, S) -> (N, S, C)
|
||||
# add absolute position-encoding information
|
||||
token_embedding = self.abs_pos(token_embedding)
|
||||
|
||||
token_embedding = token_embedding.permute(1, 0, 2) # (N, S, C) -> (S, N, C)
|
||||
|
||||
token_memory = self.reverse_encoder(token_embedding,
|
||||
src_key_padding_mask=tokens_key_padding_mask)
|
||||
# token_memory is of shape (S, N, C), if S is length of token sequence.
|
||||
|
||||
T = memory.shape[0]
|
||||
# the targets, here, are the hidden discrete symbols we are predicting
|
||||
tgt_mask = generate_square_subsequent_mask(T, device=memory.device)
|
||||
|
||||
token_embedding = self.token_embed(tokens_padded) * self.token_embed_scale
|
||||
token_memory = self.reverse_encoder(token_embedding,
|
||||
src_key_padding_mask=tokens_key_padding_mask)
|
||||
# tokens_encoded is of shape (S, N, C), if S is length of token sequence.
|
||||
|
||||
hidden_predictor = self.reverse_decoder(
|
||||
tgt=memory_shifted,
|
||||
memory=token_memory,
|
||||
@ -584,25 +610,13 @@ class BidirectionalConformer(nn.Module):
|
||||
|
||||
total_prob = self.discrete_bottleneck.compute_prob(
|
||||
hidden_predictor,
|
||||
TODO, # HERE
|
||||
)
|
||||
sampled,
|
||||
softmax,
|
||||
memory_key_padding_mask)
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
pred_pad = self.decoder(
|
||||
tgt=tgt,
|
||||
memory=memory,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
) # (T, N, C)
|
||||
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
|
||||
# TODO: consider using a label-smoothed loss.
|
||||
return total_prob
|
||||
|
||||
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
|
||||
|
||||
return decoder_loss
|
||||
|
||||
class SimpleCausalEncoderLayer(nn.Module):
|
||||
"""
|
||||
@ -650,129 +664,6 @@ class SimpleCausalEncoderLayer(nn.Module):
|
||||
src = self.norm_final(src)
|
||||
return src
|
||||
|
||||
# for search: SimpleCausalEncoder
|
||||
def simple_causal_encoder(encoder_layer: nn.Module,
|
||||
num_layers: int):
|
||||
return torch.nn.Sequential([copy.deepcopy(encoder_leyer) for _ in range(num_layers)])
|
||||
|
||||
|
||||
|
||||
class DiscreteBottleneckConformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
num_classes (int): Number of output classes
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): attention dimension
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
num_decoder_layers (int): number of decoder layers
|
||||
dropout (float): dropout rate
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
discrete_bottleneck_pos (int): position in the encoder at which to place
|
||||
the discrete bottleneck (this many encoder layers will
|
||||
precede it)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
num_decoder_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
is_espnet_structure: bool = False,
|
||||
mmi_loss: bool = True,
|
||||
use_feat_batchnorm: bool = False,
|
||||
discrete_bottleneck_pos: int = 8,
|
||||
discrete_bottleneck_tot_classes: int = 512,
|
||||
discrete_bottleneck_num_groups: int = 2
|
||||
) -> None:
|
||||
super(DiscreteBottleneckConformer, self).__init__(
|
||||
num_features=num_features,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=subsampling_factor,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
mmi_loss=mmi_loss,
|
||||
use_feat_batchnorm=use_feat_batchnorm,
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
normalize_before,
|
||||
is_espnet_structure,
|
||||
)
|
||||
|
||||
discrete_bottleneck = DiscreteBottleneck(dim=d_model,
|
||||
tot_classes=discrete_bottleneck_tot_classes,
|
||||
num_groups=discrete_bottleneck_num_groups)
|
||||
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||
discrete_bottleneck=discrete_bottleneck,
|
||||
discrete_bottleneck_pos=discrete_bottleneck_pos)
|
||||
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
self.is_espnet_structure = is_espnet_structure
|
||||
if self.normalize_before and self.is_espnet_structure:
|
||||
self.after_norm = nn.LayerNorm(d_model)
|
||||
else:
|
||||
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||
# and throws an error without this change.
|
||||
self.after_norm = identity
|
||||
|
||||
def run_encoder(
|
||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is [N, T, C].
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
to compute encoder padding mask, which is used as memory key padding
|
||||
mask for the decoder.
|
||||
|
||||
Returns:
|
||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||
"""
|
||||
x = self.feat_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||
if mask is not None:
|
||||
mask = mask.to(x.device)
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
class ReverseGrad(torch.autograd.Function):
|
||||
def apply(ctx, x):
|
||||
@ -878,9 +769,9 @@ class DiscreteBottleneck(nn.Module):
|
||||
# [ True, True, True, True]])
|
||||
self.register_buffer('pred_cross_mask',
|
||||
((torch.arange(d) // c).unsqueeze(1) >= (torch.arange(d) // c).unsqueeze(0)))
|
||||
self.reset_parameters()
|
||||
self._reset_parameters()
|
||||
|
||||
def reset_parameters_(self):
|
||||
def _reset_parameters(self):
|
||||
if hasattr(self, 'pred_cross'):
|
||||
torch.nn.init.kaiming_uniform_(self.pred_cross, a=math.sqrt(5))
|
||||
|
||||
@ -967,7 +858,8 @@ class DiscreteBottleneck(nn.Module):
|
||||
This will be useful in computing a loss function that has
|
||||
a likelihood term with negative sign (i.e. the self-prediction).
|
||||
We'll later need negate the gradient one more more time
|
||||
where we give the input to whatever module generated 'x'.
|
||||
where we give the input to the prediction module that
|
||||
generated 'x'.
|
||||
|
||||
Returns a scalar Tensor represnting the total probability.
|
||||
"""
|
||||
@ -980,6 +872,8 @@ class DiscreteBottleneck(nn.Module):
|
||||
|
||||
logprobs = self.pred_linear(x)
|
||||
|
||||
# Add "cross-terms" to logprobs; this is a regression that uses earlier
|
||||
# groups to predict later groups
|
||||
if self.num_groups > 1:
|
||||
pred_cross = self.pred_cross * self.pred_cross_mask
|
||||
t = self.tot_classes
|
||||
@ -999,11 +893,12 @@ class DiscreteBottleneck(nn.Module):
|
||||
logprobs = logprobs.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||
# Normalize the log-probs (so they sum to one)
|
||||
logprobs = torch.nn.functional.logsoftmax(logprobs, dim=-1)
|
||||
logprobs = logprobs.reshape(S, N, C)
|
||||
|
||||
if padding_mask is not None:
|
||||
assert padding_mask.dtype == torch.bool and padding_mask.shape == (N, S)
|
||||
padding_mask = torch.logical_not(padding_mask).transpose(0, 1).unsqueeze(-1)
|
||||
# padding_mask.shape == (S, N, E)
|
||||
assert padding_mask.shape == (S, N, 1)
|
||||
tot_prob = (logprobs * softmax * padding_mask).sum()
|
||||
else:
|
||||
tot_prob = (logprobs * softmax).sum()
|
||||
@ -1080,8 +975,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
@ -1089,17 +984,18 @@ class ConformerEncoderLayer(nn.Module):
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
attn_mask: the mask for the src sequence (optional).
|
||||
key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
attn_mask: (S, S). This probably won't be used, in fact should not
|
||||
be (e.g. could in principle ensure causal behavior, but
|
||||
actually the conformer does not support this).
|
||||
key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
|
||||
# macaron style feed forward module
|
||||
residual = src
|
||||
src = self.norm_ff_macaron(src)
|
||||
@ -1115,8 +1011,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
src,
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
key_padding_mask=key_padding_mask,
|
||||
)[0]
|
||||
src = residual + self.dropout(src_att)
|
||||
|
||||
@ -1141,7 +1037,6 @@ class ConformerEncoder(nn.Module):
|
||||
Args:
|
||||
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
@ -1153,7 +1048,7 @@ class ConformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
||||
super(ConformerEncoder, self).__init__()
|
||||
self.layers = torch.nn.ModuleList([copy.deepcopy(encoder_leyer) for _ in range(num_layers)])
|
||||
self.layers = torch.nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1183,7 +1078,6 @@ class ConformerEncoder(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
@ -1313,7 +1207,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
@ -1663,6 +1556,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
@ -1701,7 +1595,7 @@ class CausalConvolutionModule(nn.Module):
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
super(CausalConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
self.kernel_size = kernel_size
|
||||
@ -1752,11 +1646,12 @@ class CausalConvolutionModule(nn.Module):
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
(B, C, T) = x
|
||||
(B, C, T) = x.shape
|
||||
padding = self.kernel_size - 1
|
||||
x = torch.cat((torch.zeros(B, C, padding, device=x.device, dtype=x.type), x),
|
||||
dim=2)
|
||||
x = self.depthwise_conv(x) # <-- This has no padding.
|
||||
x = self.depthwise_conv(x) # <-- This convolution module does no padding,
|
||||
# so we padded manually, on the left only.
|
||||
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
@ -1835,7 +1730,7 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
return x.permute(2, 0, 1) # (time, batch channel)
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
@ -1851,16 +1746,47 @@ def identity(x):
|
||||
|
||||
|
||||
|
||||
def test_discrete_bottleneck_conformer():
|
||||
def _gen_rand_tokens(N: int) -> List[List[int]]:
|
||||
ans = []
|
||||
for _ in range(N):
|
||||
S = random.randint(1, 20)
|
||||
ans.append([random.randint(3, 30) for _ in range(S)])
|
||||
return ans
|
||||
|
||||
def _gen_supervision(tokens: List[List[int]]):
|
||||
ans = dict()
|
||||
N = len(tokens)
|
||||
ans['sequence_idx'] = torch.arange(N, dtype=torch.int32)
|
||||
ans['start_frame'] = torch.zeros(N, dtype=torch.int32)
|
||||
ans['num_frames'] = torch.tensor([ random.randint(20, 35) for _ in tokens])
|
||||
return ans
|
||||
|
||||
def _test_bidirectional_conformer():
|
||||
num_features = 40
|
||||
num_classes = 1000
|
||||
m = DiscreteBottleneckConformer(num_features, num_classes)
|
||||
m = BidirectionalConformer(num_features, num_classes)
|
||||
T = 35
|
||||
N = 10
|
||||
C = num_features
|
||||
feats = torch.randn(N, T, C)
|
||||
ctc_output, _, _ = m(feats)
|
||||
# [N, T, C].
|
||||
|
||||
tokens = _gen_rand_tokens(N)
|
||||
supervision = _gen_supervision(tokens)
|
||||
print("tokens = ", tokens)
|
||||
print("supervision = ", supervision)
|
||||
# memory: [T, N, C]
|
||||
(memory, bn_memory, pos_emb, sampled, softmax, key_padding_mask) = m(feats, supervision)
|
||||
|
||||
# ctc_output: [N, T, C].
|
||||
ctc_output = m.ctc_encoder_forward(memory, pos_emb, key_padding_mask)
|
||||
|
||||
decoder_loss = m.decoder_forward(memory, key_padding_mask, tokens,
|
||||
sos_id=1,
|
||||
eos_id=2)
|
||||
|
||||
(T, N, E) = memory.shape
|
||||
memory_shifted = torch.cat((torch.zeros(1, N, E), memory[:-1,:,:]), dim=0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_discrete_bottleneck_conformer()
|
||||
_test_bidirectional_conformer()
|
||||
|
@ -914,6 +914,9 @@ def encoder_padding_mask(
|
||||
).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
# Assert that in each row, i.e. each utterance, at least one frame is not
|
||||
# masked. Otherwise it may lead to nan's appearing in the attention computation.
|
||||
assert torch.all(torch.sum(torch.logical_not(mask), dim=1) != 0)
|
||||
return mask
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user