Daniil b293db4baf
Tedlium3 conformer ctc2 (#696)
* modify preparation

* small refacor

* add tedlium3 conformer_ctc2

* modify decode

* filter unk in decode

* add scaling converter

* address comments

* fix lambda function lhotse

* add implicit manifest shuffle

* refactor ctc_greedy_search

* import model arguments from train.py

* style fix

* fix ci test and last style issues

* update RESULTS

* fix RESULTS numbers

* fix label smoothing loss

* update model parameters number in RESULTS
2022-12-13 16:13:26 +08:00

321 lines
9.9 KiB
Python

# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.
The scaled version are required only in the training time. It simplifies our
life by converting them to their non-scaled version during inference.
"""
import copy
import re
from typing import List
import torch
import torch.nn as nn
from lstmp import LSTMP
from scaling import (
ActivationBalancer,
BasicNorm,
ScaledConv1d,
ScaledConv2d,
ScaledEmbedding,
ScaledLinear,
ScaledLSTM,
)
class NonScaledNorm(nn.Module):
"""See BasicNorm for doc"""
def __init__(
self,
num_channels: int,
eps_exp: float,
channel_dim: int = -1, # CAUTION: see documentation.
):
super().__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
self.eps_exp = eps_exp
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not torch.jit.is_tracing():
assert x.shape[self.channel_dim] == self.num_channels
scales = (
torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp
).pow(-0.5)
return x * scales
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
"""Convert an instance of ScaledLinear to nn.Linear.
Args:
scaled_linear:
The layer to be converted.
Returns:
Return a linear layer. It satisfies:
scaled_linear(x) == linear(x)
for any given input tensor `x`.
"""
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
weight = scaled_linear.get_weight()
bias = scaled_linear.get_bias()
has_bias = bias is not None
linear = torch.nn.Linear(
in_features=scaled_linear.in_features,
out_features=scaled_linear.out_features,
bias=True, # otherwise, it throws errors when converting to PNNX format
# device=weight.device, # Pytorch version before v1.9.0 does not have
# this argument. Comment out for now, we will
# see if it will raise error for versions
# after v1.9.0
)
linear.weight.data.copy_(weight)
if has_bias:
linear.bias.data.copy_(bias)
else:
linear.bias.data.zero_()
return linear
def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d:
"""Convert an instance of ScaledConv1d to nn.Conv1d.
Args:
scaled_conv1d:
The layer to be converted.
Returns:
Return an instance of nn.Conv1d that has the same `forward()` behavior
of the given `scaled_conv1d`.
"""
assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d)
weight = scaled_conv1d.get_weight()
bias = scaled_conv1d.get_bias()
has_bias = bias is not None
conv1d = nn.Conv1d(
in_channels=scaled_conv1d.in_channels,
out_channels=scaled_conv1d.out_channels,
kernel_size=scaled_conv1d.kernel_size,
stride=scaled_conv1d.stride,
padding=scaled_conv1d.padding,
dilation=scaled_conv1d.dilation,
groups=scaled_conv1d.groups,
bias=scaled_conv1d.bias is not None,
padding_mode=scaled_conv1d.padding_mode,
)
conv1d.weight.data.copy_(weight)
if has_bias:
conv1d.bias.data.copy_(bias)
return conv1d
def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
"""Convert an instance of ScaledConv2d to nn.Conv2d.
Args:
scaled_conv2d:
The layer to be converted.
Returns:
Return an instance of nn.Conv2d that has the same `forward()` behavior
of the given `scaled_conv2d`.
"""
assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d)
weight = scaled_conv2d.get_weight()
bias = scaled_conv2d.get_bias()
has_bias = bias is not None
conv2d = nn.Conv2d(
in_channels=scaled_conv2d.in_channels,
out_channels=scaled_conv2d.out_channels,
kernel_size=scaled_conv2d.kernel_size,
stride=scaled_conv2d.stride,
padding=scaled_conv2d.padding,
dilation=scaled_conv2d.dilation,
groups=scaled_conv2d.groups,
bias=scaled_conv2d.bias is not None,
padding_mode=scaled_conv2d.padding_mode,
)
conv2d.weight.data.copy_(weight)
if has_bias:
conv2d.bias.data.copy_(bias)
return conv2d
def scaled_embedding_to_embedding(
scaled_embedding: ScaledEmbedding,
) -> nn.Embedding:
"""Convert an instance of ScaledEmbedding to nn.Embedding.
Args:
scaled_embedding:
The layer to be converted.
Returns:
Return an instance of nn.Embedding that has the same `forward()` behavior
of the given `scaled_embedding`.
"""
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
embedding = nn.Embedding(
num_embeddings=scaled_embedding.num_embeddings,
embedding_dim=scaled_embedding.embedding_dim,
padding_idx=scaled_embedding.padding_idx,
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
sparse=scaled_embedding.sparse,
)
weight = scaled_embedding.weight
scale = scaled_embedding.scale
embedding.weight.data.copy_(weight * scale.exp())
return embedding
def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm:
assert isinstance(basic_norm, BasicNorm), type(BasicNorm)
norm = NonScaledNorm(
num_channels=basic_norm.num_channels,
eps_exp=basic_norm.eps.data.exp().item(),
channel_dim=basic_norm.channel_dim,
)
return norm
def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM:
"""Convert an instance of ScaledLSTM to nn.LSTM.
Args:
scaled_lstm:
The layer to be converted.
Returns:
Return an instance of nn.LSTM that has the same `forward()` behavior
of the given `scaled_lstm`.
"""
assert isinstance(scaled_lstm, ScaledLSTM), type(scaled_lstm)
lstm = nn.LSTM(
input_size=scaled_lstm.input_size,
hidden_size=scaled_lstm.hidden_size,
num_layers=scaled_lstm.num_layers,
bias=scaled_lstm.bias,
batch_first=scaled_lstm.batch_first,
dropout=scaled_lstm.dropout,
bidirectional=scaled_lstm.bidirectional,
proj_size=scaled_lstm.proj_size,
)
assert lstm._flat_weights_names == scaled_lstm._flat_weights_names
for idx in range(len(scaled_lstm._flat_weights_names)):
scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp()
lstm._flat_weights[idx].data.copy_(scaled_weight)
return lstm
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_onnx: bool = False,
):
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
and `nn.Conv2d`.
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_onnx:
If True, we are going to export the model to ONNX. In this case,
we will convert nn.LSTM with proj_size to LSTMP.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
excluded_patterns = r"(self|src)_attn\.(in|out)_proj"
p = re.compile(excluded_patterns)
d = {}
for name, m in model.named_modules():
if isinstance(m, ScaledLinear):
if p.search(name) is not None:
continue
d[name] = scaled_linear_to_linear(m)
elif isinstance(m, ScaledConv1d):
d[name] = scaled_conv1d_to_conv1d(m)
elif isinstance(m, ScaledConv2d):
d[name] = scaled_conv2d_to_conv2d(m)
elif isinstance(m, ScaledEmbedding):
d[name] = scaled_embedding_to_embedding(m)
elif isinstance(m, BasicNorm):
d[name] = convert_basic_norm(m)
elif isinstance(m, ScaledLSTM):
if is_onnx:
d[name] = LSTMP(scaled_lstm_to_lstm(m))
# See
# https://github.com/pytorch/pytorch/issues/47887
# d[name] = torch.jit.script(LSTMP(scaled_lstm_to_lstm(m)))
else:
d[name] = scaled_lstm_to_lstm(m)
elif isinstance(m, ActivationBalancer):
d[name] = nn.Identity()
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model