training pruned_transducer_stateless4 with delay-penalty

This commit is contained in:
pkufool 2022-07-26 13:31:38 +08:00
parent 116d0cf26d
commit 718086460e
3 changed files with 59 additions and 4 deletions

View File

@ -78,6 +78,8 @@ class Transducer(nn.Module):
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
delay_penalty: float = 0.0,
return_sym_delay: bool = False,
) -> torch.Tensor:
"""
Args:
@ -155,10 +157,31 @@ class Transducer(nn.Module):
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum",
return_grad=True,
)
sym_delay = None
if return_sym_delay:
B, S, T0 = px_grad.shape
T = T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2,
dtype=px_grad.dtype,
device=px_grad.device,
).expand(B, 1, 1)
total_syms = S * B
else:
offset = (boundary[:, 3] - 1) / 2
total_syms = torch.sum(boundary[:, 2])
offset = torch.arange(
T0, device=px_grad.device
).reshape(1, 1, T0) - offset.reshape(B, 1, 1)
sym_delay = px_grad * offset
sym_delay = torch.sum(sym_delay) / total_syms
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
@ -188,7 +211,8 @@ class Transducer(nn.Module):
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction="sum",
)
return (simple_loss, pruned_loss)
return (simple_loss, pruned_loss, sym_delay)

View File

@ -322,6 +322,25 @@ def get_parser():
""",
)
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
@ -625,7 +644,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, sym_delay = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -633,6 +652,8 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
@ -662,6 +683,9 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item()
return loss, info
@ -905,6 +929,10 @@ def run(rank, world_size, args):
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
logging.info(params)

View File

@ -546,8 +546,11 @@ class MetricsTracker(collections.defaultdict):
ans = []
for k, v in self.items():
if k != "frames":
norm_value = float(v) / num_frames
ans.append((k, norm_value))
if k != "sym_delay":
norm_value = float(v) / num_frames
ans.append((k, norm_value))
else:
ans.append((k, float(v)))
return ans
def reduce(self, device):