From b926cb0045a4e8278f194412195ed3e2b17e1c58 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Mon, 9 Jan 2023 20:02:35 +0900 Subject: [PATCH] from local --- .../incremental_transf/.identity_train.py.swp | Bin 57344 -> 61440 bytes .../ASR/incremental_transf/identity_train.py | 17 +++++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp b/egs/librispeech/ASR/incremental_transf/.identity_train.py.swp index 8cf25957fa61a2e0e852256309e8c4848dd35a80..9f2a0bce4046f8425bc1eed2b7f7cad17b8083fd 100644 GIT binary patch delta 1016 zcmXZaUr3Wt7{~Evd)IVxYg06%8|MlZ75>SjM7j{OW`AI)#p6Y2TMR@Hy|pd7pD`o?qXj(w9^Un#+#V zA97TclnKi&Vp)oxx;b9(&8D-eihV%mRPWdebJuErS#AG5@-rH?@3T!f4Cky*b40ur z&bjPVb8j~WQ&y2URCJ&nO{j(wIY@CMj(g~YAKEJ}n2?PaFJD3v8c~mGiUASzU=voPuk?J!3-4yihm|LldOnw$|#BdqC@M15@k&j=b;RD`7OUf%ma0}<* z#cmVR+)1$sf>mOX#543ji;WK_aRU2M3M&?f!(;To0~hjP!b_(A0x^u>3c7F@dtk&e sNtmu8(Wcwj50n|rH>Yc3>KaI1+!e7LI3~& delta 526 zcmXBRPbh-{7{~Ev-kEKNNoz?Mld__@Fo{x#MhfKYSUfs65#r#?^5;dglI7kX4e=ZbC29q8?Mv^2Pdv5I6fvGK8pEnTFu zXo_`1r)puVJ({q3ioz+!nL|HcP(rXW=Rd7##hbohRiVgFfyfh*IKd*iPzya2#4I8U zaKVCn#Pda#F^B;A(2hdr@N5>j#tzmn0xzmz!;MLTk~E2&VgjA`Hi~5MfO{O`04k>8 zhX-ZI=81ga6Yogk47*SdbTsK=XL zz8MT)- Td$e_{+#GHGNa&)@{kx7|mswYb diff --git a/egs/librispeech/ASR/incremental_transf/identity_train.py b/egs/librispeech/ASR/incremental_transf/identity_train.py index f7138a066..d75f20842 100755 --- a/egs/librispeech/ASR/incremental_transf/identity_train.py +++ b/egs/librispeech/ASR/incremental_transf/identity_train.py @@ -497,6 +497,23 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: ) return model +def get_interformer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + def load_checkpoint_if_available( params: AttributeDict,