Remove unnecessary code and update docs

This commit is contained in:
pkufool 2021-09-12 15:30:41 +08:00
parent de42c0ebb5
commit 4de7f19e03
11 changed files with 70 additions and 78 deletions

View File

@ -1,3 +0,0 @@
Please visit
<https://icefall.readthedocs.io/en/latest/recipes/aishell/conformer_ctc.html>
for how to run this recipe.

View File

@ -40,6 +40,7 @@ class Conformer(Transformer):
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block. normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend. vgg_frontend (bool): whether to use vgg frontend.
use_feat_batchnorm(bool): whether to use batch-normalize the input.
""" """
def __init__( def __init__(
@ -56,8 +57,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
is_espnet_structure: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False, use_feat_batchnorm: bool = False,
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
@ -72,7 +71,6 @@ class Conformer(Transformer):
dropout=dropout, dropout=dropout,
normalize_before=normalize_before, normalize_before=normalize_before,
vgg_frontend=vgg_frontend, vgg_frontend=vgg_frontend,
mmi_loss=mmi_loss,
use_feat_batchnorm=use_feat_batchnorm, use_feat_batchnorm=use_feat_batchnorm,
) )
@ -85,12 +83,10 @@ class Conformer(Transformer):
dropout, dropout,
cnn_module_kernel, cnn_module_kernel,
normalize_before, normalize_before,
is_espnet_structure,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
self.normalize_before = normalize_before self.normalize_before = normalize_before
self.is_espnet_structure = is_espnet_structure if self.normalize_before:
if self.normalize_before and self.is_espnet_structure:
self.after_norm = nn.LayerNorm(d_model) self.after_norm = nn.LayerNorm(d_model)
else: else:
# Note: TorchScript detects that self.after_norm could be used inside forward() # Note: TorchScript detects that self.after_norm could be used inside forward()
@ -125,7 +121,7 @@ class Conformer(Transformer):
mask = mask.to(x.device) mask = mask.to(x.device)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
if self.normalize_before and self.is_espnet_structure: if self.normalize_before:
x = self.after_norm(x) x = self.after_norm(x)
return x, mask return x, mask
@ -159,11 +155,10 @@ class ConformerEncoderLayer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
is_espnet_structure: bool = False,
) -> None: ) -> None:
super(ConformerEncoderLayer, self).__init__() super(ConformerEncoderLayer, self).__init__()
self.self_attn = RelPositionMultiheadAttention( self.self_attn = RelPositionMultiheadAttention(
d_model, nhead, dropout=0.0, is_espnet_structure=is_espnet_structure d_model, nhead, dropout=0.0
) )
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
@ -436,7 +431,6 @@ class RelPositionMultiheadAttention(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
dropout: float = 0.0, dropout: float = 0.0,
is_espnet_structure: bool = False,
) -> None: ) -> None:
super(RelPositionMultiheadAttention, self).__init__() super(RelPositionMultiheadAttention, self).__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
@ -459,7 +453,6 @@ class RelPositionMultiheadAttention(nn.Module):
self._reset_parameters() self._reset_parameters()
self.is_espnet_structure = is_espnet_structure
def _reset_parameters(self) -> None: def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight) nn.init.xavier_uniform_(self.in_proj.weight)
@ -690,8 +683,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:] _b = _b[_start:]
v = nn.functional.linear(value, _w, _b) v = nn.functional.linear(value, _w, _b)
if not self.is_espnet_structure:
q = q * scaling
if attn_mask is not None: if attn_mask is not None:
assert ( assert (
@ -785,14 +776,9 @@ class RelPositionMultiheadAttention(nn.Module):
) # (batch, head, time1, 2*time1-1) ) # (batch, head, time1, 2*time1-1)
matrix_bd = self.rel_shift(matrix_bd) matrix_bd = self.rel_shift(matrix_bd)
if not self.is_espnet_structure: attn_output_weights = (
attn_output_weights = ( matrix_ac + matrix_bd
matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2)
) # (batch, head, time1, time2)
else:
attn_output_weights = (
matrix_ac + matrix_bd
) * scaling # (batch, head, time1, time2)
attn_output_weights = attn_output_weights.view( attn_output_weights = attn_output_weights.view(
bsz * num_heads, tgt_len, -1 bsz * num_heads, tgt_len, -1

View File

@ -57,7 +57,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=34, default=49,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
@ -101,7 +101,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--lattice-score-scale",
type=float, type=float,
default=1.0, default=0.5,
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
It's needed if you use any kinds of n-best based rescoring. It's needed if you use any kinds of n-best based rescoring.
Used only when "method" is one of the following values: Used only when "method" is one of the following values:
@ -116,19 +116,19 @@ def get_parser():
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp_char"), "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_char"), "lang_dir": Path("data/lang_char"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80, "feature_dim": 80,
"nhead": 4, "nhead": 4,
"attention_dim": 512, "attention_dim": 512,
"subsampling_factor": 4,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"num_decoder_layers": 6, "num_decoder_layers": 6,
"vgg_frontend": False, "vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
# parameters for decoder
"search_beam": 20, "search_beam": 20,
"output_beam": 7, "output_beam": 7,
"min_active_states": 30, "min_active_states": 30,
@ -364,9 +364,12 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
results_tmp = []
for res in results:
results_tmp.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log f, f"{test_set_name}-{key}", results_tmp, enable_log=enable_log
) )
test_set_wers[key] = wer test_set_wers[key] = wer
@ -440,8 +443,6 @@ def main():
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Wei Kang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -77,7 +78,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--num-epochs", "--num-epochs",
type=int, type=int,
default=35, default=50,
help="Number of epochs to train.", help="Number of epochs to train.",
) )
@ -111,19 +112,6 @@ def get_params() -> AttributeDict:
- lang_dir: It contains language related input files such as - lang_dir: It contains language related input files such as
"lexicon.txt" "lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select - best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is the model that has the lowest validation loss. It is
updated during the training. updated during the training.
@ -138,23 +126,45 @@ def get_params() -> AttributeDict:
- log_interval: Print training loss if batch_idx % log_interval` is 0 - log_interval: Print training loss if batch_idx % log_interval` is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0 - reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- beam_size: It is used in k2.ctc_loss - beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss - use_double_scores: It is used in k2.ctc_loss
- att_rate: The proportion of label smoothing loss, final loss will be
(1 - att_rate) * ctc_loss + att_rate * label_smoothing_loss
- subsampling_factor: The subsampling factor for the model.
- feature_dim: The model input dim. It has to match the one used
in computing features.
- attention_dim: Attention dimension.
- nhead: Number of heads in multi-head attention.
Must satisfy attention_dim // nhead == 0.
- num_encoder_layers: Number of attention encoder layers.
- num_decoder_layers: Number of attention decoder layers.
- use_feat_batchnorm: Whether to do normalization in the input layer.
- weight_decay: The weight_decay for the optimizer.
- lr_factor: The lr_factor for the optimizer.
- warm_step: The warm_step for the optimizer.
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp_char"), "exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_char"), "lang_dir": Path("data/lang_char"),
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,
@ -163,18 +173,21 @@ def get_params() -> AttributeDict:
"log_interval": 10, "log_interval": 10,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, "valid_interval": 3000,
# parameters for k2.ctc_loss
"beam_size": 10, "beam_size": 10,
"reduction": "sum", "reduction": "sum",
"use_double_scores": True, "use_double_scores": True,
"accum_grad": 1,
"att_rate": 0.7, "att_rate": 0.7,
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80,
"attention_dim": 512, "attention_dim": 512,
"nhead": 4, "nhead": 4,
"num_decoder_layers": 6,
"num_encoder_layers": 12, "num_encoder_layers": 12,
"is_espnet_structure": True, "num_decoder_layers": 6,
"mmi_loss": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
# parameters for Noam
"weight_decay": 1e-5,
"lr_factor": 5.0, "lr_factor": 5.0,
"warm_step": 36000, "warm_step": 36000,
} }
@ -648,8 +661,6 @@ def run(rank, world_size, args):
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
num_decoder_layers=params.num_decoder_layers, num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False, vgg_frontend=False,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm, use_feat_batchnorm=params.use_feat_batchnorm,
) )

