Fix typo and reformatted zipformer

This commit is contained in:
Yifan Yang 2023-06-02 12:41:34 +08:00
parent 208839fb9b
commit a98f6b27a4
16 changed files with 524 additions and 207 deletions

View File

@ -273,7 +273,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -370,7 +371,9 @@ def decode_one_batch(
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out, encoder_out_lens = model.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
@ -430,7 +433,10 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
@ -561,7 +567,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -594,7 +602,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -655,7 +664,9 @@ def main():
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -687,9 +698,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -716,9 +727,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -777,7 +788,9 @@ def main():
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
decoding_graph = k2.trivial_graph(
params.vocab_size - 1, device=device
)
else:
decoding_graph = None
word_table = None

View File

@ -90,11 +90,13 @@ class DecodeStream(object):
)
elif params.decoding_method == "fast_beam_search":
# The rnnt_decoding_stream for fast_beam_search.
self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
decoding_graph
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
k2.RnntDecodingStream(decoding_graph)
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
@property
def done(self) -> bool:
@ -124,10 +126,13 @@ class DecodeStream(object):
"""Consume chunk_size frames of features"""
chunk_length = chunk_size + self.pad_length
ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
ret_length = min(
self.num_frames - self.num_processed_frames, chunk_length
)
ret_features = self.features[
self.num_processed_frames : self.num_processed_frames + ret_length # noqa
self.num_processed_frames : self.num_processed_frames
+ ret_length # noqa
]
self.num_processed_frames += chunk_size

View File

@ -118,7 +118,9 @@ class Decoder(nn.Module):
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output

View File

@ -276,7 +276,9 @@ class EncoderModel(nn.Module):
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out, encoder_out_lens = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return encoder_out, encoder_out_lens
@ -288,7 +290,9 @@ class StreamingEncoderModel(nn.Module):
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
super().__init__()
assert len(encoder.chunk_size) == 1, encoder.chunk_size
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
assert (
len(encoder.left_context_frames) == 1
), encoder.left_context_frames
self.chunk_size = encoder.chunk_size[0]
self.left_context_len = encoder.left_context_frames[0]
@ -315,7 +319,11 @@ class StreamingEncoderModel(nn.Module):
left_context_len = self.left_context_len
cached_embed_left_pad = states[-2]
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
(
x,
x_lens,
new_cached_embed_left_pad,
) = self.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lengths,
cached_left_pad=cached_embed_left_pad,
@ -335,7 +343,9 @@ class StreamingEncoderModel(nn.Module):
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
@ -377,7 +387,9 @@ class StreamingEncoderModel(nn.Module):
embed_states = self.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
processed_lens = torch.zeros(
batch_size, dtype=torch.int32, device=device
)
states.append(processed_lens)
return states
@ -411,9 +423,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -438,9 +450,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -494,10 +506,14 @@ def main():
# Wrap encoder and encoder_embed as a module
if params.causal:
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
model.encoder = StreamingEncoderModel(
model.encoder, model.encoder_embed
)
chunk_size = model.encoder.chunk_size
left_context_len = model.encoder.left_context_len
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
filename = (
f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
)
else:
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
filename = "jit_script.pt"
@ -516,7 +532,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -266,7 +266,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -126,7 +126,9 @@ def greedy_search(
if decoder_out is None:
assert hyp is None, hyp
hyp = [blank_id] * context_size
decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0)
decoder_input = torch.tensor(
hyp, dtype=torch.int32, device=device
).unsqueeze(0)
# decoder_input.shape (1,, 1 context_size)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
else:
@ -146,7 +148,9 @@ def greedy_search(
decoder_input = torch.tensor(
decoder_input, dtype=torch.int32, device=device
).unsqueeze(0)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(
1
)
return hyp, decoder_out
@ -247,7 +251,12 @@ def main():
num_processed_frames += chunk_length
hyp, decoder_out = greedy_search(
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device
decoder,
joiner,
encoder_out.squeeze(0),
decoder_out,
hyp,
device=device,
)
context_size = 2
@ -263,7 +272,9 @@ torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._set_graph_executor_optimize(False)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -29,8 +29,12 @@ class Joiner(nn.Module):
):
super().__init__()
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.encoder_proj = ScaledLinear(
encoder_dim, joiner_dim, initial_scale=0.25
)
self.decoder_proj = ScaledLinear(
decoder_dim, joiner_dim, initial_scale=0.25
)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
def forward(
@ -58,7 +62,9 @@ class Joiner(nn.Module):
)
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
decoder_out
)
else:
logit = encoder_out + decoder_out

