fix style

This commit is contained in:
marcoyang 2024-03-15 11:05:45 +08:00
parent 7ead73f746
commit 77bfecd3d8
6 changed files with 66 additions and 54 deletions

View File

@ -484,13 +484,9 @@ class LibriSpeechAsrDataModule:
@lru_cache() @lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet: def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts") logging.info("About to get Gigaspeech dev cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
self.args.manifest_dir / "cuts_DEV.jsonl.gz"
)
@lru_cache() @lru_cache()
def gigaspeech_test_cuts(self) -> CutSet: def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts") logging.info("About to get Gigaspeech test cuts")
return load_manifest_lazy( return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
self.args.manifest_dir / "cuts_TEST.jsonl.gz"
)

View File

@ -121,7 +121,7 @@ from beam_search import (
modified_beam_search_lm_shallow_fusion, modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from finetune import add_model_arguments, add_finetune_arguments, get_model, get_params from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import ( from icefall.checkpoint import (

View File

@ -165,9 +165,9 @@ from typing import List, Tuple
import k2 import k2
import torch import torch
from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn from torch import Tensor, nn
from finetune import add_model_arguments, add_finetune_arguments, get_model, get_params
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -499,7 +499,7 @@ def main():
for k in param_names: for k in param_names:
assert k in state_dict.keys() assert k in state_dict.keys()
new_state_dict[k] = state_dict[k] new_state_dict[k] = state_dict[k]
base_model.load_state_dict(new_state_dict, strict=True) base_model.load_state_dict(new_state_dict, strict=True)
model = base_model model = base_model

View File

@ -147,17 +147,11 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
) )
parser.add_argument( parser.add_argument(
"--use-lora", "--use-lora", type=str2bool, default=True, help="If use LoRA for fine-tune"
type=str2bool,
default=True,
help="If use LoRA for fine-tune"
) )
parser.add_argument( parser.add_argument(
"--lora-r", "--lora-r", type=int, default=0, help="The bottleneck dimension of LoRA"
type=int,
default=0,
help="The bottleneck dimension of LoRA"
) )
parser.add_argument( parser.add_argument(
@ -1287,8 +1281,12 @@ def run(rank, world_size, args):
else: else:
p.requires_grad = False p.requires_grad = False
logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100)) logging.info(
"A total of {} trainable parameters ({:.3f}% of the whole model)".format(
num_trainable, num_trainable / num_param * 100
)
)
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")

View File

@ -15,16 +15,17 @@
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union
import logging import logging
import k2
from torch.cuda.amp import custom_fwd, custom_bwd
import random
import torch
import math import math
import random
from typing import Optional, Tuple, Union
import k2
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor: def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
@ -518,18 +519,19 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
return ans return ans
class LoRALayer: class LoRALayer:
def __init__( def __init__(
self, self,
r: int, r: int,
lora_alpha: int, lora_alpha: int,
lora_dropout: float, lora_dropout: float,
merge_weights: bool, merge_weights: bool,
): ):
self.r = r self.r = r
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
# Optional dropout # Optional dropout
if lora_dropout > 0.: if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout) self.lora_dropout = nn.Dropout(p=lora_dropout)
else: else:
self.lora_dropout = lambda x: x self.lora_dropout = lambda x: x
@ -537,23 +539,29 @@ class LoRALayer:
self.merged = False self.merged = False
self.merge_weights = merge_weights self.merge_weights = merge_weights
class ScaledLinear_lora(nn.Linear, LoRALayer): class ScaledLinear_lora(nn.Linear, LoRALayer):
def __init__( def __init__(
self, self,
in_features: int, in_features: int,
out_features: int, out_features: int,
r: int=0, r: int = 0,
fan_in_fan_out: bool=False, fan_in_fan_out: bool = False,
lora_alpha: int=1, lora_alpha: int = 1,
lora_dropout: float=0.0, lora_dropout: float = 0.0,
initial_scale: float = 1.0, initial_scale: float = 1.0,
merge_weights: bool = True, merge_weights: bool = True,
**kwargs, **kwargs,
): ):
nn.Linear.__init__(self, in_features, out_features, **kwargs) nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, LoRALayer.__init__(
merge_weights=merge_weights) self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights,
)
self.initial_scale = initial_scale self.initial_scale = initial_scale
self.fan_in_fan_out = fan_in_fan_out self.fan_in_fan_out = fan_in_fan_out
if r > 0: if r > 0:
@ -563,7 +571,7 @@ class ScaledLinear_lora(nn.Linear, LoRALayer):
self.weight.requires_grad = False self.weight.requires_grad = False
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
# initialize the parameters # initialize the parameters
nn.Linear.reset_parameters(self) nn.Linear.reset_parameters(self)
@ -572,16 +580,19 @@ class ScaledLinear_lora(nn.Linear, LoRALayer):
with torch.no_grad(): with torch.no_grad():
self.weight[:] *= initial_scale self.weight[:] *= initial_scale
if self.bias is not None: if self.bias is not None:
nn.init.uniform_(self.bias, -0.1 * initial_scale, 0.1 * initial_scale) nn.init.uniform_(
if hasattr(self, 'lora_A'): self.bias, -0.1 * initial_scale, 0.1 * initial_scale
)
if hasattr(self, "lora_A"):
# initialize B the same way as the default for nn.Linear and A to zero # initialize B the same way as the default for nn.Linear and A to zero
# this is different than what is described in the paper but should not affect performance # this is different than what is described in the paper but should not affect performance
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B) nn.init.zeros_(self.lora_B)
def train(self, mode: bool=True): def train(self, mode: bool = True):
def T(w): def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w return w.transpose(0, 1) if self.fan_in_fan_out else w
nn.Linear.train(self, mode) nn.Linear.train(self, mode)
if mode: if mode:
# We don't want the weights to be merged in training mode # We don't want the weights to be merged in training mode
@ -595,18 +606,24 @@ class ScaledLinear_lora(nn.Linear, LoRALayer):
# Merge the weights and mark it # Merge the weights and mark it
if self.r > 0: if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True self.merged = True
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
def T(w): def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w return w.transpose(0, 1) if self.fan_in_fan_out else w
if self.r > 0 and not self.merged: if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias) result = F.linear(x, T(self.weight), bias=self.bias)
delta_result = self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1) delta_result = (
self.lora_dropout(x)
@ self.lora_A.transpose(0, 1)
@ self.lora_B.transpose(0, 1)
)
return result + delta_result * self.scaling return result + delta_result * self.scaling
else: else:
return F.linear(x, T(self.weight), bias=self.bias) return F.linear(x, T(self.weight), bias=self.bias)
def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
""" """
Behaves like a constructor of a modified version of nn.Conv1d Behaves like a constructor of a modified version of nn.Conv1d
@ -1740,6 +1757,7 @@ class ActivationDropoutAndLinear(torch.nn.Module):
self.dropout_shared_dim, self.dropout_shared_dim,
) )
class ActivationDropoutAndLinear_lora(torch.nn.Module): class ActivationDropoutAndLinear_lora(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -1749,9 +1767,9 @@ class ActivationDropoutAndLinear_lora(torch.nn.Module):
activation: str = "SwooshL", activation: str = "SwooshL",
dropout_p: FloatLike = 0.0, dropout_p: FloatLike = 0.0,
dropout_shared_dim: Optional[int] = -1, dropout_shared_dim: Optional[int] = -1,
r: int=0, r: int = 0,
lora_alpha: int=1, lora_alpha: int = 1,
lora_dropout: float=0.0, lora_dropout: float = 0.0,
initial_scale: float = 1.0, initial_scale: float = 1.0,
): ):
super().__init__() super().__init__()

