Testing configuration for conformer_ctc_bn

This commit is contained in:
Daniel Povey 2021-09-17 18:55:34 +08:00
parent cfdfcf657d
commit c6c3750cab
4 changed files with 41 additions and 13 deletions

View File

@ -190,7 +190,7 @@ class DiscreteBottleneck(nn.Module):
tot_classes: int,
num_groups: int,
interp_prob: float = 1.0,
straight_through_scale: float = 0.333,
straight_through_scale: float = 1.0,
min_prob_ratio: float = 0.1
):
super(DiscreteBottleneck, self).__init__()
@ -234,6 +234,7 @@ class DiscreteBottleneck(nn.Module):
(S, N, tot_classes) = x.shape
x = x.reshape(S, N, self.num_groups, self.classes_per_group)
if self.training:
x = torch_flow_sampling.flow_sample(x,
interp_prob=self.interp_prob,
straight_through_scale=self.straight_through_scale)
@ -241,13 +242,16 @@ class DiscreteBottleneck(nn.Module):
assert x.shape == (S, N, self.num_groups, self.classes_per_group)
x = x.reshape(S, N, tot_classes)
if self.training:
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
self.class_probs = (self.class_probs * self.class_probs_decay +
mean_class_probs * (1.0 - self.class_probs_decay))
prob_floor = self.min_prob_ratio / self.classes_per_group
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost
else:
x = torch.softmax(x, dim=-1)
x = x.reshape(S, N, tot_classes)
x = self.linear2(x)
x = self.norm_out(x)
return x

View File

@ -26,7 +26,7 @@ import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from conformer import DiscreteBottleneckConformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
@ -134,7 +134,7 @@ def get_parser():
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"),
"feature_dim": 80,
@ -142,7 +142,6 @@ def get_params() -> AttributeDict:
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
@ -529,14 +528,14 @@ def main():
else:
G = None
model = Conformer(
model = DiscreteBottleneckConformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,

View File

@ -1052,6 +1052,7 @@ class Gloam(object):
return {
"optimizer": self.optimizer.state_dict(),
"_step": self._step,
"_rate": self._rate,
"_epoch": self._epoch,
}

View File

@ -15,6 +15,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: changed straight_through_scale from 0.333 to 1.0 on epoch 15, which seemed
# to make it train slightly faster, based on comparison of valid prob on
# the initial logs from epoch 15. (best valid loss: 0.0427 -> 0.0425, on batch
# 6000 of epoch 15. was 0.429 last time the valid loss was printed on
# epoch 14.
# RESULTS (it's worse!), with:
# python3 conformer_ctc_bn/decode.py --lattice-score-scale=0.5 --method=attention-decoder --epoch=25 --avg=10 --max-duration=30
#
# With sampling in test-time:
# ngram_lm_scale_1.2_attention_scale_1.5 3.48 best for test-clean
# ngram_lm_scale_0.9_attention_scale_1.2 8.4 best for test-other
# After I modified conformer.py so that in eval mode, it uses the softmax output with no sampling:
# ngram_lm_scale_0.9_attention_scale_1.2 3.44 best for test-clean
# ngram_lm_scale_0.9_attention_scale_1.0 8.09 best for test-other
# Vs. BASELINE:
# evaluated with
# python3 conformer_ctc/decode.py --lattice-score-scale=0.5 --method=attention-decoder --epoch=23 --avg=10 --max-duration=30 &
# (also uses foam optimizer)
# ngram_lm_scale_1.2_attention_scale_1.2 2.8 best for test-clean
# ngram_lm_scale_0.9_attention_scale_0.7 6.6 best for test-other
import argparse
import logging