mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
* Add k2SSL * fix flake8 * fix for black * fix for black * fix for black * Update ssl_datamodule.py * Fix bugs in HubertDataset * update comments * add librilight * add checkpoint convert script * format --------- Co-authored-by: yifanyeung <yifanyeung@yifanyeung.local> Co-authored-by: zzasdf <15218404468@163.com>
109 lines
3.6 KiB
Python
109 lines
3.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
#
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
#
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import math
|
|
from typing import List, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast
|
|
|
|
|
|
class ConvFeatureExtractionModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
conv_layers: List[Tuple[int, int, int]],
|
|
dropout: float = 0.0,
|
|
mode: str = "default",
|
|
conv_bias: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
assert mode in {"default", "layer_norm"}
|
|
|
|
def block(
|
|
n_in,
|
|
n_out,
|
|
k,
|
|
stride,
|
|
is_layer_norm=False,
|
|
is_group_norm=False,
|
|
conv_bias=False,
|
|
):
|
|
def make_conv():
|
|
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
|
nn.init.kaiming_normal_(conv.weight)
|
|
return conv
|
|
|
|
assert (
|
|
is_layer_norm and is_group_norm
|
|
) == False, "layer norm and group norm are exclusive"
|
|
|
|
if is_layer_norm:
|
|
return nn.Sequential(
|
|
make_conv(),
|
|
nn.Dropout(p=dropout),
|
|
nn.Sequential(
|
|
TransposeLast(),
|
|
Fp32LayerNorm(dim, elementwise_affine=True),
|
|
TransposeLast(),
|
|
),
|
|
nn.GELU(),
|
|
)
|
|
elif is_group_norm:
|
|
return nn.Sequential(
|
|
make_conv(),
|
|
nn.Dropout(p=dropout),
|
|
Fp32GroupNorm(dim, dim, affine=True),
|
|
nn.GELU(),
|
|
)
|
|
else:
|
|
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
|
|
|
in_d = 1
|
|
self.conv_layers = nn.ModuleList()
|
|
for i, cl in enumerate(conv_layers):
|
|
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
|
(dim, k, stride) = cl
|
|
|
|
self.conv_layers.append(
|
|
block(
|
|
in_d,
|
|
dim,
|
|
k,
|
|
stride,
|
|
is_layer_norm=mode == "layer_norm",
|
|
is_group_norm=mode == "default" and i == 0,
|
|
conv_bias=conv_bias,
|
|
)
|
|
)
|
|
in_d = dim
|
|
|
|
def forward(self, x):
|
|
# BxT -> BxCxT
|
|
x = x.unsqueeze(1)
|
|
|
|
for conv in self.conv_layers:
|
|
x = conv(x)
|
|
|
|
return x
|