From 8695a97ac58ab361f6d70f5d9181c086d37ec338 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Sat, 10 Dec 2022 15:07:10 +0900 Subject: [PATCH] from local --- .../.train.py.swp | Bin 110592 -> 110592 bytes .../train.py | 13 +++++++++++++ 2 files changed, 13 insertions(+) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/.train.py.swp index 0fe1c626a2fd732b6211d9831b02eaf887d35130..dbcc18b7cc63a3d145b1acaabda5db35a872d1c5 100644 GIT binary patch delta 563 zcmXZYUr1AN6u|LwU2ZB04q+H6U6t4-ahdvJf?)I{E8qM>5#zSq>$ax%vR#cKGQk&L z>k8Y8Pn8IP5E2T}Q(zCJ5K&rBzC=(a2|@-@ez$qx^Ml_xhjV`X77n}%2j0NF$k^~L z^_mtDqFLl~$8!8h{qw*VzmncK_P1zFFDT?|s@1gXA94%-8$`b29Ufx|ag3lJ-S}EB zvW^mp$RUb9bt0#Dg=LuN#*t6tC-(3OAF+iM%wZf?(c+VW{PT+J;|;d4il=aJ6Fq1} zEe>l%cJLY{*sw5xKJ?=E1(CNXA%puE#~4O2ge&minWs&}@Q7ST06+Mf4_HIK=HdI7 z8^3OJ`I=Q?%1Rg+&B^Mi=}^(Wtz2p|gUP%-pw4xwom$!=&DF}?iQEm(2P++QW-x1* zbQG0>((a`)9}E|DN4eaPvj=mwl;^s;v3B=c?4tYVZZQ6=WX7xlUJRp@4NA<17tey#lE6J3pRwpbrMmHOy%1Po3rr~`( delta 332 zcmYMvze_>^7=_`Z)xB1TU2ZrPl?`p`(4wG}xR_~MQ$nb%)zXj%3T=Tn=-3i2)g}%) z=UTRhf)o@eK|lV1-f-x^!};KwPI+jQhlV#B+)t!dL;j%1f^@J}&;7WX-e=2^X?U-&9=y&-FX-VB4`|~G1^CdPmO7{)iyf@sZ}K+1 Y0XP2B9HgDC^Vqt1VJF*m((#f11+oY>N&o-= diff --git a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py index 6cc2856c5..36aa913c4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless_d2v_v2/train.py @@ -835,6 +835,19 @@ def compute_loss( assert loss.requires_grad == is_training + if decode: + model.eval() + with torch.no_grad(): + hypos = model.module.decode( + x=feature, + x_lens=feature_lens, + y=y, + sp=sp + ) + logging.info(f'ref: {batch["supervisions"]["text"][0]}') + logging.info(f'hyp: {" ".join(hypos[0])}') + model.train() + with warnings.catch_warnings(): warnings.simplefilter("ignore") info["frames"] = (feature_lens // params.subsampling_factor).sum().item()