from local

This commit is contained in:
dohe0342 2023-05-18 16:55:27 +09:00
parent a35d5ec8a0
commit 2a54d49f96
2 changed files with 96 additions and 0 deletions

View File

@ -133,6 +133,102 @@ class TransformerEncoderAdapter(TransformerEncoder):
return x, layer_results
class TransformerEncoderAdapter(TransformerEncoder):
def __init__(self, args: Wav2Vec2Config):
super().__init__(args)
self.adapters = ResidualAdapterModule(proj_dim=512)
for p in self.adapters.parameters():
p.data /= 10.
#p.data = nn.Parameter(torch.zeros(p.size()).to('cuda'))
#p.data = nn.Parameter(torch.randn(p.size()).to('cuda')/20.)
def forward(self, x, padding_mask=None, layer=None, tgt_layer=None):
x, layer_results = self.extract_features_with_adapter(
x,
padding_mask=padding_mask,
tgt_layer=tgt_layer
)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features_with_adapter(
self,
x,
padding_mask=None,
tgt_layer=None,
min_layer=0,
):
if padding_mask is not None:
x = index_put(x, padding_mask, 0)
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
if not self.training or (dropout_probability > self.layerdrop):
x, (z, lr) = layer(
x, self_attn_padding_mask=padding_mask, need_weights=False,
)
x = self.adapters(x, layer_id=i)
if i >= min_layer:
layer_results.append((x, z, lr))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
def undo_pad(a, b, c):
return (
a[:-pad_length],
b[:-pad_length] if b is not None else b,
c[:-pad_length],
)
layer_results = [undo_pad(*u) for u in layer_results]
return x, layer_results
class ResidualAdapterModule(nn.Module):
"""
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf