mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
from local
This commit is contained in:
parent
2baa7e3631
commit
4c60473622
683
egs/librispeech/ASR/local/data2vec_audio.py
Normal file
683
egs/librispeech/ASR/local/data2vec_audio.py
Normal file
@ -0,0 +1,683 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from omegaconf import II
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from fairseq.data.data_utils import compute_mask_indices
|
||||
from fairseq.models import BaseFairseqModel, register_model
|
||||
from fairseq.models.wav2vec import (
|
||||
ConvFeatureExtractionModel,
|
||||
Wav2Vec2Config,
|
||||
TransformerEncoder,
|
||||
)
|
||||
from fairseq.modules import (
|
||||
GradMultiply,
|
||||
LayerNorm,
|
||||
)
|
||||
from fairseq.utils import index_put
|
||||
from utils import pad_to_multiple
|
||||
|
||||
from convolution import ConvolutionModule
|
||||
|
||||
logger = logging.getLogger().setLevel(logging.INFO)
|
||||
|
||||
class TransformerEncoderAdapter(TransformerEncoder):
|
||||
def __init__(self, args: Wav2Vec2Config):
|
||||
super().__init__(args)
|
||||
self.adapters = ResidualAdapterModule()
|
||||
|
||||
for p in self.adapters.parameters():
|
||||
p.data /= 10.
|
||||
#p.data = nn.Parameter(torch.zeros(p.size()).to('cuda'))
|
||||
#p.data = nn.Parameter(torch.randn(p.size()).to('cuda')/20.)
|
||||
|
||||
def forward(self, x, padding_mask=None, layer=None, tgt_layer=None):
|
||||
x, layer_results = self.extract_features_with_adapter(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
tgt_layer=tgt_layer
|
||||
)
|
||||
|
||||
if self.layer_norm_first and layer is None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
return x, layer_results
|
||||
|
||||
def extract_features_with_adapter(
|
||||
self,
|
||||
x,
|
||||
padding_mask=None,
|
||||
tgt_layer=None,
|
||||
min_layer=0,
|
||||
):
|
||||
|
||||
if padding_mask is not None:
|
||||
x = index_put(x, padding_mask, 0)
|
||||
|
||||
x_conv = self.pos_conv(x.transpose(1, 2))
|
||||
x_conv = x_conv.transpose(1, 2)
|
||||
x = x + x_conv
|
||||
|
||||
if not self.layer_norm_first:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# pad to the sequence length dimension
|
||||
x, pad_length = pad_to_multiple(
|
||||
x, self.required_seq_len_multiple, dim=-2, value=0
|
||||
)
|
||||
if pad_length > 0 and padding_mask is None:
|
||||
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
||||
padding_mask[:, -pad_length:] = True
|
||||
else:
|
||||
padding_mask, _ = pad_to_multiple(
|
||||
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
||||
)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
layer_results = []
|
||||
r = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
||||
if not self.training or (dropout_probability > self.layerdrop):
|
||||
x, (z, lr) = layer(
|
||||
x, self_attn_padding_mask=padding_mask, need_weights=False,
|
||||
)
|
||||
x = self.adapters(x, layer_id=i)
|
||||
|
||||
if i >= min_layer:
|
||||
layer_results.append((x, z, lr))
|
||||
|
||||
if i == tgt_layer:
|
||||
r = x
|
||||
break
|
||||
|
||||
if r is not None:
|
||||
x = r
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# undo paddding
|
||||
if pad_length > 0:
|
||||
x = x[:, :-pad_length]
|
||||
|
||||
def undo_pad(a, b, c):
|
||||
return (
|
||||
a[:-pad_length],
|
||||
b[:-pad_length] if b is not None else b,
|
||||
c[:-pad_length],
|
||||
)
|
||||
|
||||
layer_results = [undo_pad(*u) for u in layer_results]
|
||||
|
||||
return x, layer_results
|
||||
|
||||
|
||||
class ResidualAdapterModule(nn.Module):
|
||||
"""
|
||||
Implements a residual adapter based on https://arxiv.org/pdf/1909.08478.pdf
|
||||
modules similar to the original residual adapter except layernorm location (first -> last)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: float = 768,
|
||||
layer_num: int = 12,
|
||||
proj_dim: float = 512,
|
||||
) -> None:
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.type = 'linear'
|
||||
|
||||
def build_adapter(embedding_dim, proj_dim, type_=self.type):
|
||||
if type_ == 'conv':
|
||||
return ConvolutionModule(768, 31)
|
||||
else:
|
||||
return nn.Sequential(
|
||||
#nn.LayerNorm(embedding_dim),
|
||||
nn.Linear(embedding_dim, proj_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(proj_dim, embedding_dim),
|
||||
nn.LayerNorm(embedding_dim),
|
||||
)
|
||||
|
||||
self.adapter_layers = nn.ModuleList(
|
||||
[build_adapter(embedding_dim, proj_dim, type_=self.type) for _ in range(layer_num)]
|
||||
)
|
||||
|
||||
def forward(self, x, layer_id=-1):
|
||||
x = x.transpose(0, 1)
|
||||
residual = x
|
||||
x = self.adapter_layers[layer_id](x)
|
||||
x = residual + x
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@dataclass
|
||||
class Data2VecAudioConfig(Wav2Vec2Config):
|
||||
|
||||
loss_beta: float = field(
|
||||
default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"}
|
||||
)
|
||||
loss_scale: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
|
||||
},
|
||||
)
|
||||
average_top_k_layers: int = field(
|
||||
default=8, metadata={"help": "how many layers to average"}
|
||||
)
|
||||
|
||||
layer_norm_target_layer: bool = False
|
||||
instance_norm_target_layer: bool = False
|
||||
instance_norm_targets: bool = False
|
||||
layer_norm_targets: bool = False
|
||||
batch_norm_target_layer: bool = False
|
||||
group_norm_target_layer: bool = False
|
||||
|
||||
ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"})
|
||||
ema_end_decay: float = field(
|
||||
default=0.9999, metadata={"help": "final ema decay rate"}
|
||||
)
|
||||
|
||||
# when to finish annealing ema decay rate
|
||||
ema_anneal_end_step: int = II("optimization.max_update")
|
||||
|
||||
ema_transformer_only: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "whether to momentum update only the transformer"},
|
||||
)
|
||||
ema_layers_only: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "whether to momentum update only the transformer layers"},
|
||||
)
|
||||
|
||||
max_update: int = II("optimization.max_update")
|
||||
|
||||
min_target_var: float = field(
|
||||
default=0.1, metadata={"help": "stop training if target var falls below this"}
|
||||
)
|
||||
min_pred_var: float = field(
|
||||
default=0.01,
|
||||
metadata={"help": "stop training if prediction var falls below this"},
|
||||
)
|
||||
|
||||
|
||||
def get_annealed_rate(start, end, curr_step, total_steps):
|
||||
r = end - start
|
||||
pct_remaining = 1 - curr_step / total_steps
|
||||
return end - r * pct_remaining
|
||||
|
||||
|
||||
@register_model("data2vec_audio", dataclass=Data2VecAudioConfig)
|
||||
class Data2VecAudioModel(BaseFairseqModel):
|
||||
def __init__(self, cfg: Data2VecAudioConfig):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
feature_enc_layers = eval(cfg.conv_feature_layers)
|
||||
self.extractor_embed = feature_enc_layers[-1][0]
|
||||
|
||||
self.ema = None
|
||||
self.embed = cfg.encoder_embed_dim
|
||||
|
||||
self.average_top_k_layers = cfg.average_top_k_layers
|
||||
self.loss_beta = cfg.loss_beta
|
||||
self.loss_scale = cfg.loss_scale
|
||||
|
||||
self.feature_extractor = ConvFeatureExtractionModel(
|
||||
conv_layers=feature_enc_layers,
|
||||
dropout=0.0,
|
||||
mode=cfg.extractor_mode,
|
||||
conv_bias=cfg.conv_bias,
|
||||
)
|
||||
|
||||
self.post_extract_proj = nn.Linear(self.extractor_embed, cfg.encoder_embed_dim)
|
||||
|
||||
self.mask_prob = cfg.mask_prob
|
||||
self.mask_selection = cfg.mask_selection
|
||||
self.mask_other = cfg.mask_other
|
||||
self.mask_length = cfg.mask_length
|
||||
self.no_mask_overlap = cfg.no_mask_overlap
|
||||
self.mask_min_space = cfg.mask_min_space
|
||||
|
||||
self.mask_channel_prob = cfg.mask_channel_prob
|
||||
self.mask_channel_before = cfg.mask_channel_before
|
||||
self.mask_channel_selection = cfg.mask_channel_selection
|
||||
self.mask_channel_other = cfg.mask_channel_other
|
||||
self.mask_channel_length = cfg.mask_channel_length
|
||||
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
||||
self.mask_channel_min_space = cfg.mask_channel_min_space
|
||||
|
||||
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
||||
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
||||
|
||||
self.feature_grad_mult = cfg.feature_grad_mult
|
||||
|
||||
self.mask_emb = nn.Parameter(
|
||||
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
||||
)
|
||||
|
||||
#self.encoder = TransformerEncoder(cfg)
|
||||
self.encoder = TransformerEncoderAdapter(cfg)
|
||||
self.layer_norm = LayerNorm(self.extractor_embed)
|
||||
|
||||
self.final_proj = nn.Linear(self.embed, self.embed)
|
||||
|
||||
self.num_updates = 0
|
||||
|
||||
'''
|
||||
def make_ema_teacher(self):
|
||||
ema_config = EMAModuleConfig(
|
||||
ema_decay=self.cfg.ema_decay,
|
||||
ema_fp32=True,
|
||||
)
|
||||
skip_keys = set()
|
||||
if self.cfg.ema_layers_only:
|
||||
self.cfg.ema_transformer_only = True
|
||||
for k, _ in self.encoder.pos_conv.named_parameters():
|
||||
skip_keys.add(f"pos_conv.{k}")
|
||||
|
||||
self.ema = EMAModule(
|
||||
self.encoder if self.cfg.ema_transformer_only else self,
|
||||
ema_config,
|
||||
skip_keys=skip_keys,
|
||||
)
|
||||
'''
|
||||
|
||||
def set_num_updates(self, num_updates):
|
||||
super().set_num_updates(num_updates)
|
||||
|
||||
'''
|
||||
if self.ema is None and self.final_proj is not None:
|
||||
logger.info(f"making ema teacher")
|
||||
self.make_ema_teacher()
|
||||
elif self.training and self.ema is not None:
|
||||
if self.cfg.ema_decay != self.cfg.ema_end_decay:
|
||||
if num_updates >= self.cfg.ema_anneal_end_step:
|
||||
decay = self.cfg.ema_end_decay
|
||||
else:
|
||||
decay = get_annealed_rate(
|
||||
self.cfg.ema_decay,
|
||||
self.cfg.ema_end_decay,
|
||||
num_updates,
|
||||
self.cfg.ema_anneal_end_step,
|
||||
)
|
||||
self.ema.set_decay(decay)
|
||||
if self.ema.get_decay() < 1:
|
||||
self.ema.step(self.encoder if self.cfg.ema_transformer_only else self)
|
||||
'''
|
||||
self.num_updates = num_updates
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
state = super().state_dict(destination, prefix, keep_vars)
|
||||
|
||||
if self.ema is not None:
|
||||
state[prefix + "_ema"] = self.ema.fp32_params
|
||||
|
||||
return state
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
if self.ema is not None:
|
||||
k = prefix + "_ema"
|
||||
assert k in state_dict
|
||||
self.ema.restore(state_dict[k], True)
|
||||
del state_dict[k]
|
||||
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, cfg: Data2VecAudioConfig, task=None):
|
||||
"""Build a new model instance."""
|
||||
|
||||
return cls(cfg)
|
||||
|
||||
def apply_mask(
|
||||
self,
|
||||
x,
|
||||
padding_mask,
|
||||
mask_indices=None,
|
||||
mask_channel_indices=None,
|
||||
):
|
||||
B, T, C = x.shape
|
||||
|
||||
if self.mask_channel_prob > 0 and self.mask_channel_before:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x[mask_channel_indices] = 0
|
||||
|
||||
if self.mask_prob > 0:
|
||||
if mask_indices is None:
|
||||
mask_indices = compute_mask_indices(
|
||||
(B, T),
|
||||
padding_mask,
|
||||
self.mask_prob,
|
||||
self.mask_length,
|
||||
self.mask_selection,
|
||||
self.mask_other,
|
||||
min_masks=1,
|
||||
no_overlap=self.no_mask_overlap,
|
||||
min_space=self.mask_min_space,
|
||||
require_same_masks=self.cfg.require_same_masks,
|
||||
mask_dropout=self.cfg.mask_dropout,
|
||||
)
|
||||
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
||||
#x = index_put(x, mask_indices, self.mask_emb)
|
||||
x = index_put(x, mask_indices, 0)
|
||||
else:
|
||||
mask_indices = None
|
||||
|
||||
if self.mask_channel_prob > 0 and not self.mask_channel_before:
|
||||
if mask_channel_indices is None:
|
||||
mask_channel_indices = compute_mask_indices(
|
||||
(B, C),
|
||||
None,
|
||||
self.mask_channel_prob,
|
||||
self.mask_channel_length,
|
||||
self.mask_channel_selection,
|
||||
self.mask_channel_other,
|
||||
no_overlap=self.no_mask_channel_overlap,
|
||||
min_space=self.mask_channel_min_space,
|
||||
)
|
||||
mask_channel_indices = (
|
||||
torch.from_numpy(mask_channel_indices)
|
||||
.to(x.device)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, T, -1)
|
||||
)
|
||||
x = index_put(x, mask_channel_indices, 0)
|
||||
|
||||
return x, mask_indices
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
return torch.floor((input_length - kernel_size) / stride + 1)
|
||||
|
||||
conv_cfg_list = eval(self.cfg.conv_feature_layers)
|
||||
|
||||
for i in range(len(conv_cfg_list)):
|
||||
input_lengths = _conv_out_length(
|
||||
input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
|
||||
)
|
||||
|
||||
return input_lengths.to(torch.long)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
source,
|
||||
padding_mask=None,
|
||||
mask=True,
|
||||
features_only=False,
|
||||
layer=None,
|
||||
mask_indices=None,
|
||||
mask_channel_indices=None,
|
||||
padding_count=None,
|
||||
):
|
||||
features = source
|
||||
|
||||
if self.feature_grad_mult > 0:
|
||||
features = self.feature_extractor(features)
|
||||
if self.feature_grad_mult != 1.0:
|
||||
features = GradMultiply.apply(features, self.feature_grad_mult)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
features = self.feature_extractor(features)
|
||||
|
||||
features = features.transpose(1, 2)
|
||||
|
||||
features = self.layer_norm(features)
|
||||
|
||||
orig_padding_mask = padding_mask
|
||||
|
||||
if padding_mask is not None and padding_mask.any():
|
||||
input_lengths = (1 - padding_mask.long()).sum(-1)
|
||||
# apply conv formula to get real output_lengths
|
||||
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
|
||||
|
||||
padding_mask = torch.zeros(
|
||||
features.shape[:2], dtype=features.dtype, device=features.device
|
||||
)
|
||||
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
padding_mask[
|
||||
(
|
||||
torch.arange(padding_mask.shape[0], device=padding_mask.device),
|
||||
output_lengths - 1,
|
||||
)
|
||||
] = 1
|
||||
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
if self.post_extract_proj is not None:
|
||||
features = self.post_extract_proj(features)
|
||||
|
||||
pre_encoder_features = None
|
||||
if self.cfg.ema_transformer_only:
|
||||
pre_encoder_features = features.clone()
|
||||
|
||||
features = self.dropout_input(features)
|
||||
|
||||
if mask:
|
||||
x, mask_indices = self.apply_mask(
|
||||
features,
|
||||
padding_mask,
|
||||
mask_indices=mask_indices,
|
||||
mask_channel_indices=mask_channel_indices,
|
||||
)
|
||||
else:
|
||||
x = features
|
||||
mask_indices = None
|
||||
|
||||
x, layer_results = self.encoder(
|
||||
x,
|
||||
padding_mask=padding_mask,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
if features_only:
|
||||
return {
|
||||
"x": x,
|
||||
"padding_mask": padding_mask,
|
||||
"layer_results": layer_results,
|
||||
}
|
||||
|
||||
result = {
|
||||
"losses": {},
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
self.ema.model.eval()
|
||||
|
||||
if self.cfg.ema_transformer_only:
|
||||
y, layer_results = self.ema.model.extract_features(
|
||||
pre_encoder_features,
|
||||
padding_mask=padding_mask,
|
||||
min_layer=self.cfg.encoder_layers - self.average_top_k_layers,
|
||||
)
|
||||
y = {
|
||||
"x": y,
|
||||
"padding_mask": padding_mask,
|
||||
"layer_results": layer_results,
|
||||
}
|
||||
else:
|
||||
y = self.ema.model.extract_features(
|
||||
source=source,
|
||||
padding_mask=orig_padding_mask,
|
||||
mask=False,
|
||||
)
|
||||
|
||||
target_layer_results = [l[2] for l in y["layer_results"]]
|
||||
|
||||
permuted = False
|
||||
if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer:
|
||||
target_layer_results = [
|
||||
tl.permute(1, 2, 0) for tl in target_layer_results # TBC -> BCT
|
||||
]
|
||||
permuted = True
|
||||
|
||||
if self.cfg.batch_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.batch_norm(
|
||||
tl.float(), running_mean=None, running_var=None, training=True
|
||||
)
|
||||
for tl in target_layer_results
|
||||
]
|
||||
|
||||
if self.cfg.instance_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.instance_norm(tl.float()) for tl in target_layer_results
|
||||
]
|
||||
|
||||
if permuted:
|
||||
target_layer_results = [
|
||||
tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC
|
||||
]
|
||||
|
||||
if self.cfg.group_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.layer_norm(tl.float(), tl.shape[-2:])
|
||||
for tl in target_layer_results
|
||||
]
|
||||
|
||||
if self.cfg.layer_norm_target_layer:
|
||||
target_layer_results = [
|
||||
F.layer_norm(tl.float(), tl.shape[-1:])
|
||||
for tl in target_layer_results
|
||||
]
|
||||
|
||||
y = sum(target_layer_results) / len(target_layer_results)
|
||||
|
||||
if self.cfg.layer_norm_targets:
|
||||
y = F.layer_norm(y.float(), y.shape[-1:])
|
||||
|
||||
if self.cfg.instance_norm_targets:
|
||||
y = F.instance_norm(y.float().transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
if not permuted:
|
||||
y = y.transpose(0, 1)
|
||||
|
||||
y = y[mask_indices]
|
||||
|
||||
x = x[mask_indices]
|
||||
x = self.final_proj(x)
|
||||
|
||||
sz = x.size(-1)
|
||||
|
||||
if self.loss_beta == 0:
|
||||
loss = F.mse_loss(x.float(), y.float(), reduction="none").sum(dim=-1)
|
||||
else:
|
||||
loss = F.smooth_l1_loss(
|
||||
x.float(), y.float(), reduction="none", beta=self.loss_beta
|
||||
).sum(dim=-1)
|
||||
|
||||
if self.loss_scale is not None:
|
||||
scale = self.loss_scale
|
||||
else:
|
||||
scale = 1 / math.sqrt(sz)
|
||||
|
||||
result["losses"]["regression"] = loss.sum() * scale
|
||||
|
||||
if "sample_size" not in result:
|
||||
result["sample_size"] = loss.numel()
|
||||
|
||||
with torch.no_grad():
|
||||
result["target_var"] = self.compute_var(y)
|
||||
result["pred_var"] = self.compute_var(x.float())
|
||||
|
||||
if self.num_updates > 5000 and result["target_var"] < self.cfg.min_target_var:
|
||||
logger.error(
|
||||
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
||||
)
|
||||
raise Exception(
|
||||
f"target var is {result['target_var'].item()} < {self.cfg.min_target_var}, exiting"
|
||||
)
|
||||
if self.num_updates > 5000 and result["pred_var"] < self.cfg.min_pred_var:
|
||||
logger.error(
|
||||
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
||||
)
|
||||
raise Exception(
|
||||
f"pred var is {result['pred_var'].item()} < {self.cfg.min_pred_var}, exiting"
|
||||
)
|
||||
|
||||
if self.ema is not None:
|
||||
result["ema_decay"] = self.ema.get_decay() * 1000
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def compute_var(y):
|
||||
y = y.view(-1, y.size(-1))
|
||||
if dist.is_initialized():
|
||||
zc = torch.tensor(y.size(0)).cuda()
|
||||
zs = y.sum(dim=0)
|
||||
zss = (y ** 2).sum(dim=0)
|
||||
|
||||
dist.all_reduce(zc)
|
||||
dist.all_reduce(zs)
|
||||
dist.all_reduce(zss)
|
||||
|
||||
var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
|
||||
return torch.sqrt(var + 1e-6).mean()
|
||||
else:
|
||||
return torch.sqrt(y.var(dim=0) + 1e-6).mean()
|
||||
|
||||
def extract_features(
|
||||
self, source, padding_mask, mask=False, layer=None
|
||||
):
|
||||
res = self.forward(
|
||||
source,
|
||||
padding_mask,
|
||||
mask=mask,
|
||||
features_only=True,
|
||||
layer=layer,
|
||||
)
|
||||
return res
|
||||
|
||||
def remove_pretraining_modules(self, last_layer=None):
|
||||
self.final_proj = None
|
||||
self.ema = None
|
||||
if last_layer is not None:
|
||||
self.encoder.layers = nn.ModuleList(
|
||||
l for i, l in enumerate(self.encoder.layers) if i <= last_layer
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user