mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
Fix train.py
This commit is contained in:
parent
915c4e9d87
commit
22e9d837f8
@ -95,16 +95,6 @@ class LibriHeavyAsrDataModule:
|
||||
""",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--with-punctuation",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to train the model on transcription with punctuation,
|
||||
False to train the model on normalized transcription
|
||||
(upper without punctuation).
|
||||
""",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
@ -413,35 +403,29 @@ class LibriHeavyAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_small_cuts(self) -> CutSet:
|
||||
logging.info("About to get small subset cuts")
|
||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_small.jsonl.gz"
|
||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def train_medium_cuts(self) -> CutSet:
|
||||
logging.info("About to get medium subset cuts")
|
||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_medium.jsonl.gz"
|
||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def train_large_cuts(self) -> CutSet:
|
||||
logging.info("About to get large subset cuts")
|
||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_large.jsonl.gz"
|
||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_dev.jsonl.gz"
|
||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get the test-clean cuts")
|
||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_test-clean.jsonl.gz"
|
||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get the test-other cuts")
|
||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_test-other.jsonl.gz"
|
||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz")
|
||||
|
@ -261,6 +261,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -470,6 +471,13 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--train-with-punctuation",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, the training text will include casing and punctuation.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -1178,9 +1186,20 @@ def run(rank, world_size, args):
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
train_cuts = libriheavy.train_cuts()
|
||||
def normalize_text(c: Cut):
|
||||
text = c.supervisions[0].text
|
||||
if params.train_with_punctuation:
|
||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||
text = text.translate(table)
|
||||
else:
|
||||
text = text.replace("‘", "'")
|
||||
text = text.replace("’", "'")
|
||||
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||
text_list = [x.upper() if x in tokens else " " for x in text]
|
||||
text = " ".join("".join(text_list).split()).strip()
|
||||
c.supervisions[0].text = text
|
||||
return c
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
@ -1204,7 +1223,7 @@ def run(rank, world_size, args):
|
||||
# In ./zipformer.py, the conv module uses the following expression
|
||||
# for subsampling
|
||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
tokens = sp.encode(c.supervisions[0].texts[0], out_type=str)
|
||||
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||
|
||||
if T < len(tokens):
|
||||
logging.warning(
|
||||
@ -1219,6 +1238,15 @@ def run(rank, world_size, args):
|
||||
|
||||
return True
|
||||
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
train_cuts = libriheavy.train_small_cuts()
|
||||
if params.subset == 'M' or params.subset == 'L':
|
||||
train_cuts += libriheavy.train_medium_cuts()
|
||||
if params.subset == "L":
|
||||
train_cuts += libriheavy.train_large_cuts()
|
||||
train_cuts = train_cuts.map(normalize_text)
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
@ -1232,15 +1260,8 @@ def run(rank, world_size, args):
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
def add_texts(c: Cut):
|
||||
text = c.supervisions[0].text
|
||||
c.supervisions[0].texts = [text]
|
||||
return c
|
||||
|
||||
valid_cuts = libriheavy.librispeech_dev_clean_cuts()
|
||||
valid_cuts += libriheavy.librispeech_dev_other_cuts()
|
||||
|
||||
valid_cuts = valid_cuts.map(add_texts)
|
||||
valid_cuts = libriheavy.dev_cuts()
|
||||
valid_cuts = valid_cuts.map(normalize_text)
|
||||
|
||||
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user