diff --git a/egs/librispeech/ASR/.distillation_with_hubert.sh.swp b/egs/librispeech/ASR/.distillation_with_hubert.sh.swp deleted file mode 100644 index 21fbae33c..000000000 Binary files a/egs/librispeech/ASR/.distillation_with_hubert.sh.swp and /dev/null differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless5/.train.py.swp deleted file mode 100644 index a0986cff4..000000000 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless5/.train.py.swp and /dev/null differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 6b7122efc..f3c6df3ff 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -966,7 +966,6 @@ def run(rank, world_size, args): logging.info("About to create model") model = get_transducer_model(params) logging.info(model) - exit() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/.conformer.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless6/.conformer.py.swp deleted file mode 100644 index f67431d0f..000000000 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless6/.conformer.py.swp and /dev/null differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/.hubert_xlarge.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless6/.hubert_xlarge.py.swp index 3db792bc0..b5e5613e3 100644 Binary files a/egs/librispeech/ASR/pruned_transducer_stateless6/.hubert_xlarge.py.swp and b/egs/librispeech/ASR/pruned_transducer_stateless6/.hubert_xlarge.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp new file mode 100644 index 000000000..7b4418016 Binary files /dev/null and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.data2vec_audio.py.swp differ diff --git a/egs/librispeech/ASR/.run_v3.sh.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp similarity index 51% rename from egs/librispeech/ASR/.run_v3.sh.swp rename to egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 4bb4e7349..a9389cd0e 100644 Binary files a/egs/librispeech/ASR/.run_v3.sh.swp and b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp differ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py index 405a6107a..4c9f531e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_audio.py @@ -33,6 +33,141 @@ from fairseq.utils import index_put logger = logging.getLogger(__name__) +class TransformerEncoderAdapter(TransformerEncoder): + def __init__(self, args: Wav2Vec2Config): + super().__init__(args) + self.adapters = ResidualAdapterModule() + + @classmethod + def add_adapter_arguments(cls, parser: argparse.ArgumentParser): + parser.add_argument( + "--add-adapter", + type=str2bool, + default=False, + help="add adapter to rep model's encoder" + ) + + def forward(self, x, padding_mask=None, layer=None, tgt_layer=None): + x, layer_results = self.extract_features_with_adapter( + x, + padding_mask=padding_mask, + tgt_layer=tgt_layer + ) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features_with_adapter( + self, + x, + padding_mask=None, + tgt_layer=None, + min_layer=0, + ): + + if padding_mask is not None: + x = index_put(x, padding_mask, 0) + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + # pad to the sequence length dimension + x, pad_length = pad_to_multiple( + x, self.required_seq_len_multiple, dim=-2, value=0 + ) + if pad_length > 0 and padding_mask is None: + padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) + padding_mask[:, -pad_length:] = True + else: + padding_mask, _ = pad_to_multiple( + padding_mask, self.required_seq_len_multiple, dim=-1, value=True + ) + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + r = None + + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() if self.layerdrop > 0 else 1 + if not self.training or (dropout_probability > self.layerdrop): + x, (z, lr) = layer( + x, self_attn_padding_mask=padding_mask, need_weights=False, layer_num=i + ) + x = self.adapters(x, layer_id=i) + + if i >= min_layer: + layer_results.append((x, z, lr)) + + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + # undo paddding + if pad_length > 0: + x = x[:, :-pad_length] + + def undo_pad(a, b, c): + return ( + a[:-pad_length], + b[:-pad_length] if b is not None else b, + c[:-pad_length], + ) + + layer_results = [undo_pad(*u) for u in layer_results] + + return x, layer_results + + +class ResidualAdapterModule(nn.Module): + """ + Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf + modules similar to the original residual adapter except layernorm location (first -> last) + """ + def __init__( + self, + embedding_dim: float = 768, + layer_num: int = 12, + proj_dim: float = 384, + ) -> None: + + super().__init__() + + def build_adapter(embedding_dim, proj_dim): + return nn.Sequential( + nn.Linear(embedding_dim, proj_dim), + nn.ReLU(), + nn.Linear(proj_dim, embedding_dim), + nn.LayerNorm(embedding_dim), + ) + + self.adapter_layers = nn.ModuleList( + [build_adapter(embedding_dim, proj_dim) for _ in range(layer_num)] + ) + + def forward(x, layer_id): + x = x.transpose(0, 1) + residual = x + x = self.adapter_layers[layer_id](x) + x = residual + x + x = x.transpose(0, 1) + return x + + @dataclass class Data2VecAudioConfig(Wav2Vec2Config): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_encoder.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_encoder.py index 37ad7edf1..cdd630744 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_encoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/data2vec_encoder.py @@ -56,12 +56,6 @@ class FairSeqData2VecEncoder(EncoderInterface): assert check_argument_types() super().__init__() - ''' - if os.path.exists('/home/work/workspace/models/data2vec_model/audio_base_ls.pt'): - self.w2v_model_path = '/home/work/workspace/models/data2vec_model/audio_base_ls.pt' - if os.path.exists('/workspace/models/audio_base_ls.pt'): - self.w2v_model_path = '/workspace/models/audio_base_ls.pt' - ''' self.w2v_model_path = download_d2v() self._output_size = output_size @@ -120,7 +114,7 @@ class FairSeqData2VecEncoder(EncoderInterface): self.num_updates += 1 elif ft and self.num_updates == self.freeze_finetune_updates + 1: self.num_updates += 1 - logging.info("Start fine-tuning wav2vec parameters!") + logging.info("Start fine-tuning data2vec parameters!") with torch.no_grad() if not ft else contextlib.nullcontext(): enc_outputs = self.encoders( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py index bef12f8fd..673f4ef9a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/model.py @@ -26,6 +26,32 @@ from encoder_interface import EncoderInterface from icefall.utils import add_sos +class AdapterHook(): + ''' + Implementation of the forward hook to track feature statistics and compute a loss on them. + Will compute mean and variance, and will use l2 as a loss + ''' + def __init__(self, module): + self.hook = module.register_forward_hook(self.hook_fn) + + def hook_fn(self, module, input, output): + # hook co compute deepinversion's feature distribution regularization + nch = input[0].shape[1] + mean = input[0].mean([0, 2, 3]) + var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False) + + #forcing mean and variance to match between two distributions + #other ways might work better, i.g. KL divergence + r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm( + module.running_mean.data - mean, 2) + + self.r_feature = r_feature + # must have no output + + def close(self): + self.hook.remove() + + class Transducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks"