From 0ba1e37cb389af106138ab10e242f5b5ef953703 Mon Sep 17 00:00:00 2001 From: dohe0342 Date: Mon, 9 Jan 2023 19:19:40 +0900 Subject: [PATCH] from local --- .../ASR/incremental_transf/.model.py.swp | Bin 4096 -> 24576 bytes .../ASR/incremental_transf/model.py | 181 ++++++++++++++++++ 2 files changed, 181 insertions(+) diff --git a/egs/librispeech/ASR/incremental_transf/.model.py.swp b/egs/librispeech/ASR/incremental_transf/.model.py.swp index c80cdc060e500d6b25cb2b1213674132cf9db5de..97bc4e8e3489b5b6d6fc7038002e7adcf2b3f547 100644 GIT binary patch literal 24576 zcmeI4dyHIHdBBH~wrSE3LYoK>L?`1anRRFHK5WO$rsJ^jgRBy-sdpVGbvNTXGk15c zedpfc-aFo%{1N&}BM40ih_s?F)QW=oND3}$h|($%5zwL{f?H8U2qHjQD1}m}K)8ke zzH`rg%zw%3e9+mBYKdO^qWtp0F3>+*Lgfl>nJ zlfWQeo&32=Dm!Lp^imc1WE~%5-255N}!a$519nw z{somc(54?T?dqE6i^iTmYJRuO(;W+cvH9IG&yS9UUtoUUYM#@kfV3R*x_p!pC?!xz zpp-xmc1WE~%5-255N}!ZLDS=V~KL82X80?pl??=TD=l(yn|Nn}Y?}L-z!Zq-P zmsBc`z{7A4+zmIt*-I*wzl0O805`#F!GcNH23J4@p8QGT!b9+z@MgFdeAorwxwumK zPxv}~7M_F$;jOR^E3g-~!)0(0eD|VC|g`b8BoaRvGkKjJI0W5eGNMHT^=7SSe_{^(Q zwl{Fpx*hfgeYK~drd8avW7TnDC+xYt6RGvC+wLknDh3|E?9fqmJ9bYw^(1Xn$@5FH znJc_es&9w!$cyE!qaxfs92EyD>IFgErA$Y8K}2CLzr3ER1%tRhh*hwv`k~Wt+i?&o z--*Q|WTKFr*K6v$zV9;akun{>jR(n=CE8n-xQz>8`W z8ao?tQ7k&2rp{+nGJ}|Ywf(j;TJ}6;uB|y86+3=Jy>&OtDtm(-KSr9BW1ih|ycJ6= z#L~#vcBzSe7fByUI}42&juzT-kOhsiI~@cwUh>T3zO-0!in?~+QMJYT#AsO}Gk?%) z(J92zB8`;mtK%ET3k7d9JsL1t`n=k)bqQ*@uJ%q8tFw`(oDbIu`BbY_UCM--+@qEU zea|_@fNXcIWi78dLplGolIMaKl09tp9pCoi;hb9a0_GNz%$lxewPtt^m{5T~r=mC{ zv>FY1)jZrxsx0wLF>yBjv*3}uv%$J`aB-L*OsU$gykIJPa@IPm2Zm{9R%s*Xk#^@P zH|EX}5INpz60PIVi$O@^I<-oct%O^V+e5%jn${146Vz!*v_XwI*y*2&5qkc z1K~f`YDVszPCojJdC>l?@-WkTJP#(B@5U}}Wa3@cX-r$YCdTS^BA^!1Ny(Sdikqqt zYQBvIr%I)v%(pymJBpkzR$V)?<2Y+$T~%A2?VoJAooe1b6W5KHg0W2Hn#r!N_9tp< z0eujz;x;$Iz#R6o2zegSpznmW39H$Z7B`!7~#fsEb6iCllR`WUk9hJTlUomRbz7rewaj zS)xpyQa4FQqQZC>nafg`2vo6V(61<7v%2Vlqq($PBRm-k6Uz2Gg))*p&S$ffN~Tda z7DbMqZ<>I!=!<4n%n#7yVk$!iA)|^`R3FLX^LVHCZi?Ws-g@t@LUp z-A%$(g^nSE4PCAo3t1}DCj>vQx*6$cC(w52_C>3nM?iMF?eyc~2F4Ve3+EimKpr92 zKywH@&UIiicw|_O1r!GCSV*DSVWx@lPTHSag{95eq!{w#qZ&EW2Z?G* zbwPE>dHsOQ9JMU|wPw9ab_s`__8<)LzEz9PAhbOmdY2X@D<)YmWwb^Y7hSe5aomsQ zrl#z0!#!mM;o4NcvpO|1J2PYL*v()3sA`?Y+hO!egfyHGc$s-jw9hWG~H0u5_MOragB+I<-EA)!*kGxHPoWu`W zR}jrxC|qwEsQWhMxji?wWoSlf;L9#k=It%C$B)>SZBt85vPCqn^hQvkuqd_NRkzK) z%U>JVY!TOjQ%>khDNOpF8$~)(QDWQF(S_xk4joxmN9PY8o?l#E*uSI>9aeh}E$&-b zUO2SKuLEj+@pg6d!s0&WDf@0};%uN<5ItK!=`z>pNQF8s8xyVGI-6^(m~_Pm%sAUu z)%hj0uvAqy&Mz%2)yd7#Jle2iMOOo4z0UPF_ z3R<#lRW_Qt&#%>@+|l5~9GaN2pbaAF1sb>NQbA`!RE1PaHu0V->uxgs$2!k+T2WB* zx{hfN*VU~~C>^D!pQ=gns!3FxxL#LS27qNvWq(YA^gv*0gfa;9079H=Y@lQS^aft+ zvTaiKAnpcXq%+z}Lqm718;jItr)OuBy3MtNo~!l-Vc(Ls@~N3LbLs#i>%_oUHxKwy zS4SOJ-ORU%irD|`(<=XH?0FehYl4!arfx{~bICCtwdW z;7a%&_Wd*PId~Hsfc>x&u7*qCYuNi|;QjCbtivzD99#ik#@>G#J`Nv*2jJIXC&*dA z#qbFBzMKWT8TP79!g@ba(nwa@SRnlJ$DcqJjdFLx6nWl*B36 zVVhn^ep|07v$CR0Pw%{LT9y<&cY0!la9Mh(nL`dXf5L)fxo?9`Evs?VOaDp{KGAf2F( zU$lBG806U*Ms{v6nvqzlk`GTkh9nb?3}jhS{Bs?ZCDXRL6tk&1x%oSibz=3ZJ!(c( zE+Rdc(qple9!`xmxIE@ckJid#uJV{mmB(D=F;{uamDr~& zl=@R1b5)Ea^rJlHDv!Aohb`%tD37^*V8>j^X+6f*rgN@T_-2=$SZ!R4u*t3`SA`w@ zW1?tRDV3k)=NXtIyGFL#&QAc-NF%01W8_y|+y8IE&ipQRsNDaj{`numzP}&thDF#8 z|BW60Bzzox7v2F!U>c_28`$z6fd}C2;K6I)2Dlg=$F6@2?gz2+n;>`ouYjkq?>__Y zg*#vmTnVp+uVLRm4v)dTZ~(4>3*lMp`VYYU;DXrwSHU;2=RW~Yzz5;?;PtQ!yW!>V zE$sP!hp)i<;V!rtE{D%z+dlzf?|%S(1MY@v;j7s7e**7*C;=vvK}#*dK7w)^qPn zj-=(tNMD~9Zet~%O~oiKu;mb)BOAF0kkc>Dhs<@>HR<)GNhlW^dHc9Tk^9O>&8cHI z>f2-8oW$X!c#pi2rx{%p&(qCbuTc=2m0s6LMKb3D?y$%D)E-Th-Yn@<>5tJ=TlGgW z9L8jwj|SSE_S`v|p4OIY$CW}iW@VitSz1G#EK4^kvM3R6qbTYus^XK*z;jH+{;;_g z+MR}Zj!lUT6PP^nz@(!RxvV*HJ_wzpI#E$NCh}IJqA(JquA@zpH;HabYOx#+oV*GW zDxZ-p(f9@_iIV@FZM4p}<$4Z@NpAF9;CsX5(7nMg?e4a@atw?@yzG>|lGiBt_JWk* zb89<>z1`bd%Sn06CGQ-;sh5x5%+~s*-<2616XS|8aTNUiRU-$WMlG5O=&8ZfPN#rO z(n=yLssC7PJ^#iOL9S!c{ya8US36cnnz$`w*%iZQC48ve8a9O!lIZ+9oiK=MwGDit z?sRI?`19<9jyUx+k{DcFU%OMBMow}4+1Nl`$(OtZ=|ZIoOV%ta4r+zO6M6lMku4(E zZiyEEkCnUsJn{Hm37(xI8sWqERtRMg601f7FmwgxdkRUFGJ3>w3MqEgf3Q zYmn+MdVv~iJ#1dpxYuu{WfaRTBuOUzlKNRSGvY^6nEb_Dqy)+I`7|}%F%(NK=2a+q zg)8fk=2?0BbP~!njtVtycI1`Y$+jeU0cnwAg0Rx}d9KMYB3>A*zzmJu_tj*-r*oG_Cdgg>+q@B8c$&$U<87Z%saaK!)8#h9=WJYkW zQy4pisaH2sCT+6`s%@kJb>pktNG^y9gVz+5%@x&1=U+N|p0P(ug~VmTJG)xq{;85nI%azo znk7r!a?&<%gFiEws{F}xjhIl|#aW?+!(vCgAX`O8CrO&7AYP)@S`jlRP0rgyHOgCU z|6}e}{>*Kz*X=xHvK=qqmbJC_wao;RN-H->raE+ z>-XUnxDvjBO)ozFe*=$#`1K!#Ux3Tu%h>e~!(YRF&;lDK;q%z`a?gJneieQNj>7d& zf$w40%N_qO!*O^G+yIxs=V_?;`u_!VVF9+oh45+Y`qOYPbYU-C0pG-~e-^$7e+zGe zQ*b-X!7JcWcnN$G8~>fq1F`#G1kZA=@)Ud;-UD*?e;d3MKEk=nhvBWzhM$L5!dd#_ zV<3I=vqpbP3#nox zkv*P=@?w=LSU&0URa!{O_$TUrneqY-k4g)vQ(8!=EmvAdoHCXc61IA75Q)`OT1dEB Uk61|Al2v-smfo~Kbl$Z83k5Yo(*OVf delta 16 YcmZoTz}TR$;U)9L3#^;pF$?hk06pmjrT_o{ diff --git a/egs/librispeech/ASR/incremental_transf/model.py b/egs/librispeech/ASR/incremental_transf/model.py index 272d06c37..837d24f65 100644 --- a/egs/librispeech/ASR/incremental_transf/model.py +++ b/egs/librispeech/ASR/incremental_transf/model.py @@ -205,3 +205,184 @@ class Transducer(nn.Module): ) return (simple_loss, pruned_loss) + + +class Interformer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + reduction: str = "sum", + delay_penalty: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. + Returns: + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert reduction in ("sum", "none"), reduction + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction=reduction, + delay_penalty=delay_penalty, + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + delay_penalty=delay_penalty, + reduction=reduction, + ) + + return (simple_loss, pruned_loss)