from local

This commit is contained in:
dohe0342 2022-12-26 13:25:58 +09:00
parent ce930352ae
commit c64c7eb4a8
10 changed files with 162 additions and 8 deletions

View File

@ -966,7 +966,6 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
logging.info(model) logging.info(model)
exit()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")

View File

@ -33,6 +33,141 @@ from fairseq.utils import index_put
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TransformerEncoderAdapter(TransformerEncoder):
def __init__(self, args: Wav2Vec2Config):
super().__init__(args)
self.adapters = ResidualAdapterModule()
@classmethod
def add_adapter_arguments(cls, parser: argparse.ArgumentParser):
parser.add_argument(
"--add-adapter",
type=str2bool,
default=False,
help="add adapter to rep model's encoder"
)
def forward(self, x, padding_mask=None, layer=None, tgt_layer=None):
x, layer_results = self.extract_features_with_adapter(
x,
padding_mask=padding_mask,
tgt_layer=tgt_layer
)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features_with_adapter(
self,
x,
padding_mask=None,
tgt_layer=None,
min_layer=0,
):
if padding_mask is not None:
x = index_put(x, padding_mask, 0)
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
# pad to the sequence length dimension
x, pad_length = pad_to_multiple(
x, self.required_seq_len_multiple, dim=-2, value=0
)
if pad_length > 0 and padding_mask is None:
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
padding_mask[:, -pad_length:] = True
else:
padding_mask, _ = pad_to_multiple(
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
r = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
if not self.training or (dropout_probability > self.layerdrop):
x, (z, lr) = layer(
x, self_attn_padding_mask=padding_mask, need_weights=False, layer_num=i
)
x = self.adapters(x, layer_id=i)
if i >= min_layer:
layer_results.append((x, z, lr))
if i == tgt_layer:
r = x
break
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# undo paddding
if pad_length > 0:
x = x[:, :-pad_length]
def undo_pad(a, b, c):
return (
a[:-pad_length],
b[:-pad_length] if b is not None else b,
c[:-pad_length],
)
layer_results = [undo_pad(*u) for u in layer_results]
return x, layer_results
class ResidualAdapterModule(nn.Module):
"""
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf
modules similar to the original residual adapter except layernorm location (first -> last)
"""
def __init__(
self,
embedding_dim: float = 768,
layer_num: int = 12,
proj_dim: float = 384,
) -> None:
super().__init__()
def build_adapter(embedding_dim, proj_dim):
return nn.Sequential(
nn.Linear(embedding_dim, proj_dim),
nn.ReLU(),
nn.Linear(proj_dim, embedding_dim),
nn.LayerNorm(embedding_dim),
)
self.adapter_layers = nn.ModuleList(
[build_adapter(embedding_dim, proj_dim) for _ in range(layer_num)]
)
def forward(x, layer_id):
x = x.transpose(0, 1)
residual = x
x = self.adapter_layers[layer_id](x)
x = residual + x
x = x.transpose(0, 1)
return x
@dataclass @dataclass
class Data2VecAudioConfig(Wav2Vec2Config): class Data2VecAudioConfig(Wav2Vec2Config):

View File

@ -56,12 +56,6 @@ class FairSeqData2VecEncoder(EncoderInterface):
assert check_argument_types() assert check_argument_types()
super().__init__() super().__init__()
'''
if os.path.exists('/home/work/workspace/models/data2vec_model/audio_base_ls.pt'):
self.w2v_model_path = '/home/work/workspace/models/data2vec_model/audio_base_ls.pt'
if os.path.exists('/workspace/models/audio_base_ls.pt'):
self.w2v_model_path = '/workspace/models/audio_base_ls.pt'
'''
self.w2v_model_path = download_d2v() self.w2v_model_path = download_d2v()
self._output_size = output_size self._output_size = output_size
@ -120,7 +114,7 @@ class FairSeqData2VecEncoder(EncoderInterface):
self.num_updates += 1 self.num_updates += 1
elif ft and self.num_updates == self.freeze_finetune_updates + 1: elif ft and self.num_updates == self.freeze_finetune_updates + 1:
self.num_updates += 1 self.num_updates += 1
logging.info("Start fine-tuning wav2vec parameters!") logging.info("Start fine-tuning data2vec parameters!")
with torch.no_grad() if not ft else contextlib.nullcontext(): with torch.no_grad() if not ft else contextlib.nullcontext():
enc_outputs = self.encoders( enc_outputs = self.encoders(

View File

@ -26,6 +26,32 @@ from encoder_interface import EncoderInterface
from icefall.utils import add_sos from icefall.utils import add_sos
class AdapterHook():
'''
Implementation of the forward hook to track feature statistics and compute a loss on them.
Will compute mean and variance, and will use l2 as a loss
'''
def __init__(self, module):
self.hook = module.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
# hook co compute deepinversion's feature distribution regularization
nch = input[0].shape[1]
mean = input[0].mean([0, 2, 3])
var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
#forcing mean and variance to match between two distributions
#other ways might work better, i.g. KL divergence
r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
module.running_mean.data - mean, 2)
self.r_feature = r_feature
# must have no output
def close(self):
self.hook.remove()
class Transducer(nn.Module): class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks" "Sequence Transduction with Recurrent Neural Networks"