mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Remove reconstruction loss, have randomly averaged CTC loss
This commit is contained in:
parent
3415dab779
commit
6fa0f16e0c
@ -222,6 +222,11 @@ class BidirectionalConformer(nn.Module):
|
||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||
)
|
||||
|
||||
self.bottleneck_ctc_encoder = ConformerEncoder(encoder_layer, num_ctc_encoder_layers)
|
||||
self.bottleneck_ctc_output_layer = nn.Sequential(
|
||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||
)
|
||||
|
||||
# absolute position encoding, used by various layer types
|
||||
self.abs_pos = PositionalEncoding(d_model, dropout)
|
||||
|
||||
@ -449,6 +454,7 @@ class BidirectionalConformer(nn.Module):
|
||||
memory: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
positive_embed: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Passes the output of forward() through the CTC encoder and the CTC
|
||||
@ -461,6 +467,8 @@ class BidirectionalConformer(nn.Module):
|
||||
Relative positional embedding tensor: (N, 2*T-1, E)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from forward(), a tensor of bool of shape (N, T)
|
||||
positive_embed:
|
||||
Needed only during training, so we can train the bottleneck layer..
|
||||
|
||||
Returns:
|
||||
A Tensor with shape [N, T, C] where C is the number of classes
|
||||
@ -473,6 +481,21 @@ class BidirectionalConformer(nn.Module):
|
||||
x = self.ctc_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
||||
|
||||
if self.training:
|
||||
# Randomly interpolate half-and-half with the bottleneck CTC
|
||||
# encoder, at the frame level
|
||||
y = self.bottleneck_ctc_encoder(positive_embed,
|
||||
pos_emb,
|
||||
key_padding_mask=memory_key_padding_mask)
|
||||
y = self.bottleneck_ctc_output_layer(y)
|
||||
y = y.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
y = nn.functional.log_softmax(y, dim=-1) # (N, T, C)
|
||||
(N, T, C) = y.shape
|
||||
r = torch.rand(N, T, 1, device=y.device)
|
||||
x = (y * r) + x - (x * r)
|
||||
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
@ -156,7 +156,7 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_delay"),
|
||||
"exp_dir": Path("conformer_ctc_bn_2d/exp_bidirectional_delay_norecon"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4, # can't be changed
|
||||
@ -175,8 +175,7 @@ def get_params() -> AttributeDict:
|
||||
"att_scale": 0.5,
|
||||
"reverse_att_scale": 0.2,
|
||||
"ctc_scale": 0.3,
|
||||
"reconstruction_scale": 0.5, # Scale on log of reconstruction error after discrete bottleneck.
|
||||
"delay_scale": 2.0, # Scale on difference between current and
|
||||
"delay_scale": 0.1, # Scale on difference between current and
|
||||
# delayed version of positive_embed.
|
||||
"delay_minibatches": 200,
|
||||
"attention_dim": 512,
|
||||
@ -476,12 +475,11 @@ def compute_loss(
|
||||
delay_loss = compute_distance(old_positive_embed, positive_embed)
|
||||
|
||||
num_subsampled_frames = memory.shape[0] * memory.shape[1]
|
||||
reconstruction_loss = (((positive_embed - memory.detach()) ** 2).sum() / num_subsampled_frames).sqrt() * num_subsampled_frames
|
||||
|
||||
|
||||
ctc_output = mmodel.ctc_encoder_forward(memory,
|
||||
position_embedding,
|
||||
memory_mask)
|
||||
memory_mask,
|
||||
positive_embed)
|
||||
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
@ -556,7 +554,6 @@ def compute_loss(
|
||||
|
||||
|
||||
loss = (params.ctc_scale * ctc_loss +
|
||||
(params.reconstruction_scale if params.cur_epoch > 0 else 0.1 * params.reconstruction_scale) * reconstruction_loss +
|
||||
params.att_scale * att_loss +
|
||||
(params.reverse_att_scale if params.cur_epoch > 0 else 0.001 * params.reverse_att_scale) * reverse_att_loss)
|
||||
if params.cur_epoch > 0 and params.delay_scale > 0.0:
|
||||
@ -569,7 +566,6 @@ def compute_loss(
|
||||
# TODO: there are many GPU->CPU transfers here, maybe combine them into one.
|
||||
info['frames'] = supervision_segments[:, 2].sum().item()
|
||||
info['ctc_loss'] = ctc_loss.detach().cpu().item()
|
||||
info['reconstruction_loss'] = reconstruction_loss.detach().cpu().item()
|
||||
if params.cur_epoch > 0 and params.delay_scale > 0.0:
|
||||
info['delay_loss'] = delay_loss.detach().cpu().item()
|
||||
if params.att_scale != 0.0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user