From 3492d9415c2cda10a39f3047d614691d38502e56 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Mon, 11 Mar 2024 12:02:35 +0800 Subject: [PATCH] add lora version of ActivationDropouAndLinear; currently a simple version --- egs/librispeech/ASR/zipformer_lora/scaling.py | 39 +++++++++++++++++++ .../ASR/zipformer_lora/zipformer.py | 7 +++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/egs/librispeech/ASR/zipformer_lora/scaling.py b/egs/librispeech/ASR/zipformer_lora/scaling.py index 7aeb25721..fdfca7f6c 100644 --- a/egs/librispeech/ASR/zipformer_lora/scaling.py +++ b/egs/librispeech/ASR/zipformer_lora/scaling.py @@ -1740,6 +1740,45 @@ class ActivationDropoutAndLinear(torch.nn.Module): self.dropout_shared_dim, ) +class ActivationDropoutAndLinear_lora(torch.nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + bias: bool = True, + activation: str = "SwooshL", + dropout_p: FloatLike = 0.0, + dropout_shared_dim: Optional[int] = -1, + r: int=0, + lora_alpha: int=1, + lora_dropout: float=0.0, + initial_scale: float = 1.0, + ): + super().__init__() + # create a temporary module of nn.Linear that we'll steal the + # weights and bias from + self.l = ScaledLinear_lora( + in_features=in_channels, + out_features=out_channels, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + initial_scale=initial_scale, + bias=bias, + ) + + self.activation = activation + self.dropout = Dropout3(dropout_p, dropout_shared_dim) + + def forward(self, x: Tensor): + if self.activation == "SwooshL": + x = SwooshLForward(x) + elif self.activation == "SwooshR": + x = SwooshRForward(x) + else: + assert False, self.activation + return self.dropout(self.l(x)) + def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: if num_channels <= x.shape[-1]: diff --git a/egs/librispeech/ASR/zipformer_lora/zipformer.py b/egs/librispeech/ASR/zipformer_lora/zipformer.py index fcba1f1c3..259d216aa 100644 --- a/egs/librispeech/ASR/zipformer_lora/zipformer.py +++ b/egs/librispeech/ASR/zipformer_lora/zipformer.py @@ -34,6 +34,7 @@ from scaling import ( ) from scaling import ( ActivationDropoutAndLinear, + ActivationDropoutAndLinear_lora, Balancer, BiasNorm, ChunkCausalDepthwiseConv1d, @@ -2066,7 +2067,6 @@ class FeedforwardModule(nn.Module): lora_dropout: float=0.0 ): super(FeedforwardModule, self).__init__() - # self.in_proj = nn.Linear(embed_dim, feedforward_dim) self.in_proj = ScaledLinear_lora( in_features=embed_dim, out_features=feedforward_dim, @@ -2086,13 +2086,16 @@ class FeedforwardModule(nn.Module): ) # shared_dim=0 means we share the dropout mask along the time axis - self.out_proj = ActivationDropoutAndLinear( + self.out_proj = ActivationDropoutAndLinear_lora( feedforward_dim, embed_dim, activation="SwooshL", dropout_p=dropout, dropout_shared_dim=0, bias=True, + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, initial_scale=0.1, )