training pruned_transducer_stateless4 with delay-penalty
This commit is contained in:
parent
116d0cf26d
commit
718086460e
@ -78,6 +78,8 @@ class Transducer(nn.Module):
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
warmup: float = 1.0,
|
||||
delay_penalty: float = 0.0,
|
||||
return_sym_delay: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@ -155,10 +157,31 @@ class Transducer(nn.Module):
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
delay_penalty=delay_penalty,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
sym_delay = None
|
||||
if return_sym_delay:
|
||||
B, S, T0 = px_grad.shape
|
||||
T = T0 - 1
|
||||
if boundary is None:
|
||||
offset = torch.tensor(
|
||||
(T - 1) / 2,
|
||||
dtype=px_grad.dtype,
|
||||
device=px_grad.device,
|
||||
).expand(B, 1, 1)
|
||||
total_syms = S * B
|
||||
else:
|
||||
offset = (boundary[:, 3] - 1) / 2
|
||||
total_syms = torch.sum(boundary[:, 2])
|
||||
offset = torch.arange(
|
||||
T0, device=px_grad.device
|
||||
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
|
||||
sym_delay = px_grad * offset
|
||||
sym_delay = torch.sum(sym_delay) / total_syms
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
@ -188,7 +211,8 @@ class Transducer(nn.Module):
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
delay_penalty=delay_penalty,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
return (simple_loss, pruned_loss, sym_delay)
|
||||
|
||||
@ -322,6 +322,25 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--delay-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""A constant value to penalize symbol delay, this may be
|
||||
needed when training with time masking, to avoid the time masking
|
||||
encouraging the network to delay symbols.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--return-sym-delay",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to return `sym_delay` during training, this is a stat
|
||||
to measure symbols emission delay, especially for time masking training.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
@ -625,7 +644,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
simple_loss, pruned_loss, sym_delay = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -633,6 +652,8 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
delay_penalty=params.delay_penalty,
|
||||
return_sym_delay=params.return_sym_delay,
|
||||
)
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
@ -662,6 +683,9 @@ def compute_loss(
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
|
||||
if params.return_sym_delay:
|
||||
info["sym_delay"] = sym_delay.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
@ -905,6 +929,10 @@ def run(rank, world_size, args):
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "dynamic_chunk_training requires causal convolution"
|
||||
else:
|
||||
assert (
|
||||
params.delay_penalty == 0.0
|
||||
), "delay_penalty is intended for dynamic_chunk_training"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
|
||||
@ -546,8 +546,11 @@ class MetricsTracker(collections.defaultdict):
|
||||
ans = []
|
||||
for k, v in self.items():
|
||||
if k != "frames":
|
||||
norm_value = float(v) / num_frames
|
||||
ans.append((k, norm_value))
|
||||
if k != "sym_delay":
|
||||
norm_value = float(v) / num_frames
|
||||
ans.append((k, norm_value))
|
||||
else:
|
||||
ans.append((k, float(v)))
|
||||
return ans
|
||||
|
||||
def reduce(self, device):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user