merge upstream

This commit is contained in:
Desh Raj 2022-11-16 19:50:43 -05:00
commit cad8f6aca4
6 changed files with 110 additions and 101 deletions

View File

@ -164,13 +164,6 @@ def get_parser():
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--initial-lr",
type=float,
@ -516,14 +509,6 @@ def compute_loss(
nnet_output, encoder_memory, memory_mask = model(
feature, supervisions, warmup=warmup
)
# logging.info('feature shape: {}'.format(feature.shape))
# logging.info('nnet_output shape: {}'.format(nnet_output.shape))
# logging.info('encoder_memory shape: {}'.format(encoder_memory.shape))
# logging.info('memory_mask shape: {}'.format(memory_mask.shape))
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by

View File

@ -401,7 +401,6 @@ class TransformerEncoderLayer(nn.Module):
dim_feedforward: int = 2048,
dropout: float = 0.1,
layer_dropout: float = 0.075,
activation: str = "relu",
) -> None:
super(TransformerEncoderLayer, self).__init__()
@ -427,11 +426,6 @@ class TransformerEncoderLayer(nn.Module):
self.dropout = nn.Dropout(dropout)
# def __setstate__(self, state):
# if "activation" not in state:
# state["activation"] = nn.functional.relu
# super(TransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
src: torch.Tensor,
@ -523,7 +517,6 @@ class TransformerDecoderLayer(nn.Module):
dim_feedforward: int = 2048,
dropout: float = 0.1,
layer_dropout: float = 0.075,
# activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerDecoderLayer, self).__init__()
@ -548,11 +541,6 @@ class TransformerDecoderLayer(nn.Module):
self.dropout = nn.Dropout(dropout)
# def __setstate__(self, state):
# if "activation" not in state:
# state["activation"] = nn.functional.relu
# super(TransformerDecoderLayer, self).__setstate__(state)
def forward(
self,
tgt: torch.Tensor,
@ -637,15 +625,6 @@ class TransformerDecoderLayer(nn.Module):
return tgt
def _get_activation_fn(activation: str):
if activation == "relu":
return nn.functional.relu
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers
@ -690,7 +669,7 @@ class TransformerEncoder(nn.Module):
"""
output = src
for i, mod in enumerate(self.layers):
for mod in self.layers:
output = mod(
output,
src_mask=mask,
@ -751,7 +730,7 @@ class TransformerDecoder(nn.Module):
"""
output = tgt
for i, mod in enumerate(self.layers):
for mod in self.layers:
output = mod(
output,
memory,

View File

@ -40,6 +40,13 @@ from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument(
"--lang-dir",
type=str,
@ -50,11 +57,13 @@ def get_args():
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa:
def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
"""
Args:
lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return:
An FSA representing HLG.
@ -65,15 +74,15 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file():
logging.info("Loading pre-compiled G_3_gram")
d = torch.load("data/lm/G_3_gram.pt")
if Path(f"data/lm/{lm}.pt").is_file():
logging.info(f"Loading pre-compiled {lm}")
d = torch.load(f"data/lm/{lm}.pt")
G = k2.Fsa.from_dict(d)
else:
logging.info("Loading G_3_gram.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f:
logging.info(f"Loading {lm}.fst.txt")
with open(f"data/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "data/lm/G_3_gram.pt")
torch.save(G.as_dict(), f"data/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"]
@ -144,7 +153,7 @@ def main():
logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir)
HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")

View File

@ -17,10 +17,10 @@
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
@ -28,7 +28,6 @@ from lhotse.dataset import (
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
@ -159,21 +158,18 @@ class TedLiumAsrDataModule:
"were used to construct it."
),
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
@ -185,7 +181,6 @@ class TedLiumAsrDataModule:
"A value less than 1 means to disable time warp."
),
)
group.add_argument(
"--enable-musan",
type=str2bool,
@ -196,7 +191,36 @@ class TedLiumAsrDataModule:
),
)
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=10,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=0.15,
p=0.9,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to get Musan cuts")
transforms = []
if self.args.enable_musan:
@ -222,40 +246,7 @@ class TedLiumAsrDataModule:
)
] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=0.15,
p=0.9,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
@ -273,6 +264,12 @@ class TedLiumAsrDataModule:
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
@ -290,6 +287,11 @@ class TedLiumAsrDataModule:
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
@ -302,6 +304,7 @@ class TedLiumAsrDataModule:
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
@ -322,11 +325,13 @@ class TedLiumAsrDataModule:
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
@ -338,25 +343,32 @@ class TedLiumAsrDataModule:
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
if self.args.on_the_fly_feats:
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
test = K2SpeechRecognitionDataset(
return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts_test,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
sampler=test_sampler,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return test_dl

View File

@ -451,7 +451,8 @@ class Nbest(object):
def one_best_decoding(
lattice: k2.Fsa,
use_double_scores: bool = True,
) -> k2.Fsa:
lm_scale_list: Optional[List[float]] = None,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
"""Get the best path from a lattice.
Args:
@ -460,11 +461,26 @@ def one_best_decoding(
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
lm_scale_list:
A list of floats representing LM score scales.
Return:
An FsaVec containing linear paths.
"""
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
return best_path
if lm_scale_list is not None:
ans = dict()
saved_am_scores = lattice.scores - lattice.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lattice.scores = am_scores + lattice.lm_scores
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
return k2.shortest_path(lattice, use_double_scores=use_double_scores)
def nbest_decoding(

View File

@ -192,8 +192,16 @@ def encode_supervisions(
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor,
supervisions["num_frames"] // subsampling_factor,
torch.div(
supervisions["start_frame"],
subsampling_factor,
rounding_mode="floor",
),
torch.div(
supervisions["num_frames"],
subsampling_factor,
rounding_mode="floor",
),
),
1,
).to(torch.int32)