View File

@ -30,7 +30,6 @@ from scaling import (
) )
from scaling import ( from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values. ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
ScaledLinear_lora
) )
from scaling import ( from scaling import (
ActivationDropoutAndLinear, ActivationDropoutAndLinear,
@ -40,6 +39,7 @@ from scaling import (
ChunkCausalDepthwiseConv1d, ChunkCausalDepthwiseConv1d,
Dropout2, Dropout2,
FloatLike, FloatLike,
ScaledLinear_lora,
ScheduledFloat, ScheduledFloat,
Whiten, Whiten,
convert_num_channels, convert_num_channels,
@ -636,7 +636,7 @@ class Zipformer2EncoderLayer(nn.Module):
) )
self.self_attn1 = SelfAttention( self.self_attn1 = SelfAttention(
embed_dim, embed_dim,
num_heads, num_heads,
value_head_dim, value_head_dim,
lora_r=lora_r, lora_r=lora_r,
@ -645,7 +645,7 @@ class Zipformer2EncoderLayer(nn.Module):
) )
self.self_attn2 = SelfAttention( self.self_attn2 = SelfAttention(
embed_dim, embed_dim,
num_heads, num_heads,
value_head_dim, value_head_dim,
lora_r=lora_r, lora_r=lora_r,
@ -654,7 +654,7 @@ class Zipformer2EncoderLayer(nn.Module):
) )
self.feed_forward1 = FeedforwardModule( self.feed_forward1 = FeedforwardModule(
embed_dim, embed_dim,
(feedforward_dim * 3) // 4, (feedforward_dim * 3) // 4,
dropout, dropout,
lora_r=lora_r, lora_r=lora_r,
@ -672,7 +672,7 @@ class Zipformer2EncoderLayer(nn.Module):
) )
self.feed_forward3 = FeedforwardModule( self.feed_forward3 = FeedforwardModule(
embed_dim, embed_dim,
(feedforward_dim * 5) // 4, (feedforward_dim * 5) // 4,
dropout, dropout,
lora_r=lora_r, lora_r=lora_r,
@ -1566,7 +1566,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
lora_r: int = 0, lora_r: int = 0,
lora_alpha: int = 4, lora_alpha: int = 4,
lora_dropout: float=0.0 lora_dropout: float = 0.0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -1935,7 +1935,7 @@ class SelfAttention(nn.Module):
value_head_dim: int, value_head_dim: int,
lora_r: int = 0, lora_r: int = 0,
lora_alpha: int = 4, lora_alpha: int = 4,
lora_dropout: float=0.0 lora_dropout: float = 0.0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.in_proj = ScaledLinear_lora( self.in_proj = ScaledLinear_lora(
@ -2064,7 +2064,7 @@ class FeedforwardModule(nn.Module):
dropout: FloatLike, dropout: FloatLike,
lora_r: int = 0, lora_r: int = 0,
lora_alpha: int = 4, lora_alpha: int = 4,
lora_dropout: float=0.0 lora_dropout: float = 0.0,
): ):
super(FeedforwardModule, self).__init__() super(FeedforwardModule, self).__init__()
self.in_proj = ScaledLinear_lora( self.in_proj = ScaledLinear_lora(