Initially working version with delay_loss...

This commit is contained in:
Daniel Povey 2021-09-23 11:25:42 +08:00
parent 65b737576e
commit 2213457bd3
2 changed files with 100 additions and 70 deletions

View File

@ -222,11 +222,6 @@ class BidirectionalConformer(nn.Module):
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
) )
self.bottleneck_ctc_encoder = ConformerEncoder(encoder_layer, num_ctc_encoder_layers)
self.bottleneck_ctc_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
)
# absolute position encoding, used by various layer types # absolute position encoding, used by various layer types
self.abs_pos = PositionalEncoding(d_model, dropout) self.abs_pos = PositionalEncoding(d_model, dropout)
@ -480,38 +475,6 @@ class BidirectionalConformer(nn.Module):
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x return x
def bottleneck_ctc_encoder_forward(
self,
positive_embed: torch.Tensor,
pos_emb: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
) -> torch.Tensor:
"""
Passes the output of sample_forward() through the CTC "from-bottleneck"
CTC encoder and the CTC
output layer to give the output that can be given to the CTC loss function
Args:
positive_embed:
One of the outputs of sample_forward(), with shape (T, N, E)
pos_emb:
Relative positional embedding tensor: (N, 2*T-1, E)
memory_key_padding_mask:
The padding mask from forward(), a tensor of bool of shape (N, T)
Returns:
A Tensor with shape [N, T, C] where C is the number of classes
(e.g. number of phones or word pieces). Contains normalized
log-probabilities.
"""
x = self.bottleneck_ctc_encoder(positive_embed,
pos_emb,
key_padding_mask=memory_key_padding_mask)
x = self.bottleneck_ctc_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x
def self_prediction_forward( def self_prediction_forward(
self, self,

View File

@ -20,6 +20,7 @@
import argparse import argparse
import collections import collections
import copy
import logging import logging
from pathlib import Path from pathlib import Path
import random # temp.. import random # temp..
@ -155,7 +156,7 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_2"), "exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_delay"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, # can't be changed "subsampling_factor": 4, # can't be changed
@ -171,9 +172,13 @@ def get_params() -> AttributeDict:
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
"accum_grad": 1, "accum_grad": 1,
"att_scale": 0.3, "att_scale": 0.5,
"reverse_att_scale": 0.2, "reverse_att_scale": 0.2,
"bottleneck_ctc_scale": 0.2, # ctc_scale == 1.0 - att_scale - reverse_att_scale - bottleneck_ctc_scale "ctc_scale": 0.3,
"reconstruction_scale": 0.5, # Scale on log of reconstruction error after discrete bottleneck.
"delay_scale": 2.0, # Scale on difference between current and
# delayed version of positive_embed.
"delay_minibatches": 200,
"attention_dim": 512, "attention_dim": 512,
"nhead": 8, "nhead": 8,
"num_trunk_encoder_layers": 12, "num_trunk_encoder_layers": 12,
@ -349,6 +354,69 @@ class LossRecord(collections.defaultdict):
tb_writer.add_scalar(prefix + k, v, batch_idx) tb_writer.add_scalar(prefix + k, v, batch_idx)
def get_delayed_model(model: nn.Module,
params: AttributeDict) -> nn.Module:
if hasattr(model, "module"):
model = model.module
delay_minibatches = params.delay_minibatches
cur_batch_idx = params.batch_idx_train
try:
# hasattr doesn't seem to work for this... use try-except to test if it
# has the attribute.
_ = params.cur_delayed_model
except KeyError:
params.cur_delayed_model = copy.deepcopy(model)
params.prev_delayed_model = params.cur_delayed_model
params.cur_delayed_batch_idx = cur_batch_idx
if params.cur_delayed_batch_idx <= cur_batch_idx - delay_minibatches:
params.cur_delayed_batch_idx = cur_batch_idx
params.prev_delayed_model = params.cur_delayed_model
params.cur_delayed_model = copy.deepcopy(model)
return params.prev_delayed_model
def compute_distance(feats1, feats2):
"""
Assumes that feats1 and feats2 are some kind of features with (N, T, C) or (T, N, C) layout.
Computes a distance between them, that will have the property that its derivative w.r.t.
feats2 will be orthogonal to feats2. This will avoid any direct pressure for feats2 to
grow or shrink. (We assume that feats1 is without grad). Let K = T * N (the total
number of frames. Then the returned value is half the total, over the frames, of the
log of the (average squared distance, per frame), between feats1 and (alpha * feats2),
for optimally chosen alpha, i.e.:
ans = K * (((feats1 - alpha * feats2) ** 2).sum() / K).sqrt()
[the minimum of that result, for any alpha]. d/d alpha of:
((feats1 - alpha * feats2) * (feats1 - alpha * feats2)).sum()
Expanded in terms of alpha, this is:
(feats1**2).sum() + alpha**2 *(feats2**2).sum() - 2 * alpha * (feats1*feats2).sum()
and d/dalpha of this is:
2 * alpha *(feats2**2).sum() - 2 * (feats1*feats2).sum()
so alpha = (feats1*feats2).sum() / (feats2**2).sum()
"""
feats1_prod = (feats1 ** 2).sum()
feats2_prod = (feats2 ** 2).sum()
cross_prod = (feats2 * feats1).sum()
alpha = cross_prod.detach() / feats2_prod.detach()
if random.random() < 0.01:
logging.info(f"compute_distance: alpha = {alpha.to('cpu').item()}")
K = feats1.shape[0] * feats1.shape[1]
avg_dist = ((feats1_prod + (alpha**2 * feats2_prod) - 2 * alpha * cross_prod) / K)
if avg_dist <= 0.0:
avg_dist = torch.tensor([0.0], device=feats1.device)
else:
avg_dist = avg_dist.sqrt()
return K * avg_dist
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
@ -388,10 +456,28 @@ def compute_loss(
mmodel = model.module if hasattr(model, "module") else model mmodel = model.module if hasattr(model, "module") else model
if params.cur_epoch > 0 and params.delay_scale > 0.0:
with torch.no_grad():
delayed_model = get_delayed_model(model, params)
with torch.random.fork_rng(devices=[device], enabled=True):
(old_memory, _, _) = delayed_model(feature, supervisions)
(_, _, old_positive_embed, _, _) = delayed_model.sample_forward(old_memory)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
memory, position_embedding, memory_mask = model(feature, supervisions) memory, position_embedding, memory_mask = model(feature, supervisions)
# memory's shape is (N, T, C) # memory's shape is (N, T, C)
(sampled, softmax, positive_embed,
positive_embed_shifted,
negative_embed_shifted) = mmodel.sample_forward(memory)
if params.cur_epoch > 0 and params.delay_scale > 0.0:
delay_loss = compute_distance(old_positive_embed, positive_embed)
num_subsampled_frames = memory.shape[0] * memory.shape[1]
reconstruction_loss = (((positive_embed - memory.detach()) ** 2).sum() / num_subsampled_frames).sqrt() * num_subsampled_frames
ctc_output = mmodel.ctc_encoder_forward(memory, ctc_output = mmodel.ctc_encoder_forward(memory,
position_embedding, position_embedding,
@ -437,13 +523,6 @@ def compute_loss(
if params.reverse_att_scale != 0.0: if params.reverse_att_scale != 0.0:
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
(sampled, softmax, positive_embed,
positive_embed_shifted,
negative_embed_shifted) = mmodel.sample_forward(memory)
#if True: # TEMP
# positive_embed_shifted = torch.randn_like(positive_embed_shifted)
# negative_embed_shifted = positive_embed_shifted
reverse_decoder_logprob = mmodel.reverse_decoder_forward( reverse_decoder_logprob = mmodel.reverse_decoder_forward(
positive_embed_shifted, positive_embed_shifted,
@ -472,39 +551,27 @@ def compute_loss(
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, " f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, "
f"reverse_att_loss = {reverse_att_loss/num_frames}") f"reverse_att_loss = {reverse_att_loss/num_frames}")
bottleneck_ctc_output = mmodel.bottleneck_ctc_encoder_forward(positive_embed,
position_embedding,
memory_mask)
dense_fsa_vec = k2.DenseFsaVec(
bottleneck_ctc_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
bottleneck_ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
else: else:
reverse_att_loss = torch.tensor([0.0]).to(device) reverse_att_loss = torch.tensor([0.0]).to(device)
bottleneck_ctc_loss = torch.tensor([0.0]).to(device)
ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale - params.bottleneck_ctc_scale
loss = (ctc_scale * ctc_loss + loss = (params.ctc_scale * ctc_loss +
params.bottleneck_ctc_scale * bottleneck_ctc_loss + (params.reconstruction_scale if params.cur_epoch > 0 else 0.1 * params.reconstruction_scale) * reconstruction_loss +
params.att_scale * att_loss + params.att_scale * att_loss +
(params.reverse_att_scale if params.cur_epoch > 0 else 0.01 * params.reverse_att_scale) * reverse_att_loss) (params.reverse_att_scale if params.cur_epoch > 0 else 0.001 * params.reverse_att_scale) * reverse_att_loss)
if params.cur_epoch > 0 and params.delay_scale > 0.0:
loss = loss + params.delay_scale * delay_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = LossRecord() info = LossRecord()
# TODO: there are many GPU->CPU transfers here, maybe combine them into one. # TODO: there are many GPU->CPU transfers here, maybe combine them into one.
info['frames'] = supervision_segments[:, 2].sum().item() info['frames'] = supervision_segments[:, 2].sum().item()
info['ctc_loss'] = ctc_loss.detach().cpu().item() info['ctc_loss'] = ctc_loss.detach().cpu().item()
info['bottleneck_ctc_loss'] = bottleneck_ctc_loss.detach().cpu().item() info['reconstruction_loss'] = reconstruction_loss.detach().cpu().item()
if params.cur_epoch > 0 and params.delay_scale > 0.0:
info['delay_loss'] = delay_loss.detach().cpu().item()
if params.att_scale != 0.0: if params.att_scale != 0.0:
info['att_loss'] = att_loss.detach().cpu().item() info['att_loss'] = att_loss.detach().cpu().item()
if params.reverse_att_scale != 0.0: if params.reverse_att_scale != 0.0: