mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
style check
This commit is contained in:
parent
cdb6591014
commit
c3b2a522b0
@ -91,7 +91,7 @@ def get_parser():
|
||||
|
||||
def compute_fbank_wenetspeech_splits(args):
|
||||
subset = args.training_subset
|
||||
subset = str(subset)
|
||||
subset = str(subset)
|
||||
num_splits = args.num_splits
|
||||
output_dir = f"data/fbank/{subset}_split_{num_splits}"
|
||||
output_dir = Path(output_dir)
|
||||
|
@ -31,12 +31,11 @@ from lhotse import load_manifest
|
||||
|
||||
def main():
|
||||
paths = [
|
||||
#"./data/fbank/cuts_S.jsonl.gz",
|
||||
#"./data/fbank/cuts_M.jsonl.gz",
|
||||
"./data/fbank/cuts_L.jsonl.gz",
|
||||
#"./data/fbank/cuts_DEV.jsonl.gz",
|
||||
#"./data/fbank/cuts_TEST_NET.jsonl.gz",
|
||||
#"./data/fbank/cuts_TEST_MEETING.jsonl.gz"
|
||||
"./data/fbank/cuts_S.jsonl.gz",
|
||||
"./data/fbank/cuts_M.jsonl.gz",
|
||||
"./data/fbank/cuts_DEV.jsonl.gz",
|
||||
"./data/fbank/cuts_TEST_NET.jsonl.gz",
|
||||
"./data/fbank/cuts_TEST_MEETING.jsonl.gz",
|
||||
]
|
||||
|
||||
for path in paths:
|
||||
|
@ -338,7 +338,7 @@ class WenetSpeechAsrDataModule:
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_dl.sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
|
@ -503,8 +503,7 @@ def modified_beam_search(
|
||||
for i in range(batch_size):
|
||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||
|
||||
#topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
|
||||
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||
|
||||
for k in range(len(topk_hyp_indexes)):
|
||||
|
@ -900,7 +900,7 @@ def run(rank, world_size, args):
|
||||
train_dl = wenetspeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
|
||||
if not params.print_diagnostics and params.start_batch == 0:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
@ -909,7 +909,7 @@ def run(rank, world_size, args):
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
)
|
||||
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
|
Loading…
x
Reference in New Issue
Block a user