Add BS to zipformer2

This commit is contained in:
Erwan 2023-08-15 10:26:38 +02:00
parent b1028c0d61
commit 752e16be10
3 changed files with 13 additions and 3 deletions

View File

@ -27,6 +27,8 @@ import torch.nn.functional as F
from icefall.utils import make_pad_mask
NON_BLANK_THRES = 0.9
class FrameReducer(nn.Module):
"""The encoder output is first used to calculate
@ -72,7 +74,9 @@ class FrameReducer(nn.Module):
N, T, C = x.size()
padding_mask = make_pad_mask(x_lens)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * (
~padding_mask
)
if y_lens is not None or self.training is False:
# Limit the maximum number of reduced frames

View File

@ -181,7 +181,6 @@ class AsrModel(nn.Module):
reduction:
Specifies the reduction to apply to the output
"""
# TODO: Add delay penalty to CTC Loss
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
encoder_out_fr = encoder_out
@ -211,7 +210,7 @@ class AsrModel(nn.Module):
token_ids=targets,
)
# TODO: Find out why we need to do that but not in icefall
# TODO: Crash without this line
supervision_segments = supervision_segments.to("cpu")
decoding_graph = k2.ctc_graph(
token_ids, modified=False, device=encoder_out.device

View File

@ -538,6 +538,7 @@ def get_params() -> AttributeDict:
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for zipformer
"ctc_beam_size": 10,
"feature_dim": 80,
"subsampling_factor": 4, # not passed in, this is fixed.
"warm_step": 2000,
@ -783,6 +784,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
warmup: float,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute loss given the model and its inputs.
@ -823,9 +825,12 @@ def compute_loss(
x=feature,
x_lens=feature_lens,
y=y,
supervisions=supervisions,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
ctc_beam_size=params.ctc_beam_size,
warmup=warmup,
)
loss = 0.0
@ -886,6 +891,7 @@ def compute_validation_loss(
sp=sp,
batch=batch,
is_training=False,
warmup=(params.batch_idx_train / params.warm_step),
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
@ -980,6 +986,7 @@ def train_one_epoch(
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info