From a98f6b27a40dc8502074804be4e56fc1a639d91b Mon Sep 17 00:00:00 2001 From: Yifan Yang Date: Fri, 2 Jun 2023 12:41:34 +0800 Subject: [PATCH] Fix typo and reformatted zipformer --- egs/librispeech/ASR/zipformer/decode.py | 39 ++-- .../ASR/zipformer/decode_stream.py | 15 +- egs/librispeech/ASR/zipformer/decoder.py | 4 +- egs/librispeech/ASR/zipformer/export.py | 46 +++-- .../ASR/zipformer/jit_pretrained.py | 4 +- .../ASR/zipformer/jit_pretrained_streaming.py | 19 +- egs/librispeech/ASR/zipformer/joiner.py | 12 +- egs/librispeech/ASR/zipformer/optim.py | 95 ++++++--- egs/librispeech/ASR/zipformer/pretrained.py | 12 +- egs/librispeech/ASR/zipformer/profile.py | 12 +- egs/librispeech/ASR/zipformer/scaling.py | 160 +++++++++++---- .../ASR/zipformer/streaming_beam_search.py | 12 +- .../ASR/zipformer/streaming_decode.py | 51 +++-- egs/librispeech/ASR/zipformer/subsampling.py | 14 +- egs/librispeech/ASR/zipformer/train.py | 47 ++++- egs/librispeech/ASR/zipformer/zipformer.py | 189 ++++++++++++------ 16 files changed, 524 insertions(+), 207 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 7bff65c10..f4b81cfe3 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -273,7 +273,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -370,7 +371,9 @@ def decode_one_batch( src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens = model.encoder( + x, x_lens, src_key_padding_mask + ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) hyps = [] @@ -430,7 +433,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -561,7 +567,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -594,7 +602,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -655,7 +664,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -687,9 +698,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -716,9 +727,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -777,7 +788,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/zipformer/decode_stream.py b/egs/librispeech/ASR/zipformer/decode_stream.py index 946db275c..cb173619e 100644 --- a/egs/librispeech/ASR/zipformer/decode_stream.py +++ b/egs/librispeech/ASR/zipformer/decode_stream.py @@ -90,11 +90,13 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) @property def done(self) -> bool: @@ -124,10 +126,13 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length + ) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa + self.num_processed_frames : self.num_processed_frames + + ret_length # noqa ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/zipformer/decoder.py b/egs/librispeech/ASR/zipformer/decoder.py index 0ca06233a..78d22a878 100644 --- a/egs/librispeech/ASR/zipformer/decoder.py +++ b/egs/librispeech/ASR/zipformer/decoder.py @@ -118,7 +118,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/zipformer/export.py b/egs/librispeech/ASR/zipformer/export.py index e44fd72b1..d1152372a 100755 --- a/egs/librispeech/ASR/zipformer/export.py +++ b/egs/librispeech/ASR/zipformer/export.py @@ -276,7 +276,9 @@ class EncoderModel(nn.Module): src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens = self.encoder( + x, x_lens, src_key_padding_mask + ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) return encoder_out, encoder_out_lens @@ -288,7 +290,9 @@ class StreamingEncoderModel(nn.Module): def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: super().__init__() assert len(encoder.chunk_size) == 1, encoder.chunk_size - assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + assert ( + len(encoder.left_context_frames) == 1 + ), encoder.left_context_frames self.chunk_size = encoder.chunk_size[0] self.left_context_len = encoder.left_context_frames[0] @@ -315,7 +319,11 @@ class StreamingEncoderModel(nn.Module): left_context_len = self.left_context_len cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = self.encoder_embed.streaming_forward( x=features, x_lens=feature_lengths, cached_left_pad=cached_embed_left_pad, @@ -335,7 +343,9 @@ class StreamingEncoderModel(nn.Module): new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -377,7 +387,9 @@ class StreamingEncoderModel(nn.Module): embed_states = self.encoder_embed.get_init_states(batch_size, device) states.append(embed_states) - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + processed_lens = torch.zeros( + batch_size, dtype=torch.int32, device=device + ) states.append(processed_lens) return states @@ -411,9 +423,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -438,9 +450,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -494,10 +506,14 @@ def main(): # Wrap encoder and encoder_embed as a module if params.causal: - model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + model.encoder = StreamingEncoderModel( + model.encoder, model.encoder_embed + ) chunk_size = model.encoder.chunk_size left_context_len = model.encoder.left_context_len - filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + filename = ( + f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + ) else: model.encoder = EncoderModel(model.encoder, model.encoder_embed) filename = "jit_script.pt" @@ -516,7 +532,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained.py b/egs/librispeech/ASR/zipformer/jit_pretrained.py index 4092d165e..7bdb8cbee 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained.py @@ -266,7 +266,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py index 58d736685..f3333c4d2 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_streaming.py @@ -126,7 +126,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32, device=device).unsqueeze(0) + decoder_input = torch.tensor( + hyp, dtype=torch.int32, device=device + ).unsqueeze(0) # decoder_input.shape (1,, 1 context_size) decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) else: @@ -146,7 +148,9 @@ def greedy_search( decoder_input = torch.tensor( decoder_input, dtype=torch.int32, device=device ).unsqueeze(0) - decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze(1) + decoder_out = decoder(decoder_input, torch.tensor([False])).squeeze( + 1 + ) return hyp, decoder_out @@ -247,7 +251,12 @@ def main(): num_processed_frames += chunk_length hyp, decoder_out = greedy_search( - decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp, device=device + decoder, + joiner, + encoder_out.squeeze(0), + decoder_out, + hyp, + device=device, ) context_size = 2 @@ -263,7 +272,9 @@ torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) torch._C._set_graph_executor_optimize(False) if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/joiner.py b/egs/librispeech/ASR/zipformer/joiner.py index dfb0a0057..d64eeded6 100644 --- a/egs/librispeech/ASR/zipformer/joiner.py +++ b/egs/librispeech/ASR/zipformer/joiner.py @@ -29,8 +29,12 @@ class Joiner(nn.Module): ): super().__init__() - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25) + self.encoder_proj = ScaledLinear( + encoder_dim, joiner_dim, initial_scale=0.25 + ) + self.decoder_proj = ScaledLinear( + decoder_dim, joiner_dim, initial_scale=0.25 + ) self.output_linear = nn.Linear(joiner_dim, vocab_size) def forward( @@ -58,7 +62,9 @@ class Joiner(nn.Module): ) if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/zipformer/optim.py b/egs/librispeech/ASR/zipformer/optim.py index f71e437c1..98af7b70c 100644 --- a/egs/librispeech/ASR/zipformer/optim.py +++ b/egs/librispeech/ASR/zipformer/optim.py @@ -90,7 +90,9 @@ class BatchedOptimizer(Optimizer): sorted_idx = sorted( range(len(batches_names)), key=lambda i: batches_names_keys[i] ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches_names = [ + batches_names[batches_names_keys[idx]] for idx in sorted_idx + ] batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] stacked_params_dict = dict() @@ -108,7 +110,10 @@ class BatchedOptimizer(Optimizer): state = self.state[p] p_stacked = torch.stack(batch) grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] + [ + torch.zeros_like(p) if p.grad is None else p.grad + for p in batch + ] ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked @@ -325,8 +330,12 @@ class ScaledAdam(BatchedOptimizer): batch = True - for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: + for group, group_params_names in zip( + self.param_groups, self.parameters_names + ): + with self.batched_params( + group["params"], group_params_names + ) as batches: # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. @@ -378,7 +387,9 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) batch_size = p.shape[0] numel = p.numel() // batch_size @@ -387,7 +398,9 @@ class ScaledAdam(BatchedOptimizer): # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = ( + (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + ) state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) @@ -396,7 +409,9 @@ class ScaledAdam(BatchedOptimizer): ) # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) def _get_clipping_scale( self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] @@ -432,7 +447,9 @@ class ScaledAdam(BatchedOptimizer): "ScaledAdam optimizer does not support sparse gradients" ) if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] + tot_sumsq += ( + grad**2 + ).sum() # sum() to change shape [1] to [] else: tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() @@ -451,7 +468,8 @@ class ScaledAdam(BatchedOptimizer): quartiles = [] for n in range(0, 5): index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n + clipping_update_period - 1, + (clipping_update_period // 4) * n, ) quartiles.append(sorted_norms[index].item()) @@ -536,7 +554,9 @@ class ScaledAdam(BatchedOptimizer): sorted_by_proportion = { k: v for k, v in sorted( - all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + all_sumsq_orig.items(), + key=lambda item: item[1][0], + reverse=True, ) } dominant_param_name = next(iter(sorted_by_proportion)) @@ -589,7 +609,9 @@ class ScaledAdam(BatchedOptimizer): if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + (p**2) + .mean(dim=list(range(1, p.ndim)), keepdim=True) + .sqrt() ) if step > 0: # self._size_update() learns the overall scale on the @@ -636,9 +658,13 @@ class ScaledAdam(BatchedOptimizer): # faster decay at this level. beta2_corr = beta2**size_update_period - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) + scale_exp_avg_sq = state[ + "scale_exp_avg_sq" + ] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` + (scale_grads**2).mean( + dim=0 + ), # mean over dim `size_update_period` alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...) @@ -651,7 +677,10 @@ class ScaledAdam(BatchedOptimizer): denom = scale_exp_avg_sq.sqrt() + eps scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom + -size_lr + * (bias_correction2**0.5) + * scale_grads.sum(dim=0) + / denom ) is_too_small = param_rms < param_min_rms @@ -663,7 +692,9 @@ class ScaledAdam(BatchedOptimizer): # We have to look at the trained model for parameters at or around the # param_max_rms, because sometimes they can indicate a problem with the # topology or settings. - scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms) + scale_step = torch.minimum( + scale_step, (param_max_rms - param_rms) / param_rms + ) delta = state["delta"] # the factor of (1-beta1) relates to momentum. @@ -692,7 +723,9 @@ class ScaledAdam(BatchedOptimizer): exp_avg_sq = state["exp_avg_sq"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) + this_step = state["step"] - ( + state["zero_step"] if "zero_step" in state else 0 + ) bias_correction2 = 1 - beta2 ** (this_step + 1) if bias_correction2 < 0.99: # note: not in-place. @@ -742,7 +775,9 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose @@ -869,7 +904,8 @@ class Eden(LRScheduler): factor = ( (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) + ** -0.25 ) warmup_factor = ( 1.0 @@ -958,11 +994,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -998,7 +1040,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -1036,7 +1080,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -1087,7 +1133,8 @@ def _test_scaled_adam(hidden_dim: int): 100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) + * output_magnitudes, ) for _ in range(20) ] diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index 3999fe95c..1efc2d8b7 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -314,7 +314,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) # model forward @@ -323,7 +325,9 @@ def main(): src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens = model.encoder( + x, x_lens, src_key_padding_mask + ) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) hyps = [] @@ -374,7 +378,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/profile.py b/egs/librispeech/ASR/zipformer/profile.py index 57f44a90a..b460b5338 100755 --- a/egs/librispeech/ASR/zipformer/profile.py +++ b/egs/librispeech/ASR/zipformer/profile.py @@ -100,13 +100,17 @@ class Model(nn.Module): self.encoder_embed = encoder_embed self.encoder_proj = encoder_proj - def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]: + def forward( + self, feature: Tensor, feature_lens: Tensor + ) -> Tuple[Tensor, Tensor]: x, x_lens = self.encoder_embed(feature, feature_lens) src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out, encoder_out_lens = self.encoder( + x, x_lens, src_key_padding_mask + ) encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C) logits = self.encoder_proj(encoder_out) @@ -164,7 +168,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 391bac5c0..23b23bfd8 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -64,7 +64,9 @@ class PiecewiseLinear(object): for i in range(1, len(self.pairs)): next_x, next_y = self.pairs[i] if x >= cur_x and x <= next_x: - return cur_y + (next_y - cur_y) * (x - cur_x) / (next_x - cur_x) + return cur_y + (next_y - cur_y) * (x - cur_x) / ( + next_x - cur_x + ) cur_x, cur_y = next_x, next_y assert False @@ -98,7 +100,9 @@ class PiecewiseLinear(object): def __eq__(self, other): return self.pairs == other.pairs - def get_common_basis(self, p: "PiecewiseLinear", include_crossings: bool = False): + def get_common_basis( + self, p: "PiecewiseLinear", include_crossings: bool = False + ): """ Returns (self_mod, p_mod) which are equivalent piecewise lienar functions to self and p, but with the same x values. @@ -110,14 +114,18 @@ class PiecewiseLinear(object): assert isinstance(p, PiecewiseLinear), type(p) # get sorted x-values without repetition. - x_vals = sorted(set([x for x, _ in self.pairs] + [x for x, _ in p.pairs])) + x_vals = sorted( + set([x for x, _ in self.pairs] + [x for x, _ in p.pairs]) + ) y_vals1 = [self(x) for x in x_vals] y_vals2 = [p(x) for x in x_vals] if include_crossings: extra_x_vals = [] for i in range(len(x_vals) - 1): - if (y_vals1[i] > y_vals2[i]) != (y_vals1[i + 1] > y_vals2[i + 1]): + if (y_vals1[i] > y_vals2[i]) != ( + y_vals1[i + 1] > y_vals2[i + 1] + ): # if the two lines in this subsegment potentially cross each other.. diff_cur = abs(y_vals1[i] - y_vals2[i]) diff_next = abs(y_vals1[i + 1] - y_vals2[i + 1]) @@ -163,9 +171,7 @@ class ScheduledFloat(torch.nn.Module): self.schedule = PiecewiseLinear(*args) def extra_repr(self) -> str: - return ( - f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" - ) + return f"batch_count={self.batch_count}, schedule={str(self.schedule.pairs[1:-1])}" def __float__(self): batch_count = self.batch_count @@ -192,7 +198,8 @@ class ScheduledFloat(torch.nn.Module): return ScheduledFloat(self.schedule.max(x), default=self.default) else: return ScheduledFloat( - self.schedule.max(x.schedule), default=max(self.default, x.default) + self.schedule.max(x.schedule), + default=max(self.default, x.default), ) @@ -374,7 +381,8 @@ class BiasNormFunction(torch.autograd.Function): with torch.enable_grad(): # recompute scales from x, bias and log_scale. scales = ( - torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) ** -0.5 + torch.mean((x - bias) ** 2, dim=ctx.channel_dim, keepdim=True) + ** -0.5 ) * log_scale.exp() ans = x * scales ans.backward(gradient=ans_grad) @@ -443,7 +451,8 @@ class BiasNorm(torch.nn.Module): for _ in range(channel_dim + 1, x.ndim): bias = bias.unsqueeze(-1) scales = ( - torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) ** -0.5 + torch.mean((x - bias) ** 2, dim=channel_dim, keepdim=True) + ** -0.5 ) * self.log_scale.exp() return x * scales @@ -455,7 +464,11 @@ class BiasNorm(torch.nn.Module): ) return BiasNormFunction.apply( - x, self.bias, log_scale, self.channel_dim, self.store_output_for_backprop + x, + self.bias, + log_scale, + self.channel_dim, + self.store_output_for_backprop, ) @@ -478,7 +491,9 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) return ans @@ -501,7 +516,9 @@ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) return ans @@ -525,7 +542,9 @@ def ScaledConv2d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv2d: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_( + ans.bias, -0.1 * initial_scale, 0.1 * initial_scale + ) return ans @@ -587,7 +606,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): # first row is correction factors added to the scale near the left edge of the chunk, # second row is correction factors added to the scale near the right edge of the chunk, # both of these are added to a default scale of 1.0. - self.chunkwise_conv_scale = nn.Parameter(torch.zeros(2, channels, kernel_size)) + self.chunkwise_conv_scale = nn.Parameter( + torch.zeros(2, channels, kernel_size) + ) self.kernel_size = kernel_size with torch.no_grad(): @@ -595,7 +616,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): self.chunkwise_conv.weight[:] *= initial_scale if bias: torch.nn.init.uniform_( - self.causal_conv.bias, -0.1 * initial_scale, 0.1 * initial_scale + self.causal_conv.bias, + -0.1 * initial_scale, + 0.1 * initial_scale, ) def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor: @@ -623,7 +646,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): x_chunk = x[..., left_pad:] num_chunks = x_chunk.shape[2] // chunk_size - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks, chunk_size) + x_chunk = x_chunk.reshape( + batch_size, num_channels, num_chunks, chunk_size + ) x_chunk = x_chunk.permute(0, 2, 1, 3).reshape( batch_size * num_chunks, num_channels, chunk_size ) @@ -635,9 +660,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module): x_chunk = x_chunk.reshape( batch_size, num_chunks, num_channels, chunk_size ).permute(0, 2, 1, 3) - x_chunk = x_chunk.reshape(batch_size, num_channels, num_chunks * chunk_size)[ - ..., :seq_len - ] + x_chunk = x_chunk.reshape( + batch_size, num_channels, num_chunks * chunk_size + )[..., :seq_len] return x_chunk + x_causal @@ -711,13 +736,29 @@ class BalancerFunction(torch.autograd.Function): channel_dim += x.ndim ctx.channel_dim = channel_dim ctx.save_for_backward(x) - ctx.config = (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) + ctx.config = ( + min_mean, + max_mean, + min_rms, + max_rms, + grad_scale, + channel_dim, + ) return x @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None, None, None]: (x,) = ctx.saved_tensors - (min_mean, max_mean, min_rms, max_rms, grad_scale, channel_dim) = ctx.config + ( + min_mean, + max_mean, + min_rms, + max_rms, + grad_scale, + channel_dim, + ) = ctx.config try: with torch.enable_grad(): @@ -728,7 +769,11 @@ class BalancerFunction(torch.autograd.Function): mean_dims = [i for i in range(x.ndim) if i != channel_dim] uncentered_var = (x**2).mean(dim=mean_dims, keepdim=True) mean = x.mean(dim=mean_dims, keepdim=True) - stddev = (uncentered_var - (mean * mean)).clamp(min=1.0e-20).sqrt() + stddev = ( + (uncentered_var - (mean * mean)) + .clamp(min=1.0e-20) + .sqrt() + ) rms = uncentered_var.clamp(min=1.0e-20).sqrt() m = mean / stddev @@ -877,7 +922,13 @@ class Balancer(torch.nn.Module): assert x.shape[self.channel_dim] == self.num_channels return BalancerFunction.apply( - x, min_mean, max_mean, min_rms, max_rms, grad_scale, self.channel_dim + x, + min_mean, + max_mean, + min_rms, + max_rms, + grad_scale, + self.channel_dim, ) else: return _no_op(x) @@ -956,7 +1007,9 @@ def _whitening_metric(x: Tensor, num_groups: int): # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar**2).sum() / ( + num_groups * channels_per_group + ) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) return metric @@ -1067,7 +1120,11 @@ class Whiten(nn.Module): and nothing will happen in backprop. """ grad_scale = float(self.grad_scale) - if not x.requires_grad or random.random() > self.prob or grad_scale == 0: + if ( + not x.requires_grad + or random.random() > self.prob + or grad_scale == 0 + ): return _no_op(x) else: return WhiteningPenaltyFunction.apply(x, self) @@ -1086,7 +1143,9 @@ class WithLoss(torch.autograd.Function): def backward(ctx, ans_grad: Tensor): return ( ans_grad, - torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), + torch.ones( + ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device + ), None, ) @@ -1141,7 +1200,9 @@ class LimitParamValue(torch.autograd.Function): ) # where x > ctx.max, ensure all grads are positive (this will tend to make # x more negative). - x_grad *= torch.where(torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0) + x_grad *= torch.where( + torch.logical_and(x_grad < 0, x > ctx.max), -1.0, 1.0 + ) return x_grad, None, None @@ -1213,9 +1274,9 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.044 ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) + d_scaled = (deriv - floor) * ( + 255.0 / (ceil - floor) + ) + torch.rand_like(deriv) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1257,7 +1318,9 @@ class Dropout2(nn.Module): self.p = p def forward(self, x: Tensor) -> Tensor: - return torch.nn.functional.dropout(x, p=float(self.p), training=self.training) + return torch.nn.functional.dropout( + x, p=float(self.p), training=self.training + ) class MulForDropout3(torch.autograd.Function): @@ -1330,9 +1393,9 @@ class SwooshLFunction(torch.autograd.Function): floor = coeff ceil = 1.0 + coeff + 0.005 - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) + d_scaled = (grad - floor) * ( + 255.0 / (ceil - floor) + ) + torch.rand_like(grad) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1399,9 +1462,9 @@ class SwooshRFunction(torch.autograd.Function): floor = -0.08 ceil = 0.925 - d_scaled = (grad - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - grad - ) + d_scaled = (grad - floor) * ( + 255.0 / (ceil - floor) + ) + torch.rand_like(grad) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1472,7 +1535,8 @@ class ActivationDropoutAndLinearFunction(torch.autograd.Function): dropout_shape[dropout_shared_dim] = 1 # else it won't be very memory efficient. dropout_mask = (1.0 / (1.0 - dropout_p)) * ( - torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) > dropout_p + torch.rand(*dropout_shape, device=x.device, dtype=x.dtype) + > dropout_p ) else: dropout_mask = None @@ -1641,7 +1705,9 @@ def _test_whiten(): def _test_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) + x = 1.0 * ( + (2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0 + ) x = x.detach() x.requires_grad = True m = Balancer( @@ -1665,7 +1731,9 @@ def _test_balancer_sign(): def _test_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = Balancer( @@ -1825,9 +1893,13 @@ def _test_activation_dropout_and_linear(): print("y1 = ", y1) print("y2 = ", y2) assert torch.allclose(y1, y2, atol=0.02) - assert torch.allclose(m1[2].weight.grad, m2.weight.grad, atol=1.0e-05) + assert torch.allclose( + m1[2].weight.grad, m2.weight.grad, atol=1.0e-05 + ) if bias: - assert torch.allclose(m1[2].bias.grad, m2.bias.grad, atol=1.0e-05) + assert torch.allclose( + m1[2].bias.grad, m2.bias.grad, atol=1.0e-05 + ) print("x1.grad = ", x1.grad) print("x2.grad = ", x2.grad) diff --git a/egs/librispeech/ASR/zipformer/streaming_beam_search.py b/egs/librispeech/ASR/zipformer/streaming_beam_search.py index e6e0fb1c8..9bcd2f9f9 100644 --- a/egs/librispeech/ASR/zipformer/streaming_beam_search.py +++ b/egs/librispeech/ASR/zipformer/streaming_beam_search.py @@ -153,7 +153,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -170,10 +172,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/zipformer/streaming_decode.py b/egs/librispeech/ASR/zipformer/streaming_decode.py index c4d85ecec..3f140b4fa 100755 --- a/egs/librispeech/ASR/zipformer/streaming_decode.py +++ b/egs/librispeech/ASR/zipformer/streaming_decode.py @@ -282,7 +282,9 @@ def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: ) batch_states.append(cached_embed_left_pad) - processed_lens = torch.cat([state_list[i][-1] for i in range(batch_size)], dim=0) + processed_lens = torch.cat( + [state_list[i][-1] for i in range(batch_size)], dim=0 + ) batch_states.append(processed_lens) return batch_states @@ -320,7 +322,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: for layer in range(tot_num_layers): layer_offset = layer * 6 # cached_key: (left_context_len, batch_size, key_dim) - cached_key_list = batch_states[layer_offset].chunk(chunks=batch_size, dim=1) + cached_key_list = batch_states[layer_offset].chunk( + chunks=batch_size, dim=1 + ) # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( chunks=batch_size, dim=1 @@ -351,7 +355,9 @@ def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: cached_conv2_list[i], ] - cached_embed_left_pad_list = batch_states[-2].chunk(chunks=batch_size, dim=0) + cached_embed_left_pad_list = batch_states[-2].chunk( + chunks=batch_size, dim=0 + ) for i in range(batch_size): state_list[i].append(cached_embed_left_pad_list[i]) @@ -398,7 +404,9 @@ def streaming_forward( new_processed_lens = processed_lens + x_lens # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) encoder_states = states[:-2] @@ -486,7 +494,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = torch.tensor(processed_lens, device=device) processed_lens = processed_lens + encoder_out_lens @@ -507,7 +517,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = unstack_states(new_states) @@ -565,7 +577,9 @@ def decode_dataset( decode_streams = [] for num, cut in enumerate(cuts): # each utterance has a DecodeStream. - initial_states = get_init_states(model=model, batch_size=1, device=device) + initial_states = get_init_states( + model=model, batch_size=1, device=device + ) decode_stream = DecodeStream( params=params, cut_id=cut.id, @@ -635,7 +649,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -668,7 +684,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -701,7 +718,9 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" assert params.causal, params.causal - assert "," not in params.chunk_size, "chunk_size should be one value in decoding." + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." assert ( "," not in params.left_context_frames ), "left_context_frames should be one value in decoding." @@ -741,9 +760,9 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" @@ -770,9 +789,9 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( f"No checkpoints found for" diff --git a/egs/librispeech/ASR/zipformer/subsampling.py b/egs/librispeech/ASR/zipformer/subsampling.py index 068c463ef..47403f13c 100644 --- a/egs/librispeech/ASR/zipformer/subsampling.py +++ b/egs/librispeech/ASR/zipformer/subsampling.py @@ -107,7 +107,9 @@ class ConvNeXt(nn.Module): if layerdrop_rate != 0.0: batch_size = x.shape[0] mask = ( - torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) + torch.rand( + (batch_size, 1, 1, 1), dtype=x.dtype, device=x.device + ) > layerdrop_rate ) else: @@ -273,7 +275,9 @@ class Conv2dSubsampling(nn.Module): # many copies of this extra gradient term. self.out_whiten = Whiten( num_groups=1, - whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), + whitening_limit=ScheduledFloat( + (0.0, 4.0), (20000.0, 8.0), default=4.0 + ), prob=(0.025, 0.25), grad_scale=0.02, ) @@ -396,8 +400,8 @@ class Conv2dSubsampling(nn.Module): left_pad = self.convnext.padding[0] freq = self.out_width channels = self.layer3_channels - cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( - device - ) + cached_embed_left_pad = torch.zeros( + batch_size, channels, left_pad, freq + ).to(device) return cached_embed_left_pad diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index bec9a3986..3cfd155b4 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -95,7 +95,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_adjusted_batch_count(params: AttributeDict) -> float: @@ -342,7 +344,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -365,7 +368,8 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( @@ -738,7 +742,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -785,7 +793,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -967,7 +977,9 @@ def train_one_epoch( # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: if not saved_bad_model: @@ -989,7 +1001,11 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -1000,13 +1016,20 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if params.use_fp16: tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1096,7 +1119,9 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = ScaledAdam( - get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True + ), lr=params.base_lr, # should have no effect clipping_scale=2.0, ) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 7d605d200..7aae6ae22 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -125,7 +125,9 @@ class Zipformer2(EncoderInterface): if len(x) == 1: x = x * len(downsampling_factor) else: - assert len(x) == len(downsampling_factor) and isinstance(x[0], int) + assert len(x) == len(downsampling_factor) and isinstance( + x[0], int + ) return x self.output_downsampling_factor = output_downsampling_factor # int @@ -140,7 +142,9 @@ class Zipformer2(EncoderInterface): pos_head_dim = _to_tuple(pos_head_dim) self.num_heads = num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) - self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple( + cnn_module_kernel + ) self.causal = causal self.chunk_size = chunk_size @@ -192,7 +196,9 @@ class Zipformer2(EncoderInterface): self.encoders = nn.ModuleList(encoders) self.downsample_output = SimpleDownsample( - max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout, ) def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: @@ -223,7 +229,8 @@ class Zipformer2(EncoderInterface): # mask1 shape: (1, batch_size, 1) mask1 = ( - torch.rand(1, batch_size, 1, device=x.device) > feature_mask_dropout_prob + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob ).to(x.dtype) # mask2 has additional sequences masked, about twice the number. @@ -271,7 +278,9 @@ class Zipformer2(EncoderInterface): left_context_chunks = -1 else: if torch.jit.is_scripting(): - assert len(self.left_context_frames) == 1, self.left_context_frames + assert ( + len(self.left_context_frames) == 1 + ), self.left_context_frames left_context_frames = self.left_context_frames[0] else: left_context_frames = random.choice(self.left_context_frames) @@ -370,7 +379,8 @@ class Zipformer2(EncoderInterface): num_encoders = len(self.encoder_dim) assert all( chunk_size * left_context_chunks - >= (self.cnn_module_kernel[i] // 2) * self.downsampling_factor[i] + >= (self.cnn_module_kernel[i] // 2) + * self.downsampling_factor[i] for i in range(num_encoders) ) else: @@ -390,7 +400,9 @@ class Zipformer2(EncoderInterface): src_c = c tgt_c = c.unsqueeze(-1) - attn_mask = torch.logical_or(src_c > tgt_c, src_c < tgt_c - left_context_chunks) + attn_mask = torch.logical_or( + src_c > tgt_c, src_c < tgt_c - left_context_chunks + ) if __name__ == "__main__": logging.info(f"attn_mask = {attn_mask}") return attn_mask @@ -448,7 +460,9 @@ class Zipformer2(EncoderInterface): x, new_layer_states = module.streaming_forward( x, - states=states[layer_offset * 6 : (layer_offset + num_layers) * 6], + states=states[ + layer_offset * 6 : (layer_offset + num_layers) * 6 + ], left_context_len=self.left_context_frames[0] // ds, src_key_padding_mask=src_key_padding_mask[..., ::ds], ) @@ -496,24 +510,24 @@ class Zipformer2(EncoderInterface): nonlin_attn_head_dim = 3 * embed_dim // 4 conv_left_pad = self.cnn_module_kernel[i] // 2 for layer in range(num_layers): - cached_key = torch.zeros(downsample_left, batch_size, key_dim).to( - device - ) + cached_key = torch.zeros( + downsample_left, batch_size, key_dim + ).to(device) cached_nonlin_attn = torch.zeros( 1, batch_size, downsample_left, nonlin_attn_head_dim ).to(device) - cached_val1 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_val2 = torch.zeros(downsample_left, batch_size, value_dim).to( - device - ) - cached_conv1 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) - cached_conv2 = torch.zeros(batch_size, embed_dim, conv_left_pad).to( - device - ) + cached_val1 = torch.zeros( + downsample_left, batch_size, value_dim + ).to(device) + cached_val2 = torch.zeros( + downsample_left, batch_size, value_dim + ).to(device) + cached_conv1 = torch.zeros( + batch_size, embed_dim, conv_left_pad + ).to(device) + cached_conv2 = torch.zeros( + batch_size, embed_dim, conv_left_pad + ).to(device) states += [ cached_key, cached_nonlin_attn, @@ -621,7 +635,9 @@ class Zipformer2EncoderLayer(nn.Module): embed_dim, (feedforward_dim * 3) // 4, dropout ) - self.feed_forward2 = FeedforwardModule(embed_dim, feedforward_dim, dropout) + self.feed_forward2 = FeedforwardModule( + embed_dim, feedforward_dim, dropout + ) self.feed_forward3 = FeedforwardModule( embed_dim, (feedforward_dim * 5) // 4, dropout @@ -708,7 +724,9 @@ class Zipformer2EncoderLayer(nn.Module): if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): return None batch_size = x.shape[1] - mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to( + x.dtype + ) return mask def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: @@ -774,7 +792,9 @@ class Zipformer2EncoderLayer(nn.Module): selected_attn_weights = attn_weights[0:1] if torch.jit.is_scripting(): pass - elif not self.training and random.random() < float(self.const_attention_rate): + elif not self.training and random.random() < float( + self.const_attention_rate + ): # Make attention weights constant. The intention is to # encourage these modules to do something similar to an # averaging-over-time operation. @@ -790,7 +810,9 @@ class Zipformer2EncoderLayer(nn.Module): na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights)) src = src + ( - na if self_attn_dropout_mask is None else na * self_attn_dropout_mask + na + if self_attn_dropout_mask is None + else na * self_attn_dropout_mask ) self_attn = self.self_attn1(src, attn_weights) @@ -804,10 +826,14 @@ class Zipformer2EncoderLayer(nn.Module): if torch.jit.is_scripting(): conv_skip_rate = 0.0 else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + conv_skip_rate = ( + float(self.conv_skip_rate) if self.training else 0.0 + ) src = src + self.sequence_dropout( self.conv_module1( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + src, + chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask, ), conv_skip_rate, ) @@ -834,10 +860,14 @@ class Zipformer2EncoderLayer(nn.Module): if torch.jit.is_scripting(): conv_skip_rate = 0.0 else: - conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0 + conv_skip_rate = ( + float(self.conv_skip_rate) if self.training else 0.0 + ) src = src + self.sequence_dropout( self.conv_module2( - src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask + src, + chunk_size=chunk_size, + src_key_padding_mask=src_key_padding_mask, ), conv_skip_rate, ) @@ -1150,7 +1180,9 @@ class BypassModule(nn.Module): embed_dim: int, skip_rate: FloatLike = 0.0, straight_through_rate: FloatLike = 0.0, - scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_min: FloatLike = ScheduledFloat( + (0.0, 0.9), (20000.0, 0.2), default=0 + ), scale_max: FloatLike = 1.0, ): super().__init__() @@ -1169,11 +1201,15 @@ class BypassModule(nn.Module): return self.bypass_scale else: ans = limit_param_value( - self.bypass_scale, min=float(self.scale_min), max=float(self.scale_max) + self.bypass_scale, + min=float(self.scale_min), + max=float(self.scale_max), ) skip_rate = float(self.skip_rate) if skip_rate != 0.0: - mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + mask = ( + torch.rand((batch_size, 1), device=ans.device) > skip_rate + ) ans = ans * mask # now ans is of shape (batch_size, num_channels), and is zero for sequences # on which we have randomly chosen to do layer-skipping. @@ -1320,7 +1356,9 @@ class SimpleDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0] - 1 :].expand( + pad, src.shape[1], src.shape[2] + ) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds @@ -1354,7 +1392,9 @@ class SimpleUpsample(torch.nn.Module): """ upsample = self.upsample (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) + src = src.unsqueeze(1).expand( + seq_len, upsample, batch_size, num_channels + ) src = src.reshape(seq_len * upsample, batch_size, num_channels) return src @@ -1411,12 +1451,18 @@ class CompactRelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(0) >= T * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] - x = torch.arange(-(T - 1), T, device=x.device).to(torch.float32).unsqueeze(1) + x = ( + torch.arange(-(T - 1), T, device=x.device) + .to(torch.float32) + .unsqueeze(1) + ) freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) @@ -1430,7 +1476,10 @@ class CompactRelPositionalEncoding(torch.nn.Module): x_compressed = ( compression_length * x.sign() - * ((x.abs() + compression_length).log() - math.log(compression_length)) + * ( + (x.abs() + compression_length).log() + - math.log(compression_length) + ) ) # if self.length_factor == 1.0, then length_scale is chosen so that the @@ -1442,7 +1491,9 @@ class CompactRelPositionalEncoding(torch.nn.Module): # note for machine implementations: if atan is not available, we can use: # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) - x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + x_atan = ( + x_compressed / length_scale + ).atan() # results between -pi and pi cosines = (x_atan * freqs).cos() sines = (x_atan * freqs).sin() @@ -1507,7 +1558,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_head_dim: int, pos_head_dim: int, dropout: float = 0.0, - pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)), + pos_emb_skip_rate: FloatLike = ScheduledFloat( + (0.0, 0.5), (4000.0, 0.0) + ), ) -> None: super().__init__() self.embed_dim = embed_dim @@ -1516,7 +1569,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): self.pos_head_dim = pos_head_dim self.dropout = dropout self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) - self.name = None # will be overwritten in training code; for diagnostics. + self.name = ( + None # will be overwritten in training code; for diagnostics. + ) key_head_dim = query_head_dim in_proj_dim = (query_head_dim + key_head_dim + pos_head_dim) * num_heads @@ -1527,7 +1582,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): # to be used with the ScaledAdam optimizer; with most other optimizers, # it would be necessary to apply the scaling factor in the forward function. self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=query_head_dim**-0.25 + embed_dim, + in_proj_dim, + bias=True, + initial_scale=query_head_dim**-0.25, ) self.whiten_keys = Whiten( @@ -1601,7 +1659,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): assert p.shape[-1] == num_heads * pos_head_dim q = self.copy_query(q) # for diagnostics only, does nothing. - k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + k = self.whiten_keys( + self.balance_keys(k) + ) # does nothing in the forward pass. p = self.copy_pos_query(p) # for diagnostics only, does nothing. q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) @@ -1619,15 +1679,17 @@ class RelPositionMultiheadAttentionWeights(nn.Module): if torch.jit.is_scripting(): # We can't put random.random() in the same line use_pos_scores = True - elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + elif not self.training or random.random() >= float( + self.pos_emb_skip_rate + ): use_pos_scores = True if use_pos_scores: pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) + pos_emb = pos_emb.reshape( + -1, seq_len2, num_heads, pos_head_dim + ).permute(2, 0, 3, 1) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) @@ -1769,9 +1831,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_emb = self.linear_pos(pos_emb) seq_len2 = 2 * seq_len - 1 + left_context_len - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - 2, 0, 3, 1 - ) + pos_emb = pos_emb.reshape( + -1, seq_len2, num_heads, pos_head_dim + ).permute(2, 0, 3, 1) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) @@ -1801,7 +1863,10 @@ class RelPositionMultiheadAttentionWeights(nn.Module): ), attn_scores.shape if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, k_len), key_padding_mask.shape + assert key_padding_mask.shape == ( + batch_size, + k_len, + ), key_padding_mask.shape attn_scores = attn_scores.masked_fill( key_padding_mask.unsqueeze(1), -1000, @@ -1846,7 +1911,9 @@ class SelfAttention(nn.Module): value_head_dim: int, ) -> None: super().__init__() - self.in_proj = nn.Linear(embed_dim, num_heads * value_head_dim, bias=True) + self.in_proj = nn.Linear( + embed_dim, num_heads * value_head_dim, bias=True + ) self.out_proj = ScaledLinear( num_heads * value_head_dim, embed_dim, bias=True, initial_scale=0.05 @@ -1958,7 +2025,9 @@ class SelfAttention(nn.Module): class FeedforwardModule(nn.Module): """Feedforward module in Zipformer2 model.""" - def __init__(self, embed_dim: int, feedforward_dim: int, dropout: FloatLike): + def __init__( + self, embed_dim: int, feedforward_dim: int, dropout: FloatLike + ): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(embed_dim, feedforward_dim) @@ -2226,7 +2295,9 @@ class ConvolutionModule(nn.Module): assert kernel_size % 2 == 1 self.depthwise_conv = ( - ChunkCausalDepthwiseConv1d(channels=bottleneck_dim, kernel_size=kernel_size) + ChunkCausalDepthwiseConv1d( + channels=bottleneck_dim, kernel_size=kernel_size + ) if causal else nn.Conv1d( in_channels=bottleneck_dim, @@ -2294,7 +2365,9 @@ class ConvolutionModule(nn.Module): x = x.permute(1, 2, 0) # (#batch, channels, time). if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = x.masked_fill( + src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0 + ) if not torch.jit.is_scripting() and chunk_size >= 0: # Not support exporting a model for simulated streaming decoding @@ -2344,7 +2417,9 @@ class ConvolutionModule(nn.Module): x = x.permute(1, 2, 0) # (#batch, channels, time). if src_key_padding_mask is not None: - x = x.masked_fill(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) + x = x.masked_fill( + src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0 + ) x, cache = self.depthwise_conv.streaming_forward(x, cache=cache)