View File

@ -41,7 +41,6 @@ class Transformer(nn.Module):
dropout: float = 0.1, dropout: float = 0.1,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
mmi_loss: bool = True,
use_feat_batchnorm: bool = False, use_feat_batchnorm: bool = False,
) -> None: ) -> None:
""" """
@ -70,7 +69,6 @@ class Transformer(nn.Module):
If True, use pre-layer norm; False to use post-layer norm. If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend: vgg_frontend:
True to use vgg style frontend for subsampling. True to use vgg style frontend for subsampling.
mmi_loss:
use_feat_batchnorm: use_feat_batchnorm:
True to use batchnorm for the input layer. True to use batchnorm for the input layer.
""" """
@ -122,14 +120,9 @@ class Transformer(nn.Module):
) )
if num_decoder_layers > 0: if num_decoder_layers > 0:
if mmi_loss: self.decoder_num_class = (
self.decoder_num_class = ( self.num_classes
self.num_classes + 1 ) # bpe model already has sos/eos symbol
) # +1 for the sos/eos symbol
else:
self.decoder_num_class = (
self.num_classes
) # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding( self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model num_embeddings=self.decoder_num_class, embedding_dim=d_model

View File

@ -124,7 +124,7 @@ def lexicon_to_fst_no_sil(
def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
"""Return if all the tokens are in token symbol table. """Check if all the given tokens are in token symbol table.
Args: Args:
token_sym_table: token_sym_table:

View File

@ -3,7 +3,7 @@
set -eou pipefail set -eou pipefail
nj=15 nj=15
stage=6 stage=-1
stop_stage=10 stop_stage=10
# We assume dl_dir (download dir) contains the following # We assume dl_dir (download dir) contains the following
@ -11,7 +11,7 @@ stop_stage=10
# by this script automatically. # by this script automatically.
# #
# - $dl_dir/aishell # - $dl_dir/aishell
# You can data_aishell, resource_aishell inside it. # You can find data_aishell, resource_aishell inside it.
# You can download them from https://www.openslr.org/33 # You can download them from https://www.openslr.org/33
# #
# - $dl_dir/lm # - $dl_dir/lm
@ -27,6 +27,7 @@ stop_stage=10
# - music # - music
# - noise # - noise
# - speech # - speech
dl_dir=$PWD/download dl_dir=$PWD/download
. shared/parse_options.sh || exit 1 . shared/parse_options.sh || exit 1

View File

@ -73,14 +73,14 @@ class AishellAsrDataModule(DataModule):
group.add_argument( group.add_argument(
"--max-duration", "--max-duration",
type=int, type=int,
default=500.0, default=200.0,
help="Maximum pooled recordings duration (seconds) in a " help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.", "single batch. You can reduce it if it causes CUDA OOM.",
) )
group.add_argument( group.add_argument(
"--bucketing-sampler", "--bucketing-sampler",
type=str2bool, type=str2bool,
default=False, default=True,
help="When enabled, the batches will come from buckets of " help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).", "similar duration (saves padding frames).",
) )

View File

@ -95,7 +95,7 @@ def get_params() -> AttributeDict:
# Possible values for method: # Possible values for method:
# - 1best # - 1best
# - nbest # - nbest
"method": "nbest", "method": "1best",
# num_paths is used when method is "nbest" # num_paths is used when method is "nbest"
"num_paths": 30, "num_paths": 30,
} }
@ -274,8 +274,11 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
results_tmp = []
for res in results:
results_tmp.append((list("".join(res[0])), list("".join(res[1]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results) wer = write_error_stats(f, f"{test_set_name}-{key}", results_tmp)
test_set_wers[key] = wer test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))

View File

@ -883,3 +883,4 @@ def rescore_with_attention_decoder(
key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}"
ans[key] = best_path_fsa ans[key] = best_path_fsa
return ans return ans

View File

@ -99,7 +99,6 @@ def setup_logger(
""" """
now = datetime.now() now = datetime.now()
date_time = now.strftime("%Y-%m-%d-%H-%M-%S") date_time = now.strftime("%Y-%m-%d-%H-%M-%S")
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size() world_size = dist.get_world_size()
rank = dist.get_rank() rank = dist.get_rank()