diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp index a42de904d..b1d6479e3 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp index 451906d49..7b9e471f6 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train_lora.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py index 173404859..3a0df11ac 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train_lora.py @@ -101,6 +101,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter from zipformer import Zipformer from data2vec_encoder import FairSeqData2VecEncoder +from data2vec_audio import LoRAModule from icefall import diagnostics from icefall.checkpoint import remove_checkpoints @@ -127,6 +128,12 @@ import wandb #from icefall.checkpoint import save_checkpoint as save_checkpoint_impl LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +class LoRAHook(): + def __init__(self, module): + self.hook = module.register_forward_hook(self.hook_fn) + + def hook_fn(self, module, input, output): + def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP):