Fix train.py

This commit is contained in:
pkufool 2023-09-22 12:38:23 +08:00
parent 915c4e9d87
commit 22e9d837f8
2 changed files with 39 additions and 34 deletions

View File

@ -95,16 +95,6 @@ class LibriHeavyAsrDataModule:
""",
)
group.add_argument(
"--with-punctuation",
type=str2bool,
default=False,
help="""True to train the model on transcription with punctuation,
False to train the model on normalized transcription
(upper without punctuation).
""",
)
group.add_argument(
"--manifest-dir",
type=Path,
@ -413,35 +403,29 @@ class LibriHeavyAsrDataModule:
@lru_cache()
def train_small_cuts(self) -> CutSet:
logging.info("About to get small subset cuts")
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_small.jsonl.gz"
return load_manifest_lazy(self.args.manifest_dir / basename)
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz")
@lru_cache()
def train_medium_cuts(self) -> CutSet:
logging.info("About to get medium subset cuts")
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_medium.jsonl.gz"
return load_manifest_lazy(self.args.manifest_dir / basename)
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz")
@lru_cache()
def train_large_cuts(self) -> CutSet:
logging.info("About to get large subset cuts")
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_large.jsonl.gz"
return load_manifest_lazy(self.args.manifest_dir / basename)
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz")
@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_dev.jsonl.gz"
return load_manifest_lazy(self.args.manifest_dir / basename)
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz")
@lru_cache()
def test_clean_cuts(self) -> CutSet:
logging.info("About to get the test-clean cuts")
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_test-clean.jsonl.gz"
return load_manifest_lazy(self.args.manifest_dir / basename)
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz")
@lru_cache()
def test_other_cuts(self) -> CutSet:
logging.info("About to get the test-other cuts")
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_test-other.jsonl.gz"
return load_manifest_lazy(self.args.manifest_dir / basename)
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz")

View File

@ -261,6 +261,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -470,6 +471,13 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--train-with-punctuation",
type=str2bool,
default=False,
help="If True, the training text will include casing and punctuation.",
)
add_model_arguments(parser)
return parser
@ -1178,9 +1186,20 @@ def run(rank, world_size, args):
if params.inf_check:
register_inf_check_hooks(model)
libriheavy = LibriHeavyAsrDataModule(args)
train_cuts = libriheavy.train_cuts()
def normalize_text(c: Cut):
text = c.supervisions[0].text
if params.train_with_punctuation:
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
text = text.translate(table)
else:
text = text.replace("", "'")
text = text.replace("", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
text_list = [x.upper() if x in tokens else " " for x in text]
text = " ".join("".join(text_list).split()).strip()
c.supervisions[0].text = text
return c
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
@ -1204,7 +1223,7 @@ def run(rank, world_size, args):
# In ./zipformer.py, the conv module uses the following expression
# for subsampling
T = ((c.num_frames - 7) // 2 + 1) // 2
tokens = sp.encode(c.supervisions[0].texts[0], out_type=str)
tokens = sp.encode(c.supervisions[0].text, out_type=str)
if T < len(tokens):
logging.warning(
@ -1219,6 +1238,15 @@ def run(rank, world_size, args):
return True
libriheavy = LibriHeavyAsrDataModule(args)
train_cuts = libriheavy.train_small_cuts()
if params.subset == 'M' or params.subset == 'L':
train_cuts += libriheavy.train_medium_cuts()
if params.subset == "L":
train_cuts += libriheavy.train_large_cuts()
train_cuts = train_cuts.map(normalize_text)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
@ -1232,15 +1260,8 @@ def run(rank, world_size, args):
train_cuts, sampler_state_dict=sampler_state_dict
)
def add_texts(c: Cut):
text = c.supervisions[0].text
c.supervisions[0].texts = [text]
return c
valid_cuts = libriheavy.librispeech_dev_clean_cuts()
valid_cuts += libriheavy.librispeech_dev_other_cuts()
valid_cuts = valid_cuts.map(add_texts)
valid_cuts = libriheavy.dev_cuts()
valid_cuts = valid_cuts.map(normalize_text)
valid_dl = libriheavy.valid_dataloaders(valid_cuts)