From c6c3750cabc26de327a856b7f361b81dd3a95674 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Fri, 17 Sep 2021 18:55:34 +0800 Subject: [PATCH] Testing configuration for conformer_ctc_bn --- .../ASR/conformer_ctc_bn/conformer.py | 20 +++++++++------- .../ASR/conformer_ctc_bn/decode.py | 9 ++++--- egs/librispeech/ASR/conformer_ctc_bn/madam.py | 1 + egs/librispeech/ASR/conformer_ctc_bn/train.py | 24 +++++++++++++++++++ 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc_bn/conformer.py b/egs/librispeech/ASR/conformer_ctc_bn/conformer.py index 566cad8cf..e4130183f 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc_bn/conformer.py @@ -190,7 +190,7 @@ class DiscreteBottleneck(nn.Module): tot_classes: int, num_groups: int, interp_prob: float = 1.0, - straight_through_scale: float = 0.333, + straight_through_scale: float = 1.0, min_prob_ratio: float = 0.1 ): super(DiscreteBottleneck, self).__init__() @@ -234,20 +234,24 @@ class DiscreteBottleneck(nn.Module): (S, N, tot_classes) = x.shape x = x.reshape(S, N, self.num_groups, self.classes_per_group) - x = torch_flow_sampling.flow_sample(x, - interp_prob=self.interp_prob, - straight_through_scale=self.straight_through_scale) - - assert x.shape == (S, N, self.num_groups, self.classes_per_group) - x = x.reshape(S, N, tot_classes) - if self.training: + x = torch_flow_sampling.flow_sample(x, + interp_prob=self.interp_prob, + straight_through_scale=self.straight_through_scale) + + assert x.shape == (S, N, self.num_groups, self.classes_per_group) + x = x.reshape(S, N, tot_classes) + mean_class_probs = torch.mean(x.detach(), dim=(0,1)) self.class_probs = (self.class_probs * self.class_probs_decay + mean_class_probs * (1.0 - self.class_probs_decay)) prob_floor = self.min_prob_ratio / self.classes_per_group self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost + else: + x = torch.softmax(x, dim=-1) + x = x.reshape(S, N, tot_classes) + x = self.linear2(x) x = self.norm_out(x) return x diff --git a/egs/librispeech/ASR/conformer_ctc_bn/decode.py b/egs/librispeech/ASR/conformer_ctc_bn/decode.py index cfdcff756..14162b95a 100755 --- a/egs/librispeech/ASR/conformer_ctc_bn/decode.py +++ b/egs/librispeech/ASR/conformer_ctc_bn/decode.py @@ -26,7 +26,7 @@ import k2 import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer +from conformer import DiscreteBottleneckConformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -134,7 +134,7 @@ def get_parser(): def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), + "exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"), "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, @@ -142,7 +142,6 @@ def get_params() -> AttributeDict: "attention_dim": 512, "subsampling_factor": 4, "num_decoder_layers": 6, - "vgg_frontend": False, "is_espnet_structure": True, "mmi_loss": False, "use_feat_batchnorm": True, @@ -529,14 +528,14 @@ def main(): else: G = None - model = Conformer( + model = DiscreteBottleneckConformer( num_features=params.feature_dim, nhead=params.nhead, d_model=params.attention_dim, num_classes=num_classes, subsampling_factor=params.subsampling_factor, num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, + vgg_frontend=False, is_espnet_structure=params.is_espnet_structure, mmi_loss=params.mmi_loss, use_feat_batchnorm=params.use_feat_batchnorm, diff --git a/egs/librispeech/ASR/conformer_ctc_bn/madam.py b/egs/librispeech/ASR/conformer_ctc_bn/madam.py index 0ee0655d4..3c6c47df3 100644 --- a/egs/librispeech/ASR/conformer_ctc_bn/madam.py +++ b/egs/librispeech/ASR/conformer_ctc_bn/madam.py @@ -1052,6 +1052,7 @@ class Gloam(object): return { "optimizer": self.optimizer.state_dict(), "_step": self._step, + "_rate": self._rate, "_epoch": self._epoch, } diff --git a/egs/librispeech/ASR/conformer_ctc_bn/train.py b/egs/librispeech/ASR/conformer_ctc_bn/train.py index 48a58d96c..667f2c773 100755 --- a/egs/librispeech/ASR/conformer_ctc_bn/train.py +++ b/egs/librispeech/ASR/conformer_ctc_bn/train.py @@ -15,6 +15,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Note: changed straight_through_scale from 0.333 to 1.0 on epoch 15, which seemed +# to make it train slightly faster, based on comparison of valid prob on +# the initial logs from epoch 15. (best valid loss: 0.0427 -> 0.0425, on batch +# 6000 of epoch 15. was 0.429 last time the valid loss was printed on +# epoch 14. + + +# RESULTS (it's worse!), with: +# python3 conformer_ctc_bn/decode.py --lattice-score-scale=0.5 --method=attention-decoder --epoch=25 --avg=10 --max-duration=30 +# +# With sampling in test-time: +# ngram_lm_scale_1.2_attention_scale_1.5 3.48 best for test-clean +# ngram_lm_scale_0.9_attention_scale_1.2 8.4 best for test-other + +# After I modified conformer.py so that in eval mode, it uses the softmax output with no sampling: +# ngram_lm_scale_0.9_attention_scale_1.2 3.44 best for test-clean +# ngram_lm_scale_0.9_attention_scale_1.0 8.09 best for test-other + +# Vs. BASELINE: +# evaluated with +# python3 conformer_ctc/decode.py --lattice-score-scale=0.5 --method=attention-decoder --epoch=23 --avg=10 --max-duration=30 & +# (also uses foam optimizer) +# ngram_lm_scale_1.2_attention_scale_1.2 2.8 best for test-clean +# ngram_lm_scale_0.9_attention_scale_0.7 6.6 best for test-other import argparse import logging