mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce a scale dependent on the masking value
This commit is contained in:
parent
1be455438a
commit
93dff29243
@ -18,17 +18,18 @@
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import logging
|
||||
import torch
|
||||
import random
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
from s import (
|
||||
ActivationBalancer,
|
||||
BasicNorm,
|
||||
DoubleSwish,
|
||||
ScaledConv1d,
|
||||
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
|
||||
LearnedScale,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
|
||||
@ -171,6 +172,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=dropout,
|
||||
)
|
||||
self.self_attn_scale = LearnedScale()
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
@ -181,6 +183,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
initial_scale=0.1),
|
||||
)
|
||||
self.feed_forward_scale = LearnedScale()
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
@ -191,11 +194,14 @@ class ConformerEncoderLayer(nn.Module):
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
initial_scale=0.1),
|
||||
)
|
||||
self.feed_forward_macaron_scale = LearnedScale()
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model,
|
||||
cnn_module_kernel)
|
||||
self.conv_scale = LearnedScale()
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
self.final_scale = LearnedScale()
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
self.balancer = ActivationBalancer(
|
||||
@ -209,11 +215,11 @@ class ConformerEncoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
feature_mask: Union[Tensor, float],
|
||||
pos_emb: Tensor,
|
||||
attn_scores_in: Optional[Tensor] = None,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
feature_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
"""
|
||||
@ -233,10 +239,10 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
feature_mask: float, or (S, N, 1)
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
feature_mask: (S, N, E)
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
src_orig = src
|
||||
@ -254,7 +260,8 @@ class ConformerEncoderLayer(nn.Module):
|
||||
alpha = 1.0
|
||||
|
||||
# macaron style feed forward module
|
||||
src = src + self.feed_forward_macaron(src)
|
||||
src = src + self.feed_forward_macaron_scale(self.feed_forward_macaron(src),
|
||||
feature_mask)
|
||||
|
||||
# multi-headed self-attention module
|
||||
src_att, _, attn_scores_out = self.self_attn(
|
||||
@ -264,25 +271,24 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
src = src + src_att
|
||||
src = src + self.self_attn_scale(src_att, feature_mask)
|
||||
|
||||
# convolution module
|
||||
src = src + self.conv_module(src, src_key_padding_mask=src_key_padding_mask)
|
||||
|
||||
|
||||
src = src + self.conv_scale(self.conv_module(src, src_key_padding_mask=src_key_padding_mask),
|
||||
feature_mask)
|
||||
|
||||
|
||||
# feed forward module
|
||||
src = src + self.feed_forward(src)
|
||||
src = src + self.feed_forward_scale(self.feed_forward(src),
|
||||
feature_mask)
|
||||
|
||||
src = self.final_scale(src, feature_mask)
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
src = alpha * src + (1 - alpha) * src_orig
|
||||
|
||||
if feature_mask is not None:
|
||||
src = src * feature_mask
|
||||
|
||||
return src, attn_scores_out
|
||||
|
||||
|
||||
@ -359,23 +365,28 @@ class ConformerEncoder(nn.Module):
|
||||
feature_mask_dropout_prob = 0.15
|
||||
feature_unmasked_dim = 256 # hardcode dim for now, 1st 256 are non-masked.
|
||||
|
||||
feature_mask = torch.ones_like(src) # S, N, E
|
||||
# is_masked_frame is 0 with probability `feature_mask_dropout_prob`
|
||||
is_masked_frame = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
|
||||
feature_mask[..., feature_unmasked_dim:] *= is_masked_frame
|
||||
full_feature_mask = torch.ones_like(src) # S, N, E
|
||||
# feature_mask is 0 with probability `feature_mask_dropout_prob`
|
||||
# feature_mask shape: (S, N, 1)
|
||||
feature_mask = (torch.rand_like(src[...,:1]) > feature_mask_dropout_prob).to(src.dtype)
|
||||
full_feature_mask[..., feature_unmasked_dim:] *= feature_mask
|
||||
else:
|
||||
feature_mask = None
|
||||
feature_mask = 1.0
|
||||
full_feature_mask = 1.0
|
||||
|
||||
src = src * full_feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output, attn_scores = mod(
|
||||
output,
|
||||
feature_mask,
|
||||
pos_emb,
|
||||
attn_scores,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
feature_mask=feature_mask,
|
||||
warmup=warmup,
|
||||
)
|
||||
output = output * full_feature_mask
|
||||
if i in self.aux_layers:
|
||||
outputs.append(output)
|
||||
|
||||
|
||||
@ -326,6 +326,30 @@ def ScaledConv1d(*args,
|
||||
return ans
|
||||
|
||||
|
||||
class LearnedScale(torch.nn.Module):
|
||||
"""
|
||||
Module that learns a scale dependent on some kind of mask that is typically going to be 0 or 1
|
||||
in training. The scale will be 1.0 if the mask is 1.0, but may be a different (learned) value
|
||||
if the mask value is not 1.0.
|
||||
|
||||
The idea is that if we have some kind of feature mask that would always be 1.0 in
|
||||
test mode but might sometimes be 0.0 in training mode, we might want the multiply
|
||||
the remaining features by a value dependent on this mask.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(LearnedScale, self).__init__()
|
||||
self.alpha = nn.Parameter(torch.tensor(0.0))
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
mask: Tensor):
|
||||
"""
|
||||
Mask should either be a number (probably 1.0) or a tensors that broadcasts with x.
|
||||
"""
|
||||
if self.training and mask is 1.0:
|
||||
return x
|
||||
return x * (1.0 + self.alpha * (1.0 - mask))
|
||||
|
||||
|
||||
class ActivationBalancer(torch.nn.Module):
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user