style check

This commit is contained in:
luomingshuang 2022-05-06 18:52:24 +08:00
parent cdb6591014
commit c3b2a522b0
5 changed files with 10 additions and 12 deletions

View File

@ -91,7 +91,7 @@ def get_parser():
def compute_fbank_wenetspeech_splits(args):
subset = args.training_subset
subset = str(subset)
subset = str(subset)
num_splits = args.num_splits
output_dir = f"data/fbank/{subset}_split_{num_splits}"
output_dir = Path(output_dir)

View File

@ -31,12 +31,11 @@ from lhotse import load_manifest
def main():
paths = [
#"./data/fbank/cuts_S.jsonl.gz",
#"./data/fbank/cuts_M.jsonl.gz",
"./data/fbank/cuts_L.jsonl.gz",
#"./data/fbank/cuts_DEV.jsonl.gz",
#"./data/fbank/cuts_TEST_NET.jsonl.gz",
#"./data/fbank/cuts_TEST_MEETING.jsonl.gz"
"./data/fbank/cuts_S.jsonl.gz",
"./data/fbank/cuts_M.jsonl.gz",
"./data/fbank/cuts_DEV.jsonl.gz",
"./data/fbank/cuts_TEST_NET.jsonl.gz",
"./data/fbank/cuts_TEST_MEETING.jsonl.gz",
]
for path in paths:

View File

@ -338,7 +338,7 @@ class WenetSpeechAsrDataModule:
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_dl.sampler.load_state_dict(sampler_state_dict)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:

View File

@ -503,8 +503,7 @@ def modified_beam_search(
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
#topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):

View File

@ -900,7 +900,7 @@ def run(rank, world_size, args):
train_dl = wenetspeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
if not params.print_diagnostics and params.start_batch == 0:
scan_pessimistic_batches_for_oom(
model=model,
@ -909,7 +909,7 @@ def run(rank, world_size, args):
graph_compiler=graph_compiler,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")