mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
a35d5ec8a0
commit
2a54d49f96
Binary file not shown.
@ -133,6 +133,102 @@ class TransformerEncoderAdapter(TransformerEncoder):
|
|||||||
return x, layer_results
|
return x, layer_results
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderAdapter(TransformerEncoder):
|
||||||
|
def __init__(self, args: Wav2Vec2Config):
|
||||||
|
super().__init__(args)
|
||||||
|
self.adapters = ResidualAdapterModule(proj_dim=512)
|
||||||
|
|
||||||
|
for p in self.adapters.parameters():
|
||||||
|
p.data /= 10.
|
||||||
|
#p.data = nn.Parameter(torch.zeros(p.size()).to('cuda'))
|
||||||
|
#p.data = nn.Parameter(torch.randn(p.size()).to('cuda')/20.)
|
||||||
|
|
||||||
|
def forward(self, x, padding_mask=None, layer=None, tgt_layer=None):
|
||||||
|
x, layer_results = self.extract_features_with_adapter(
|
||||||
|
x,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
tgt_layer=tgt_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.layer_norm_first and layer is None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
def extract_features_with_adapter(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
padding_mask=None,
|
||||||
|
tgt_layer=None,
|
||||||
|
min_layer=0,
|
||||||
|
):
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
x = index_put(x, padding_mask, 0)
|
||||||
|
|
||||||
|
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||||
|
x_conv = x_conv.transpose(1, 2)
|
||||||
|
x = x + x_conv
|
||||||
|
|
||||||
|
if not self.layer_norm_first:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
# pad to the sequence length dimension
|
||||||
|
x, pad_length = pad_to_multiple(
|
||||||
|
x, self.required_seq_len_multiple, dim=-2, value=0
|
||||||
|
)
|
||||||
|
if pad_length > 0 and padding_mask is None:
|
||||||
|
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
||||||
|
padding_mask[:, -pad_length:] = True
|
||||||
|
else:
|
||||||
|
padding_mask, _ = pad_to_multiple(
|
||||||
|
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
||||||
|
)
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# B x T x C -> T x B x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
layer_results = []
|
||||||
|
r = None
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||||||
|
if not self.training or (dropout_probability > self.layerdrop):
|
||||||
|
x, (z, lr) = layer(
|
||||||
|
x, self_attn_padding_mask=padding_mask, need_weights=False,
|
||||||
|
)
|
||||||
|
x = self.adapters(x, layer_id=i)
|
||||||
|
|
||||||
|
if i >= min_layer:
|
||||||
|
layer_results.append((x, z, lr))
|
||||||
|
|
||||||
|
if i == tgt_layer:
|
||||||
|
r = x
|
||||||
|
break
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
x = r
|
||||||
|
|
||||||
|
# T x B x C -> B x T x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
# undo paddding
|
||||||
|
if pad_length > 0:
|
||||||
|
x = x[:, :-pad_length]
|
||||||
|
|
||||||
|
def undo_pad(a, b, c):
|
||||||
|
return (
|
||||||
|
a[:-pad_length],
|
||||||
|
b[:-pad_length] if b is not None else b,
|
||||||
|
c[:-pad_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_results = [undo_pad(*u) for u in layer_results]
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
|
||||||
class ResidualAdapterModule(nn.Module):
|
class ResidualAdapterModule(nn.Module):
|
||||||
"""
|
"""
|
||||||
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf
|
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user