mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
ce930352ae
commit
c64c7eb4a8
Binary file not shown.
Binary file not shown.
@ -966,7 +966,6 @@ def run(rank, world_size, args):
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
logging.info(model)
|
||||
exit()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -33,6 +33,141 @@ from fairseq.utils import index_put
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransformerEncoderAdapter(TransformerEncoder):
|
||||
def __init__(self, args: Wav2Vec2Config):
|
||||
super().__init__(args)
|
||||
self.adapters = ResidualAdapterModule()
|
||||
|
||||
@classmethod
|
||||
def add_adapter_arguments(cls, parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--add-adapter",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="add adapter to rep model's encoder"
|
||||
)
|
||||
|
||||
def forward(self, x, padding_mask=None, layer=None, tgt_layer=None):
|
||||
x, layer_results = self.extract_features_with_adapter(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
tgt_layer=tgt_layer
|
||||
)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features_with_adapter(
|
||||
self,
|
||||
x,
|
||||
padding_mask=None,
|
||||
tgt_layer=None,
|
||||
min_layer=0,
|
||||
):
|
||||
|
||||
if padding_mask is not None:
|
||||
x = index_put(x, padding_mask, 0)
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# pad to the sequence length dimension
|
||||
x, pad_length = pad_to_multiple(
|
||||
x, self.required_seq_len_multiple, dim=-2, value=0
|
||||
)
|
||||
if pad_length > 0 and padding_mask is None:
|
||||
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
||||
padding_mask[:, -pad_length:] = True
|
||||
else:
|
||||
padding_mask, _ = pad_to_multiple(
|
||||
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
r = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, (z, lr) = layer(
|
||||
x, self_attn_padding_mask=padding_mask, need_weights=False, layer_num=i
|
||||
)
|
||||
x = self.adapters(x, layer_id=i)
|
||||
|
||||
if i >= min_layer:
|
||||
layer_results.append((x, z, lr))
|
||||
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# undo paddding
|
||||
if pad_length > 0:
|
||||
x = x[:, :-pad_length]
|
||||
|
||||
def undo_pad(a, b, c):
|
||||
return (
|
||||
a[:-pad_length],
|
||||
b[:-pad_length] if b is not None else b,
|
||||
c[:-pad_length],
|
||||
)
|
||||
|
||||
layer_results = [undo_pad(*u) for u in layer_results]
|
||||
|
||||
return x, layer_results
|
||||
|
||||
|
||||
class ResidualAdapterModule(nn.Module):
|
||||
"""
|
||||
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf
|
||||
modules similar to the original residual adapter except layernorm location (first -> last)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
layer_num: int = 12,
|
||||
proj_dim: float = 384,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
def build_adapter(embedding_dim, proj_dim):
|
||||
return nn.Sequential(
|
||||
nn.Linear(embedding_dim, proj_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(proj_dim, embedding_dim),
|
||||
nn.LayerNorm(embedding_dim),
|
||||
)
|
||||
|
||||
self.adapter_layers = nn.ModuleList(
|
||||
[build_adapter(embedding_dim, proj_dim) for _ in range(layer_num)]
|
||||
)
|
||||
|
||||
def forward(x, layer_id):
|
||||
x = x.transpose(0, 1)
|
||||
residual = x
|
||||
x = self.adapter_layers[layer_id](x)
|
||||
x = residual + x
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class Data2VecAudioConfig(Wav2Vec2Config):
|
||||
|
||||
|
||||
@ -56,12 +56,6 @@ class FairSeqData2VecEncoder(EncoderInterface):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
|
||||
'''
|
||||
if os.path.exists('/home/work/workspace/models/data2vec_model/audio_base_ls.pt'):
|
||||
self.w2v_model_path = '/home/work/workspace/models/data2vec_model/audio_base_ls.pt'
|
||||
if os.path.exists('/workspace/models/audio_base_ls.pt'):
|
||||
self.w2v_model_path = '/workspace/models/audio_base_ls.pt'
|
||||
'''
|
||||
self.w2v_model_path = download_d2v()
|
||||
self._output_size = output_size
|
||||
|
||||
@ -120,7 +114,7 @@ class FairSeqData2VecEncoder(EncoderInterface):
|
||||
self.num_updates += 1
|
||||
elif ft and self.num_updates == self.freeze_finetune_updates + 1:
|
||||
self.num_updates += 1
|
||||
logging.info("Start fine-tuning wav2vec parameters!")
|
||||
logging.info("Start fine-tuning data2vec parameters!")
|
||||
|
||||
with torch.no_grad() if not ft else contextlib.nullcontext():
|
||||
enc_outputs = self.encoders(
|
||||
|
||||
@ -26,6 +26,32 @@ from encoder_interface import EncoderInterface
|
||||
from icefall.utils import add_sos
|
||||
|
||||
|
||||
class AdapterHook():
|
||||
'''
|
||||
Implementation of the forward hook to track feature statistics and compute a loss on them.
|
||||
Will compute mean and variance, and will use l2 as a loss
|
||||
'''
|
||||
def __init__(self, module):
|
||||
self.hook = module.register_forward_hook(self.hook_fn)
|
||||
|
||||
def hook_fn(self, module, input, output):
|
||||
# hook co compute deepinversion's feature distribution regularization
|
||||
nch = input[0].shape[1]
|
||||
mean = input[0].mean([0, 2, 3])
|
||||
var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
|
||||
|
||||
#forcing mean and variance to match between two distributions
|
||||
#other ways might work better, i.g. KL divergence
|
||||
r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
|
||||
module.running_mean.data - mean, 2)
|
||||
|
||||
self.r_feature = r_feature
|
||||
# must have no output
|
||||
|
||||
def close(self):
|
||||
self.hook.remove()
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user