mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
style check
This commit is contained in:
parent
cdb6591014
commit
c3b2a522b0
@ -91,7 +91,7 @@ def get_parser():
|
|||||||
|
|
||||||
def compute_fbank_wenetspeech_splits(args):
|
def compute_fbank_wenetspeech_splits(args):
|
||||||
subset = args.training_subset
|
subset = args.training_subset
|
||||||
subset = str(subset)
|
subset = str(subset)
|
||||||
num_splits = args.num_splits
|
num_splits = args.num_splits
|
||||||
output_dir = f"data/fbank/{subset}_split_{num_splits}"
|
output_dir = f"data/fbank/{subset}_split_{num_splits}"
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
@ -31,12 +31,11 @@ from lhotse import load_manifest
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
paths = [
|
paths = [
|
||||||
#"./data/fbank/cuts_S.jsonl.gz",
|
"./data/fbank/cuts_S.jsonl.gz",
|
||||||
#"./data/fbank/cuts_M.jsonl.gz",
|
"./data/fbank/cuts_M.jsonl.gz",
|
||||||
"./data/fbank/cuts_L.jsonl.gz",
|
"./data/fbank/cuts_DEV.jsonl.gz",
|
||||||
#"./data/fbank/cuts_DEV.jsonl.gz",
|
"./data/fbank/cuts_TEST_NET.jsonl.gz",
|
||||||
#"./data/fbank/cuts_TEST_NET.jsonl.gz",
|
"./data/fbank/cuts_TEST_MEETING.jsonl.gz",
|
||||||
#"./data/fbank/cuts_TEST_MEETING.jsonl.gz"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
|
@ -338,7 +338,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
if sampler_state_dict is not None:
|
if sampler_state_dict is not None:
|
||||||
logging.info("Loading sampler state dict")
|
logging.info("Loading sampler state dict")
|
||||||
train_dl.sampler.load_state_dict(sampler_state_dict)
|
train_dl.sampler.load_state_dict(sampler_state_dict)
|
||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||||
|
@ -503,8 +503,7 @@ def modified_beam_search(
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
#topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
|
|
||||||
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
for k in range(len(topk_hyp_indexes)):
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
@ -900,7 +900,7 @@ def run(rank, world_size, args):
|
|||||||
train_dl = wenetspeech.train_dataloaders(
|
train_dl = wenetspeech.train_dataloaders(
|
||||||
train_cuts, sampler_state_dict=sampler_state_dict
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
if not params.print_diagnostics and params.start_batch == 0:
|
if not params.print_diagnostics and params.start_batch == 0:
|
||||||
scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
model=model,
|
||||||
@ -909,7 +909,7 @@ def run(rank, world_size, args):
|
|||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user