mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 07:34:21 +00:00
Add BS to zipformer2
This commit is contained in:
parent
b1028c0d61
commit
752e16be10
@ -27,6 +27,8 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
NON_BLANK_THRES = 0.9
|
||||||
|
|
||||||
|
|
||||||
class FrameReducer(nn.Module):
|
class FrameReducer(nn.Module):
|
||||||
"""The encoder output is first used to calculate
|
"""The encoder output is first used to calculate
|
||||||
@ -72,7 +74,9 @@ class FrameReducer(nn.Module):
|
|||||||
N, T, C = x.size()
|
N, T, C = x.size()
|
||||||
|
|
||||||
padding_mask = make_pad_mask(x_lens)
|
padding_mask = make_pad_mask(x_lens)
|
||||||
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(0.9)) * (~padding_mask)
|
non_blank_mask = (ctc_output[:, :, blank_id] < math.log(NON_BLANK_THRES)) * (
|
||||||
|
~padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
if y_lens is not None or self.training is False:
|
if y_lens is not None or self.training is False:
|
||||||
# Limit the maximum number of reduced frames
|
# Limit the maximum number of reduced frames
|
||||||
|
@ -181,7 +181,6 @@ class AsrModel(nn.Module):
|
|||||||
reduction:
|
reduction:
|
||||||
Specifies the reduction to apply to the output
|
Specifies the reduction to apply to the output
|
||||||
"""
|
"""
|
||||||
# TODO: Add delay penalty to CTC Loss
|
|
||||||
# Compute CTC log-prob
|
# Compute CTC log-prob
|
||||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
encoder_out_fr = encoder_out
|
encoder_out_fr = encoder_out
|
||||||
@ -211,7 +210,7 @@ class AsrModel(nn.Module):
|
|||||||
token_ids=targets,
|
token_ids=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Find out why we need to do that but not in icefall
|
# TODO: Crash without this line
|
||||||
supervision_segments = supervision_segments.to("cpu")
|
supervision_segments = supervision_segments.to("cpu")
|
||||||
decoding_graph = k2.ctc_graph(
|
decoding_graph = k2.ctc_graph(
|
||||||
token_ids, modified=False, device=encoder_out.device
|
token_ids, modified=False, device=encoder_out.device
|
||||||
|
@ -538,6 +538,7 @@ def get_params() -> AttributeDict:
|
|||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
# parameters for zipformer
|
# parameters for zipformer
|
||||||
|
"ctc_beam_size": 10,
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
@ -783,6 +784,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
|
warmup: float,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute loss given the model and its inputs.
|
Compute loss given the model and its inputs.
|
||||||
@ -823,9 +825,12 @@ def compute_loss(
|
|||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
|
supervisions=supervisions,
|
||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
ctc_beam_size=params.ctc_beam_size,
|
||||||
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
@ -886,6 +891,7 @@ def compute_validation_loss(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
|
warmup=(params.batch_idx_train / params.warm_step),
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
@ -980,6 +986,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
warmup=(params.batch_idx_train / params.warm_step),
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
Loading…
x
Reference in New Issue
Block a user