mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
merge upstream
This commit is contained in:
commit
cad8f6aca4
@ -164,13 +164,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
@ -516,14 +509,6 @@ def compute_loss(
|
||||
nnet_output, encoder_memory, memory_mask = model(
|
||||
feature, supervisions, warmup=warmup
|
||||
)
|
||||
# logging.info('feature shape: {}'.format(feature.shape))
|
||||
# logging.info('nnet_output shape: {}'.format(nnet_output.shape))
|
||||
# logging.info('encoder_memory shape: {}'.format(encoder_memory.shape))
|
||||
# logging.info('memory_mask shape: {}'.format(memory_mask.shape))
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
|
@ -401,7 +401,6 @@ class TransformerEncoderLayer(nn.Module):
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
activation: str = "relu",
|
||||
) -> None:
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
|
||||
@ -427,11 +426,6 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# def __setstate__(self, state):
|
||||
# if "activation" not in state:
|
||||
# state["activation"] = nn.functional.relu
|
||||
# super(TransformerEncoderLayer, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
@ -523,7 +517,6 @@ class TransformerDecoderLayer(nn.Module):
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
# activation: str = "relu",
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(TransformerDecoderLayer, self).__init__()
|
||||
@ -548,11 +541,6 @@ class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# def __setstate__(self, state):
|
||||
# if "activation" not in state:
|
||||
# state["activation"] = nn.functional.relu
|
||||
# super(TransformerDecoderLayer, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
@ -637,15 +625,6 @@ class TransformerDecoderLayer(nn.Module):
|
||||
return tgt
|
||||
|
||||
|
||||
def _get_activation_fn(activation: str):
|
||||
if activation == "relu":
|
||||
return nn.functional.relu
|
||||
elif activation == "gelu":
|
||||
return nn.functional.gelu
|
||||
|
||||
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
r"""TransformerEncoder is a stack of N encoder layers
|
||||
|
||||
@ -690,7 +669,7 @@ class TransformerEncoder(nn.Module):
|
||||
"""
|
||||
output = src
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
for mod in self.layers:
|
||||
output = mod(
|
||||
output,
|
||||
src_mask=mask,
|
||||
@ -751,7 +730,7 @@ class TransformerDecoder(nn.Module):
|
||||
"""
|
||||
output = tgt
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
for mod in self.layers:
|
||||
output = mod(
|
||||
output,
|
||||
memory,
|
||||
|
@ -40,6 +40,13 @@ from icefall.lexicon import Lexicon
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lm",
|
||||
type=str,
|
||||
default="G_3_gram",
|
||||
help="""Stem name for LM used in HLG compiling.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
@ -50,11 +57,13 @@ def get_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||
lm:
|
||||
The language stem base name.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
@ -65,15 +74,15 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
logging.info("Loading pre-compiled G_3_gram")
|
||||
d = torch.load("data/lm/G_3_gram.pt")
|
||||
if Path(f"data/lm/{lm}.pt").is_file():
|
||||
logging.info(f"Loading pre-compiled {lm}")
|
||||
d = torch.load(f"data/lm/{lm}.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
logging.info(f"Loading {lm}.fst.txt")
|
||||
with open(f"data/lm/{lm}.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
|
||||
torch.save(G.as_dict(), f"data/lm/{lm}.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
@ -144,7 +153,7 @@ def main():
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir)
|
||||
HLG = compile_HLG(lang_dir, args.lm)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
@ -17,10 +17,10 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import (
|
||||
@ -28,7 +28,6 @@ from lhotse.dataset import (
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
@ -159,21 +158,18 @@ class TedLiumAsrDataModule:
|
||||
"were used to construct it."
|
||||
),
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
@ -185,7 +181,6 @@ class TedLiumAsrDataModule:
|
||||
"A value less than 1 means to disable time warp."
|
||||
),
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
@ -196,7 +191,36 @@ class TedLiumAsrDataModule:
|
||||
),
|
||||
)
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=10,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
max_frames_mask_fraction=0.15,
|
||||
p=0.9,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to get Musan cuts")
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
@ -222,40 +246,7 @@ class TedLiumAsrDataModule:
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
max_frames_mask_fraction=0.15,
|
||||
p=0.9,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
@ -273,6 +264,12 @@ class TedLiumAsrDataModule:
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
@ -290,6 +287,11 @@ class TedLiumAsrDataModule:
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
logging.info("About to create train dataloader")
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
@ -302,6 +304,7 @@ class TedLiumAsrDataModule:
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
@ -322,11 +325,13 @@ class TedLiumAsrDataModule:
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
@ -338,25 +343,32 @@ class TedLiumAsrDataModule:
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
|
||||
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
if self.args.on_the_fly_feats:
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
test = K2SpeechRecognitionDataset(
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
test_sampler = DynamicBucketingSampler(
|
||||
cuts_test,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
sampler=test_sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
|
@ -451,7 +451,8 @@ class Nbest(object):
|
||||
def one_best_decoding(
|
||||
lattice: k2.Fsa,
|
||||
use_double_scores: bool = True,
|
||||
) -> k2.Fsa:
|
||||
lm_scale_list: Optional[List[float]] = None,
|
||||
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
|
||||
"""Get the best path from a lattice.
|
||||
|
||||
Args:
|
||||
@ -460,11 +461,26 @@ def one_best_decoding(
|
||||
use_double_scores:
|
||||
True to use double precision floating point in the computation.
|
||||
False to use single precision.
|
||||
lm_scale_list:
|
||||
A list of floats representing LM score scales.
|
||||
Return:
|
||||
An FsaVec containing linear paths.
|
||||
"""
|
||||
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
|
||||
return best_path
|
||||
|
||||
if lm_scale_list is not None:
|
||||
|
||||
ans = dict()
|
||||
saved_am_scores = lattice.scores - lattice.lm_scores
|
||||
for lm_scale in lm_scale_list:
|
||||
am_scores = saved_am_scores / lm_scale
|
||||
lattice.scores = am_scores + lattice.lm_scores
|
||||
|
||||
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
|
||||
key = f"lm_scale_{lm_scale}"
|
||||
ans[key] = best_path
|
||||
return ans
|
||||
|
||||
return k2.shortest_path(lattice, use_double_scores=use_double_scores)
|
||||
|
||||
|
||||
def nbest_decoding(
|
||||
|
@ -192,8 +192,16 @@ def encode_supervisions(
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // subsampling_factor,
|
||||
supervisions["num_frames"] // subsampling_factor,
|
||||
torch.div(
|
||||
supervisions["start_frame"],
|
||||
subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
torch.div(
|
||||
supervisions["num_frames"],
|
||||
subsampling_factor,
|
||||
rounding_mode="floor",
|
||||
),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
Loading…
x
Reference in New Issue
Block a user