mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 23:24:17 +00:00
Fix train.py
This commit is contained in:
parent
915c4e9d87
commit
22e9d837f8
@ -95,16 +95,6 @@ class LibriHeavyAsrDataModule:
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--with-punctuation",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""True to train the model on transcription with punctuation,
|
|
||||||
False to train the model on normalized transcription
|
|
||||||
(upper without punctuation).
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--manifest-dir",
|
"--manifest-dir",
|
||||||
type=Path,
|
type=Path,
|
||||||
@ -413,35 +403,29 @@ class LibriHeavyAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_small_cuts(self) -> CutSet:
|
def train_small_cuts(self) -> CutSet:
|
||||||
logging.info("About to get small subset cuts")
|
logging.info("About to get small subset cuts")
|
||||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_small.jsonl.gz"
|
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_medium_cuts(self) -> CutSet:
|
def train_medium_cuts(self) -> CutSet:
|
||||||
logging.info("About to get medium subset cuts")
|
logging.info("About to get medium subset cuts")
|
||||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_medium.jsonl.gz"
|
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_large_cuts(self) -> CutSet:
|
def train_large_cuts(self) -> CutSet:
|
||||||
logging.info("About to get large subset cuts")
|
logging.info("About to get large subset cuts")
|
||||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_large.jsonl.gz"
|
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_cuts(self) -> CutSet:
|
def dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_dev.jsonl.gz"
|
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_clean_cuts(self) -> CutSet:
|
def test_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get the test-clean cuts")
|
logging.info("About to get the test-clean cuts")
|
||||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_test-clean.jsonl.gz"
|
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
def test_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get the test-other cuts")
|
logging.info("About to get the test-other cuts")
|
||||||
basename = f"libriheavy_{'punc_' if self.args.with_punctuation else ''}cuts_test-other.jsonl.gz"
|
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / basename)
|
|
||||||
|
@ -261,6 +261,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -470,6 +471,13 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--train-with-punctuation",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, the training text will include casing and punctuation.",
|
||||||
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -1178,9 +1186,20 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
libriheavy = LibriHeavyAsrDataModule(args)
|
|
||||||
|
|
||||||
train_cuts = libriheavy.train_cuts()
|
def normalize_text(c: Cut):
|
||||||
|
text = c.supervisions[0].text
|
||||||
|
if params.train_with_punctuation:
|
||||||
|
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||||
|
text = text.translate(table)
|
||||||
|
else:
|
||||||
|
text = text.replace("‘", "'")
|
||||||
|
text = text.replace("’", "'")
|
||||||
|
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
|
||||||
|
text_list = [x.upper() if x in tokens else " " for x in text]
|
||||||
|
text = " ".join("".join(text_list).split()).strip()
|
||||||
|
c.supervisions[0].text = text
|
||||||
|
return c
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
def remove_short_and_long_utt(c: Cut):
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
# Keep only utterances with duration between 1 second and 20 seconds
|
||||||
@ -1204,7 +1223,7 @@ def run(rank, world_size, args):
|
|||||||
# In ./zipformer.py, the conv module uses the following expression
|
# In ./zipformer.py, the conv module uses the following expression
|
||||||
# for subsampling
|
# for subsampling
|
||||||
T = ((c.num_frames - 7) // 2 + 1) // 2
|
T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||||
tokens = sp.encode(c.supervisions[0].texts[0], out_type=str)
|
tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||||
|
|
||||||
if T < len(tokens):
|
if T < len(tokens):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
@ -1219,6 +1238,15 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
libriheavy = LibriHeavyAsrDataModule(args)
|
||||||
|
|
||||||
|
train_cuts = libriheavy.train_small_cuts()
|
||||||
|
if params.subset == 'M' or params.subset == 'L':
|
||||||
|
train_cuts += libriheavy.train_medium_cuts()
|
||||||
|
if params.subset == "L":
|
||||||
|
train_cuts += libriheavy.train_large_cuts()
|
||||||
|
train_cuts = train_cuts.map(normalize_text)
|
||||||
|
|
||||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
@ -1232,15 +1260,8 @@ def run(rank, world_size, args):
|
|||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_texts(c: Cut):
|
valid_cuts = libriheavy.dev_cuts()
|
||||||
text = c.supervisions[0].text
|
valid_cuts = valid_cuts.map(normalize_text)
|
||||||
c.supervisions[0].texts = [text]
|
|
||||||
return c
|
|
||||||
|
|
||||||
valid_cuts = libriheavy.librispeech_dev_clean_cuts()
|
|
||||||
valid_cuts += libriheavy.librispeech_dev_other_cuts()
|
|
||||||
|
|
||||||
valid_cuts = valid_cuts.map(add_texts)
|
|
||||||
|
|
||||||
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user