mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-15 04:52:22 +00:00
Testing configuration for conformer_ctc_bn
This commit is contained in:
parent
cfdfcf657d
commit
c6c3750cab
@ -190,7 +190,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
tot_classes: int,
|
tot_classes: int,
|
||||||
num_groups: int,
|
num_groups: int,
|
||||||
interp_prob: float = 1.0,
|
interp_prob: float = 1.0,
|
||||||
straight_through_scale: float = 0.333,
|
straight_through_scale: float = 1.0,
|
||||||
min_prob_ratio: float = 0.1
|
min_prob_ratio: float = 0.1
|
||||||
):
|
):
|
||||||
super(DiscreteBottleneck, self).__init__()
|
super(DiscreteBottleneck, self).__init__()
|
||||||
@ -234,6 +234,7 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
(S, N, tot_classes) = x.shape
|
(S, N, tot_classes) = x.shape
|
||||||
x = x.reshape(S, N, self.num_groups, self.classes_per_group)
|
x = x.reshape(S, N, self.num_groups, self.classes_per_group)
|
||||||
|
|
||||||
|
if self.training:
|
||||||
x = torch_flow_sampling.flow_sample(x,
|
x = torch_flow_sampling.flow_sample(x,
|
||||||
interp_prob=self.interp_prob,
|
interp_prob=self.interp_prob,
|
||||||
straight_through_scale=self.straight_through_scale)
|
straight_through_scale=self.straight_through_scale)
|
||||||
@ -241,13 +242,16 @@ class DiscreteBottleneck(nn.Module):
|
|||||||
assert x.shape == (S, N, self.num_groups, self.classes_per_group)
|
assert x.shape == (S, N, self.num_groups, self.classes_per_group)
|
||||||
x = x.reshape(S, N, tot_classes)
|
x = x.reshape(S, N, tot_classes)
|
||||||
|
|
||||||
if self.training:
|
|
||||||
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
|
mean_class_probs = torch.mean(x.detach(), dim=(0,1))
|
||||||
self.class_probs = (self.class_probs * self.class_probs_decay +
|
self.class_probs = (self.class_probs * self.class_probs_decay +
|
||||||
mean_class_probs * (1.0 - self.class_probs_decay))
|
mean_class_probs * (1.0 - self.class_probs_decay))
|
||||||
prob_floor = self.min_prob_ratio / self.classes_per_group
|
prob_floor = self.min_prob_ratio / self.classes_per_group
|
||||||
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost
|
self.class_offsets += (self.class_probs > prob_floor) * self.prob_boost
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = torch.softmax(x, dim=-1)
|
||||||
|
x = x.reshape(S, N, tot_classes)
|
||||||
|
|
||||||
x = self.linear2(x)
|
x = self.linear2(x)
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
return x
|
return x
|
||||||
|
@ -26,7 +26,7 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import DiscreteBottleneckConformer
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
@ -134,7 +134,7 @@ def get_parser():
|
|||||||
def get_params() -> AttributeDict:
|
def get_params() -> AttributeDict:
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("conformer_ctc/exp"),
|
"exp_dir": Path("conformer_ctc_bn/exp_gloam_5e-4_0.85_discrete8"),
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
"lang_dir": Path("data/lang_bpe"),
|
||||||
"lm_dir": Path("data/lm"),
|
"lm_dir": Path("data/lm"),
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
@ -142,7 +142,6 @@ def get_params() -> AttributeDict:
|
|||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"num_decoder_layers": 6,
|
"num_decoder_layers": 6,
|
||||||
"vgg_frontend": False,
|
|
||||||
"is_espnet_structure": True,
|
"is_espnet_structure": True,
|
||||||
"mmi_loss": False,
|
"mmi_loss": False,
|
||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
@ -529,14 +528,14 @@ def main():
|
|||||||
else:
|
else:
|
||||||
G = None
|
G = None
|
||||||
|
|
||||||
model = Conformer(
|
model = DiscreteBottleneckConformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
d_model=params.attention_dim,
|
d_model=params.attention_dim,
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
num_decoder_layers=params.num_decoder_layers,
|
num_decoder_layers=params.num_decoder_layers,
|
||||||
vgg_frontend=params.vgg_frontend,
|
vgg_frontend=False,
|
||||||
is_espnet_structure=params.is_espnet_structure,
|
is_espnet_structure=params.is_espnet_structure,
|
||||||
mmi_loss=params.mmi_loss,
|
mmi_loss=params.mmi_loss,
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||||
|
@ -1052,6 +1052,7 @@ class Gloam(object):
|
|||||||
return {
|
return {
|
||||||
"optimizer": self.optimizer.state_dict(),
|
"optimizer": self.optimizer.state_dict(),
|
||||||
"_step": self._step,
|
"_step": self._step,
|
||||||
|
"_rate": self._rate,
|
||||||
"_epoch": self._epoch,
|
"_epoch": self._epoch,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,6 +15,30 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Note: changed straight_through_scale from 0.333 to 1.0 on epoch 15, which seemed
|
||||||
|
# to make it train slightly faster, based on comparison of valid prob on
|
||||||
|
# the initial logs from epoch 15. (best valid loss: 0.0427 -> 0.0425, on batch
|
||||||
|
# 6000 of epoch 15. was 0.429 last time the valid loss was printed on
|
||||||
|
# epoch 14.
|
||||||
|
|
||||||
|
|
||||||
|
# RESULTS (it's worse!), with:
|
||||||
|
# python3 conformer_ctc_bn/decode.py --lattice-score-scale=0.5 --method=attention-decoder --epoch=25 --avg=10 --max-duration=30
|
||||||
|
#
|
||||||
|
# With sampling in test-time:
|
||||||
|
# ngram_lm_scale_1.2_attention_scale_1.5 3.48 best for test-clean
|
||||||
|
# ngram_lm_scale_0.9_attention_scale_1.2 8.4 best for test-other
|
||||||
|
|
||||||
|
# After I modified conformer.py so that in eval mode, it uses the softmax output with no sampling:
|
||||||
|
# ngram_lm_scale_0.9_attention_scale_1.2 3.44 best for test-clean
|
||||||
|
# ngram_lm_scale_0.9_attention_scale_1.0 8.09 best for test-other
|
||||||
|
|
||||||
|
# Vs. BASELINE:
|
||||||
|
# evaluated with
|
||||||
|
# python3 conformer_ctc/decode.py --lattice-score-scale=0.5 --method=attention-decoder --epoch=23 --avg=10 --max-duration=30 &
|
||||||
|
# (also uses foam optimizer)
|
||||||
|
# ngram_lm_scale_1.2_attention_scale_1.2 2.8 best for test-clean
|
||||||
|
# ngram_lm_scale_0.9_attention_scale_0.7 6.6 best for test-other
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
Loading…
x
Reference in New Issue
Block a user