diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b59784dbf..259de4fd0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: 21.6b0 + rev: 21.7b0 hooks: - id: black args: [--line-length=80] diff --git a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/conformer.py b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/conformer.py index a769a4995..152b8b256 100644 --- a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/conformer.py @@ -869,7 +869,10 @@ class ConvolutionModule(nn.Module): groups=channels, bias=bias, ) - self.norm = nn.BatchNorm1d(channels) + # NOTE(fangjun): The process hangs when using DDP + # if we try to recover from CUDA OOM, so we disable + # batchnorm layer here. + # self.norm = nn.BatchNorm1d(channels) self.pointwise_conv2 = nn.Conv1d( channels, channels, @@ -899,7 +902,8 @@ class ConvolutionModule(nn.Module): # 1D Depthwise Conv x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) + # x = self.activation(self.norm(x)) + x = self.activation(x) x = self.pointwise_conv2(x) # (batch, channel, time) diff --git a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py index 4ec296646..49acf445c 100755 --- a/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py +++ b/egs/librispeech/ASR/conformer_ctc_madam_no_warmup/train.py @@ -153,7 +153,7 @@ def get_params() -> AttributeDict: "num_decoder_layers": 6, "is_espnet_structure": True, "mmi_loss": False, - "use_feat_batchnorm": True, + "use_feat_batchnorm": False, "lr_factor": 2.0, "warm_step": 30000, } @@ -282,75 +282,59 @@ def compute_loss_impl( assert feature.ndim == 3 feature = feature.to(device) - try: + supervisions = batch["supervisions"] - supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model( - feature, supervisions - ) - # nnet_output is [N, T, C] - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate != 0.0: - with torch.set_grad_enabled(is_training): - if hasattr(model, "module"): - att_loss = model.module.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = ( - 1.0 - params.att_rate - ) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) - except RuntimeError as ex: - try: - del nnet_output - del encoder_memory - del dense_fsa_vec - del ctc_loss - del att_loss - del loss - except NameError as ne: - pass - raise ex + if hasattr(model, "module"): + att_loss = model.module.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + else: + att_loss = model.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) # train_frames and valid_frames are used for printing. if is_training: @@ -394,11 +378,6 @@ def compute_loss( s += f" max duration: {max_cut_duration:.3f} s \n" logging.info(s) - # see https://github.com/pytorch/fairseq/blob/50a671f78d0c8de0392f924180db72ac9b41b801/fairseq/trainer.py#L283 - for p in model.parameters(): - if p.grad is not None: - del p.grad # free some memory - torch.cuda.empty_cache() gc.collect()