mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
Add knowledge-base lookup to model
This commit is contained in:
parent
a359bfe504
commit
0d40b4617a
@ -19,6 +19,7 @@ import copy
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from sampling import create_knowledge_base, KnowledgeBaseLookup
|
||||
|
||||
import torch
|
||||
from encoder_interface import EncoderInterface
|
||||
@ -61,6 +62,10 @@ class Conformer(EncoderInterface):
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
knowledge_M: int = 256,
|
||||
knowledge_N: int = 2,
|
||||
knowledge_D: int = 256,
|
||||
knowledge_K: int = 16,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__()
|
||||
|
||||
@ -69,6 +74,10 @@ class Conformer(EncoderInterface):
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
|
||||
self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N,
|
||||
knowledge_D)
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_features)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
@ -79,12 +88,17 @@ class Conformer(EncoderInterface):
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
self.knowledge_base,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
dropout,
|
||||
layer_dropout,
|
||||
cnn_module_kernel,
|
||||
knowledge_M,
|
||||
knowledge_N,
|
||||
knowledge_D,
|
||||
knowledge_K
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
@ -134,11 +148,15 @@ class ConformerEncoderLayer(nn.Module):
|
||||
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
||||
|
||||
Args:
|
||||
knowledge_base: shared knowledge base parameter matrix, to be passed to constructors
|
||||
of lookup modules
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
knowledge_M, knowledge_N, knowledge_D, knowledge_K: parameters for knowledge-base,
|
||||
see docs for KnowlegeBaseLookup.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
@ -149,12 +167,17 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
knowledge_base: nn.Parameter,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
knowledge_M: int = 256,
|
||||
knowledge_N: int = 2,
|
||||
knowledge_D: int = 256,
|
||||
knowledge_K: int = 16,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
|
||||
@ -184,6 +207,11 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N,
|
||||
knowledge_D, knowledge_K,
|
||||
d_model,
|
||||
knowledge_base)
|
||||
|
||||
self.norm_final = BasicNorm(d_model)
|
||||
|
||||
# try to ensure the output is close to zero-mean (or at least, zero-median).
|
||||
@ -253,6 +281,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
# feed forward module
|
||||
src = src + self.dropout(self.feed_forward(src))
|
||||
|
||||
# knowledge-base lookup
|
||||
src = src + self.dropout(self.lookup(src))
|
||||
|
||||
src = self.norm_final(self.balancer(src))
|
||||
|
||||
if alpha != 1.0:
|
||||
|
@ -417,9 +417,12 @@ def get_indexes_for_samples(P: Tensor,
|
||||
# right=True means we find
|
||||
# P_cumsum[...,index-1] <= this_samples[...,k] < P_cumsum[...,index],
|
||||
# which is what we want, as opposed to ... < ... <= (i.e. swap < and <=)
|
||||
idx = ans_indexes[...,n] = torch.searchsorted(P_cumsum[...,n,:], # (*, M)
|
||||
this_samples, # (*, K)
|
||||
right=True)
|
||||
# .contiguous() suppresses a warning about searchsorted needing contiguous
|
||||
# input. N tends to be 2 or 3 so this copy is not too big a deal.
|
||||
idx = ans_indexes[...,n] = torch.searchsorted(
|
||||
P_cumsum[...,n,:].contiguous(), # (*, M)
|
||||
this_samples, # (*, K)
|
||||
right=True)
|
||||
this_P = torch.gather(P[...,n,:], dim=-1, index=idx) # shape: (*, K)
|
||||
|
||||
if n == 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user