Add emformer model.

This commit is contained in:
Fangjun Kuang 2022-03-30 16:09:32 +08:00
parent 395a3f952b
commit b4c7a27f3c
4 changed files with 239 additions and 0 deletions

View File

@ -0,0 +1,163 @@
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Tuple
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from subsampling import Conv2dSubsampling, VggSubsampling
from torchaudio.models import Emformer as _Emformer
LOG_EPSILON = math.log(1e-10)
class Emformer(EncoderInterface):
"""This is just a simple wrapper around torchaudio.models.Emformer.
We may replace it with our own implementation some time later.
"""
def __init__(
self,
num_features: int,
output_dim: int,
d_model: int,
nhead: int,
dim_feedforward: int,
num_encoder_layers: int,
segment_length: int,
left_context_length: int,
right_context_length: int,
max_memory_size: int = 0,
dropout: float = 0.1,
subsampling_factor: int = 4,
vgg_frontend: bool = False,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
output_dim:
The output dimension of the model.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
dim_feedforward:
The output dimension of the feedforward layers in encoder.
num_encoder_layers:
Number of encoder layers.
segment_length:
Number of frames per segment.
left_context_length:
Number of frames in the left context.
right_context_length:
Number of frames in the right context.
max_memory_size:
TODO.
dropout:
Dropout in encoder.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
vgg_frontend:
True to use vgg style frontend for subsampling.
"""
super().__init__()
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.right_context_length = right_context_length
assert right_context_length % subsampling_factor == 0
assert segment_length % subsampling_factor == 0
assert left_context_length % subsampling_factor == 0
left_context_length = left_context_length // subsampling_factor
right_context_length = right_context_length // subsampling_factor
segment_length = segment_length // subsampling_factor
self.model = _Emformer(
input_dim=d_model,
num_heads=nhead,
ffn_dim=dim_feedforward,
num_layers=num_encoder_layers,
segment_length=segment_length,
dropout=dropout,
activation="relu",
left_context_length=left_context_length,
right_context_length=right_context_length,
max_memory_size=max_memory_size,
weight_init_scale_strategy="depthwise",
tanh_on_mem=False,
negative_inf=-1e8,
)
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, output_dim)
)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
Input features of shape (N, T, C).
x_lens:
A int32 tensor of shape (N,) containing valid frames in `x` before
padding. We have `x.size(1) == x_lens.max()`
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of shape (N, T', C)
- encoder_out_lens, a int32 tensor of shape (N,) containing the
valid frames in `encoder_out` before padding
"""
x = nn.functional.pad(
x,
# (left, right, top, bottom)
# left/right are for the channel dimension, i.e., axis 2
# top/bottom are for the time dimension, i.e., axis 1
(0, 0, 0, self.right_context_length),
value=LOG_EPSILON,
) # (N, T, C) -> (N, T+right_context_length, C)
x = self.encoder_embed(x)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4!
x_lens = ((x_lens - 1) // 2 - 1) // 2
emformer_out, emformer_out_lens = self.model(x, x_lens)
logits = self.encoder_output_layer(emformer_out)
return logits, emformer_out_lens

View File

@ -0,0 +1 @@
../transducer_stateless/encoder_interface.py

View File

@ -0,0 +1 @@
../conformer_ctc/subsampling.py

View File

@ -0,0 +1,74 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_emformer/test_emformer.py
"""
import warnings
import torch
from emformer import Emformer
def test_emformer():
N = 3
T = 300
C = 80
output_dim = 500
encoder = Emformer(
num_features=C,
output_dim=output_dim,
d_model=512,
nhead=8,
dim_feedforward=2048,
num_encoder_layers=12,
segment_length=16,
left_context_length=120,
right_context_length=4,
vgg_frontend=True,
)
x = torch.rand(N, T, C)
x_lens = torch.randint(100, T, (N,))
x_lens[0] = T
y, y_lens = encoder(x, x_lens)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
assert (y_lens == ((x_lens - 1) // 2 - 1) // 2).all()
assert x.size(0) == x.size(0)
assert y.size(1) == max(y_lens)
assert y.size(2) == output_dim
num_param = sum([p.numel() for p in encoder.parameters()])
print(f"Number of encoder parameters: {num_param}")
def main():
test_emformer()
if __name__ == "__main__":
torch.manual_seed(20220329)
main()