merge upstream

This commit is contained in:
Desh Raj 2022-11-16 19:50:43 -05:00
commit cad8f6aca4
6 changed files with 110 additions and 101 deletions

View File

@ -164,13 +164,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument( parser.add_argument(
"--initial-lr", "--initial-lr",
type=float, type=float,
@ -516,14 +509,6 @@ def compute_loss(
nnet_output, encoder_memory, memory_mask = model( nnet_output, encoder_memory, memory_mask = model(
feature, supervisions, warmup=warmup feature, supervisions, warmup=warmup
) )
# logging.info('feature shape: {}'.format(feature.shape))
# logging.info('nnet_output shape: {}'.format(nnet_output.shape))
# logging.info('encoder_memory shape: {}'.format(encoder_memory.shape))
# logging.info('memory_mask shape: {}'.format(memory_mask.shape))
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
# NOTE: We need `encode_supervisions` to sort sequences with # NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by # different duration in decreasing order, required by

View File

@ -401,7 +401,6 @@ class TransformerEncoderLayer(nn.Module):
dim_feedforward: int = 2048, dim_feedforward: int = 2048,
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
activation: str = "relu",
) -> None: ) -> None:
super(TransformerEncoderLayer, self).__init__() super(TransformerEncoderLayer, self).__init__()
@ -427,11 +426,6 @@ class TransformerEncoderLayer(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
# def __setstate__(self, state):
# if "activation" not in state:
# state["activation"] = nn.functional.relu
# super(TransformerEncoderLayer, self).__setstate__(state)
def forward( def forward(
self, self,
src: torch.Tensor, src: torch.Tensor,
@ -523,7 +517,6 @@ class TransformerDecoderLayer(nn.Module):
dim_feedforward: int = 2048, dim_feedforward: int = 2048,
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
# activation: str = "relu",
normalize_before: bool = True, normalize_before: bool = True,
) -> None: ) -> None:
super(TransformerDecoderLayer, self).__init__() super(TransformerDecoderLayer, self).__init__()
@ -548,11 +541,6 @@ class TransformerDecoderLayer(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
# def __setstate__(self, state):
# if "activation" not in state:
# state["activation"] = nn.functional.relu
# super(TransformerDecoderLayer, self).__setstate__(state)
def forward( def forward(
self, self,
tgt: torch.Tensor, tgt: torch.Tensor,
@ -637,15 +625,6 @@ class TransformerDecoderLayer(nn.Module):
return tgt return tgt
def _get_activation_fn(activation: str):
if activation == "relu":
return nn.functional.relu
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers r"""TransformerEncoder is a stack of N encoder layers
@ -690,7 +669,7 @@ class TransformerEncoder(nn.Module):
""" """
output = src output = src
for i, mod in enumerate(self.layers): for mod in self.layers:
output = mod( output = mod(
output, output,
src_mask=mask, src_mask=mask,
@ -751,7 +730,7 @@ class TransformerDecoder(nn.Module):
""" """
output = tgt output = tgt
for i, mod in enumerate(self.layers): for mod in self.layers:
output = mod( output = mod(
output, output,
memory, memory,

View File

@ -40,6 +40,13 @@ from icefall.lexicon import Lexicon
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--lm",
type=str,
default="G_3_gram",
help="""Stem name for LM used in HLG compiling.
""",
)
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=str,
@ -50,11 +57,13 @@ def get_args():
return parser.parse_args() return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa: def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa:
""" """
Args: Args:
lang_dir: lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe_5000. The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
lm:
The language stem base name.
Return: Return:
An FSA representing HLG. An FSA representing HLG.
@ -65,15 +74,15 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
H = k2.ctc_topo(max_token_id) H = k2.ctc_topo(max_token_id)
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
if Path("data/lm/G_3_gram.pt").is_file(): if Path(f"data/lm/{lm}.pt").is_file():
logging.info("Loading pre-compiled G_3_gram") logging.info(f"Loading pre-compiled {lm}")
d = torch.load("data/lm/G_3_gram.pt") d = torch.load(f"data/lm/{lm}.pt")
G = k2.Fsa.from_dict(d) G = k2.Fsa.from_dict(d)
else: else:
logging.info("Loading G_3_gram.fst.txt") logging.info(f"Loading {lm}.fst.txt")
with open("data/lm/G_3_gram.fst.txt") as f: with open(f"data/lm/{lm}.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False) G = k2.Fsa.from_openfst(f.read(), acceptor=False)
torch.save(G.as_dict(), "data/lm/G_3_gram.pt") torch.save(G.as_dict(), f"data/lm/{lm}.pt")
first_token_disambig_id = lexicon.token_table["#0"] first_token_disambig_id = lexicon.token_table["#0"]
first_word_disambig_id = lexicon.word_table["#0"] first_word_disambig_id = lexicon.word_table["#0"]
@ -144,7 +153,7 @@ def main():
logging.info(f"Processing {lang_dir}") logging.info(f"Processing {lang_dir}")
HLG = compile_HLG(lang_dir) HLG = compile_HLG(lang_dir, args.lm)
logging.info(f"Saving HLG.pt to {lang_dir}") logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")

View File

@ -17,10 +17,10 @@
import argparse import argparse
import inspect
import logging import logging
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( from lhotse.dataset import (
@ -28,7 +28,6 @@ from lhotse.dataset import (
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
K2SpeechRecognitionDataset, K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler, SingleCutSampler,
SpecAugment, SpecAugment,
) )
@ -159,21 +158,18 @@ class TedLiumAsrDataModule:
"were used to construct it." "were used to construct it."
), ),
) )
group.add_argument( group.add_argument(
"--num-workers", "--num-workers",
type=int, type=int,
default=2, default=2,
help="The number of training dataloader workers that collect the batches.", help="The number of training dataloader workers that collect the batches.",
) )
group.add_argument( group.add_argument(
"--enable-spec-aug", "--enable-spec-aug",
type=str2bool, type=str2bool,
default=True, default=True,
help="When enabled, use SpecAugment for training dataset.", help="When enabled, use SpecAugment for training dataset.",
) )
group.add_argument( group.add_argument(
"--spec-aug-time-warp-factor", "--spec-aug-time-warp-factor",
type=int, type=int,
@ -185,7 +181,6 @@ class TedLiumAsrDataModule:
"A value less than 1 means to disable time warp." "A value less than 1 means to disable time warp."
), ),
) )
group.add_argument( group.add_argument(
"--enable-musan", "--enable-musan",
type=str2bool, type=str2bool,
@ -196,7 +191,36 @@ class TedLiumAsrDataModule:
), ),
) )
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=10,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=0.15,
p=0.9,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
@ -222,40 +246,7 @@ class TedLiumAsrDataModule:
) )
] + transforms ] + transforms
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=0.15,
p=0.9,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset") logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats: if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we # NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage. # remove it from data prep stage.
@ -273,6 +264,12 @@ class TedLiumAsrDataModule:
input_transforms=input_transforms, input_transforms=input_transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
else:
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler: if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.") logging.info("Using DynamicBucketingSampler.")
@ -290,6 +287,11 @@ class TedLiumAsrDataModule:
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=self.args.shuffle, shuffle=self.args.shuffle,
) )
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
logging.info("About to create train dataloader") logging.info("About to create train dataloader")
train_dl = DataLoader( train_dl = DataLoader(
train, train,
@ -302,6 +304,7 @@ class TedLiumAsrDataModule:
return train_dl return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = [] transforms = []
if self.args.concatenate_cuts: if self.args.concatenate_cuts:
transforms = [ transforms = [
@ -322,11 +325,13 @@ class TedLiumAsrDataModule:
cut_transforms=transforms, cut_transforms=transforms,
return_cuts=self.args.return_cuts, return_cuts=self.args.return_cuts,
) )
valid_sampler = DynamicBucketingSampler( valid_sampler = DynamicBucketingSampler(
cuts_valid, cuts_valid,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
) )
logging.info("About to create dev dataloader") logging.info("About to create dev dataloader")
valid_dl = DataLoader( valid_dl = DataLoader(
validate, validate,
@ -338,25 +343,32 @@ class TedLiumAsrDataModule:
return valid_dl return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader: def test_dataloaders(self, cuts_test: CutSet) -> DataLoader:
logging.debug("About to create test dataset") logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset( if self.args.on_the_fly_feats:
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) test = K2SpeechRecognitionDataset(
if self.args.on_the_fly_feats input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
else PrecomputedFeatures(), return_cuts=self.args.return_cuts,
return_cuts=self.args.return_cuts, )
) else:
sampler = DynamicBucketingSampler( test = K2SpeechRecognitionDataset(
cuts, return_cuts=self.args.return_cuts,
)
test_sampler = DynamicBucketingSampler(
cuts_test,
max_duration=self.args.max_duration, max_duration=self.args.max_duration,
shuffle=False, shuffle=False,
) )
logging.debug("About to create test dataloader") logging.debug("About to create test dataloader")
test_dl = DataLoader( test_dl = DataLoader(
test, test,
batch_size=None, batch_size=None,
sampler=sampler, sampler=test_sampler,
num_workers=self.args.num_workers, num_workers=self.args.num_workers,
persistent_workers=False,
) )
return test_dl return test_dl

View File

@ -451,7 +451,8 @@ class Nbest(object):
def one_best_decoding( def one_best_decoding(
lattice: k2.Fsa, lattice: k2.Fsa,
use_double_scores: bool = True, use_double_scores: bool = True,
) -> k2.Fsa: lm_scale_list: Optional[List[float]] = None,
) -> Union[k2.Fsa, Dict[str, k2.Fsa]]:
"""Get the best path from a lattice. """Get the best path from a lattice.
Args: Args:
@ -460,11 +461,26 @@ def one_best_decoding(
use_double_scores: use_double_scores:
True to use double precision floating point in the computation. True to use double precision floating point in the computation.
False to use single precision. False to use single precision.
lm_scale_list:
A list of floats representing LM score scales.
Return: Return:
An FsaVec containing linear paths. An FsaVec containing linear paths.
""" """
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
return best_path if lm_scale_list is not None:
ans = dict()
saved_am_scores = lattice.scores - lattice.lm_scores
for lm_scale in lm_scale_list:
am_scores = saved_am_scores / lm_scale
lattice.scores = am_scores + lattice.lm_scores
best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores)
key = f"lm_scale_{lm_scale}"
ans[key] = best_path
return ans
return k2.shortest_path(lattice, use_double_scores=use_double_scores)
def nbest_decoding( def nbest_decoding(

View File

@ -192,8 +192,16 @@ def encode_supervisions(
supervision_segments = torch.stack( supervision_segments = torch.stack(
( (
supervisions["sequence_idx"], supervisions["sequence_idx"],
supervisions["start_frame"] // subsampling_factor, torch.div(
supervisions["num_frames"] // subsampling_factor, supervisions["start_frame"],
subsampling_factor,
rounding_mode="floor",
),
torch.div(
supervisions["num_frames"],
subsampling_factor,
rounding_mode="floor",
),
), ),
1, 1,
).to(torch.int32) ).to(torch.int32)