From 0d40b4617afee3d74ab457a236ff02138c2bf374 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 25 Apr 2022 13:40:47 +0800 Subject: [PATCH] Add knowledge-base lookup to model --- .../ASR/pruned2_knowledge/conformer.py | 31 +++++++++++++++++++ .../ASR/pruned2_knowledge/sampling.py | 9 ++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 257936b59..e07aba60b 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -19,6 +19,7 @@ import copy import math import warnings from typing import Optional, Tuple +from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface @@ -61,6 +62,10 @@ class Conformer(EncoderInterface): dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, + knowledge_M: int = 256, + knowledge_N: int = 2, + knowledge_D: int = 256, + knowledge_K: int = 16, ) -> None: super(Conformer, self).__init__() @@ -69,6 +74,10 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") + + self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, + knowledge_D) + # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: @@ -79,12 +88,17 @@ class Conformer(EncoderInterface): self.encoder_pos = RelPositionalEncoding(d_model, dropout) encoder_layer = ConformerEncoderLayer( + self.knowledge_base, d_model, nhead, dim_feedforward, dropout, layer_dropout, cnn_module_kernel, + knowledge_M, + knowledge_N, + knowledge_D, + knowledge_K ) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) @@ -134,11 +148,15 @@ class ConformerEncoderLayer(nn.Module): See: "Conformer: Convolution-augmented Transformer for Speech Recognition" Args: + knowledge_base: shared knowledge base parameter matrix, to be passed to constructors + of lookup modules d_model: the number of expected features in the input (required). nhead: the number of heads in the multiheadattention models (required). dim_feedforward: the dimension of the feedforward network model (default=2048). dropout: the dropout value (default=0.1). cnn_module_kernel (int): Kernel size of convolution module. + knowledge_M, knowledge_N, knowledge_D, knowledge_K: parameters for knowledge-base, + see docs for KnowlegeBaseLookup. Examples:: >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) @@ -149,12 +167,17 @@ class ConformerEncoderLayer(nn.Module): def __init__( self, + knowledge_base: nn.Parameter, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, layer_dropout: float = 0.075, cnn_module_kernel: int = 31, + knowledge_M: int = 256, + knowledge_N: int = 2, + knowledge_D: int = 256, + knowledge_K: int = 16, ) -> None: super(ConformerEncoderLayer, self).__init__() @@ -184,6 +207,11 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, + knowledge_D, knowledge_K, + d_model, + knowledge_base) + self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). @@ -253,6 +281,9 @@ class ConformerEncoderLayer(nn.Module): # feed forward module src = src + self.dropout(self.feed_forward(src)) + # knowledge-base lookup + src = src + self.dropout(self.lookup(src)) + src = self.norm_final(self.balancer(src)) if alpha != 1.0: diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index b6aec23d7..02cac6748 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -417,9 +417,12 @@ def get_indexes_for_samples(P: Tensor, # right=True means we find # P_cumsum[...,index-1] <= this_samples[...,k] < P_cumsum[...,index], # which is what we want, as opposed to ... < ... <= (i.e. swap < and <=) - idx = ans_indexes[...,n] = torch.searchsorted(P_cumsum[...,n,:], # (*, M) - this_samples, # (*, K) - right=True) + # .contiguous() suppresses a warning about searchsorted needing contiguous + # input. N tends to be 2 or 3 so this copy is not too big a deal. + idx = ans_indexes[...,n] = torch.searchsorted( + P_cumsum[...,n,:].contiguous(), # (*, M) + this_samples, # (*, K) + right=True) this_P = torch.gather(P[...,n,:], dim=-1, index=idx) # shape: (*, K) if n == 0: