mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
a1840e672a
commit
a1a1964b95
Binary file not shown.
@ -0,0 +1,199 @@
|
|||||||
|
# Copyright 2021 Xuankai Chang
|
||||||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||||
|
|
||||||
|
"""Encoder definition."""
|
||||||
|
import contextlib
|
||||||
|
import time
|
||||||
|
import copy
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from filelock import FileLock
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
from nets_utils import make_pad_mask
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import (
|
||||||
|
ActivationBalancer,
|
||||||
|
BasicNorm,
|
||||||
|
DoubleSwish,
|
||||||
|
ScaledConv1d,
|
||||||
|
ScaledConv2d,
|
||||||
|
ScaledLinear,
|
||||||
|
)
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
from icefall.utils import make_pad_mask, subsequent_chunk_mask
|
||||||
|
|
||||||
|
|
||||||
|
class FairSeqData2VecEncoder(EncoderInterface):
|
||||||
|
"""FairSeq Wav2Vec2 encoder module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size: input dim
|
||||||
|
output_size: dimension of attention
|
||||||
|
w2v_url: url to Wav2Vec2.0 pretrained model
|
||||||
|
w2v_dir_path: directory to download the Wav2Vec2.0 pretrained model.
|
||||||
|
normalize_before: whether to use layer_norm before the first block
|
||||||
|
finetune_last_n_layers: last n layers to be finetuned in Wav2Vec2.0
|
||||||
|
0 means to finetune every layer if freeze_w2v=False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
w2v_url: str,
|
||||||
|
w2v_dir_path: str = "./",
|
||||||
|
output_size: int = 256,
|
||||||
|
freeze_finetune_updates: int = 0,
|
||||||
|
additional_block: bool = False,
|
||||||
|
):
|
||||||
|
assert check_argument_types()
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if w2v_url != "":
|
||||||
|
try:
|
||||||
|
import fairseq
|
||||||
|
from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model
|
||||||
|
except Exception as e:
|
||||||
|
print("Error: FairSeq is not properly installed.")
|
||||||
|
print(
|
||||||
|
"Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if os.path.exists('/home/work/workspace/models/data2vec_model/audio_base_ls.pt'):
|
||||||
|
self.w2v_model_path = '/home/work/workspace/models/data2vec_model/audio_base_ls.pt'
|
||||||
|
if os.path.exists('./models/audio_base_ls.pt'):
|
||||||
|
self.w2v_model_path = './models/audio_base_ls.pt'
|
||||||
|
|
||||||
|
self._output_size = output_size
|
||||||
|
|
||||||
|
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||||
|
[self.w2v_model_path],
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
model = models[0]
|
||||||
|
model.feature_grad_mult = 0.0 ## for conv network freeze
|
||||||
|
#model.mask_prob = 0.3 ## for conv network freeze
|
||||||
|
|
||||||
|
if not isinstance(model, Wav2Vec2Model):
|
||||||
|
try:
|
||||||
|
model = model.w2v_encoder.w2v_model
|
||||||
|
|
||||||
|
except:
|
||||||
|
print(
|
||||||
|
"using data2vec ..."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoders = model
|
||||||
|
self.pretrained_params = copy.deepcopy(model.state_dict())
|
||||||
|
|
||||||
|
if model.cfg.encoder_embed_dim != output_size or additional_block:
|
||||||
|
# TODO(xkc09): try LSTM
|
||||||
|
self.output_layer = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(model.cfg.encoder_embed_dim, output_size),
|
||||||
|
torch.nn.LayerNorm(output_size),
|
||||||
|
torch.nn.GELU(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.output_layer = None
|
||||||
|
|
||||||
|
self.freeze_finetune_updates = freeze_finetune_updates
|
||||||
|
self.num_updates = 0
|
||||||
|
|
||||||
|
def output_size(self) -> int:
|
||||||
|
return self._output_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
warmup = None,
|
||||||
|
prev_states: torch.Tensor = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
xs_pad = x
|
||||||
|
ilens = x_lens
|
||||||
|
"""Forward FairSeqWav2Vec2 Encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
xs_pad: input tensor (B, L, D)
|
||||||
|
ilens: input length (B)
|
||||||
|
prev_states: Not to be used now.
|
||||||
|
Returns:
|
||||||
|
position embedded tensor and mask
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
xs_pad = torch.nn.functional.layer_norm(xs_pad, xs_pad.shape)
|
||||||
|
|
||||||
|
masks = make_pad_mask(ilens).to(xs_pad.device)
|
||||||
|
|
||||||
|
ft = (self.freeze_finetune_updates <= self.num_updates) and self.encoders.training
|
||||||
|
if self.num_updates <= self.freeze_finetune_updates:
|
||||||
|
self.num_updates += 1
|
||||||
|
elif ft and self.num_updates == self.freeze_finetune_updates + 1:
|
||||||
|
self.num_updates += 1
|
||||||
|
logging.info("Start fine-tuning wav2vec parameters!")
|
||||||
|
|
||||||
|
with torch.no_grad() if not ft else contextlib.nullcontext():
|
||||||
|
enc_outputs = self.encoders(
|
||||||
|
xs_pad,
|
||||||
|
masks,
|
||||||
|
mask = ft,
|
||||||
|
features_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
xs_pad = enc_outputs["x"] # (B,T,C),
|
||||||
|
bs = xs_pad.shape[0]
|
||||||
|
if enc_outputs["padding_mask"] is not None:
|
||||||
|
masks = enc_outputs["padding_mask"] # (B, T)
|
||||||
|
olens = (~masks).sum(dim=1) # (B)
|
||||||
|
else:
|
||||||
|
olens = torch.IntTensor([xs_pad.shape[1]]).repeat(bs).to(xs_pad.device)
|
||||||
|
|
||||||
|
if self.output_layer is not None:
|
||||||
|
xs_pad = self.output_layer(xs_pad)
|
||||||
|
|
||||||
|
return xs_pad, olens
|
||||||
|
|
||||||
|
def reload_pretrained_parameters(self):
|
||||||
|
self.encoders.load_state_dict(self.pretrained_params)
|
||||||
|
logging.info("Pretrained Wav2Vec model parameters reloaded!")
|
||||||
|
|
||||||
|
|
||||||
|
def download_w2v(model_url, dir_path):
|
||||||
|
os.makedirs(dir_path, exist_ok=True)
|
||||||
|
|
||||||
|
model_name = model_url.split("/")[-1]
|
||||||
|
model_path = os.path.join(dir_path, model_name)
|
||||||
|
|
||||||
|
dict_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt"
|
||||||
|
dict_path = os.path.join(dir_path, dict_url.split("/")[-1])
|
||||||
|
|
||||||
|
with FileLock(model_path + ".lock"):
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
torch.hub.download_url_to_file(model_url, model_path)
|
||||||
|
torch.hub.download_url_to_file(dict_url, dict_path)
|
||||||
|
logging.info(f"Wav2Vec model downloaded {model_path}")
|
||||||
|
else:
|
||||||
|
logging.info(f"Wav2Vec model {model_path} already exists.")
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
d2v = FairSeqData2VecEncoder(input_size=768, w2v_url='ww', output_size=768)
|
||||||
|
inputs = torch.randn([1, 211564])
|
||||||
|
#a = torch.ones([1000]
|
||||||
|
#b = torch.ones([10000])
|
||||||
|
#c = torch.ones([10000])
|
||||||
|
length = torch.tensor([211564])
|
||||||
|
outputs = d2v(inputs, length)
|
||||||
|
print(outputs[0].size())
|
||||||
|
|
||||||
|
#for n, p in d2v.named_parameters():
|
||||||
|
# print(n)
|
||||||
Loading…
x
Reference in New Issue
Block a user