mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
208 lines
6.3 KiB
Python
208 lines
6.3 KiB
Python
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
This file provides functions to convert `ScaledLinear`, `ScaledConv1d`,
|
|
`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts:
|
|
`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`.
|
|
|
|
The scaled version are required only in the training time. It simplifies our
|
|
life by converting them to their non-scaled version during inference.
|
|
"""
|
|
|
|
import copy
|
|
import re
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from scaling import ScaledConv1d, ScaledConv2d, ScaledEmbedding, ScaledLinear
|
|
|
|
|
|
def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear:
|
|
"""Convert an instance of ScaledLinear to nn.Linear.
|
|
|
|
Args:
|
|
scaled_linear:
|
|
The layer to be converted.
|
|
Returns:
|
|
Return a linear layer. It satisfies:
|
|
|
|
scaled_linear(x) == linear(x)
|
|
|
|
for any given input tensor `x`.
|
|
"""
|
|
assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear)
|
|
|
|
weight = scaled_linear.get_weight()
|
|
bias = scaled_linear.get_bias()
|
|
has_bias = bias is not None
|
|
|
|
linear = torch.nn.Linear(
|
|
in_features=scaled_linear.in_features,
|
|
out_features=scaled_linear.out_features,
|
|
bias=True, # otherwise, it throws errors when converting to PNNX format.
|
|
device=weight.device,
|
|
)
|
|
linear.weight.data.copy_(weight)
|
|
|
|
if has_bias:
|
|
linear.bias.data.copy_(bias)
|
|
else:
|
|
linear.bias.data.zero_()
|
|
|
|
return linear
|
|
|
|
|
|
def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d:
|
|
"""Convert an instance of ScaledConv1d to nn.Conv1d.
|
|
|
|
Args:
|
|
scaled_conv1d:
|
|
The layer to be converted.
|
|
Returns:
|
|
Return an instance of nn.Conv1d that has the same `forward()` behavior
|
|
of the given `scaled_conv1d`.
|
|
"""
|
|
assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d)
|
|
|
|
weight = scaled_conv1d.get_weight()
|
|
bias = scaled_conv1d.get_bias()
|
|
has_bias = bias is not None
|
|
|
|
conv1d = nn.Conv1d(
|
|
in_channels=scaled_conv1d.in_channels,
|
|
out_channels=scaled_conv1d.out_channels,
|
|
kernel_size=scaled_conv1d.kernel_size,
|
|
stride=scaled_conv1d.stride,
|
|
padding=scaled_conv1d.padding,
|
|
dilation=scaled_conv1d.dilation,
|
|
groups=scaled_conv1d.groups,
|
|
bias=scaled_conv1d.bias is not None,
|
|
padding_mode=scaled_conv1d.padding_mode,
|
|
)
|
|
|
|
conv1d.weight.data.copy_(weight)
|
|
if has_bias:
|
|
conv1d.bias.data.copy_(bias)
|
|
|
|
return conv1d
|
|
|
|
|
|
def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d:
|
|
"""Convert an instance of ScaledConv2d to nn.Conv2d.
|
|
|
|
Args:
|
|
scaled_conv2d:
|
|
The layer to be converted.
|
|
Returns:
|
|
Return an instance of nn.Conv2d that has the same `forward()` behavior
|
|
of the given `scaled_conv2d`.
|
|
"""
|
|
assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d)
|
|
|
|
weight = scaled_conv2d.get_weight()
|
|
bias = scaled_conv2d.get_bias()
|
|
has_bias = bias is not None
|
|
|
|
conv2d = nn.Conv2d(
|
|
in_channels=scaled_conv2d.in_channels,
|
|
out_channels=scaled_conv2d.out_channels,
|
|
kernel_size=scaled_conv2d.kernel_size,
|
|
stride=scaled_conv2d.stride,
|
|
padding=scaled_conv2d.padding,
|
|
dilation=scaled_conv2d.dilation,
|
|
groups=scaled_conv2d.groups,
|
|
bias=scaled_conv2d.bias is not None,
|
|
padding_mode=scaled_conv2d.padding_mode,
|
|
)
|
|
|
|
conv2d.weight.data.copy_(weight)
|
|
if has_bias:
|
|
conv2d.bias.data.copy_(bias)
|
|
|
|
return conv2d
|
|
|
|
|
|
def scaled_embedding_to_embedding(
|
|
scaled_embedding: ScaledEmbedding,
|
|
) -> nn.Embedding:
|
|
"""Convert an instance of ScaledEmbedding to nn.Embedding.
|
|
|
|
Args:
|
|
scaled_embedding:
|
|
The layer to be converted.
|
|
Returns:
|
|
Return an instance of nn.Embedding that has the same `forward()` behavior
|
|
of the given `scaled_embedding`.
|
|
"""
|
|
assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding)
|
|
embedding = nn.Embedding(
|
|
num_embeddings=scaled_embedding.num_embeddings,
|
|
embedding_dim=scaled_embedding.embedding_dim,
|
|
padding_idx=scaled_embedding.padding_idx,
|
|
scale_grad_by_freq=scaled_embedding.scale_grad_by_freq,
|
|
sparse=scaled_embedding.sparse,
|
|
)
|
|
weight = scaled_embedding.weight
|
|
scale = scaled_embedding.scale
|
|
|
|
embedding.weight.data.copy_(weight * scale.exp())
|
|
|
|
return embedding
|
|
|
|
|
|
def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False):
|
|
"""Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d`
|
|
in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`,
|
|
and `nn.Conv2d`.
|
|
|
|
Args:
|
|
model:
|
|
The model to be converted.
|
|
inplace:
|
|
If True, the input model is modified inplace.
|
|
If False, the input model is copied and we modify the copied version.
|
|
Return:
|
|
Return a model without scaled layers.
|
|
"""
|
|
if not inplace:
|
|
model = copy.deepcopy(model)
|
|
|
|
excluded_patterns = r"self_attn\.(in|out)_proj"
|
|
p = re.compile(excluded_patterns)
|
|
|
|
d = {}
|
|
for name, m in model.named_modules():
|
|
if isinstance(m, ScaledLinear):
|
|
if p.search(name) is not None:
|
|
continue
|
|
d[name] = scaled_linear_to_linear(m)
|
|
elif isinstance(m, ScaledConv1d):
|
|
d[name] = scaled_conv1d_to_conv1d(m)
|
|
elif isinstance(m, ScaledConv2d):
|
|
d[name] = scaled_conv2d_to_conv2d(m)
|
|
elif isinstance(m, ScaledEmbedding):
|
|
d[name] = scaled_embedding_to_embedding(m)
|
|
|
|
for k, v in d.items():
|
|
if "." in k:
|
|
parent, child = k.rsplit(".", maxsplit=1)
|
|
setattr(model.get_submodule(parent), child, v)
|
|
else:
|
|
setattr(model, k, v)
|
|
|
|
return model
|