mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
Add BS to zipformer2
This commit is contained in:
parent
b1028c0d61
commit
752e16be10
@ -27,6 +27,8 @@ import torch.nn.functional as F
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
NON_BLANK_THRES = 0.9
|
||||
|
||||
|
||||
class FrameReducer(nn.Module):
|
||||
"""The encoder output is first used to calculate
|
||||
@ -72,7 +74,9 @@ class FrameReducer(nn.Module):
|
||||
N, T, C = x.size()
|
||||
|
||||
padding_mask = make_pad_mask(x_lens)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * (
|
||||
~padding_mask
|
||||
)
|
||||
|
||||
if y_lens is not None or self.training is False:
|
||||
# Limit the maximum number of reduced frames
|
||||
|
@ -181,7 +181,6 @@ class AsrModel(nn.Module):
|
||||
reduction:
|
||||
Specifies the reduction to apply to the output
|
||||
"""
|
||||
# TODO: Add delay penalty to CTC Loss
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
encoder_out_fr = encoder_out
|
||||
@ -211,7 +210,7 @@ class AsrModel(nn.Module):
|
||||
token_ids=targets,
|
||||
)
|
||||
|
||||
# TODO: Find out why we need to do that but not in icefall
|
||||
# TODO: Crash without this line
|
||||
supervision_segments = supervision_segments.to("cpu")
|
||||
decoding_graph = k2.ctc_graph(
|
||||
token_ids, modified=False, device=encoder_out.device
|
||||
|
@ -538,6 +538,7 @@ def get_params() -> AttributeDict:
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
# parameters for zipformer
|
||||
"ctc_beam_size": 10,
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
"warm_step": 2000,
|
||||
@ -783,6 +784,7 @@ def compute_loss(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
warmup: float,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute loss given the model and its inputs.
|
||||
@ -823,9 +825,12 @@ def compute_loss(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
supervisions=supervisions,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
ctc_beam_size=params.ctc_beam_size,
|
||||
warmup=warmup,
|
||||
)
|
||||
|
||||
loss = 0.0
|
||||
@ -886,6 +891,7 @@ def compute_validation_loss(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
warmup=(params.batch_idx_train / params.warm_step),
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
@ -980,6 +986,7 @@ def train_one_epoch(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.warm_step),
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
Loading…
x
Reference in New Issue
Block a user