This commit is contained in:
marcoyang 2023-02-14 15:52:54 +08:00
parent 56c2474c0d
commit f6f29b9321
3 changed files with 12 additions and 34 deletions

View File

@ -302,9 +302,7 @@ def decode_one_batch(
en_hyps.append(en_text)
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -360,9 +358,7 @@ def decode_one_batch(
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
model=model, encoder_out=encoder_out_i, beam=params.beam_size,
)
else:
raise ValueError(
@ -726,19 +722,13 @@ def main():
sp=sp,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
params=params, test_set_name=test_set, results_dict=results_dict,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=zh_results_dict,
params=params, test_set_name=test_set, results_dict=zh_results_dict,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=en_results_dict,
params=params, test_set_name=test_set, results_dict=en_results_dict,
)
logging.info("Done!")

View File

@ -107,10 +107,7 @@ def get_parser():
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="Path to the lang",
"--lang-dir", type=str, default="data/lang_char", help="Path to the lang",
)
parser.add_argument(
@ -137,8 +134,7 @@ def get_parser():
def export_encoder_model_jit_trace(
encoder_model: torch.nn.Module,
encoder_filename: str,
encoder_model: torch.nn.Module, encoder_filename: str,
) -> None:
"""Export the given encoder model with torch.jit.trace()
@ -160,8 +156,7 @@ def export_encoder_model_jit_trace(
def export_decoder_model_jit_trace(
decoder_model: torch.nn.Module,
decoder_filename: str,
decoder_model: torch.nn.Module, decoder_filename: str,
) -> None:
"""Export the given decoder model with torch.jit.trace()
@ -182,8 +177,7 @@ def export_decoder_model_jit_trace(
def export_joiner_model_jit_trace(
joiner_model: torch.nn.Module,
joiner_filename: str,
joiner_model: torch.nn.Module, joiner_filename: str,
) -> None:
"""Export the given joiner model with torch.jit.trace()

View File

@ -210,9 +210,7 @@ class TAL_CSASRAsrDataModule:
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None,
) -> DataLoader:
"""
Args:
@ -357,8 +355,7 @@ class TAL_CSASRAsrDataModule:
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
cut_transforms=transforms, return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
@ -395,10 +392,7 @@ class TAL_CSASRAsrDataModule:
)
logging.info("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
test, batch_size=None, sampler=sampler, num_workers=self.args.num_workers,
)
return test_dl