From 51b1b746717dadeaef09323479414ab0b6c97f50 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Tue, 14 Feb 2023 15:11:18 +0900 Subject: [PATCH] from local --- .../ASR/conformer_ctc3/.decode_multi.py.swp | Bin 49152 -> 53248 bytes .../ASR/conformer_ctc3/decode_multi.py | 144 ++++-------------- 2 files changed, 33 insertions(+), 111 deletions(-) diff --git a/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp b/egs/tedlium2/ASR/conformer_ctc3/.decode_multi.py.swp index 4622dd2f1e0e52b987981808983d1e09ede6b24c..cbaf71bb0d06d261eede59552cd40ed22498290e 100644 GIT binary patch delta 1810 zcmb8uduUr#9Ki9DSDWS4m#OX6)!WtCT>ElWnv&aWOWnZdIyV21IjC_>Y;{R$TlAl- ziT}tf({3JcRdj0@h?Q>Z7RpdmC@4Pum^g5DLqudKvx22MHz)(YH@O$AL&3m@bAGw! z+;h(FcTYBJ%TC#5t>qSX-d66lSn6a&#qFv`6n_?@EwinTjZk#tH&vYWomK4RnMJ1% zU#($Sxey)3w4cj_xPT+riC*||%_qd?_zcH!3^DZM0kohQOI{&9#Wco|z+>2gwOC}I z7je!bGGda>I9^5@T2Tfcy!gW{#5@jRH>&ZIONfs!iKEzxM*PV^uHiEDL!HNmcn29g z0zVcw^cfs>)(J62rw6C-3Z6g^ozA?)URx@c8tf&B=wLFS%FeEErNXZmiov+j*O!dS zU{?d05JnL9>alw3&_t}S@78~ne6VYqRlX2vl&^>OO6S398Az{_y}KN;PFpKGG{20+ z%j7Oikx6Z%oFJakYUP}^ULH%iZ{-k$~Bu7ZT{#Ed4O+B0W zfaoRY|E91W0?Uj%j~S$)LB$H?xP({nC^R^+Nc!JI8ZD^DuVnjcyoDDLLMeVJ72*v{ zl=7X7&}o3catVuK25D$eaor)rHyFo*2*QIab|JpR7=lpY!BsBAEDm4?no$ckZd&<< zkVD!^9zt|}=Vo2PVN3pKZrUY#+NAN3}4W;%C9B3OF-O!>W^o**7OQH&^->~G$rh6;1?AJ6nuWupB;e9M))4!sEGjn?D$ za0n5saNRCo3Q6q6L%0uKT;{r+Kpgt@tA_{klO619$k4jY0dSs;iHrnlSV$vyhwy!asHOe0*?ecW4 z(s*9Uczs-hF|R{!t?>=U8b($K^!iZ`va~npKS5aM<_>mnvr?O^~?aZRDdDF)d z0(&nqGKh+-u#83*_TfcD1ku-37e!G)P*(-fC$auhA?StQc+a!vJ?A~o^PZlwxjj9( z*K-PO&U9{}%Vuj)9?z^PxiUWL(i=R(?!IT*4b5JdoRue1?ckq6k-CYDhXo?Ha0Yu( z2RGihMeg7V25=gOkw5^yT_V45AD58CK{UaS_`Ii$nt+6DpVRm~LY2Iv# zx0SjuCepcwWg3x=L}%4_M7Q)vR5#36ykacSp0MJ+u@yW@b%mCWugr%3Uz?3p_(Yu# z`P4J~3C&R@;cC?w4yse(ptmz>n2{=z`;e90!+5}oR)aR!KYDdbW-d`w57o00|s_MZOZ^u?` eSJcubs+)+LZF;vE)3=#M+gPt^Bt1H`YWiPpDD4dZ diff --git a/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py b/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py index 030c602cf..2c4d8725c 100755 --- a/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py +++ b/egs/tedlium2/ASR/conformer_ctc3/decode_multi.py @@ -777,121 +777,43 @@ def main() -> None: num_decoder_layers=params.num_decoder_layers, group_num=params.group_num, ) + + for epoch in range(params.start, params.end+1): + load_checkpoint(f"{params.exp_dir}/epoch-{epoch}.pt", model) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) + # we need cut ids to display recognition results. + args.return_cuts = True + tedlium = TedLiumAsrDataModule(args) + + valid_cuts = tedlium.dev_cuts() + + valid_dl = tedlium.valid_dataloaders(valid_cuts) + + test_sets = ["dev"] + test_dls = [dev_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, ) - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") + save_results(params=params, test_set_name=test_set, results_dict=results_dict) - # we need cut ids to display recognition results. - args.return_cuts = True - tedlium = TedLiumAsrDataModule(args) - - valid_cuts = tedlium.dev_cuts() - test_cuts = tedlium.test_cuts() - - valid_dl = tedlium.valid_dataloaders(valid_cuts) - test_dl = tedlium.test_dataloaders(test_cuts) - - #test_sets = ["dev", "test"] - #test_dls = [valid_dl, test_dl] - test_sets = ["dev"] - test_dls = [dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - H=H, - bpe_model=bpe_model, - word_table=lexicon.word_table, - G=G, - sos_id=sos_id, - eos_id=eos_id, - ) - - save_results(params=params, test_set_name=test_set, results_dict=results_dict) - - logging.info("Done!") + logging.info("Done!") torch.set_num_threads(1)