mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
First version that seems to be converging OK...
This commit is contained in:
parent
c4cc952265
commit
6f8b7b9c3b
@ -222,6 +222,11 @@ class BidirectionalConformer(nn.Module):
|
||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||
)
|
||||
|
||||
self.bottleneck_ctc_encoder = ConformerEncoder(encoder_layer, num_ctc_encoder_layers)
|
||||
self.bottleneck_ctc_output_layer = nn.Sequential(
|
||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||
)
|
||||
|
||||
# absolute position encoding, used by various layer types
|
||||
self.abs_pos = PositionalEncoding(d_model, dropout)
|
||||
|
||||
@ -354,7 +359,8 @@ class BidirectionalConformer(nn.Module):
|
||||
Given the "memory" from forward(), run the sample_and_redict module.
|
||||
See documentation for forward() of class SampleAndPredict for more info.
|
||||
|
||||
Returns (sampled, softmax, positive_embed_shifted, negative_embed_shifted),
|
||||
Returns (sampled, softmax, positive_embed, positive_embed_shifted,
|
||||
negative_embed_shifted),
|
||||
where positive_embed_shifted, for instance, is positive_embed
|
||||
shifted by one so that positive_embed_shifted[t] == positive_embed[t-1], as in:
|
||||
(T, N, E) = positive_embed.shape
|
||||
@ -368,7 +374,7 @@ class BidirectionalConformer(nn.Module):
|
||||
positive_embed_shifted = torch.cat((zeros, positive_embed[:-1,:,:]), dim=0)
|
||||
negative_embed_shifted = torch.cat((zeros, negative_embed[:-1,:,:]), dim=0)
|
||||
|
||||
return (sampled, softmax, positive_embed_shifted, negative_embed_shifted)
|
||||
return (sampled, softmax, positive_embed, positive_embed_shifted, negative_embed_shifted)
|
||||
|
||||
def decoder_forward(
|
||||
self,
|
||||
@ -451,7 +457,7 @@ class BidirectionalConformer(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Passes the output of forward() through the CTC encoder and the CTC
|
||||
output to give the output that can be given to the CTC loss function
|
||||
output layer to give the output that can be given to the CTC loss function
|
||||
|
||||
Args:
|
||||
memory:
|
||||
@ -474,6 +480,38 @@ class BidirectionalConformer(nn.Module):
|
||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
||||
return x
|
||||
|
||||
def bottleneck_ctc_encoder_forward(
|
||||
self,
|
||||
positive_embed: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Passes the output of sample_forward() through the CTC "from-bottleneck"
|
||||
CTC encoder and the CTC
|
||||
output layer to give the output that can be given to the CTC loss function
|
||||
|
||||
Args:
|
||||
positive_embed:
|
||||
One of the outputs of sample_forward(), with shape (T, N, E)
|
||||
pos_emb:
|
||||
Relative positional embedding tensor: (N, 2*T-1, E)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from forward(), a tensor of bool of shape (N, T)
|
||||
|
||||
Returns:
|
||||
A Tensor with shape [N, T, C] where C is the number of classes
|
||||
(e.g. number of phones or word pieces). Contains normalized
|
||||
log-probabilities.
|
||||
"""
|
||||
x = self.bottleneck_ctc_encoder(positive_embed,
|
||||
pos_emb,
|
||||
key_padding_mask=memory_key_padding_mask)
|
||||
x = self.bottleneck_ctc_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
||||
return x
|
||||
|
||||
|
||||
def self_prediction_forward(
|
||||
self,
|
||||
|
@ -171,8 +171,9 @@ def get_params() -> AttributeDict:
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
"accum_grad": 1,
|
||||
"att_scale": 0.6,
|
||||
"reverse_att_scale": 0.1, # ctc_scale == 1.0 - att_scale - reverse_att_scale
|
||||
"att_scale": 0.3,
|
||||
"reverse_att_scale": 0.2,
|
||||
"bottleneck_ctc_scale": 0.2, # ctc_scale == 1.0 - att_scale - reverse_att_scale - bottleneck_ctc_scale
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"num_trunk_encoder_layers": 12,
|
||||
@ -391,6 +392,7 @@ def compute_loss(
|
||||
memory, position_embedding, memory_mask = model(feature, supervisions)
|
||||
# memory's shape is (N, T, C)
|
||||
|
||||
|
||||
ctc_output = mmodel.ctc_encoder_forward(memory,
|
||||
position_embedding,
|
||||
memory_mask)
|
||||
@ -435,10 +437,14 @@ def compute_loss(
|
||||
|
||||
if params.reverse_att_scale != 0.0:
|
||||
with torch.set_grad_enabled(is_training):
|
||||
(sampled, softmax,
|
||||
(sampled, softmax, positive_embed,
|
||||
positive_embed_shifted,
|
||||
negative_embed_shifted) = mmodel.sample_forward(memory)
|
||||
|
||||
#if True: # TEMP
|
||||
# positive_embed_shifted = torch.randn_like(positive_embed_shifted)
|
||||
# negative_embed_shifted = positive_embed_shifted
|
||||
|
||||
reverse_decoder_logprob = mmodel.reverse_decoder_forward(
|
||||
positive_embed_shifted,
|
||||
memory_mask,
|
||||
@ -465,19 +471,40 @@ def compute_loss(
|
||||
print(f"Self-prediction logprob = {self_prediction_logprob/num_frames}, "
|
||||
f"reverse-decoder logprob = {reverse_decoder_logprob/num_frames}, "
|
||||
f"reverse_att_loss = {reverse_att_loss/num_frames}")
|
||||
|
||||
bottleneck_ctc_output = mmodel.bottleneck_ctc_encoder_forward(positive_embed,
|
||||
position_embedding,
|
||||
memory_mask)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
bottleneck_ctc_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
bottleneck_ctc_loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
reduction=params.reduction,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
else:
|
||||
reverse_att_loss = torch.tensor([0.0]).to(device)
|
||||
bottleneck_ctc_loss = torch.tensor([0.0]).to(device)
|
||||
|
||||
ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale
|
||||
ctc_scale = 1.0 - params.att_scale - params.reverse_att_scale - params.bottleneck_ctc_scale
|
||||
loss = (ctc_scale * ctc_loss +
|
||||
params.bottleneck_ctc_scale * bottleneck_ctc_loss +
|
||||
params.att_scale * att_loss +
|
||||
params.reverse_att_scale * reverse_att_loss)
|
||||
(params.reverse_att_scale if params.cur_epoch > 0 else 0.01 * params.reverse_att_scale) * reverse_att_loss)
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = LossRecord()
|
||||
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
||||
info['bottleneck_ctc_loss'] = bottleneck_ctc_loss.detach().cpu().item()
|
||||
if params.att_scale != 0.0:
|
||||
info['att_loss'] = att_loss.detach().cpu().item()
|
||||
if params.reverse_att_scale != 0.0:
|
||||
@ -709,6 +736,7 @@ def run(rank, world_size, args):
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
optimizer.set_epoch(epoch) # specific to Gloam
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
params.cur_epoch = epoch
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user