mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-06 15:44:17 +00:00
black
This commit is contained in:
parent
c195a12a36
commit
2f0d3d7ae3
@ -19,6 +19,7 @@ import argparse
|
|||||||
import codecs
|
import codecs
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -52,6 +53,6 @@ def main():
|
|||||||
print(remove_punc_to_upper(line))
|
print(remove_punc_to_upper(line))
|
||||||
line = f.readline()
|
line = f.readline()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -20,11 +20,13 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
def simple_cleanup(text: str) -> str:
|
def simple_cleanup(text: str) -> str:
|
||||||
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]")
|
||||||
text = text.translate(table)
|
text = text.translate(table)
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
# Assign text of the supervisions and remove unnecessary entries.
|
# Assign text of the supervisions and remove unnecessary entries.
|
||||||
def main():
|
def main():
|
||||||
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
|
assert len(sys.argv) == 3, "Usage: ./local/prepare_manifest.py INPUT OUTPUT_DIR"
|
||||||
@ -33,7 +35,9 @@ def main():
|
|||||||
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
with gzip.open(sys.argv[1], "r") as fin, gzip.open(oname, "w") as fout:
|
||||||
for line in fin:
|
for line in fin:
|
||||||
cut = json.loads(line)
|
cut = json.loads(line)
|
||||||
cut["supervisions"][0]["text"] = simple_cleanup(cut["supervisions"][0]["custom"]["texts"][0])
|
cut["supervisions"][0]["text"] = simple_cleanup(
|
||||||
|
cut["supervisions"][0]["custom"]["texts"][0]
|
||||||
|
)
|
||||||
del cut["supervisions"][0]["custom"]
|
del cut["supervisions"][0]["custom"]
|
||||||
del cut["custom"]
|
del cut["custom"]
|
||||||
fout.write((json.dumps(cut) + "\n").encode())
|
fout.write((json.dumps(cut) + "\n").encode())
|
||||||
|
@ -44,8 +44,8 @@ def get_args():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--byte-fallback",
|
"--byte-fallback",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
help="""Whether to enable byte_fallback when training bpe."""
|
help="""Whether to enable byte_fallback when training bpe.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -56,15 +56,11 @@ def get_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--transcript",
|
"--transcript", type=str, help="Training transcript.",
|
||||||
type=str,
|
|
||||||
help="Training transcript.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vocab-size",
|
"--vocab-size", type=int, help="Vocabulary size for BPE training",
|
||||||
type=int,
|
|
||||||
help="Vocabulary size for BPE training",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
@ -215,9 +215,7 @@ class LibriHeavyAsrDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self,
|
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
cuts_train: CutSet,
|
|
||||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -359,13 +357,10 @@ class LibriHeavyAsrDataModule:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
validate = K2SpeechRecognitionDataset(
|
validate = K2SpeechRecognitionDataset(
|
||||||
cut_transforms=transforms,
|
cut_transforms=transforms, return_cuts=self.args.return_cuts,
|
||||||
return_cuts=self.args.return_cuts,
|
|
||||||
)
|
)
|
||||||
valid_sampler = DynamicBucketingSampler(
|
valid_sampler = DynamicBucketingSampler(
|
||||||
cuts_valid,
|
cuts_valid, max_duration=self.args.max_duration, shuffle=False,
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
@ -387,45 +382,52 @@ class LibriHeavyAsrDataModule:
|
|||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
cuts,
|
cuts, max_duration=self.args.max_duration, shuffle=False,
|
||||||
max_duration=self.args.max_duration,
|
|
||||||
shuffle=False,
|
|
||||||
)
|
)
|
||||||
logging.debug("About to create test dataloader")
|
logging.debug("About to create test dataloader")
|
||||||
test_dl = DataLoader(
|
test_dl = DataLoader(
|
||||||
test,
|
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers,
|
||||||
batch_size=None,
|
|
||||||
sampler=sampler,
|
|
||||||
num_workers=self.args.num_workers,
|
|
||||||
)
|
)
|
||||||
return test_dl
|
return test_dl
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_small_cuts(self) -> CutSet:
|
def train_small_cuts(self) -> CutSet:
|
||||||
logging.info("About to get small subset cuts")
|
logging.info("About to get small subset cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libriheavy_cuts_small.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_medium_cuts(self) -> CutSet:
|
def train_medium_cuts(self) -> CutSet:
|
||||||
logging.info("About to get medium subset cuts")
|
logging.info("About to get medium subset cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libriheavy_cuts_medium.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_large_cuts(self) -> CutSet:
|
def train_large_cuts(self) -> CutSet:
|
||||||
logging.info("About to get large subset cuts")
|
logging.info("About to get large subset cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libriheavy_cuts_large.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def dev_cuts(self) -> CutSet:
|
def dev_cuts(self) -> CutSet:
|
||||||
logging.info("About to get dev cuts")
|
logging.info("About to get dev cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libriheavy_cuts_dev.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_clean_cuts(self) -> CutSet:
|
def test_clean_cuts(self) -> CutSet:
|
||||||
logging.info("About to get the test-clean cuts")
|
logging.info("About to get the test-clean cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libriheavy_cuts_test_clean.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def test_other_cuts(self) -> CutSet:
|
def test_other_cuts(self) -> CutSet:
|
||||||
logging.info("About to get the test-other cuts")
|
logging.info("About to get the test-other cuts")
|
||||||
return load_manifest_lazy(self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz")
|
return load_manifest_lazy(
|
||||||
|
self.args.manifest_dir / "libriheavy_cuts_test_other.jsonl.gz"
|
||||||
|
)
|
||||||
|
@ -255,24 +255,17 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use-ctc",
|
"--use-ctc", type=str2bool, default=False, help="If True, use CTC head.",
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="If True, use CTC head.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--world-size",
|
"--world-size", type=int, default=1, help="Number of GPUs for DDP training.",
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of GPUs for DDP training.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -290,10 +283,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-epochs",
|
"--num-epochs", type=int, default=30, help="Number of epochs to train.",
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help="Number of epochs to train.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -401,10 +391,7 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ctc-loss-scale",
|
"--ctc-loss-scale", type=float, default=0.2, help="Scale for CTC loss.",
|
||||||
type=float,
|
|
||||||
default=0.2,
|
|
||||||
help="Scale for CTC loss.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -615,11 +602,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
assert (
|
assert params.use_transducer or params.use_ctc, (
|
||||||
params.use_transducer or params.use_ctc
|
f"At least one of them should be True, "
|
||||||
), (f"At least one of them should be True, "
|
|
||||||
f"but got params.use_transducer={params.use_transducer}, "
|
f"but got params.use_transducer={params.use_transducer}, "
|
||||||
f"params.use_ctc={params.use_ctc}")
|
f"params.use_ctc={params.use_ctc}"
|
||||||
|
)
|
||||||
|
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
@ -797,12 +784,12 @@ def compute_loss(
|
|||||||
|
|
||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
warm_step = params.warm_step
|
warm_step = params.warm_step
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
simple_loss, pruned_loss, ctc_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
@ -820,17 +807,16 @@ def compute_loss(
|
|||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
# to params.simple_loss scale by warm_step.
|
# to params.simple_loss scale by warm_step.
|
||||||
simple_loss_scale = (
|
simple_loss_scale = (
|
||||||
s if batch_idx_train >= warm_step
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
)
|
)
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
1.0 if batch_idx_train >= warm_step
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
)
|
)
|
||||||
loss += (
|
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
simple_loss_scale * simple_loss
|
|
||||||
+ pruned_loss_scale * pruned_loss
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.use_ctc:
|
if params.use_ctc:
|
||||||
loss += params.ctc_loss_scale * ctc_loss
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
@ -867,11 +853,7 @@ def compute_validation_loss(
|
|||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params, model=model, sp=sp, batch=batch, is_training=False,
|
||||||
model=model,
|
|
||||||
sp=sp,
|
|
||||||
batch=batch,
|
|
||||||
is_training=False,
|
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
@ -961,11 +943,7 @@ def train_one_epoch(
|
|||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params, model=model, sp=sp, batch=batch, is_training=True,
|
||||||
model=model,
|
|
||||||
sp=sp,
|
|
||||||
batch=batch,
|
|
||||||
is_training=True,
|
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -975,7 +953,9 @@ def train_one_epoch(
|
|||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
# Use the number of hours of speech to adjust the learning rate
|
# Use the number of hours of speech to adjust the learning rate
|
||||||
scheduler.step_epoch(params.batch_idx_train * params.max_duration * params.world_size / 3600)
|
scheduler.step_epoch(
|
||||||
|
params.batch_idx_train * params.max_duration * params.world_size / 3600
|
||||||
|
)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
@ -994,9 +974,7 @@ def train_one_epoch(
|
|||||||
and params.batch_idx_train % params.average_period == 0
|
and params.batch_idx_train % params.average_period == 0
|
||||||
):
|
):
|
||||||
update_averaged_model(
|
update_averaged_model(
|
||||||
params=params,
|
params=params, model_cur=model, model_avg=model_avg,
|
||||||
model_cur=model,
|
|
||||||
model_avg=model_avg,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -1016,9 +994,7 @@ def train_one_epoch(
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
remove_checkpoints(
|
remove_checkpoints(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir, topk=params.keep_last_k, rank=rank,
|
||||||
topk=params.keep_last_k,
|
|
||||||
rank=rank,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
if batch_idx % 100 == 0 and params.use_fp16:
|
||||||
@ -1180,14 +1156,13 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
2 ** 22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
|
|
||||||
def normalize_text(c: Cut):
|
def normalize_text(c: Cut):
|
||||||
text = remove_punc_to_upper(c.supervisions[0].text)
|
text = remove_punc_to_upper(c.supervisions[0].text)
|
||||||
c.supervisions[0].text = text
|
c.supervisions[0].text = text
|
||||||
@ -1233,7 +1208,7 @@ def run(rank, world_size, args):
|
|||||||
libriheavy = LibriHeavyAsrDataModule(args)
|
libriheavy = LibriHeavyAsrDataModule(args)
|
||||||
|
|
||||||
train_cuts = libriheavy.train_small_cuts()
|
train_cuts = libriheavy.train_small_cuts()
|
||||||
if params.subset == 'M' or params.subset == 'L':
|
if params.subset == "M" or params.subset == "L":
|
||||||
train_cuts += libriheavy.train_medium_cuts()
|
train_cuts += libriheavy.train_medium_cuts()
|
||||||
if params.subset == "L":
|
if params.subset == "L":
|
||||||
train_cuts += libriheavy.train_large_cuts()
|
train_cuts += libriheavy.train_large_cuts()
|
||||||
@ -1322,9 +1297,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
|
|
||||||
def display_and_save_batch(
|
def display_and_save_batch(
|
||||||
batch: dict,
|
batch: dict, params: AttributeDict, sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
|
||||||
sp: spm.SentencePieceProcessor,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Display the batch statistics and save the batch into disk.
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
@ -1371,11 +1344,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params, model=model, sp=sp, batch=batch, is_training=True,
|
||||||
model=model,
|
|
||||||
sp=sp,
|
|
||||||
batch=batch,
|
|
||||||
is_training=True,
|
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user