From 9ec52fdb77b340a97f1b793e060ca77741b23364 Mon Sep 17 00:00:00 2001 From: wangtiance Date: Thu, 26 Oct 2023 18:34:08 +0800 Subject: [PATCH] black format --- .../ASR/tiny_transducer_ctc/asr_datamodule.py | 4 +- .../ASR/tiny_transducer_ctc/ctc_decode.py | 9 +- .../ASR/tiny_transducer_ctc/decode.py | 3 +- .../ASR/tiny_transducer_ctc/encoder.py | 143 ++++++++++-------- .../ASR/tiny_transducer_ctc/export.py | 11 +- .../ASR/tiny_transducer_ctc/pretrained.py | 2 +- .../ASR/tiny_transducer_ctc/train.py | 84 +++++----- 7 files changed, 131 insertions(+), 125 deletions(-) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py index cd550467a..8facb6dba 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py @@ -264,8 +264,8 @@ class LibriSpeechAsrDataModule: num_feature_masks=2, features_mask_size=5, num_frame_masks=10, - frames_mask_size=5, - p=0.5, + frames_mask_size=5, + p=0.5, ) ) else: diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py index aaa5380a6..402aeac0c 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py @@ -373,7 +373,6 @@ def decode_one_batch( # lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] lm_scale_list = [0.6, 0.7, 0.8, 0.9] - if params.decoding_method == "nbest-rescoring": best_path_dict = rescore_with_n_best_list( lattice=lattice, @@ -507,9 +506,7 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(errs_filename)) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) + errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt" with open(errs_info, "w") as f: print("settings\tWER", file=f) for key, val in test_set_wers: @@ -576,7 +573,9 @@ def main(): params.blank_id = 0 if params.decoding_method == "ctc-decoding": - assert "lang_bpe" in str(params.lang_dir), "ctc-decoding only supports BPE lexicons." + assert "lang_bpe" in str( + params.lang_dir + ), "ctc-decoding only supports BPE lexicons." HLG = None H = k2.ctc_topo( max_token=max_token_id, diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py index d4b94f3cf..74aae3ad3 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/decode.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/decode.py @@ -39,7 +39,6 @@ from icefall.utils import ( LOG_EPS = math.log(1e-10) - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -493,7 +492,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( params.res_dir - / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py index 9c6f0792c..4c7fca4fc 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/encoder.py @@ -26,7 +26,7 @@ from torch import Tensor, nn class Conv1dNet(EncoderInterface): """ - 1D Convolution network with causal squeeze and excitation + 1D Convolution network with causal squeeze and excitation module and optional skip connections. Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride. @@ -34,11 +34,11 @@ class Conv1dNet(EncoderInterface): Args: output_dim (int): Number of output channels of the last layer. input_dim (int): Number of input features - conv_layers (int): Number of convolution layers, + conv_layers (int): Number of convolution layers, excluding the subsampling layers. - channels (int): Number of output channels for each layer, + channels (int): Number of output channels for each layer, except the last layer. - subsampling_factor (int): The subsampling factor for the model. + subsampling_factor (int): The subsampling factor for the model. skip_add (bool): Whether to use skip connection for each convolution layer. dscnn (bool): Whether to use depthwise-separated convolution. activation (str): Activation function type. @@ -53,7 +53,7 @@ class Conv1dNet(EncoderInterface): subsampling_factor: int = 4, skip_add: bool = False, dscnn: bool = True, - activation: str = 'relu', + activation: str = "relu", ) -> None: super().__init__() assert subsampling_factor == 4, "Only support subsampling = 4" @@ -62,10 +62,12 @@ class Conv1dNet(EncoderInterface): self.skip_add = skip_add # 80ms latency for subsample_layer self.subsample_layer = nn.Sequential( - conv1d_bn_block(input_dim, channels, 9, - stride=2, activation=activation, dscnn=dscnn), - conv1d_bn_block(channels, channels, 5, - stride=2, activation=activation, dscnn=dscnn), + conv1d_bn_block( + input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn + ), + conv1d_bn_block( + channels, channels, 5, stride=2, activation=activation, dscnn=dscnn + ), ) self.conv_blocks = nn.ModuleList() @@ -82,13 +84,15 @@ class Conv1dNet(EncoderInterface): 3, activation=activation, dscnn=dscnn, - causal=ly % 2), - CausalSqueezeExcite1d(cout[ly], 16, 30) + causal=ly % 2, + ), + CausalSqueezeExcite1d(cout[ly], 16, 30), ) ) def forward( - self, x: torch.Tensor, + self, + x: torch.Tensor, x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -104,24 +108,25 @@ class Conv1dNet(EncoderInterface): - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. """ - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.subsample_layer(x) for idx, layer in enumerate(self.conv_blocks): - if self.skip_add and 0 < idx < self.conv_layers-1: + if self.skip_add and 0 < idx < self.conv_layers - 1: x = layer(x) + x else: x = layer(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) lengths = x_lens >> 2 return x, lengths -def get_activation(name: str, - channels: int, - channel_dim: int = -1, - min_val: int = 0, - max_val: int = 1, - ) -> nn.Module: +def get_activation( + name: str, + channels: int, + channel_dim: int = -1, + min_val: int = 0, + max_val: int = 1, +) -> nn.Module: """ Get activation function from name in string. @@ -140,44 +145,42 @@ def get_activation(name: str, """ act_layer = nn.Identity() name = name.lower() - if name == 'prelu': + if name == "prelu": act_layer = nn.PReLU(channels) - elif name == 'relu': + elif name == "relu": act_layer = nn.ReLU() - elif name == 'relu6': + elif name == "relu6": act_layer = nn.ReLU6() - elif name == 'hardtanh': + elif name == "hardtanh": act_layer = nn.Hardtanh(min_val, max_val) - elif name in ['swish', 'silu']: + elif name in ["swish", "silu"]: act_layer = nn.SiLU() - elif name == 'elu': + elif name == "elu": act_layer = nn.ELU() - elif name == 'doubleswish': + elif name == "doubleswish": act_layer = nn.Sequential( - ActivationBalancer( - num_channels=channels, - channel_dim=channel_dim), + ActivationBalancer(num_channels=channels, channel_dim=channel_dim), DoubleSwish(), ) - elif name == '': + elif name == "": act_layer = nn.Identity() else: - raise Exception(f'Unknown activation function: {name}') + raise Exception(f"Unknown activation function: {name}") return act_layer class CausalSqueezeExcite1d(nn.Module): """ - Causal squeeze and excitation module with input and output shape - (batch, channels, time). The global average pooling in the original + Causal squeeze and excitation module with input and output shape + (batch, channels, time). The global average pooling in the original SE module is replaced by a causal filter, so - the layer does not introduce any algorithmic latency. + the layer does not introduce any algorithmic latency. Args: channels (int): Number of channels reduction (int): channel reduction rate - context_window (int): Context window size for the moving average operation. + context_window (int): Context window size for the moving average operation. For EMA, the smoothing factor is 1 / context_window. """ @@ -205,8 +208,8 @@ class CausalSqueezeExcite1d(nn.Module): self.ema_matrix_size = 0 def _precompute_ema_matrix(self, N: int, device: torch.device): - a = 1.0 / self.context_window # smoothing factor - w = [[(1-a)**k * a for k in range(n, n-N, -1)] for n in range(N)] + a = 1.0 / self.context_window # smoothing factor + w = [[(1 - a) ** k * a for k in range(n, n - N, -1)] for n in range(N)] w = torch.tensor(w).to(device).tril() w[:, 0] *= self.context_window self.ema_matrix = w.T @@ -218,8 +221,8 @@ class CausalSqueezeExcite1d(nn.Module): y[t] = (1-a) * y[t-1] + a * x[t] where a = 1 / self.context_window is the smoothing factor. - For training, the iterative version is too slow. A better way is - to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication. + For training, the iterative version is too slow. A better way is + to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication. The weight matrix can be precomputed if the smoothing factor is fixed. """ if self.training: @@ -234,29 +237,29 @@ class CausalSqueezeExcite1d(nn.Module): y = torch.empty_like(x) y[:, :, 0] = x[:, :, 0] for t in range(1, y.shape[-1]): - y[:, :, t] = (1-a) * y[:, :, t-1] + a * x[:, :, t] + y[:, :, t] = (1 - a) * y[:, :, t - 1] + a * x[:, :, t] return y def moving_avg(self, x: Tensor) -> Tensor: """ - Simple moving average with context_window as window size. + Simple moving average with context_window as window size. """ y = torch.empty_like(x) k = min(x.shape[2], self.context_window) - w = [[1/n] * n + [0] * (k-n-1) for n in range(1, k)] + w = [[1 / n] * n + [0] * (k - n - 1) for n in range(1, k)] w = torch.tensor(w, device=x.device) - y[:, :, :k-1] = torch.matmul(x[:, :, :k-1], w.T) - y[:, :, k-1:] = F.avg_pool1d(x, k, 1) + y[:, :, : k - 1] = torch.matmul(x[:, :, : k - 1], w.T) + y[:, :, k - 1 :] = F.avg_pool1d(x, k, 1) return y def forward(self, x: Tensor) -> Tensor: assert len(x.shape) == 3, "Input is not a 3D tensor!" y = self.exponential_moving_avg(x) - y = y.permute(0, 2, 1) # make channel last for squeeze op + y = y.permute(0, 2, 1) # make channel last for squeeze op y = self.act1(self.linear1(y)) y = self.act2(self.linear2(y)) - y = y.permute(0, 2, 1) # back to original shape + y = y.permute(0, 2, 1) # back to original shape y = x * y return y @@ -267,12 +270,12 @@ def conv1d_bn_block( kernel_size: int = 3, stride: int = 1, dilation: int = 1, - activation: str = 'relu', + activation: str = "relu", dscnn: bool = False, causal: bool = False, ) -> nn.Sequential: """ - Conv1d - batchnorm - activation block. + Conv1d - batchnorm - activation block. If kernel size is even, output length = input length + 1. Otherwise, output and input lengths are equal. @@ -296,16 +299,19 @@ def conv1d_bn_block( stride=stride, dilation=dilation, groups=in_channels, - bias=False) if causal else - nn.Conv1d( + bias=False, + ) + if causal + else nn.Conv1d( in_channels, in_channels, kernel_size, stride=stride, - padding=(kernel_size//2) * dilation, + padding=(kernel_size // 2) * dilation, dilation=dilation, groups=in_channels, - bias=False), + bias=False, + ), nn.BatchNorm1d(in_channels), get_activation(activation, in_channels), nn.Conv1d(in_channels, out_channels, 1, bias=False), @@ -320,15 +326,18 @@ def conv1d_bn_block( kernel_size, stride=stride, dilation=dilation, - bias=False) if causal else - nn.Conv1d( + bias=False, + ) + if causal + else nn.Conv1d( in_channels, out_channels, kernel_size, stride=stride, - padding=(kernel_size//2) * dilation, + padding=(kernel_size // 2) * dilation, dilation=dilation, - bias=False), + bias=False, + ), nn.BatchNorm1d(out_channels), get_activation(activation, out_channels), ) @@ -336,7 +345,7 @@ def conv1d_bn_block( class CausalConv1d(nn.Module): """ - Causal convolution with padding automatically chosen to match input/output length. + Causal convolution with padding automatically chosen to match input/output length. """ def __init__( @@ -352,11 +361,19 @@ class CausalConv1d(nn.Module): super(CausalConv1d, self).__init__() assert kernel_size > 2 - self.padding = dilation*(kernel_size-1) + self.padding = dilation * (kernel_size - 1) self.stride = stride - self.conv = nn.Conv1d(in_channels, out_channels, - kernel_size, stride, self.padding, dilation, groups, bias=bias) + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + self.padding, + dilation, + groups, + bias=bias, + ) def forward(self, x: Tensor) -> Tensor: - return self.conv(x)[:, :, :-self.padding // self.stride] + return self.conv(x)[:, :, : -self.padding // self.stride] diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/export.py b/egs/librispeech/ASR/tiny_transducer_ctc/export.py index adf154d97..4117f7244 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/export.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/export.py @@ -78,9 +78,12 @@ from pathlib import Path import sentencepiece as spm import torch -from icefall.checkpoint import (average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, load_checkpoint) +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import UniqLexicon from icefall.utils import str2bool from train import add_model_arguments, get_params, get_transducer_model @@ -188,7 +191,7 @@ def main(): if "lang_bpe" in str(params.lang_dir): sp = spm.SentencePieceProcessor() - sp.load(params.lang_dir + '/bpe.model') + sp.load(params.lang_dir + "/bpe.model") # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py index ec6d5bb24..981039b8f 100755 --- a/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py @@ -228,7 +228,7 @@ def main(): params.update(vars(args)) sp = spm.SentencePieceProcessor() - sp.load(params.lang_dir + '/bpe.model') + sp.load(params.lang_dir + "/bpe.model") # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") diff --git a/egs/librispeech/ASR/tiny_transducer_ctc/train.py b/egs/librispeech/ASR/tiny_transducer_ctc/train.py index 25adaff58..307ad72aa 100644 --- a/egs/librispeech/ASR/tiny_transducer_ctc/train.py +++ b/egs/librispeech/ASR/tiny_transducer_ctc/train.py @@ -79,15 +79,12 @@ from icefall.utils import ( LRSchedulerType = torch.optim.lr_scheduler._LRScheduler -def set_batch_count( - model: Union[nn.Module, DDP], - batch_count: float -) -> None: +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if isinstance(model, DDP): # get underlying nn.Module model = model.module for module in model.modules(): - if hasattr(module, 'batch_count'): + if hasattr(module, "batch_count"): module.batch_count = batch_count @@ -140,7 +137,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="""Use skip connection in the encoder. """, ) - + def get_parser(): parser = argparse.ArgumentParser( @@ -226,8 +223,7 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; " - "2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", ) parser.add_argument( @@ -250,8 +246,7 @@ def get_parser(): "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" - "part.", + help="The scale to smooth the loss with am (output of encoder network)" "part.", ) parser.add_argument( @@ -450,7 +445,7 @@ def get_transducer_model(params: AttributeDict) -> Transducer: decoder = get_decoder_model(params) joiner = get_joiner_model(params) - if is_module_available('thop'): + if is_module_available("thop"): if torch.cuda.is_available(): device = torch.device("cuda:0") else: @@ -459,13 +454,14 @@ def get_transducer_model(params: AttributeDict) -> Transducer: x = torch.zeros((1, 1000, params.feature_dim)).to(device) x_lens = torch.Tensor([1000]).int().to(device) from thop import clever_format, profile + m = copy.deepcopy(encoder) m = m.to(device) ops, _ = clever_format(profile(m, (x, x_lens), verbose=False)) - logging.info(f'Encoder MAC ops for 10 seconds of audio is {ops}') + logging.info(f"Encoder MAC ops for 10 seconds of audio is {ops}") else: - logging.info('You can install thop to calculate the number of ops.') - logging.info('Command: pip install thop') + logging.info("You can install thop to calculate the number of ops.") + logging.info("Command: pip install thop") model = Transducer( encoder=encoder, @@ -624,11 +620,7 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = ( - model.device - if isinstance(model, DDP) - else next(model.parameters()).device - ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -662,18 +654,17 @@ def compute_loss( # take down the scale on the simple loss from 1.0 at the start # to params.simple_loss scale by warm_step. simple_loss_scale = ( - s if batch_idx_train >= warm_step + s + if batch_idx_train >= warm_step else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) ) pruned_loss_scale = ( - 1.0 if batch_idx_train >= warm_step + 1.0 + if batch_idx_train >= warm_step else 0.1 + 0.9 * (batch_idx_train / warm_step) ) - loss = ( - simple_loss_scale * simple_loss - + pruned_loss_scale * pruned_loss - ) + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss # Compute ctc loss @@ -704,16 +695,14 @@ def compute_loss( ) assert ctc_loss.requires_grad == is_training assert 0 <= params.ctc_loss_scale <= 1, "ctc_loss_scale must be between 0 and 1" - loss = params.ctc_loss_scale * ctc_loss + (1-params.ctc_loss_scale) * loss + loss = params.ctc_loss_scale * ctc_loss + (1 - params.ctc_loss_scale) * loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = ( - (feature_lens // params.subsampling_factor).sum().item() - ) + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -843,8 +832,9 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, - sp=sp, phone_lexicon=phone_lexicon) + display_and_save_batch( + batch, params=params, sp=sp, phone_lexicon=phone_lexicon + ) raise if params.print_diagnostics and batch_idx == 5: @@ -896,7 +886,8 @@ def train_one_epoch( logging.warning(f"Grad scale is small: {cur_grad_scale}") if cur_grad_scale < 1.0e-05: raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}") + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] @@ -918,9 +909,7 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", cur_grad_scale, params.batch_idx_train @@ -938,11 +927,10 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) + valid_info.write_summary(tb_writer, "train/valid_", params.batch_idx_train) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value @@ -987,7 +975,7 @@ def run(rank, world_size, args): if "lang_bpe" in str(params.lang_dir): sp = spm.SentencePieceProcessor() - sp.load(params.lang_dir + '/bpe.model') + sp.load(params.lang_dir + "/bpe.model") # is defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() @@ -1031,8 +1019,7 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: logging.info("Using DDP") - model = DDP(model, device_ids=[rank], - find_unused_parameters=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = AdamW( model.parameters(), @@ -1056,7 +1043,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2 ** 22 + 2**22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1108,8 +1095,7 @@ def run(rank, world_size, args): # params=params, # ) - scaler = GradScaler(enabled=params.use_fp16, - init_scale=1.0) + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) @@ -1235,11 +1221,13 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, - sp=sp, phone_lexicon=phone_lexicon) + display_and_save_batch( + batch, params=params, sp=sp, phone_lexicon=phone_lexicon + ) raise logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) def main():