View File

@ -90,7 +90,9 @@ class BatchedOptimizer(Optimizer):
sorted_idx = sorted(
range(len(batches_names)), key=lambda i: batches_names_keys[i]
)
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches_names = [
batches_names[batches_names_keys[idx]] for idx in sorted_idx
]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict()
@ -108,7 +110,10 @@ class BatchedOptimizer(Optimizer):
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack(
[torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
[
torch.zeros_like(p) if p.grad is None else p.grad
for p in batch
]
)
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
@ -325,8 +330,12 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
for group, group_params_names in zip(
self.param_groups, self.parameters_names
):
with self.batched_params(
group["params"], group_params_names
) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
@ -378,7 +387,9 @@ class ScaledAdam(BatchedOptimizer):
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
batch_size = p.shape[0]
numel = p.numel() // batch_size
@ -387,7 +398,9 @@ class ScaledAdam(BatchedOptimizer):
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
param_rms = (
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
)
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
@ -396,7 +409,9 @@ class ScaledAdam(BatchedOptimizer):
)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
def _get_clipping_scale(
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
@ -432,7 +447,9 @@ class ScaledAdam(BatchedOptimizer):
"ScaledAdam optimizer does not support sparse gradients"
)
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
tot_sumsq += (
grad**2
).sum() # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
@ -451,7 +468,8 @@ class ScaledAdam(BatchedOptimizer):
quartiles = []
for n in range(0, 5):
index = min(
clipping_update_period - 1, (clipping_update_period // 4) * n
clipping_update_period - 1,
(clipping_update_period // 4) * n,
)
quartiles.append(sorted_norms[index].item())
@ -536,7 +554,9 @@ class ScaledAdam(BatchedOptimizer):
sorted_by_proportion = {
k: v
for k, v in sorted(
all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
all_sumsq_orig.items(),
key=lambda item: item[1][0],
reverse=True,
)
}
dominant_param_name = next(iter(sorted_by_proportion))
@ -589,7 +609,9 @@ class ScaledAdam(BatchedOptimizer):
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_(
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
(p**2)
.mean(dim=list(range(1, p.ndim)), keepdim=True)
.sqrt()
)
if step > 0:
# self._size_update() learns the overall scale on the
@ -636,9 +658,13 @@ class ScaledAdam(BatchedOptimizer):
# faster decay at this level.
beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq = state[
"scale_exp_avg_sq"
] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
(scale_grads**2).mean(
dim=0
), # mean over dim `size_update_period`
alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
@ -651,7 +677,10 @@ class ScaledAdam(BatchedOptimizer):
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = (
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
-size_lr
* (bias_correction2**0.5)
* scale_grads.sum(dim=0)
/ denom
)
is_too_small = param_rms < param_min_rms
@ -663,7 +692,9 @@ class ScaledAdam(BatchedOptimizer):
# We have to look at the trained model for parameters at or around the
# param_max_rms, because sometimes they can indicate a problem with the
# topology or settings.
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
scale_step = torch.minimum(
scale_step, (param_max_rms - param_rms) / param_rms
)
delta = state["delta"]
# the factor of (1-beta1) relates to momentum.
@ -692,7 +723,9 @@ class ScaledAdam(BatchedOptimizer):
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
this_step = state["step"] - (
state["zero_step"] if "zero_step" in state else 0
)
bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
@ -742,7 +775,9 @@ class LRScheduler(object):
def __init__(self, optimizer: Optimizer, verbose: bool = False):
# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
raise TypeError(
"{} is not an Optimizer".format(type(optimizer).__name__)
)
self.optimizer = optimizer
self.verbose = verbose
@ -869,7 +904,8 @@ class Eden(LRScheduler):
factor = (
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
) ** -0.25 * (
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2)
** -0.25
)
warmup_factor = (
1.0
@ -958,11 +994,17 @@ class Eve(Optimizer):
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
raise ValueError(
"Invalid beta parameter at index 0: {}".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
raise ValueError(
"Invalid beta parameter at index 1: {}".format(betas[1])
)
if not 0 <= weight_decay <= 0.1:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0 < target_rms <= 10.0:
raise ValueError("Invalid target_rms value: {}".format(target_rms))
defaults = dict(
@ -998,7 +1040,9 @@ class Eve(Optimizer):
# Perform optimization step
grad = p.grad
if grad.is_sparse:
raise RuntimeError("AdamW does not support sparse gradients")
raise RuntimeError(
"AdamW does not support sparse gradients"
)
state = self.state[p]
@ -1036,7 +1080,9 @@ class Eve(Optimizer):
if p.numel() > 1:
# avoid applying this weight-decay on "scaling factors"
# (which are scalar).
is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
is_above_target_rms = p.norm() > (
target_rms * (p.numel() ** 0.5)
)
p.mul_(1 - (weight_decay * is_above_target_rms))
p.addcdiv_(exp_avg, denom, value=-step_size)
@ -1087,7 +1133,8 @@ def _test_scaled_adam(hidden_dim: int):
100.0
* torch.randn(B, T, E, device=device, dtype=dtype)
* input_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
torch.randn(B, T, E, device=device, dtype=dtype)
* output_magnitudes,
)
for _ in range(20)
]

View File

@ -314,7 +314,9 @@ def main():
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
# model forward
@ -323,7 +325,9 @@ def main():
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out, encoder_out_lens = model.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
hyps = []
@ -374,7 +378,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -100,13 +100,17 @@ class Model(nn.Module):
self.encoder_embed = encoder_embed
self.encoder_proj = encoder_proj
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
def forward(
self, feature: Tensor, feature_lens: Tensor
) -> Tuple[Tensor, Tensor]:
x, x_lens = self.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out, encoder_out_lens = self.encoder(
x, x_lens, src_key_padding_mask
)
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
logits = self.encoder_proj(encoder_out)
@ -164,7 +168,9 @@ def main():
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -64,7 +64,9 @@ class PiecewiseLinear(object):
for i in range(1, len(self.pairs)):
next_x, next_y = self.pairs[i]
if x >= cur_x and x <= next_x:
return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
return cur_y + (next_y - cur_y) * (x - cur_x) / (
next_x - cur_x
)
cur_x, cur_y = next_x, next_y
assert False
@ -98,7 +100,9 @@ class PiecewiseLinear(object):
def __eq__(self, other):
return self.pairs == other.pairs
def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
def get_common_basis(
self, p: "PiecewiseLinear", include_crossings: bool = False
):
"""
Returns (self_mod, p_mod) which are equivalent piecewise lienar
functions to self and p, but with the same x values.
@ -110,14 +114,18 @@ class PiecewiseLinear(object):
assert isinstance(p, PiecewiseLinear), type(p)
# get sorted x-values without repetition.
x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
x_vals = sorted(
set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])
)
y_vals1 = [self(x) for x in x_vals]
y_vals2 = [p(x) for x in x_vals]
if include_crossings:
extra_x_vals = []
for i in range(len(x_vals) - 1):
if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
if (y_vals1[i] > y_vals2[i]) != (
y_vals1[i + 1] > y_vals2[i + 1]
):
# if the two lines in this subsegment potentially cross each other..
diff_cur = abs(y_vals1[i] - y_vals2[i])
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
@ -163,9 +171,7 @@ class ScheduledFloat(torch.nn.Module):
self.schedule = PiecewiseLinear(*args)
def extra_repr(self) -> str:
return (
f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
)
return f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
def __float__(self):
batch_count = self.batch_count
@ -192,7 +198,8 @@ class ScheduledFloat(torch.nn.Module):
return ScheduledFloat(self.schedule.max(x), default=self.default)
else:
return ScheduledFloat(
self.schedule.max(x.schedule), default=max(self.default, x.default)
self.schedule.max(x.schedule),
default=max(self.default, x.default),
)
@ -374,7 +381,8 @@ class BiasNormFunction(torch.autograd.Function):
with torch.enable_grad():
# recompute scales from x, bias and log_scale.
scales = (
torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True)
** -0.5
) * log_scale.exp()
ans = x * scales
ans.backward(gradient=ans_grad)
@ -443,7 +451,8 @@ class BiasNorm(torch.nn.Module):
for _ in range(channel_dim + 1, x.ndim):
bias = bias.unsqueeze(-1)
scales = (
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True)
** -0.5
) * self.log_scale.exp()
return x * scales
@ -455,7 +464,11 @@ class BiasNorm(torch.nn.Module):
)
return BiasNormFunction.apply(
x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
x,
self.bias,
log_scale,
self.channel_dim,
self.store_output_for_backprop,
)
@ -478,7 +491,9 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
with torch.no_grad():
ans.weight[:] *= initial_scale
if ans.bias is not None:
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
torch.nn.init.uniform_(
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
)
return ans
@ -501,7 +516,9 @@ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
with torch.no_grad():
ans.weight[:] *= initial_scale
if ans.bias is not None:
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
torch.nn.init.uniform_(
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
)
return ans
@ -525,7 +542,9 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
with torch.no_grad():
ans.weight[:] *= initial_scale
if ans.bias is not None:
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
torch.nn.init.uniform_(
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
)
return ans
@ -587,7 +606,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
# first row is correction factors added to the scale near the left edge of the chunk,
# second row is correction factors added to the scale near the right edge of the chunk,
# both of these are added to a default scale of 1.0.
self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
self.chunkwise_conv_scale = nn.Parameter(
torch.zeros(2, channels, kernel_size)
)
self.kernel_size = kernel_size
with torch.no_grad():
@ -595,7 +616,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
self.chunkwise_conv.weight[:] *= initial_scale
if bias:
torch.nn.init.uniform_(
self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
self.causal_conv.bias,
-0.1 * initial_scale,
0.1 * initial_scale,
)
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
@ -623,7 +646,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
x_chunk = x[..., left_pad:]
num_chunks = x_chunk.shape[2] // chunk_size
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
x_chunk = x_chunk.reshape(
batch_size, num_channels, num_chunks, chunk_size
)
x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
batch_size * num_chunks, num_channels, chunk_size
)
@ -635,9 +660,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
x_chunk = x_chunk.reshape(
batch_size, num_chunks, num_channels, chunk_size
).permute(0, 2, 1, 3)
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
..., :seq_len
]
x_chunk = x_chunk.reshape(
batch_size, num_channels, num_chunks * chunk_size
)[..., :seq_len]
return x_chunk + x_causal
@ -711,13 +736,29 @@ class BalancerFunction(torch.autograd.Function):
channel_dim += x.ndim
ctx.channel_dim = channel_dim
ctx.save_for_backward(x)
ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
ctx.config = (
min_mean,
max_mean,
min_rms,
max_rms,
grad_scale,
channel_dim,
)
return x
@staticmethod
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
def backward(
ctx, x_grad: Tensor
) -> Tuple[Tensor, None, None, None, None, None]:
(x,) = ctx.saved_tensors
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
(
min_mean,
max_mean,
min_rms,
max_rms,
grad_scale,
channel_dim,
) = ctx.config
try:
with torch.enable_grad():
@ -728,7 +769,11 @@ class BalancerFunction(torch.autograd.Function):
mean_dims = [i for i in range(x.ndim) if i != channel_dim]
uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
mean = x.mean(dim=mean_dims, keepdim=True)
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
stddev = (
(uncentered_var - (mean * mean))
.clamp(min=1.0e-20)
.sqrt()
)
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
m = mean / stddev
@ -877,7 +922,13 @@ class Balancer(torch.nn.Module):
assert x.shape[self.channel_dim] == self.num_channels
return BalancerFunction.apply(
x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
x,
min_mean,
max_mean,
min_rms,
max_rms,
grad_scale,
self.channel_dim,
)
else:
return _no_op(x)
@ -956,7 +1007,9 @@ def _whitening_metric(x: Tensor, num_groups: int):
# the following expression is what we'd get if we took the matrix product
# of each covariance and measured the mean of its trace, i.e.
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
x_covarsq_mean_diag = (x_covar**2).sum() / (
num_groups * channels_per_group
)
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
return metric
@ -1067,7 +1120,11 @@ class Whiten(nn.Module):
and nothing will happen in backprop.
"""
grad_scale = float(self.grad_scale)
if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
if (
not x.requires_grad
or random.random() > self.prob
or grad_scale == 0
):
return _no_op(x)
else:
return WhiteningPenaltyFunction.apply(x, self)
@ -1086,7 +1143,9 @@ class WithLoss(torch.autograd.Function):
def backward(ctx, ans_grad: Tensor):
return (
ans_grad,
torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
torch.ones(
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
),
None,
)
@ -1141,7 +1200,9 @@ class LimitParamValue(torch.autograd.Function):
)
# where x > ctx.max, ensure all grads are positive (this will tend to make
# x more negative).
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
x_grad *= torch.where(
torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0
)
return x_grad, None, None
@ -1213,9 +1274,9 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving.
floor = -0.044
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
deriv
)
d_scaled = (deriv - floor) * (
255.0 / (ceil - floor)
) + torch.rand_like(deriv)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@ -1257,7 +1318,9 @@ class Dropout2(nn.Module):
self.p = p
def forward(self, x: Tensor) -> Tensor:
return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
return torch.nn.functional.dropout(
x, p=float(self.p), training=self.training
)
class MulForDropout3(torch.autograd.Function):
@ -1330,9 +1393,9 @@ class SwooshLFunction(torch.autograd.Function):
floor = coeff
ceil = 1.0 + coeff + 0.005
d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
grad
)
d_scaled = (grad - floor) * (
255.0 / (ceil - floor)
) + torch.rand_like(grad)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@ -1399,9 +1462,9 @@ class SwooshRFunction(torch.autograd.Function):
floor = -0.08
ceil = 0.925
d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
grad
)
d_scaled = (grad - floor) * (
255.0 / (ceil - floor)
) + torch.rand_like(grad)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@ -1472,7 +1535,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
dropout_shape[dropout_shared_dim] = 1
# else it won't be very memory efficient.
dropout_mask = (1.0 / (1.0 - dropout_p)) * (
torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
torch.rand(*dropout_shape, device=x.device, dtype=x.dtype)
> dropout_p
)
else:
dropout_mask = None
@ -1641,7 +1705,9 @@ def _test_whiten():
def _test_balancer_sign():
probs = torch.arange(0, 1, 0.01)
N = 1000
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
x = 1.0 * (
(2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
)
x = x.detach()
x.requires_grad = True
m = Balancer(
@ -1665,7 +1731,9 @@ def _test_balancer_sign():
def _test_balancer_magnitude():
magnitudes = torch.arange(0, 1, 0.01)
N = 1000
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
-1
)
x = x.detach()
x.requires_grad = True
m = Balancer(
@ -1825,9 +1893,13 @@ def _test_activation_dropout_and_linear():
print("y1 = ", y1)
print("y2 = ", y2)
assert torch.allclose(y1, y2, atol=0.02)
assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
assert torch.allclose(
m1[2].weight.grad, m2.weight.grad, atol=1.0e-05
)
if bias:
assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
assert torch.allclose(
m1[2].bias.grad, m2.bias.grad, atol=1.0e-05
)
print("x1.grad = ", x1.grad)
print("x2.grad = ", x2.grad)

View File

@ -153,7 +153,9 @@ def modified_beam_search(
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, encoder_out_dim)
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
logits = model.joiner(
current_encoder_out, decoder_out, project_input=False
)
# logits is of shape (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1)
@ -170,10 +172,14 @@ def modified_beam_search(
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
num_active_paths
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")

View File

@ -282,7 +282,9 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
processed_lens = torch.cat(
[state_list[i][-1] for i in range(batch_size)], dim=0
)
batch_states.append(processed_lens)
return batch_states
@ -320,7 +322,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
cached_key_list = batch_states[layer_offset].chunk(
chunks=batch_size, dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
@ -351,7 +355,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
cached_embed_left_pad_list = batch_states[-2].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
@ -398,7 +404,9 @@ def streaming_forward(
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
@ -486,7 +494,9 @@ def decode_one_chunk(
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
greedy_search(
model=model, encoder_out=encoder_out, streams=decode_streams
)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
@ -507,7 +517,9 @@ def decode_one_chunk(
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = unstack_states(new_states)
@ -565,7 +577,9 @@ def decode_dataset(
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(model=model, batch_size=1, device=device)
initial_states = get_init_states(
model=model, batch_size=1, device=device
)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
@ -635,7 +649,9 @@ def decode_dataset(
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results}
@ -668,7 +684,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -701,7 +718,9 @@ def main():
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
@ -741,9 +760,9 @@ def main():
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
@ -770,9 +789,9 @@ def main():
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"

View File

@ -107,7 +107,9 @@ class ConvNeXt(nn.Module):
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
torch.rand(
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
)
> layerdrop_rate
)
else:
@ -273,7 +275,9 @@ class Conv2dSubsampling(nn.Module):
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
whitening_limit=ScheduledFloat(
(0.0, 4.0), (20000.0, 8.0), default=4.0
),
prob=(0.025, 0.25),
grad_scale=0.02,
)
@ -396,8 +400,8 @@ class Conv2dSubsampling(nn.Module):
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
device
)
cached_embed_left_pad = torch.zeros(
batch_size, channels, left_pad, freq
).to(device)
return cached_embed_left_pad

View File

@ -95,7 +95,9 @@ from icefall.utils import (
str2bool,
)
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def get_adjusted_batch_count(params: AttributeDict) -> float:
@ -342,7 +344,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -365,7 +368,8 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" "part.",
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
@ -738,7 +742,11 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -785,7 +793,9 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -967,7 +977,9 @@ def train_one_epoch(
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
@ -989,7 +1001,11 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ (
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
)
if tb_writer is not None:
@ -1000,13 +1016,20 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
"train/grad_scale",
cur_grad_scale,
params.batch_idx_train,
)
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
if (
batch_idx % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -1096,7 +1119,9 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True
),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)

View File

@ -125,7 +125,9 @@ class Zipformer2(EncoderInterface):
if len(x) == 1:
x = x * len(downsampling_factor)
else:
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
assert len(x) == len(downsampling_factor) and isinstance(
x[0], int
)
return x
self.output_downsampling_factor = output_downsampling_factor # int
@ -140,7 +142,9 @@ class Zipformer2(EncoderInterface):
pos_head_dim = _to_tuple(pos_head_dim)
self.num_heads = num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(
cnn_module_kernel
)
self.causal = causal
self.chunk_size = chunk_size
@ -192,7 +196,9 @@ class Zipformer2(EncoderInterface):
self.encoders = nn.ModuleList(encoders)
self.downsample_output = SimpleDownsample(
max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
max(encoder_dim),
downsample=output_downsampling_factor,
dropout=dropout,
)
def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
@ -223,7 +229,8 @@ class Zipformer2(EncoderInterface):
# mask1 shape: (1, batch_size, 1)
mask1 = (
torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
torch.rand(1, batch_size, 1, device=x.device)
> feature_mask_dropout_prob
).to(x.dtype)
# mask2 has additional sequences masked, about twice the number.
@ -271,7 +278,9 @@ class Zipformer2(EncoderInterface):
left_context_chunks = -1
else:
if torch.jit.is_scripting():
assert len(self.left_context_frames) == 1, self.left_context_frames
assert (
len(self.left_context_frames) == 1
), self.left_context_frames
left_context_frames = self.left_context_frames[0]
else:
left_context_frames = random.choice(self.left_context_frames)
@ -370,7 +379,8 @@ class Zipformer2(EncoderInterface):
num_encoders = len(self.encoder_dim)
assert all(
chunk_size * left_context_chunks
>= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
>= (self.cnn_module_kernel[i] // 2)
* self.downsampling_factor[i]
for i in range(num_encoders)
)
else:
@ -390,7 +400,9 @@ class Zipformer2(EncoderInterface):
src_c = c
tgt_c = c.unsqueeze(-1)
attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
attn_mask = torch.logical_or(
src_c > tgt_c, src_c < tgt_c - left_context_chunks
)
if __name__ == "__main__":
logging.info(f"attn_mask = {attn_mask}")
return attn_mask
@ -448,7 +460,9 @@ class Zipformer2(EncoderInterface):
x, new_layer_states = module.streaming_forward(
x,
states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
states=states[
layer_offset * 6 : (layer_offset + num_layers) * 6
],
left_context_len=self.left_context_frames[0] // ds,
src_key_padding_mask=src_key_padding_mask[..., ::ds],
)
@ -496,24 +510,24 @@ class Zipformer2(EncoderInterface):
nonlin_attn_head_dim = 3 * embed_dim // 4
conv_left_pad = self.cnn_module_kernel[i] // 2
for layer in range(num_layers):
cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
device
)
cached_key = torch.zeros(
downsample_left, batch_size, key_dim
).to(device)
cached_nonlin_attn = torch.zeros(
1, batch_size, downsample_left, nonlin_attn_head_dim
).to(device)
cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
device
)
cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
device
)
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
device
)
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
device
)
cached_val1 = torch.zeros(
downsample_left, batch_size, value_dim
).to(device)
cached_val2 = torch.zeros(
downsample_left, batch_size, value_dim
).to(device)
cached_conv1 = torch.zeros(
batch_size, embed_dim, conv_left_pad
).to(device)
cached_conv2 = torch.zeros(
batch_size, embed_dim, conv_left_pad
).to(device)
states += [
cached_key,
cached_nonlin_attn,
@ -621,7 +635,9 @@ class Zipformer2EncoderLayer(nn.Module):
embed_dim, (feedforward_dim * 3) // 4, dropout
)
self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
self.feed_forward2 = FeedforwardModule(
embed_dim, feedforward_dim, dropout
)
self.feed_forward3 = FeedforwardModule(
embed_dim, (feedforward_dim * 5) // 4, dropout
@ -708,7 +724,9 @@ class Zipformer2EncoderLayer(nn.Module):
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
return None
batch_size = x.shape[1]
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(
x.dtype
)
return mask
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
@ -774,7 +792,9 @@ class Zipformer2EncoderLayer(nn.Module):
selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting():
pass
elif not self.training and random.random() < float(self.const_attention_rate):
elif not self.training and random.random() < float(
self.const_attention_rate
):
# Make attention weights constant. The intention is to
# encourage these modules to do something similar to an
# averaging-over-time operation.
@ -790,7 +810,9 @@ class Zipformer2EncoderLayer(nn.Module):
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
src = src + (
na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
na
if self_attn_dropout_mask is None
else na * self_attn_dropout_mask
)
self_attn = self.self_attn1(src, attn_weights)
@ -804,10 +826,14 @@ class Zipformer2EncoderLayer(nn.Module):
if torch.jit.is_scripting():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
conv_skip_rate = (
float(self.conv_skip_rate) if self.training else 0.0
)
src = src + self.sequence_dropout(
self.conv_module1(
src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
src,
chunk_size=chunk_size,
src_key_padding_mask=src_key_padding_mask,
),
conv_skip_rate,
)
@ -834,10 +860,14 @@ class Zipformer2EncoderLayer(nn.Module):
if torch.jit.is_scripting():
conv_skip_rate = 0.0
else:
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
conv_skip_rate = (
float(self.conv_skip_rate) if self.training else 0.0
)
src = src + self.sequence_dropout(
self.conv_module2(
src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
src,
chunk_size=chunk_size,
src_key_padding_mask=src_key_padding_mask,
),
conv_skip_rate,
)
@ -1150,7 +1180,9 @@ class BypassModule(nn.Module):
embed_dim: int,
skip_rate: FloatLike = 0.0,
straight_through_rate: FloatLike = 0.0,
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
scale_min: FloatLike = ScheduledFloat(
(0.0, 0.9), (20000.0, 0.2), default=0
),
scale_max: FloatLike = 1.0,
):
super().__init__()
@ -1169,11 +1201,15 @@ class BypassModule(nn.Module):
return self.bypass_scale
else:
ans = limit_param_value(
self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
self.bypass_scale,
min=float(self.scale_min),
max=float(self.scale_max),
)
skip_rate = float(self.skip_rate)
if skip_rate != 0.0:
mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
mask = (
torch.rand((batch_size, 1), device=ans.device) > skip_rate
)
ans = ans * mask
# now ans is of shape (batch_size, num_channels), and is zero for sequences
# on which we have randomly chosen to do layer-skipping.
@ -1320,7 +1356,9 @@ class SimpleDownsample(torch.nn.Module):
if seq_len != d_seq_len * ds:
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
src_extra = src[src.shape[0] - 1 :].expand(
pad, src.shape[1], src.shape[2]
)
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds
@ -1354,7 +1392,9 @@ class SimpleUpsample(torch.nn.Module):
"""
upsample = self.upsample
(seq_len, batch_size, num_channels) = src.shape
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
src = src.unsqueeze(1).expand(
seq_len, upsample, batch_size, num_channels
)
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src
@ -1411,12 +1451,18 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# the length of self.pe is 2 * input_len - 1
if self.pe.size(0) >= T * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
x = (
torch.arange(-(T - 1), T, device=x.device)
.to(torch.float32)
.unsqueeze(1)
)
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
@ -1430,7 +1476,10 @@ class CompactRelPositionalEncoding(torch.nn.Module):
x_compressed = (
compression_length
* x.sign()
* ((x.abs() + compression_length).log() - math.log(compression_length))
* (
(x.abs() + compression_length).log()
- math.log(compression_length)
)
)
# if self.length_factor == 1.0, then length_scale is chosen so that the
@ -1442,7 +1491,9 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# note for machine implementations: if atan is not available, we can use:
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
x_atan = (
x_compressed / length_scale
).atan() # results between -pi and pi
cosines = (x_atan * freqs).cos()
sines = (x_atan * freqs).sin()
@ -1507,7 +1558,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_head_dim: int,
pos_head_dim: int,
dropout: float = 0.0,
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
pos_emb_skip_rate: FloatLike = ScheduledFloat(
(0.0, 0.5), (4000.0, 0.0)
),
) -> None:
super().__init__()
self.embed_dim = embed_dim
@ -1516,7 +1569,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.pos_head_dim = pos_head_dim
self.dropout = dropout
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.name = None # will be overwritten in training code; for diagnostics.
self.name = (
None # will be overwritten in training code; for diagnostics.
)
key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
@ -1527,7 +1582,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
# to be used with the ScaledAdam optimizer; with most other optimizers,
# it would be necessary to apply the scaling factor in the forward function.
self.in_proj = ScaledLinear(
embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
embed_dim,
in_proj_dim,
bias=True,
initial_scale=query_head_dim**-0.25,
)
self.whiten_keys = Whiten(
@ -1601,7 +1659,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
assert p.shape[-1] == num_heads * pos_head_dim
q = self.copy_query(q) # for diagnostics only, does nothing.
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
k = self.whiten_keys(
self.balance_keys(k)
) # does nothing in the forward pass.
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
@ -1619,15 +1679,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
if torch.jit.is_scripting():
# We can't put random.random() in the same line
use_pos_scores = True
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
elif not self.training or random.random() >= float(
self.pos_emb_skip_rate
):
use_pos_scores = True
if use_pos_scores:
pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
2, 0, 3, 1
)
pos_emb = pos_emb.reshape(
-1, seq_len2, num_heads, pos_head_dim
).permute(2, 0, 3, 1)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
@ -1769,9 +1831,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1 + left_context_len
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
2, 0, 3, 1
)
pos_emb = pos_emb.reshape(
-1, seq_len2, num_heads, pos_head_dim
).permute(2, 0, 3, 1)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
@ -1801,7 +1863,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
), attn_scores.shape
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
assert key_padding_mask.shape == (
batch_size,
k_len,
), key_padding_mask.shape
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1),
-1000,
@ -1846,7 +1911,9 @@ class SelfAttention(nn.Module):
value_head_dim: int,
) -> None:
super().__init__()
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
self.in_proj = nn.Linear(
embed_dim, num_heads * value_head_dim, bias=True
)
self.out_proj = ScaledLinear(
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
@ -1958,7 +2025,9 @@ class SelfAttention(nn.Module):
class FeedforwardModule(nn.Module):
"""Feedforward module in Zipformer2 model."""
def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
def __init__(
self, embed_dim: int, feedforward_dim: int, dropout: FloatLike
):
super(FeedforwardModule, self).__init__()
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
@ -2226,7 +2295,9 @@ class ConvolutionModule(nn.Module):
assert kernel_size % 2 == 1
self.depthwise_conv = (
ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
ChunkCausalDepthwiseConv1d(
channels=bottleneck_dim, kernel_size=kernel_size
)
if causal
else nn.Conv1d(
in_channels=bottleneck_dim,
@ -2294,7 +2365,9 @@ class ConvolutionModule(nn.Module):
x = x.permute(1, 2, 0) # (#batch, channels, time).
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = x.masked_fill(
src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0
)
if not torch.jit.is_scripting() and chunk_size >= 0:
# Not support exporting a model for simulated streaming decoding
@ -2344,7 +2417,9 @@ class ConvolutionModule(nn.Module):
x = x.permute(1, 2, 0) # (#batch, channels, time).
if src_key_padding_mask is not None:
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
x = x.masked_fill(
src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0
)
x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)