mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
1d60614f12
commit
3debc7ff4d
Binary file not shown.
@ -38,99 +38,6 @@ from convolution import ConvolutionModule
|
||||
logger = logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
|
||||
class TransformerEncoderLoRA(TransformerEncoder):
|
||||
def __init__(self, args: Wav2Vec2Config):
|
||||
super().__init__(args)
|
||||
self.lora = LoRAModule()
|
||||
|
||||
def forward(self, x, padding_mask=None, layer=None, tgt_layer=None):
|
||||
x, layer_results = self.extract_features_with_lora(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
tgt_layer=tgt_layer
|
||||
)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features_with_adapter(
|
||||
self,
|
||||
x,
|
||||
padding_mask=None,
|
||||
tgt_layer=None,
|
||||
min_layer=0,
|
||||
):
|
||||
|
||||
if padding_mask is not None:
|
||||
x = index_put(x, padding_mask, 0)
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# pad to the sequence length dimension
|
||||
x, pad_length = pad_to_multiple(
|
||||
x, self.required_seq_len_multiple, dim=-2, value=0
|
||||
)
|
||||
if pad_length > 0 and padding_mask is None:
|
||||
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
||||
padding_mask[:, -pad_length:] = True
|
||||
else:
|
||||
padding_mask, _ = pad_to_multiple(
|
||||
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
r = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, (z, lr) = layer(
|
||||
x, self_attn_padding_mask=padding_mask, need_weights=False,
|
||||
)
|
||||
x_diff = self.lora(x, layer_idx=i)
|
||||
|
||||
x += x_diff
|
||||
|
||||
if i >= min_layer:
|
||||
layer_results.append((x, z, lr))
|
||||
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# undo paddding
|
||||
if pad_length > 0:
|
||||
x = x[:, :-pad_length]
|
||||
|
||||
def undo_pad(a, b, c):
|
||||
return (
|
||||
a[:-pad_length],
|
||||
b[:-pad_length] if b is not None else b,
|
||||
c[:-pad_length],
|
||||
)
|
||||
|
||||
layer_results = [undo_pad(*u) for u in layer_results]
|
||||
|
||||
return x, layer_results
|
||||
|
||||
|
||||
class TransformerEncoderAdapter(TransformerEncoder):
|
||||
def __init__(self, args: Wav2Vec2Config):
|
||||
super().__init__(args)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user