mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
ce930352ae
commit
c64c7eb4a8
Binary file not shown.
Binary file not shown.
@ -966,7 +966,6 @@ def run(rank, world_size, args):
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
logging.info(model)
|
logging.info(model)
|
||||||
exit()
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -33,6 +33,141 @@ from fairseq.utils import index_put
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderAdapter(TransformerEncoder):
|
||||||
|
def __init__(self, args: Wav2Vec2Config):
|
||||||
|
super().__init__(args)
|
||||||
|
self.adapters = ResidualAdapterModule()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_adapter_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--add-adapter",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="add adapter to rep model's encoder"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, padding_mask=None, layer=None, tgt_layer=None):
|
||||||
|
x, layer_results = self.extract_features_with_adapter(
|
||||||
|
x,
|
||||||
|
padding_mask=padding_mask,
|
||||||
|
tgt_layer=tgt_layer
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.layer_norm_first and layer is None:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
def extract_features_with_adapter(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
padding_mask=None,
|
||||||
|
tgt_layer=None,
|
||||||
|
min_layer=0,
|
||||||
|
):
|
||||||
|
|
||||||
|
if padding_mask is not None:
|
||||||
|
x = index_put(x, padding_mask, 0)
|
||||||
|
|
||||||
|
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||||
|
x_conv = x_conv.transpose(1, 2)
|
||||||
|
x = x + x_conv
|
||||||
|
|
||||||
|
if not self.layer_norm_first:
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
# pad to the sequence length dimension
|
||||||
|
x, pad_length = pad_to_multiple(
|
||||||
|
x, self.required_seq_len_multiple, dim=-2, value=0
|
||||||
|
)
|
||||||
|
if pad_length > 0 and padding_mask is None:
|
||||||
|
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
||||||
|
padding_mask[:, -pad_length:] = True
|
||||||
|
else:
|
||||||
|
padding_mask, _ = pad_to_multiple(
|
||||||
|
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
||||||
|
)
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
# B x T x C -> T x B x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
layer_results = []
|
||||||
|
r = None
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||||||
|
if not self.training or (dropout_probability > self.layerdrop):
|
||||||
|
x, (z, lr) = layer(
|
||||||
|
x, self_attn_padding_mask=padding_mask, need_weights=False, layer_num=i
|
||||||
|
)
|
||||||
|
x = self.adapters(x, layer_id=i)
|
||||||
|
|
||||||
|
if i >= min_layer:
|
||||||
|
layer_results.append((x, z, lr))
|
||||||
|
|
||||||
|
if i == tgt_layer:
|
||||||
|
r = x
|
||||||
|
break
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
x = r
|
||||||
|
|
||||||
|
# T x B x C -> B x T x C
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
# undo paddding
|
||||||
|
if pad_length > 0:
|
||||||
|
x = x[:, :-pad_length]
|
||||||
|
|
||||||
|
def undo_pad(a, b, c):
|
||||||
|
return (
|
||||||
|
a[:-pad_length],
|
||||||
|
b[:-pad_length] if b is not None else b,
|
||||||
|
c[:-pad_length],
|
||||||
|
)
|
||||||
|
|
||||||
|
layer_results = [undo_pad(*u) for u in layer_results]
|
||||||
|
|
||||||
|
return x, layer_results
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAdapterModule(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf
|
||||||
|
modules similar to the original residual adapter except layernorm location (first -> last)
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_dim: float = 768,
|
||||||
|
layer_num: int = 12,
|
||||||
|
proj_dim: float = 384,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def build_adapter(embedding_dim, proj_dim):
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(embedding_dim, proj_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(proj_dim, embedding_dim),
|
||||||
|
nn.LayerNorm(embedding_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adapter_layers = nn.ModuleList(
|
||||||
|
[build_adapter(embedding_dim, proj_dim) for _ in range(layer_num)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(x, layer_id):
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
residual = x
|
||||||
|
x = self.adapter_layers[layer_id](x)
|
||||||
|
x = residual + x
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Data2VecAudioConfig(Wav2Vec2Config):
|
class Data2VecAudioConfig(Wav2Vec2Config):
|
||||||
|
|
||||||
|
|||||||
@ -56,12 +56,6 @@ class FairSeqData2VecEncoder(EncoderInterface):
|
|||||||
assert check_argument_types()
|
assert check_argument_types()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
'''
|
|
||||||
if os.path.exists('/home/work/workspace/models/data2vec_model/audio_base_ls.pt'):
|
|
||||||
self.w2v_model_path = '/home/work/workspace/models/data2vec_model/audio_base_ls.pt'
|
|
||||||
if os.path.exists('/workspace/models/audio_base_ls.pt'):
|
|
||||||
self.w2v_model_path = '/workspace/models/audio_base_ls.pt'
|
|
||||||
'''
|
|
||||||
self.w2v_model_path = download_d2v()
|
self.w2v_model_path = download_d2v()
|
||||||
self._output_size = output_size
|
self._output_size = output_size
|
||||||
|
|
||||||
@ -120,7 +114,7 @@ class FairSeqData2VecEncoder(EncoderInterface):
|
|||||||
self.num_updates += 1
|
self.num_updates += 1
|
||||||
elif ft and self.num_updates == self.freeze_finetune_updates + 1:
|
elif ft and self.num_updates == self.freeze_finetune_updates + 1:
|
||||||
self.num_updates += 1
|
self.num_updates += 1
|
||||||
logging.info("Start fine-tuning wav2vec parameters!")
|
logging.info("Start fine-tuning data2vec parameters!")
|
||||||
|
|
||||||
with torch.no_grad() if not ft else contextlib.nullcontext():
|
with torch.no_grad() if not ft else contextlib.nullcontext():
|
||||||
enc_outputs = self.encoders(
|
enc_outputs = self.encoders(
|
||||||
|
|||||||
@ -26,6 +26,32 @@ from encoder_interface import EncoderInterface
|
|||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|
||||||
|
class AdapterHook():
|
||||||
|
'''
|
||||||
|
Implementation of the forward hook to track feature statistics and compute a loss on them.
|
||||||
|
Will compute mean and variance, and will use l2 as a loss
|
||||||
|
'''
|
||||||
|
def __init__(self, module):
|
||||||
|
self.hook = module.register_forward_hook(self.hook_fn)
|
||||||
|
|
||||||
|
def hook_fn(self, module, input, output):
|
||||||
|
# hook co compute deepinversion's feature distribution regularization
|
||||||
|
nch = input[0].shape[1]
|
||||||
|
mean = input[0].mean([0, 2, 3])
|
||||||
|
var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
|
||||||
|
|
||||||
|
#forcing mean and variance to match between two distributions
|
||||||
|
#other ways might work better, i.g. KL divergence
|
||||||
|
r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
|
||||||
|
module.running_mean.data - mean, 2)
|
||||||
|
|
||||||
|
self.r_feature = r_feature
|
||||||
|
# must have no output
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.hook.remove()
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||||
"Sequence Transduction with Recurrent Neural Networks"
|
"Sequence Transduction with Recurrent Neural Networks"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user