mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix typo and reformatted zipformer
This commit is contained in:
parent
208839fb9b
commit
a98f6b27a4
@ -273,7 +273,8 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -370,7 +371,9 @@ def decode_one_batch(
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
hyps = []
|
||||
@ -430,7 +433,10 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
@ -561,7 +567,9 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@ -594,7 +602,8 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -655,7 +664,9 @@ def main():
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -687,9 +698,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -716,9 +727,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -777,7 +788,9 @@ def main():
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
decoding_graph = k2.trivial_graph(
|
||||
params.vocab_size - 1, device=device
|
||||
)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
@ -90,11 +90,13 @@ class DecodeStream(object):
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
# The rnnt_decoding_stream for fast_beam_search.
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream(
|
||||
decoding_graph
|
||||
self.rnnt_decoding_stream: k2.RnntDecodingStream = (
|
||||
k2.RnntDecodingStream(decoding_graph)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
@property
|
||||
def done(self) -> bool:
|
||||
@ -124,10 +126,13 @@ class DecodeStream(object):
|
||||
"""Consume chunk_size frames of features"""
|
||||
chunk_length = chunk_size + self.pad_length
|
||||
|
||||
ret_length = min(self.num_frames - self.num_processed_frames, chunk_length)
|
||||
ret_length = min(
|
||||
self.num_frames - self.num_processed_frames, chunk_length
|
||||
)
|
||||
|
||||
ret_features = self.features[
|
||||
self.num_processed_frames : self.num_processed_frames + ret_length # noqa
|
||||
self.num_processed_frames : self.num_processed_frames
|
||||
+ ret_length # noqa
|
||||
]
|
||||
|
||||
self.num_processed_frames += chunk_size
|
||||
|
||||
@ -118,7 +118,9 @@ class Decoder(nn.Module):
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
|
||||
@ -276,7 +276,9 @@ class EncoderModel(nn.Module):
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
@ -288,7 +290,9 @@ class StreamingEncoderModel(nn.Module):
|
||||
def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None:
|
||||
super().__init__()
|
||||
assert len(encoder.chunk_size) == 1, encoder.chunk_size
|
||||
assert len(encoder.left_context_frames) == 1, encoder.left_context_frames
|
||||
assert (
|
||||
len(encoder.left_context_frames) == 1
|
||||
), encoder.left_context_frames
|
||||
self.chunk_size = encoder.chunk_size[0]
|
||||
self.left_context_len = encoder.left_context_frames[0]
|
||||
|
||||
@ -315,7 +319,11 @@ class StreamingEncoderModel(nn.Module):
|
||||
left_context_len = self.left_context_len
|
||||
|
||||
cached_embed_left_pad = states[-2]
|
||||
x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward(
|
||||
(
|
||||
x,
|
||||
x_lens,
|
||||
new_cached_embed_left_pad,
|
||||
) = self.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
@ -335,7 +343,9 @@ class StreamingEncoderModel(nn.Module):
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
src_key_padding_mask = torch.cat(
|
||||
[processed_mask, src_key_padding_mask], dim=1
|
||||
)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
encoder_states = states[:-2]
|
||||
@ -377,7 +387,9 @@ class StreamingEncoderModel(nn.Module):
|
||||
embed_states = self.encoder_embed.get_init_states(batch_size, device)
|
||||
states.append(embed_states)
|
||||
|
||||
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
processed_lens = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device=device
|
||||
)
|
||||
states.append(processed_lens)
|
||||
|
||||
return states
|
||||
@ -411,9 +423,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -438,9 +450,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -494,10 +506,14 @@ def main():
|
||||
|
||||
# Wrap encoder and encoder_embed as a module
|
||||
if params.causal:
|
||||
model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed)
|
||||
model.encoder = StreamingEncoderModel(
|
||||
model.encoder, model.encoder_embed
|
||||
)
|
||||
chunk_size = model.encoder.chunk_size
|
||||
left_context_len = model.encoder.left_context_len
|
||||
filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
|
||||
filename = (
|
||||
f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt"
|
||||
)
|
||||
else:
|
||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||
filename = "jit_script.pt"
|
||||
@ -516,7 +532,9 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
||||
@ -266,7 +266,9 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
||||
@ -126,7 +126,9 @@ def greedy_search(
|
||||
if decoder_out is None:
|
||||
assert hyp is None, hyp
|
||||
hyp = [blank_id] * context_size
|
||||
decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0)
|
||||
decoder_input = torch.tensor(
|
||||
hyp, dtype=torch.int32, device=device
|
||||
).unsqueeze(0)
|
||||
# decoder_input.shape (1,, 1 context_size)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
else:
|
||||
@ -146,7 +148,9 @@ def greedy_search(
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input, dtype=torch.int32, device=device
|
||||
).unsqueeze(0)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1)
|
||||
decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(
|
||||
1
|
||||
)
|
||||
|
||||
return hyp, decoder_out
|
||||
|
||||
@ -247,7 +251,12 @@ def main():
|
||||
num_processed_frames += chunk_length
|
||||
|
||||
hyp, decoder_out = greedy_search(
|
||||
decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device
|
||||
decoder,
|
||||
joiner,
|
||||
encoder_out.squeeze(0),
|
||||
decoder_out,
|
||||
hyp,
|
||||
device=device,
|
||||
)
|
||||
|
||||
context_size = 2
|
||||
@ -263,7 +272,9 @@ torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._set_graph_executor_optimize(False)
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
||||
@ -29,8 +29,12 @@ class Joiner(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||
self.encoder_proj = ScaledLinear(
|
||||
encoder_dim, joiner_dim, initial_scale=0.25
|
||||
)
|
||||
self.decoder_proj = ScaledLinear(
|
||||
decoder_dim, joiner_dim, initial_scale=0.25
|
||||
)
|
||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -58,7 +62,9 @@ class Joiner(nn.Module):
|
||||
)
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
decoder_out
|
||||
)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
|
||||
@ -90,7 +90,9 @@ class BatchedOptimizer(Optimizer):
|
||||
sorted_idx = sorted(
|
||||
range(len(batches_names)), key=lambda i: batches_names_keys[i]
|
||||
)
|
||||
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
batches_names = [
|
||||
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
||||
]
|
||||
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
|
||||
stacked_params_dict = dict()
|
||||
@ -108,7 +110,10 @@ class BatchedOptimizer(Optimizer):
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack(
|
||||
[torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
|
||||
[
|
||||
torch.zeros_like(p) if p.grad is None else p.grad
|
||||
for p in batch
|
||||
]
|
||||
)
|
||||
p_stacked.grad = grad
|
||||
stacked_params_dict[key] = p_stacked
|
||||
@ -325,8 +330,12 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
batch = True
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
for group, group_params_names in zip(
|
||||
self.param_groups, self.parameters_names
|
||||
):
|
||||
with self.batched_params(
|
||||
group["params"], group_params_names
|
||||
) as batches:
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
@ -378,7 +387,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
@ -387,7 +398,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
param_rms = (
|
||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
)
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
@ -396,7 +409,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
)
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format
|
||||
)
|
||||
|
||||
def _get_clipping_scale(
|
||||
self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||
@ -432,7 +447,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||
tot_sumsq += (
|
||||
grad**2
|
||||
).sum() # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||
|
||||
@ -451,7 +468,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
quartiles = []
|
||||
for n in range(0, 5):
|
||||
index = min(
|
||||
clipping_update_period - 1, (clipping_update_period // 4) * n
|
||||
clipping_update_period - 1,
|
||||
(clipping_update_period // 4) * n,
|
||||
)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
@ -536,7 +554,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
sorted_by_proportion = {
|
||||
k: v
|
||||
for k, v in sorted(
|
||||
all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
|
||||
all_sumsq_orig.items(),
|
||||
key=lambda item: item[1][0],
|
||||
reverse=True,
|
||||
)
|
||||
}
|
||||
dominant_param_name = next(iter(sorted_by_proportion))
|
||||
@ -589,7 +609,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||
param_rms.copy_(
|
||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
(p**2)
|
||||
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
||||
.sqrt()
|
||||
)
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
@ -636,9 +658,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2**size_update_period
|
||||
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq = state[
|
||||
"scale_exp_avg_sq"
|
||||
] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||
(scale_grads**2).mean(
|
||||
dim=0
|
||||
), # mean over dim `size_update_period`
|
||||
alpha=1 - beta2_corr,
|
||||
) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
@ -651,7 +677,10 @@ class ScaledAdam(BatchedOptimizer):
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = (
|
||||
-size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||
-size_lr
|
||||
* (bias_correction2**0.5)
|
||||
* scale_grads.sum(dim=0)
|
||||
/ denom
|
||||
)
|
||||
|
||||
is_too_small = param_rms < param_min_rms
|
||||
@ -663,7 +692,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# We have to look at the trained model for parameters at or around the
|
||||
# param_max_rms, because sometimes they can indicate a problem with the
|
||||
# topology or settings.
|
||||
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||
scale_step = torch.minimum(
|
||||
scale_step, (param_max_rms - param_rms) / param_rms
|
||||
)
|
||||
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
@ -692,7 +723,9 @@ class ScaledAdam(BatchedOptimizer):
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
||||
|
||||
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||
this_step = state["step"] - (
|
||||
state["zero_step"] if "zero_step" in state else 0
|
||||
)
|
||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
@ -742,7 +775,9 @@ class LRScheduler(object):
|
||||
def __init__(self, optimizer: Optimizer, verbose: bool = False):
|
||||
# Attach optimizer
|
||||
if not isinstance(optimizer, Optimizer):
|
||||
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
|
||||
raise TypeError(
|
||||
"{} is not an Optimizer".format(type(optimizer).__name__)
|
||||
)
|
||||
self.optimizer = optimizer
|
||||
self.verbose = verbose
|
||||
|
||||
@ -869,7 +904,8 @@ class Eden(LRScheduler):
|
||||
factor = (
|
||||
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
|
||||
) ** -0.25 * (
|
||||
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
|
||||
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2)
|
||||
** -0.25
|
||||
)
|
||||
warmup_factor = (
|
||||
1.0
|
||||
@ -958,11 +994,17 @@ class Eve(Optimizer):
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 0: {}".format(betas[0])
|
||||
)
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
raise ValueError(
|
||||
"Invalid beta parameter at index 1: {}".format(betas[1])
|
||||
)
|
||||
if not 0 <= weight_decay <= 0.1:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
)
|
||||
if not 0 < target_rms <= 10.0:
|
||||
raise ValueError("Invalid target_rms value: {}".format(target_rms))
|
||||
defaults = dict(
|
||||
@ -998,7 +1040,9 @@ class Eve(Optimizer):
|
||||
# Perform optimization step
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("AdamW does not support sparse gradients")
|
||||
raise RuntimeError(
|
||||
"AdamW does not support sparse gradients"
|
||||
)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
@ -1036,7 +1080,9 @@ class Eve(Optimizer):
|
||||
if p.numel() > 1:
|
||||
# avoid applying this weight-decay on "scaling factors"
|
||||
# (which are scalar).
|
||||
is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
|
||||
is_above_target_rms = p.norm() > (
|
||||
target_rms * (p.numel() ** 0.5)
|
||||
)
|
||||
p.mul_(1 - (weight_decay * is_above_target_rms))
|
||||
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
@ -1087,7 +1133,8 @@ def _test_scaled_adam(hidden_dim: int):
|
||||
100.0
|
||||
* torch.randn(B, T, E, device=device, dtype=dtype)
|
||||
* input_magnitudes,
|
||||
torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
|
||||
torch.randn(B, T, E, device=device, dtype=dtype)
|
||||
* output_magnitudes,
|
||||
)
|
||||
for _ in range(20)
|
||||
]
|
||||
|
||||
@ -314,7 +314,9 @@ def main():
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
# model forward
|
||||
@ -323,7 +325,9 @@ def main():
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
hyps = []
|
||||
@ -374,7 +378,9 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
||||
|
||||
@ -100,13 +100,17 @@ class Model(nn.Module):
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def forward(
|
||||
self, feature: Tensor, feature_lens: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
logits = self.encoder_proj(encoder_out)
|
||||
@ -164,7 +168,9 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
||||
|
||||
@ -64,7 +64,9 @@ class PiecewiseLinear(object):
|
||||
for i in range(1, len(self.pairs)):
|
||||
next_x, next_y = self.pairs[i]
|
||||
if x >= cur_x and x <= next_x:
|
||||
return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x)
|
||||
return cur_y + (next_y - cur_y) * (x - cur_x) / (
|
||||
next_x - cur_x
|
||||
)
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
@ -98,7 +100,9 @@ class PiecewiseLinear(object):
|
||||
def __eq__(self, other):
|
||||
return self.pairs == other.pairs
|
||||
|
||||
def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False):
|
||||
def get_common_basis(
|
||||
self, p: "PiecewiseLinear", include_crossings: bool = False
|
||||
):
|
||||
"""
|
||||
Returns (self_mod, p_mod) which are equivalent piecewise lienar
|
||||
functions to self and p, but with the same x values.
|
||||
@ -110,14 +114,18 @@ class PiecewiseLinear(object):
|
||||
assert isinstance(p, PiecewiseLinear), type(p)
|
||||
|
||||
# get sorted x-values without repetition.
|
||||
x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]))
|
||||
x_vals = sorted(
|
||||
set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])
|
||||
)
|
||||
y_vals1 = [self(x) for x in x_vals]
|
||||
y_vals2 = [p(x) for x in x_vals]
|
||||
|
||||
if include_crossings:
|
||||
extra_x_vals = []
|
||||
for i in range(len(x_vals) - 1):
|
||||
if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]):
|
||||
if (y_vals1[i] > y_vals2[i]) != (
|
||||
y_vals1[i + 1] > y_vals2[i + 1]
|
||||
):
|
||||
# if the two lines in this subsegment potentially cross each other..
|
||||
diff_cur = abs(y_vals1[i] - y_vals2[i])
|
||||
diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1])
|
||||
@ -163,9 +171,7 @@ class ScheduledFloat(torch.nn.Module):
|
||||
self.schedule = PiecewiseLinear(*args)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
|
||||
)
|
||||
return f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}"
|
||||
|
||||
def __float__(self):
|
||||
batch_count = self.batch_count
|
||||
@ -192,7 +198,8 @@ class ScheduledFloat(torch.nn.Module):
|
||||
return ScheduledFloat(self.schedule.max(x), default=self.default)
|
||||
else:
|
||||
return ScheduledFloat(
|
||||
self.schedule.max(x.schedule), default=max(self.default, x.default)
|
||||
self.schedule.max(x.schedule),
|
||||
default=max(self.default, x.default),
|
||||
)
|
||||
|
||||
|
||||
@ -374,7 +381,8 @@ class BiasNormFunction(torch.autograd.Function):
|
||||
with torch.enable_grad():
|
||||
# recompute scales from x, bias and log_scale.
|
||||
scales = (
|
||||
torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5
|
||||
torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True)
|
||||
** -0.5
|
||||
) * log_scale.exp()
|
||||
ans = x * scales
|
||||
ans.backward(gradient=ans_grad)
|
||||
@ -443,7 +451,8 @@ class BiasNorm(torch.nn.Module):
|
||||
for _ in range(channel_dim + 1, x.ndim):
|
||||
bias = bias.unsqueeze(-1)
|
||||
scales = (
|
||||
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5
|
||||
torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True)
|
||||
** -0.5
|
||||
) * self.log_scale.exp()
|
||||
return x * scales
|
||||
|
||||
@ -455,7 +464,11 @@ class BiasNorm(torch.nn.Module):
|
||||
)
|
||||
|
||||
return BiasNormFunction.apply(
|
||||
x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop
|
||||
x,
|
||||
self.bias,
|
||||
log_scale,
|
||||
self.channel_dim,
|
||||
self.store_output_for_backprop,
|
||||
)
|
||||
|
||||
|
||||
@ -478,7 +491,9 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
|
||||
with torch.no_grad():
|
||||
ans.weight[:] *= initial_scale
|
||||
if ans.bias is not None:
|
||||
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
||||
torch.nn.init.uniform_(
|
||||
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
@ -501,7 +516,9 @@ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
|
||||
with torch.no_grad():
|
||||
ans.weight[:] *= initial_scale
|
||||
if ans.bias is not None:
|
||||
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
||||
torch.nn.init.uniform_(
|
||||
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
@ -525,7 +542,9 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d:
|
||||
with torch.no_grad():
|
||||
ans.weight[:] *= initial_scale
|
||||
if ans.bias is not None:
|
||||
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
|
||||
torch.nn.init.uniform_(
|
||||
ans.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
||||
)
|
||||
return ans
|
||||
|
||||
|
||||
@ -587,7 +606,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
||||
# first row is correction factors added to the scale near the left edge of the chunk,
|
||||
# second row is correction factors added to the scale near the right edge of the chunk,
|
||||
# both of these are added to a default scale of 1.0.
|
||||
self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size))
|
||||
self.chunkwise_conv_scale = nn.Parameter(
|
||||
torch.zeros(2, channels, kernel_size)
|
||||
)
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
with torch.no_grad():
|
||||
@ -595,7 +616,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
||||
self.chunkwise_conv.weight[:] *= initial_scale
|
||||
if bias:
|
||||
torch.nn.init.uniform_(
|
||||
self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale
|
||||
self.causal_conv.bias,
|
||||
-0.1 * initial_scale,
|
||||
0.1 * initial_scale,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
|
||||
@ -623,7 +646,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
||||
|
||||
x_chunk = x[..., left_pad:]
|
||||
num_chunks = x_chunk.shape[2] // chunk_size
|
||||
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size)
|
||||
x_chunk = x_chunk.reshape(
|
||||
batch_size, num_channels, num_chunks, chunk_size
|
||||
)
|
||||
x_chunk = x_chunk.permute(0, 2, 1, 3).reshape(
|
||||
batch_size * num_chunks, num_channels, chunk_size
|
||||
)
|
||||
@ -635,9 +660,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
||||
x_chunk = x_chunk.reshape(
|
||||
batch_size, num_chunks, num_channels, chunk_size
|
||||
).permute(0, 2, 1, 3)
|
||||
x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[
|
||||
..., :seq_len
|
||||
]
|
||||
x_chunk = x_chunk.reshape(
|
||||
batch_size, num_channels, num_chunks * chunk_size
|
||||
)[..., :seq_len]
|
||||
|
||||
return x_chunk + x_causal
|
||||
|
||||
@ -711,13 +736,29 @@ class BalancerFunction(torch.autograd.Function):
|
||||
channel_dim += x.ndim
|
||||
ctx.channel_dim = channel_dim
|
||||
ctx.save_for_backward(x)
|
||||
ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim)
|
||||
ctx.config = (
|
||||
min_mean,
|
||||
max_mean,
|
||||
min_rms,
|
||||
max_rms,
|
||||
grad_scale,
|
||||
channel_dim,
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]:
|
||||
def backward(
|
||||
ctx, x_grad: Tensor
|
||||
) -> Tuple[Tensor, None, None, None, None, None]:
|
||||
(x,) = ctx.saved_tensors
|
||||
(min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config
|
||||
(
|
||||
min_mean,
|
||||
max_mean,
|
||||
min_rms,
|
||||
max_rms,
|
||||
grad_scale,
|
||||
channel_dim,
|
||||
) = ctx.config
|
||||
|
||||
try:
|
||||
with torch.enable_grad():
|
||||
@ -728,7 +769,11 @@ class BalancerFunction(torch.autograd.Function):
|
||||
mean_dims = [i for i in range(x.ndim) if i != channel_dim]
|
||||
uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True)
|
||||
mean = x.mean(dim=mean_dims, keepdim=True)
|
||||
stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt()
|
||||
stddev = (
|
||||
(uncentered_var - (mean * mean))
|
||||
.clamp(min=1.0e-20)
|
||||
.sqrt()
|
||||
)
|
||||
rms = uncentered_var.clamp(min=1.0e-20).sqrt()
|
||||
|
||||
m = mean / stddev
|
||||
@ -877,7 +922,13 @@ class Balancer(torch.nn.Module):
|
||||
assert x.shape[self.channel_dim] == self.num_channels
|
||||
|
||||
return BalancerFunction.apply(
|
||||
x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim
|
||||
x,
|
||||
min_mean,
|
||||
max_mean,
|
||||
min_rms,
|
||||
max_rms,
|
||||
grad_scale,
|
||||
self.channel_dim,
|
||||
)
|
||||
else:
|
||||
return _no_op(x)
|
||||
@ -956,7 +1007,9 @@ def _whitening_metric(x: Tensor, num_groups: int):
|
||||
# the following expression is what we'd get if we took the matrix product
|
||||
# of each covariance and measured the mean of its trace, i.e.
|
||||
# the same as _diag(torch.matmul(x_covar, x_covar)).mean().
|
||||
x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group)
|
||||
x_covarsq_mean_diag = (x_covar**2).sum() / (
|
||||
num_groups * channels_per_group
|
||||
)
|
||||
# this metric will be >= 1.0; the larger it is, the less 'white' the data was.
|
||||
metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20)
|
||||
return metric
|
||||
@ -1067,7 +1120,11 @@ class Whiten(nn.Module):
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
grad_scale = float(self.grad_scale)
|
||||
if not x.requires_grad or random.random() > self.prob or grad_scale == 0:
|
||||
if (
|
||||
not x.requires_grad
|
||||
or random.random() > self.prob
|
||||
or grad_scale == 0
|
||||
):
|
||||
return _no_op(x)
|
||||
else:
|
||||
return WhiteningPenaltyFunction.apply(x, self)
|
||||
@ -1086,7 +1143,9 @@ class WithLoss(torch.autograd.Function):
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
return (
|
||||
ans_grad,
|
||||
torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
|
||||
torch.ones(
|
||||
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
@ -1141,7 +1200,9 @@ class LimitParamValue(torch.autograd.Function):
|
||||
)
|
||||
# where x > ctx.max, ensure all grads are positive (this will tend to make
|
||||
# x more negative).
|
||||
x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0)
|
||||
x_grad *= torch.where(
|
||||
torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0
|
||||
)
|
||||
return x_grad, None, None
|
||||
|
||||
|
||||
@ -1213,9 +1274,9 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
# floors), should be expectation-preserving.
|
||||
floor = -0.044
|
||||
ceil = 1.2
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
||||
deriv
|
||||
)
|
||||
d_scaled = (deriv - floor) * (
|
||||
255.0 / (ceil - floor)
|
||||
) + torch.rand_like(deriv)
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
assert d_scaled.min() >= 0.0
|
||||
@ -1257,7 +1318,9 @@ class Dropout2(nn.Module):
|
||||
self.p = p
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return torch.nn.functional.dropout(x, p=float(self.p), training=self.training)
|
||||
return torch.nn.functional.dropout(
|
||||
x, p=float(self.p), training=self.training
|
||||
)
|
||||
|
||||
|
||||
class MulForDropout3(torch.autograd.Function):
|
||||
@ -1330,9 +1393,9 @@ class SwooshLFunction(torch.autograd.Function):
|
||||
floor = coeff
|
||||
ceil = 1.0 + coeff + 0.005
|
||||
|
||||
d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
||||
grad
|
||||
)
|
||||
d_scaled = (grad - floor) * (
|
||||
255.0 / (ceil - floor)
|
||||
) + torch.rand_like(grad)
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
assert d_scaled.min() >= 0.0
|
||||
@ -1399,9 +1462,9 @@ class SwooshRFunction(torch.autograd.Function):
|
||||
floor = -0.08
|
||||
ceil = 0.925
|
||||
|
||||
d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
||||
grad
|
||||
)
|
||||
d_scaled = (grad - floor) * (
|
||||
255.0 / (ceil - floor)
|
||||
) + torch.rand_like(grad)
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
assert d_scaled.min() >= 0.0
|
||||
@ -1472,7 +1535,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function):
|
||||
dropout_shape[dropout_shared_dim] = 1
|
||||
# else it won't be very memory efficient.
|
||||
dropout_mask = (1.0 / (1.0 - dropout_p)) * (
|
||||
torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p
|
||||
torch.rand(*dropout_shape, device=x.device, dtype=x.dtype)
|
||||
> dropout_p
|
||||
)
|
||||
else:
|
||||
dropout_mask = None
|
||||
@ -1641,7 +1705,9 @@ def _test_whiten():
|
||||
def _test_balancer_sign():
|
||||
probs = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0)
|
||||
x = 1.0 * (
|
||||
(2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0
|
||||
)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = Balancer(
|
||||
@ -1665,7 +1731,9 @@ def _test_balancer_sign():
|
||||
def _test_balancer_magnitude():
|
||||
magnitudes = torch.arange(0, 1, 0.01)
|
||||
N = 1000
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1)
|
||||
x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(
|
||||
-1
|
||||
)
|
||||
x = x.detach()
|
||||
x.requires_grad = True
|
||||
m = Balancer(
|
||||
@ -1825,9 +1893,13 @@ def _test_activation_dropout_and_linear():
|
||||
print("y1 = ", y1)
|
||||
print("y2 = ", y2)
|
||||
assert torch.allclose(y1, y2, atol=0.02)
|
||||
assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05)
|
||||
assert torch.allclose(
|
||||
m1[2].weight.grad, m2.weight.grad, atol=1.0e-05
|
||||
)
|
||||
if bias:
|
||||
assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05)
|
||||
assert torch.allclose(
|
||||
m1[2].bias.grad, m2.bias.grad, atol=1.0e-05
|
||||
)
|
||||
print("x1.grad = ", x1.grad)
|
||||
print("x2.grad = ", x2.grad)
|
||||
|
||||
|
||||
@ -153,7 +153,9 @@ def modified_beam_search(
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(current_encoder_out, decoder_out, project_input=False)
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out, project_input=False
|
||||
)
|
||||
# logits is of shape (num_hyps, 1, 1, vocab_size)
|
||||
|
||||
logits = logits.squeeze(1).squeeze(1)
|
||||
@ -170,10 +172,14 @@ def modified_beam_search(
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs
|
||||
)
|
||||
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths)
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(
|
||||
num_active_paths
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
@ -282,7 +282,9 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
|
||||
)
|
||||
batch_states.append(cached_embed_left_pad)
|
||||
|
||||
processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0)
|
||||
processed_lens = torch.cat(
|
||||
[state_list[i][-1] for i in range(batch_size)], dim=0
|
||||
)
|
||||
batch_states.append(processed_lens)
|
||||
|
||||
return batch_states
|
||||
@ -320,7 +322,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
||||
for layer in range(tot_num_layers):
|
||||
layer_offset = layer * 6
|
||||
# cached_key: (left_context_len, batch_size, key_dim)
|
||||
cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1)
|
||||
cached_key_list = batch_states[layer_offset].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
)
|
||||
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
|
||||
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
|
||||
chunks=batch_size, dim=1
|
||||
@ -351,7 +355,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
|
||||
cached_conv2_list[i],
|
||||
]
|
||||
|
||||
cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0)
|
||||
cached_embed_left_pad_list = batch_states[-2].chunk(
|
||||
chunks=batch_size, dim=0
|
||||
)
|
||||
for i in range(batch_size):
|
||||
state_list[i].append(cached_embed_left_pad_list[i])
|
||||
|
||||
@ -398,7 +404,9 @@ def streaming_forward(
|
||||
new_processed_lens = processed_lens + x_lens
|
||||
|
||||
# (batch, left_context_size + chunk_size)
|
||||
src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1)
|
||||
src_key_padding_mask = torch.cat(
|
||||
[processed_mask, src_key_padding_mask], dim=1
|
||||
)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
encoder_states = states[:-2]
|
||||
@ -486,7 +494,9 @@ def decode_one_chunk(
|
||||
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
|
||||
greedy_search(
|
||||
model=model, encoder_out=encoder_out, streams=decode_streams
|
||||
)
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
processed_lens = torch.tensor(processed_lens, device=device)
|
||||
processed_lens = processed_lens + encoder_out_lens
|
||||
@ -507,7 +517,9 @@ def decode_one_chunk(
|
||||
num_active_paths=params.num_active_paths,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
|
||||
states = unstack_states(new_states)
|
||||
|
||||
@ -565,7 +577,9 @@ def decode_dataset(
|
||||
decode_streams = []
|
||||
for num, cut in enumerate(cuts):
|
||||
# each utterance has a DecodeStream.
|
||||
initial_states = get_init_states(model=model, batch_size=1, device=device)
|
||||
initial_states = get_init_states(
|
||||
model=model, batch_size=1, device=device
|
||||
)
|
||||
decode_stream = DecodeStream(
|
||||
params=params,
|
||||
cut_id=cut.id,
|
||||
@ -635,7 +649,9 @@ def decode_dataset(
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
key = f"num_active_paths_{params.num_active_paths}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
return {key: decode_results}
|
||||
|
||||
|
||||
@ -668,7 +684,8 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -701,7 +718,9 @@ def main():
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
assert params.causal, params.causal
|
||||
assert "," not in params.chunk_size, "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.chunk_size
|
||||
), "chunk_size should be one value in decoding."
|
||||
assert (
|
||||
"," not in params.left_context_frames
|
||||
), "left_context_frames should be one value in decoding."
|
||||
@ -741,9 +760,9 @@ def main():
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
@ -770,9 +789,9 @@ def main():
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
filenames = find_checkpoints(
|
||||
params.exp_dir, iteration=-params.iter
|
||||
)[: params.avg + 1]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
|
||||
@ -107,7 +107,9 @@ class ConvNeXt(nn.Module):
|
||||
if layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = (
|
||||
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
torch.rand(
|
||||
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
|
||||
)
|
||||
> layerdrop_rate
|
||||
)
|
||||
else:
|
||||
@ -273,7 +275,9 @@ class Conv2dSubsampling(nn.Module):
|
||||
# many copies of this extra gradient term.
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||
whitening_limit=ScheduledFloat(
|
||||
(0.0, 4.0), (20000.0, 8.0), default=4.0
|
||||
),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.02,
|
||||
)
|
||||
@ -396,8 +400,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
left_pad = self.convnext.padding[0]
|
||||
freq = self.out_width
|
||||
channels = self.layer3_channels
|
||||
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
||||
device
|
||||
)
|
||||
cached_embed_left_pad = torch.zeros(
|
||||
batch_size, channels, left_pad, freq
|
||||
).to(device)
|
||||
|
||||
return cached_embed_left_pad
|
||||
|
||||
@ -95,7 +95,9 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
|
||||
|
||||
def get_adjusted_batch_count(params: AttributeDict) -> float:
|
||||
@ -342,7 +344,8 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -365,7 +368,8 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -738,7 +742,11 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -785,7 +793,9 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -967,7 +977,9 @@ def train_one_epoch(
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and batch_idx % 400 == 0
|
||||
):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
@ -989,7 +1001,11 @@ def train_one_epoch(
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
+ (
|
||||
f"grad_scale: {scaler._scale.item()}"
|
||||
if params.use_fp16
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -1000,13 +1016,20 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale", cur_grad_scale, params.batch_idx_train
|
||||
"train/grad_scale",
|
||||
cur_grad_scale,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
if (
|
||||
batch_idx % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -1096,7 +1119,9 @@ def run(rank, world_size, args):
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||
get_parameter_groups_with_lrs(
|
||||
model, lr=params.base_lr, include_names=True
|
||||
),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
|
||||
@ -125,7 +125,9 @@ class Zipformer2(EncoderInterface):
|
||||
if len(x) == 1:
|
||||
x = x * len(downsampling_factor)
|
||||
else:
|
||||
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
|
||||
assert len(x) == len(downsampling_factor) and isinstance(
|
||||
x[0], int
|
||||
)
|
||||
return x
|
||||
|
||||
self.output_downsampling_factor = output_downsampling_factor # int
|
||||
@ -140,7 +142,9 @@ class Zipformer2(EncoderInterface):
|
||||
pos_head_dim = _to_tuple(pos_head_dim)
|
||||
self.num_heads = num_heads = _to_tuple(num_heads)
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)
|
||||
self.cnn_module_kernel = cnn_module_kernel = _to_tuple(
|
||||
cnn_module_kernel
|
||||
)
|
||||
|
||||
self.causal = causal
|
||||
self.chunk_size = chunk_size
|
||||
@ -192,7 +196,9 @@ class Zipformer2(EncoderInterface):
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
self.downsample_output = SimpleDownsample(
|
||||
max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout
|
||||
max(encoder_dim),
|
||||
downsample=output_downsampling_factor,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]:
|
||||
@ -223,7 +229,8 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
# mask1 shape: (1, batch_size, 1)
|
||||
mask1 = (
|
||||
torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob
|
||||
torch.rand(1, batch_size, 1, device=x.device)
|
||||
> feature_mask_dropout_prob
|
||||
).to(x.dtype)
|
||||
|
||||
# mask2 has additional sequences masked, about twice the number.
|
||||
@ -271,7 +278,9 @@ class Zipformer2(EncoderInterface):
|
||||
left_context_chunks = -1
|
||||
else:
|
||||
if torch.jit.is_scripting():
|
||||
assert len(self.left_context_frames) == 1, self.left_context_frames
|
||||
assert (
|
||||
len(self.left_context_frames) == 1
|
||||
), self.left_context_frames
|
||||
left_context_frames = self.left_context_frames[0]
|
||||
else:
|
||||
left_context_frames = random.choice(self.left_context_frames)
|
||||
@ -370,7 +379,8 @@ class Zipformer2(EncoderInterface):
|
||||
num_encoders = len(self.encoder_dim)
|
||||
assert all(
|
||||
chunk_size * left_context_chunks
|
||||
>= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i]
|
||||
>= (self.cnn_module_kernel[i] // 2)
|
||||
* self.downsampling_factor[i]
|
||||
for i in range(num_encoders)
|
||||
)
|
||||
else:
|
||||
@ -390,7 +400,9 @@ class Zipformer2(EncoderInterface):
|
||||
src_c = c
|
||||
tgt_c = c.unsqueeze(-1)
|
||||
|
||||
attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks)
|
||||
attn_mask = torch.logical_or(
|
||||
src_c > tgt_c, src_c < tgt_c - left_context_chunks
|
||||
)
|
||||
if __name__ == "__main__":
|
||||
logging.info(f"attn_mask = {attn_mask}")
|
||||
return attn_mask
|
||||
@ -448,7 +460,9 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
x, new_layer_states = module.streaming_forward(
|
||||
x,
|
||||
states=states[layer_offset * 6 : (layer_offset + num_layers) * 6],
|
||||
states=states[
|
||||
layer_offset * 6 : (layer_offset + num_layers) * 6
|
||||
],
|
||||
left_context_len=self.left_context_frames[0] // ds,
|
||||
src_key_padding_mask=src_key_padding_mask[..., ::ds],
|
||||
)
|
||||
@ -496,24 +510,24 @@ class Zipformer2(EncoderInterface):
|
||||
nonlin_attn_head_dim = 3 * embed_dim // 4
|
||||
conv_left_pad = self.cnn_module_kernel[i] // 2
|
||||
for layer in range(num_layers):
|
||||
cached_key = torch.zeros(downsample_left, batch_size, key_dim).to(
|
||||
device
|
||||
)
|
||||
cached_key = torch.zeros(
|
||||
downsample_left, batch_size, key_dim
|
||||
).to(device)
|
||||
cached_nonlin_attn = torch.zeros(
|
||||
1, batch_size, downsample_left, nonlin_attn_head_dim
|
||||
).to(device)
|
||||
cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to(
|
||||
device
|
||||
)
|
||||
cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to(
|
||||
device
|
||||
)
|
||||
cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
|
||||
device
|
||||
)
|
||||
cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to(
|
||||
device
|
||||
)
|
||||
cached_val1 = torch.zeros(
|
||||
downsample_left, batch_size, value_dim
|
||||
).to(device)
|
||||
cached_val2 = torch.zeros(
|
||||
downsample_left, batch_size, value_dim
|
||||
).to(device)
|
||||
cached_conv1 = torch.zeros(
|
||||
batch_size, embed_dim, conv_left_pad
|
||||
).to(device)
|
||||
cached_conv2 = torch.zeros(
|
||||
batch_size, embed_dim, conv_left_pad
|
||||
).to(device)
|
||||
states += [
|
||||
cached_key,
|
||||
cached_nonlin_attn,
|
||||
@ -621,7 +635,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
embed_dim, (feedforward_dim * 3) // 4, dropout
|
||||
)
|
||||
|
||||
self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
|
||||
self.feed_forward2 = FeedforwardModule(
|
||||
embed_dim, feedforward_dim, dropout
|
||||
)
|
||||
|
||||
self.feed_forward3 = FeedforwardModule(
|
||||
embed_dim, (feedforward_dim * 5) // 4, dropout
|
||||
@ -708,7 +724,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||
return None
|
||||
batch_size = x.shape[1]
|
||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype)
|
||||
mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(
|
||||
x.dtype
|
||||
)
|
||||
return mask
|
||||
|
||||
def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor:
|
||||
@ -774,7 +792,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
selected_attn_weights = attn_weights[0:1]
|
||||
if torch.jit.is_scripting():
|
||||
pass
|
||||
elif not self.training and random.random() < float(self.const_attention_rate):
|
||||
elif not self.training and random.random() < float(
|
||||
self.const_attention_rate
|
||||
):
|
||||
# Make attention weights constant. The intention is to
|
||||
# encourage these modules to do something similar to an
|
||||
# averaging-over-time operation.
|
||||
@ -790,7 +810,9 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
|
||||
|
||||
src = src + (
|
||||
na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
|
||||
na
|
||||
if self_attn_dropout_mask is None
|
||||
else na * self_attn_dropout_mask
|
||||
)
|
||||
|
||||
self_attn = self.self_attn1(src, attn_weights)
|
||||
@ -804,10 +826,14 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
if torch.jit.is_scripting():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
conv_skip_rate = (
|
||||
float(self.conv_skip_rate) if self.training else 0.0
|
||||
)
|
||||
src = src + self.sequence_dropout(
|
||||
self.conv_module1(
|
||||
src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
|
||||
src,
|
||||
chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
),
|
||||
conv_skip_rate,
|
||||
)
|
||||
@ -834,10 +860,14 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
if torch.jit.is_scripting():
|
||||
conv_skip_rate = 0.0
|
||||
else:
|
||||
conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
|
||||
conv_skip_rate = (
|
||||
float(self.conv_skip_rate) if self.training else 0.0
|
||||
)
|
||||
src = src + self.sequence_dropout(
|
||||
self.conv_module2(
|
||||
src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
|
||||
src,
|
||||
chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
),
|
||||
conv_skip_rate,
|
||||
)
|
||||
@ -1150,7 +1180,9 @@ class BypassModule(nn.Module):
|
||||
embed_dim: int,
|
||||
skip_rate: FloatLike = 0.0,
|
||||
straight_through_rate: FloatLike = 0.0,
|
||||
scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0),
|
||||
scale_min: FloatLike = ScheduledFloat(
|
||||
(0.0, 0.9), (20000.0, 0.2), default=0
|
||||
),
|
||||
scale_max: FloatLike = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
@ -1169,11 +1201,15 @@ class BypassModule(nn.Module):
|
||||
return self.bypass_scale
|
||||
else:
|
||||
ans = limit_param_value(
|
||||
self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max)
|
||||
self.bypass_scale,
|
||||
min=float(self.scale_min),
|
||||
max=float(self.scale_max),
|
||||
)
|
||||
skip_rate = float(self.skip_rate)
|
||||
if skip_rate != 0.0:
|
||||
mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate
|
||||
mask = (
|
||||
torch.rand((batch_size, 1), device=ans.device) > skip_rate
|
||||
)
|
||||
ans = ans * mask
|
||||
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
||||
# on which we have randomly chosen to do layer-skipping.
|
||||
@ -1320,7 +1356,9 @@ class SimpleDownsample(torch.nn.Module):
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
|
||||
src_extra = src[src.shape[0] - 1 :].expand(
|
||||
pad, src.shape[1], src.shape[2]
|
||||
)
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
@ -1354,7 +1392,9 @@ class SimpleUpsample(torch.nn.Module):
|
||||
"""
|
||||
upsample = self.upsample
|
||||
(seq_len, batch_size, num_channels) = src.shape
|
||||
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
||||
src = src.unsqueeze(1).expand(
|
||||
seq_len, upsample, batch_size, num_channels
|
||||
)
|
||||
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||
return src
|
||||
|
||||
@ -1411,12 +1451,18 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(0) >= T * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
|
||||
# if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ]
|
||||
x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1)
|
||||
x = (
|
||||
torch.arange(-(T - 1), T, device=x.device)
|
||||
.to(torch.float32)
|
||||
.unsqueeze(1)
|
||||
)
|
||||
|
||||
freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device)
|
||||
|
||||
@ -1430,7 +1476,10 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
x_compressed = (
|
||||
compression_length
|
||||
* x.sign()
|
||||
* ((x.abs() + compression_length).log() - math.log(compression_length))
|
||||
* (
|
||||
(x.abs() + compression_length).log()
|
||||
- math.log(compression_length)
|
||||
)
|
||||
)
|
||||
|
||||
# if self.length_factor == 1.0, then length_scale is chosen so that the
|
||||
@ -1442,7 +1491,9 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
# note for machine implementations: if atan is not available, we can use:
|
||||
# x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2)
|
||||
# check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x))
|
||||
x_atan = (x_compressed / length_scale).atan() # results between -pi and pi
|
||||
x_atan = (
|
||||
x_compressed / length_scale
|
||||
).atan() # results between -pi and pi
|
||||
|
||||
cosines = (x_atan * freqs).cos()
|
||||
sines = (x_atan * freqs).sin()
|
||||
@ -1507,7 +1558,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
query_head_dim: int,
|
||||
pos_head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
|
||||
pos_emb_skip_rate: FloatLike = ScheduledFloat(
|
||||
(0.0, 0.5), (4000.0, 0.0)
|
||||
),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -1516,7 +1569,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
self.pos_head_dim = pos_head_dim
|
||||
self.dropout = dropout
|
||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
self.name = (
|
||||
None # will be overwritten in training code; for diagnostics.
|
||||
)
|
||||
|
||||
key_head_dim = query_head_dim
|
||||
in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads
|
||||
@ -1527,7 +1582,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
||||
# it would be necessary to apply the scaling factor in the forward function.
|
||||
self.in_proj = ScaledLinear(
|
||||
embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25
|
||||
embed_dim,
|
||||
in_proj_dim,
|
||||
bias=True,
|
||||
initial_scale=query_head_dim**-0.25,
|
||||
)
|
||||
|
||||
self.whiten_keys = Whiten(
|
||||
@ -1601,7 +1659,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
assert p.shape[-1] == num_heads * pos_head_dim
|
||||
|
||||
q = self.copy_query(q) # for diagnostics only, does nothing.
|
||||
k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.
|
||||
k = self.whiten_keys(
|
||||
self.balance_keys(k)
|
||||
) # does nothing in the forward pass.
|
||||
p = self.copy_pos_query(p) # for diagnostics only, does nothing.
|
||||
|
||||
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
||||
@ -1619,15 +1679,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
if torch.jit.is_scripting():
|
||||
# We can't put random.random() in the same line
|
||||
use_pos_scores = True
|
||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
elif not self.training or random.random() >= float(
|
||||
self.pos_emb_skip_rate
|
||||
):
|
||||
use_pos_scores = True
|
||||
|
||||
if use_pos_scores:
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||
2, 0, 3, 1
|
||||
)
|
||||
pos_emb = pos_emb.reshape(
|
||||
-1, seq_len2, num_heads, pos_head_dim
|
||||
).permute(2, 0, 3, 1)
|
||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
@ -1769,9 +1831,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
seq_len2 = 2 * seq_len - 1 + left_context_len
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||
2, 0, 3, 1
|
||||
)
|
||||
pos_emb = pos_emb.reshape(
|
||||
-1, seq_len2, num_heads, pos_head_dim
|
||||
).permute(2, 0, 3, 1)
|
||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
@ -1801,7 +1863,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
), attn_scores.shape
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape
|
||||
assert key_padding_mask.shape == (
|
||||
batch_size,
|
||||
k_len,
|
||||
), key_padding_mask.shape
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask.unsqueeze(1),
|
||||
-1000,
|
||||
@ -1846,7 +1911,9 @@ class SelfAttention(nn.Module):
|
||||
value_head_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True)
|
||||
self.in_proj = nn.Linear(
|
||||
embed_dim, num_heads * value_head_dim, bias=True
|
||||
)
|
||||
|
||||
self.out_proj = ScaledLinear(
|
||||
num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05
|
||||
@ -1958,7 +2025,9 @@ class SelfAttention(nn.Module):
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Zipformer2 model."""
|
||||
|
||||
def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike):
|
||||
def __init__(
|
||||
self, embed_dim: int, feedforward_dim: int, dropout: FloatLike
|
||||
):
|
||||
super(FeedforwardModule, self).__init__()
|
||||
self.in_proj = nn.Linear(embed_dim, feedforward_dim)
|
||||
|
||||
@ -2226,7 +2295,9 @@ class ConvolutionModule(nn.Module):
|
||||
assert kernel_size % 2 == 1
|
||||
|
||||
self.depthwise_conv = (
|
||||
ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size)
|
||||
ChunkCausalDepthwiseConv1d(
|
||||
channels=bottleneck_dim, kernel_size=kernel_size
|
||||
)
|
||||
if causal
|
||||
else nn.Conv1d(
|
||||
in_channels=bottleneck_dim,
|
||||
@ -2294,7 +2365,9 @@ class ConvolutionModule(nn.Module):
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
x = x.masked_fill(
|
||||
src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0
|
||||
)
|
||||
|
||||
if not torch.jit.is_scripting() and chunk_size >= 0:
|
||||
# Not support exporting a model for simulated streaming decoding
|
||||
@ -2344,7 +2417,9 @@ class ConvolutionModule(nn.Module):
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0)
|
||||
x = x.masked_fill(
|
||||
src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0
|
||||
)
|
||||
|
||||
x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user