mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Initially working version with delay_loss...
This commit is contained in:
parent
65b737576e
commit
2213457bd3
@ -222,11 +222,6 @@ class BidirectionalConformer(nn.Module):
|
|||||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bottleneck_ctc_encoder = ConformerEncoder(encoder_layer, num_ctc_encoder_layers)
|
|
||||||
self.bottleneck_ctc_output_layer = nn.Sequential(
|
|
||||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
|
||||||
)
|
|
||||||
|
|
||||||
# absolute position encoding, used by various layer types
|
# absolute position encoding, used by various layer types
|
||||||
self.abs_pos = PositionalEncoding(d_model, dropout)
|
self.abs_pos = PositionalEncoding(d_model, dropout)
|
||||||
|
|
||||||
@ -480,38 +475,6 @@ class BidirectionalConformer(nn.Module):
|
|||||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def bottleneck_ctc_encoder_forward(
|
|
||||||
self,
|
|
||||||
positive_embed: torch.Tensor,
|
|
||||||
pos_emb: torch.Tensor,
|
|
||||||
memory_key_padding_mask: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Passes the output of sample_forward() through the CTC "from-bottleneck"
|
|
||||||
CTC encoder and the CTC
|
|
||||||
output layer to give the output that can be given to the CTC loss function
|
|
||||||
|
|
||||||
Args:
|
|
||||||
positive_embed:
|
|
||||||
One of the outputs of sample_forward(), with shape (T, N, E)
|
|
||||||
pos_emb:
|
|
||||||
Relative positional embedding tensor: (N, 2*T-1, E)
|
|
||||||
memory_key_padding_mask:
|
|
||||||
The padding mask from forward(), a tensor of bool of shape (N, T)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Tensor with shape [N, T, C] where C is the number of classes
|
|
||||||
(e.g. number of phones or word pieces). Contains normalized
|
|
||||||
log-probabilities.
|
|
||||||
"""
|
|
||||||
x = self.bottleneck_ctc_encoder(positive_embed,
|
|
||||||
pos_emb,
|
|
||||||
key_padding_mask=memory_key_padding_mask)
|
|
||||||
x = self.bottleneck_ctc_output_layer(x)
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
|
||||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def self_prediction_forward(
|
def self_prediction_forward(
|
||||||
self,
|
self,
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import random # temp..
|
import random # temp..
|
||||||
@ -155,7 +156,7 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_2"),
|
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_delay"),
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
"lang_dir": Path("data/lang_bpe"),
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4, # can't be changed
|
"subsampling_factor": 4, # can't be changed
|
||||||
@ -171,9 +172,13 @@ def get_params() -> AttributeDict:
|
|||||||
"reduction": "sum",
|
"reduction": "sum",
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
"accum_grad": 1,
|
"accum_grad": 1,
|
||||||
"att_scale": 0.3,
|
"att_scale": 0.5,
|
||||||
"reverse_att_scale": 0.2,
|
"reverse_att_scale": 0.2,
|
||||||
"bottleneck_ctc_scale": 0.2, # ctc_scale == 1.0 - att_scale - reverse_att_scale - bottleneck_ctc_scale
|
"ctc_scale": 0.3,
|
||||||
|
"reconstruction_scale": 0.5, # Scale on log of reconstruction error after discrete bottleneck.
|
||||||
|
"delay_scale": 2.0, # Scale on difference between current and
|
||||||
|
# delayed version of positive_embed.
|
||||||
|
"delay_minibatches": 200,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"num_trunk_encoder_layers": 12,
|
"num_trunk_encoder_layers": 12,
|
||||||
@ -349,6 +354,69 @@ class LossRecord(collections.defaultdict):
|
|||||||
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
tb_writer.add_scalar(prefix + k, v, batch_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def get_delayed_model(model: nn.Module,
|
||||||
|
params: AttributeDict) -> nn.Module:
|
||||||
|
if hasattr(model, "module"):
|
||||||
|
model = model.module
|
||||||
|
delay_minibatches = params.delay_minibatches
|
||||||
|
|
||||||
|
cur_batch_idx = params.batch_idx_train
|
||||||
|
|
||||||
|
try:
|
||||||
|
# hasattr doesn't seem to work for this... use try-except to test if it
|
||||||
|
# has the attribute.
|
||||||
|
_ = params.cur_delayed_model
|
||||||
|
except KeyError:
|
||||||
|
params.cur_delayed_model = copy.deepcopy(model)
|
||||||
|
params.prev_delayed_model = params.cur_delayed_model
|
||||||
|
params.cur_delayed_batch_idx = cur_batch_idx
|
||||||
|
|
||||||
|
if params.cur_delayed_batch_idx <= cur_batch_idx - delay_minibatches:
|
||||||
|
params.cur_delayed_batch_idx = cur_batch_idx
|
||||||
|
params.prev_delayed_model = params.cur_delayed_model
|
||||||
|
params.cur_delayed_model = copy.deepcopy(model)
|
||||||
|
|
||||||
|
return params.prev_delayed_model
|
||||||
|
|
||||||
|
|
||||||
|
def compute_distance(feats1, feats2):
|
||||||
|
"""
|
||||||
|
Assumes that feats1 and feats2 are some kind of features with (N, T, C) or (T, N, C) layout.
|
||||||
|
Computes a distance between them, that will have the property that its derivative w.r.t.
|
||||||
|
feats2 will be orthogonal to feats2. This will avoid any direct pressure for feats2 to
|
||||||
|
grow or shrink. (We assume that feats1 is without grad). Let K = T * N (the total
|
||||||
|
number of frames. Then the returned value is half the total, over the frames, of the
|
||||||
|
log of the (average squared distance, per frame), between feats1 and (alpha * feats2),
|
||||||
|
for optimally chosen alpha, i.e.:
|
||||||
|
|
||||||
|
ans = K * (((feats1 - alpha * feats2) ** 2).sum() / K).sqrt()
|
||||||
|
|
||||||
|
[the minimum of that result, for any alpha]. d/d alpha of:
|
||||||
|
|
||||||
|
((feats1 - alpha * feats2) * (feats1 - alpha * feats2)).sum()
|
||||||
|
|
||||||
|
Expanded in terms of alpha, this is:
|
||||||
|
(feats1**2).sum() + alpha**2 *(feats2**2).sum() - 2 * alpha * (feats1*feats2).sum()
|
||||||
|
and d/dalpha of this is:
|
||||||
|
2 * alpha *(feats2**2).sum() - 2 * (feats1*feats2).sum()
|
||||||
|
so alpha = (feats1*feats2).sum() / (feats2**2).sum()
|
||||||
|
|
||||||
|
"""
|
||||||
|
feats1_prod = (feats1 ** 2).sum()
|
||||||
|
feats2_prod = (feats2 ** 2).sum()
|
||||||
|
cross_prod = (feats2 * feats1).sum()
|
||||||
|
alpha = cross_prod.detach() / feats2_prod.detach()
|
||||||
|
if random.random() < 0.01:
|
||||||
|
logging.info(f"compute_distance: alpha = {alpha.to('cpu').item()}")
|
||||||
|
K = feats1.shape[0] * feats1.shape[1]
|
||||||
|
avg_dist = ((feats1_prod + (alpha**2 * feats2_prod) - 2 * alpha * cross_prod) / K)
|
||||||
|
if avg_dist <= 0.0:
|
||||||
|
avg_dist = torch.tensor([0.0], device=feats1.device)
|
||||||
|
else:
|
||||||
|
avg_dist = avg_dist.sqrt()
|
||||||
|
return K * avg_dist
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -388,10 +456,28 @@ def compute_loss(
|
|||||||
|
|
||||||
mmodel = model.module if hasattr(model, "module") else model
|
mmodel = model.module if hasattr(model, "module") else model
|
||||||
|
|
||||||
|
if params.cur_epoch > 0 and params.delay_scale > 0.0:
|
||||||
|
with torch.no_grad():
|
||||||
|
delayed_model = get_delayed_model(model, params)
|
||||||
|
with torch.random.fork_rng(devices=[device], enabled=True):
|
||||||
|
(old_memory, _, _) = delayed_model(feature, supervisions)
|
||||||
|
(_, _, old_positive_embed, _, _) = delayed_model.sample_forward(old_memory)
|
||||||
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
memory, position_embedding, memory_mask = model(feature, supervisions)
|
memory, position_embedding, memory_mask = model(feature, supervisions)
|
||||||
# memory's shape is (N, T, C)
|
# memory's shape is (N, T, C)
|
||||||
|
|
||||||
|
(sampled, softmax, positive_embed,
|
||||||
|
positive_embed_shifted,
|
||||||
|
negative_embed_shifted) = mmodel.sample_forward(memory)
|
||||||
|
|
||||||
|
if params.cur_epoch > 0 and params.delay_scale > 0.0:
|
||||||
|
delay_loss = compute_distance(old_positive_embed, positive_embed)
|
||||||
|
|
||||||
|
num_subsampled_frames = memory.shape[0] * memory.shape[1]
|
||||||
|
reconstruction_loss = (((positive_embed - memory.detach()) ** 2).sum() / num_subsampled_frames).sqrt() * num_subsampled_frames
|
||||||
|
|
||||||
|
|
||||||
ctc_output = mmodel.ctc_encoder_forward(memory,
|
ctc_output = mmodel.ctc_encoder_forward(memory,
|
||||||
position_embedding,
|
position_embedding,
|
||||||
@ -437,13 +523,6 @@ def compute_loss(
|
|||||||
|
|
||||||
if params.reverse_att_scale != 0.0:
|
if params.reverse_att_scale != 0.0:
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
(sampled, softmax, positive_embed,
|
|
||||||
positive_embed_shifted,
|
|
||||||
negative_embed_shifted) = mmodel.sample_forward(memory)
|
|
||||||
|
|
||||||
#if True: # TEMP
|
|
||||||
# positive_embed_shifted = torch.randn_like(positive_embed_shifted)
|
|
||||||
# negative_embed_shifted = positive_embed_shifted
|
|
||||||
|
|
||||||
reverse_decoder_logprob = mmodel.reverse_decoder_forward(
|
reverse_decoder_logprob = mmodel.reverse_decoder_forward(
|
||||||
positive_embed_shifted,
|
positive_embed_shifted,
|
||||||
@ -472,39 +551,27 @@ def compute_loss(
|
|||||||
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, "
|
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, "
|
||||||
f"reverse_att_loss = {reverse_att_loss/num_frames}")
|
f"reverse_att_loss = {reverse_att_loss/num_frames}")
|
||||||
|
|
||||||
bottleneck_ctc_output = mmodel.bottleneck_ctc_encoder_forward(positive_embed,
|
|
||||||
position_embedding,
|
|
||||||
memory_mask)
|
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
|
||||||
bottleneck_ctc_output,
|
|
||||||
supervision_segments,
|
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
bottleneck_ctc_loss = k2.ctc_loss(
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
dense_fsa_vec=dense_fsa_vec,
|
|
||||||
output_beam=params.beam_size,
|
|
||||||
reduction=params.reduction,
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
reverse_att_loss = torch.tensor([0.0]).to(device)
|
reverse_att_loss = torch.tensor([0.0]).to(device)
|
||||||
bottleneck_ctc_loss = torch.tensor([0.0]).to(device)
|
|
||||||
|
|
||||||
ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale - params.bottleneck_ctc_scale
|
|
||||||
loss = (ctc_scale * ctc_loss +
|
loss = (params.ctc_scale * ctc_loss +
|
||||||
params.bottleneck_ctc_scale * bottleneck_ctc_loss +
|
(params.reconstruction_scale if params.cur_epoch > 0 else 0.1 * params.reconstruction_scale) * reconstruction_loss +
|
||||||
params.att_scale * att_loss +
|
params.att_scale * att_loss +
|
||||||
(params.reverse_att_scale if params.cur_epoch > 0 else 0.01 * params.reverse_att_scale) * reverse_att_loss)
|
(params.reverse_att_scale if params.cur_epoch > 0 else 0.001 * params.reverse_att_scale) * reverse_att_loss)
|
||||||
|
if params.cur_epoch > 0 and params.delay_scale > 0.0:
|
||||||
|
loss = loss + params.delay_scale * delay_loss
|
||||||
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = LossRecord()
|
info = LossRecord()
|
||||||
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
||||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||||
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
||||||
info['bottleneck_ctc_loss'] = bottleneck_ctc_loss.detach().cpu().item()
|
info['reconstruction_loss'] = reconstruction_loss.detach().cpu().item()
|
||||||
|
if params.cur_epoch > 0 and params.delay_scale > 0.0:
|
||||||
|
info['delay_loss'] = delay_loss.detach().cpu().item()
|
||||||
if params.att_scale != 0.0:
|
if params.att_scale != 0.0:
|
||||||
info['att_loss'] = att_loss.detach().cpu().item()
|
info['att_loss'] = att_loss.detach().cpu().item()
|
||||||
if params.reverse_att_scale != 0.0:
|
if params.reverse_att_scale != 0.0:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user