mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
First version of subformer that runs.
This commit is contained in:
parent
1b8be0744f
commit
047c6ffc58
67
egs/libriheavy/LM/zipformer1/decoder.py
Normal file
67
egs/libriheavy/LM/zipformer1/decoder.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
embed_dim: int,
|
||||||
|
vocab_size: int):
|
||||||
|
"""
|
||||||
|
A 'decoder' that computes the probability of symbols in a language modeling task.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.out_proj = nn.Linear(embed_dim,
|
||||||
|
vocab_size)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
labels: Tensor,
|
||||||
|
encoder_embed: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Compute log-probs.
|
||||||
|
Args:
|
||||||
|
labels: the labels, a Tensor of integer type of shape (batch_size, seq_len);
|
||||||
|
encoder_embed: the embeddings from the encoder, of shape (seq_len, batch_size, embed_dim)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
returns the log-probs for each symbol, in a Tensor of shape (batch_size, seq_len).
|
||||||
|
"""
|
||||||
|
(batch_size, seq_len) = labels.shape
|
||||||
|
(num_chunks, _batch_size, embed_dim) = encoder_embed.shape
|
||||||
|
|
||||||
|
assert batch_size == _batch_size
|
||||||
|
|
||||||
|
x = self.out_proj(encoder_embed)
|
||||||
|
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
|
||||||
|
# x: (batch_size, seq_len, vocab_size)
|
||||||
|
|
||||||
|
x = x.log_softmax(dim=-1)
|
||||||
|
|
||||||
|
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
|
||||||
|
|
||||||
|
return logprobs
|
@ -19,14 +19,14 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
from zipformer import Zipformer2
|
from subformer import Subformer
|
||||||
|
|
||||||
|
|
||||||
class Zipformer2LM(nn.Module):
|
class SubformerLM(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
encoder_embed: nn.Module,
|
encoder_embed: nn.Module,
|
||||||
encoder: Zipformer2,
|
encoder: Subformer,
|
||||||
decoder: nn.Module):
|
decoder: nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder_embed = encoder_embed
|
self.encoder_embed = encoder_embed
|
||||||
|
@ -41,7 +41,7 @@ from scaling import (
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
class Subformer2(EncoderInterface):
|
class Subformer(EncoderInterface):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
|
||||||
@ -94,11 +94,11 @@ class Subformer2(EncoderInterface):
|
|||||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||||
memory_dim: int = -1,
|
memory_dim: int = -1,
|
||||||
pos_dim: int = 4,
|
pos_dim: int = 4,
|
||||||
dropout: FloatLike = None, # see code below for default
|
dropout: Optional[FloatLike] = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Subformer2, self).__init__()
|
super(Subformer, self).__init__()
|
||||||
|
|
||||||
if dropout is None:
|
if dropout is None:
|
||||||
dropout = ScheduledFloat((0.0, 0.3),
|
dropout = ScheduledFloat((0.0, 0.3),
|
||||||
@ -129,13 +129,13 @@ class Subformer2(EncoderInterface):
|
|||||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
|
|
||||||
# each one will be Subformer2Encoder or DownsampledSubformer2Encoder
|
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
||||||
encoders = []
|
encoders = []
|
||||||
|
|
||||||
num_encoders = len(downsampling_factor)
|
num_encoders = len(downsampling_factor)
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
|
|
||||||
encoder_layer = Subformer2EncoderLayer(
|
encoder_layer = SubformerEncoderLayer(
|
||||||
embed_dim=encoder_dim[i],
|
embed_dim=encoder_dim[i],
|
||||||
pos_dim=pos_dim,
|
pos_dim=pos_dim,
|
||||||
num_heads=num_heads[i],
|
num_heads=num_heads[i],
|
||||||
@ -149,7 +149,7 @@ class Subformer2(EncoderInterface):
|
|||||||
|
|
||||||
# For the segment of the warmup period, we let the Conv2dSubsampling
|
# For the segment of the warmup period, we let the Conv2dSubsampling
|
||||||
# layer learn something. Then we start to warm up the other encoders.
|
# layer learn something. Then we start to warm up the other encoders.
|
||||||
encoder = Subformer2Encoder(
|
encoder = SubformerEncoder(
|
||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers[i],
|
num_encoder_layers[i],
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -159,7 +159,7 @@ class Subformer2(EncoderInterface):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if downsampling_factor[i] != 1:
|
if downsampling_factor[i] != 1:
|
||||||
encoder = DownsampledSubformer2Encoder(
|
encoder = DownsampledSubformerEncoder(
|
||||||
encoder,
|
encoder,
|
||||||
dim=encoder_dim[i],
|
dim=encoder_dim[i],
|
||||||
downsample=downsampling_factor[i],
|
downsample=downsampling_factor[i],
|
||||||
@ -359,7 +359,7 @@ def _balancer_schedule(min_prob: float):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Subformer2EncoderLayer(nn.Module):
|
class SubformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
embed_dim: the number of expected features in the input (required).
|
embed_dim: the number of expected features in the input (required).
|
||||||
@ -368,7 +368,7 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8)
|
||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> pos_emb = torch.rand(32, 19, 512)
|
>>> pos_emb = torch.rand(32, 19, 512)
|
||||||
>>> out = encoder_layer(src, pos_emb)
|
>>> out = encoder_layer(src, pos_emb)
|
||||||
@ -391,7 +391,7 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
|
||||||
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Subformer2EncoderLayer, self).__init__()
|
super(SubformerEncoderLayer, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
# self.bypass implements layer skipping as well as bypass; see its default values.
|
# self.bypass implements layer skipping as well as bypass; see its default values.
|
||||||
@ -508,25 +508,6 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_bypass_scale(self, batch_size: int):
|
|
||||||
# returns bypass-scale of shape (num_channels,),
|
|
||||||
# or (batch_size, num_channels,). This is actually the
|
|
||||||
# scale on the non-residual term, so 0 correponds to bypassing
|
|
||||||
# this module.
|
|
||||||
if torch.jit.is_scripting() or not self.training:
|
|
||||||
return self.bypass_scale
|
|
||||||
else:
|
|
||||||
ans = limit_param_value(self.bypass_scale,
|
|
||||||
min=float(self.bypass_min),
|
|
||||||
max=float(self.bypass_max))
|
|
||||||
layer_skip_rate = float(self.layer_skip_rate)
|
|
||||||
if layer_skip_rate != 0.0:
|
|
||||||
mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate
|
|
||||||
ans = ans * mask
|
|
||||||
# now ans is of shape (batch_size, num_channels), and is zero for sequences
|
|
||||||
# on which we have randomly chosen to do layer-skipping.
|
|
||||||
return ans
|
|
||||||
|
|
||||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||||
return None
|
return None
|
||||||
@ -645,16 +626,16 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
return src
|
return src
|
||||||
|
|
||||||
class Subformer2Encoder(nn.Module):
|
class SubformerEncoder(nn.Module):
|
||||||
r"""Subformer2Encoder is a stack of N encoder layers
|
r"""SubformerEncoder is a stack of N encoder layers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
|
encoder_layer: an instance of the SubformerEncoderLayer() class (required).
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8)
|
||||||
>>> zipformer_encoder = Subformer2Encoder(encoder_layer, num_layers=6)
|
>>> zipformer_encoder = SubformerEncoder(encoder_layer, num_layers=6)
|
||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> out = zipformer_encoder(src)
|
>>> out = zipformer_encoder(src)
|
||||||
"""
|
"""
|
||||||
@ -824,7 +805,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor) -> Tuple[Tensor, Tensor]:
|
x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: a Tensor of shape (seq_len, batch_size, embed_dim)
|
x: a Tensor of shape (seq_len, batch_size, embed_dim)
|
||||||
@ -853,17 +834,19 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
|
|
||||||
# TODO: if seq_len / downsampling_factor <= 2, do something special.
|
# TODO: if seq_len / downsampling_factor <= 2, do something special.
|
||||||
|
|
||||||
|
intermediate_rate = float(self.intermediate_rate)
|
||||||
|
|
||||||
# 'right' is the rightmost of the 2 limits; we want the scores indexed
|
# 'right' is the rightmost of the 2 limits; we want the scores indexed
|
||||||
# 'upper' to be mapped to around 0.0
|
# 'upper' to be mapped to around 0.0
|
||||||
right = seq_len_reduced
|
right = seq_len_reduced
|
||||||
# we want scores around 'left' to be mapped to around 1.0.
|
# we want scores around 'left' to be mapped to around 1.0.
|
||||||
left = int(seq_len_reduced * (1.0 - self.intermediate_rate))
|
left = int(seq_len_reduced * (1.0 - intermediate_rate))
|
||||||
|
|
||||||
# 'collar' determines the range of positions in the sorted list that we use to
|
# 'collar' determines the range of positions in the sorted list that we use to
|
||||||
# compute the average. We could let collar be 0.0, which would more exactly
|
# compute the average. We could let collar be 0.0, which would more exactly
|
||||||
# accomplish what we want; but we don't, because this would cause too-noisy
|
# accomplish what we want; but we don't, because this would cause too-noisy
|
||||||
# gradients, with too much gradient going to one frame.
|
# gradients, with too much gradient going to one frame.
|
||||||
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
|
collar = max(1, int(seq_len_reduced * 0.5 * intermediate_rate))
|
||||||
|
|
||||||
# right_avg: shape (batch_size,), this is to be mapped to 0.0
|
# right_avg: shape (batch_size,), this is to be mapped to 0.0
|
||||||
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
|
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
|
||||||
@ -1009,9 +992,9 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
return ans + x_orig * not_kept.t().unsqueeze(-1)
|
return ans + x_orig * not_kept.t().unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class DownsampledSubformer2Encoder(nn.Module):
|
class DownsampledSubformerEncoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
|
DownsampledSubformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
|
||||||
after convolutional downsampling, and then upsampled again at the output, and combined
|
after convolutional downsampling, and then upsampled again at the output, and combined
|
||||||
with the origin input, so that the output has the same shape as the input.
|
with the origin input, so that the output has the same shape as the input.
|
||||||
"""
|
"""
|
||||||
@ -1020,7 +1003,7 @@ class DownsampledSubformer2Encoder(nn.Module):
|
|||||||
dim: int,
|
dim: int,
|
||||||
downsample: int,
|
downsample: int,
|
||||||
dropout: FloatLike):
|
dropout: FloatLike):
|
||||||
super(DownsampledSubformer2Encoder, self).__init__()
|
super(DownsampledSubformerEncoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsampler = LearnedDownsamplingModule(dim,
|
self.downsampler = LearnedDownsamplingModule(dim,
|
||||||
downsample)
|
downsample)
|
||||||
@ -1028,12 +1011,11 @@ class DownsampledSubformer2Encoder(nn.Module):
|
|||||||
|
|
||||||
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
|
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
|
attn_offset: Tensor,
|
||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Union[Tensor, float] = 1.0,
|
||||||
attn_offset: Optional[Tensor] = None,
|
|
||||||
memory: Optional[Tensor] = None,
|
memory: Optional[Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
@ -1638,7 +1620,7 @@ class MultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FeedforwardModule(nn.Module):
|
class FeedforwardModule(nn.Module):
|
||||||
"""Feedforward module in Subformer2 model.
|
"""Feedforward module in Subformer model.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
@ -1795,7 +1777,7 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
memory_dim = 100
|
memory_dim = 100
|
||||||
|
|
||||||
c = Subformer2(
|
c = Subformer(
|
||||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
memory_dim=memory_dim,
|
memory_dim=memory_dim,
|
||||||
|
@ -60,11 +60,11 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from lm_datamodule import LmDataset, LmDataloader
|
from lm_datamodule import LmDataset, LmDataloader
|
||||||
from zipformer import Zipformer2
|
from subformer import Subformer
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from model import Zipformer2LM
|
from model import SubformerLM
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -121,15 +121,15 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
type=str,
|
type=str,
|
||||||
default="2,4,8",
|
default="2,4,8,4,2",
|
||||||
help="Number of zipformer encoder layers per stack, comma separated.",
|
help="Number of subformer encoder layers per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--downsampling-factor",
|
"--downsampling-factor",
|
||||||
type=str,
|
type=str,
|
||||||
default="1,2,4",
|
default="1,2,4,2,1",
|
||||||
help="Downsampling factor for each stack of encoder layers.",
|
help="Downsampling factor for each stack of encoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -137,21 +137,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--feedforward-dim",
|
"--feedforward-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="768,1024,1536",
|
default="512,768,1024,768,512",
|
||||||
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
|
help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-heads",
|
"--num-heads",
|
||||||
type=str,
|
type=str,
|
||||||
default="4,4,8",
|
default="4,4,8,4,4",
|
||||||
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
|
help="Number of attention heads in the subformer encoder layers: a single int or comma-separated list.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dim",
|
"--encoder-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="256,384,512",
|
default="256,256,384,256,256",
|
||||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -170,42 +170,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--pos-head-dim",
|
"--pos-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="4",
|
default="4",
|
||||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-unmasked-dim",
|
"--encoder-unmasked-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="192,192,256",
|
default="192,192,256,192,192",
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--cnn-module-kernel",
|
|
||||||
type=str,
|
|
||||||
default="31,31,15",
|
|
||||||
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
|
|
||||||
"a single int or comma-separated list.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder-hidden-size",
|
|
||||||
type=int,
|
|
||||||
default=768,
|
|
||||||
help="LSTM dimension in decoder",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decoder-num-layers",
|
|
||||||
type=int,
|
|
||||||
default=2,
|
|
||||||
help="Number of LSTM layers in decoder",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -474,7 +452,7 @@ def get_params() -> AttributeDict:
|
|||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"bytes_per_segment": 2048,
|
"bytes_per_segment": 2048,
|
||||||
"batch_size": 32,
|
"batch_size": 16,
|
||||||
"train_file_list": "train.txt",
|
"train_file_list": "train.txt",
|
||||||
"valid_file_list": "valid.txt",
|
"valid_file_list": "valid.txt",
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
@ -499,28 +477,26 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||||
encoder = Zipformer2(
|
encoder = Subformer(
|
||||||
#output_downsampling_factor=chunk_size,
|
#output_downsampling_factor=chunk_size,
|
||||||
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
||||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||||
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
||||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||||
pos_head_dim=_to_int_tuple(params.pos_head_dim),
|
pos_dim=int(params.pos_dim),
|
||||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||||
num_heads=_to_int_tuple(params.num_heads),
|
num_heads=_to_int_tuple(params.num_heads),
|
||||||
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
warmup_batches=4000.0,
|
warmup_batches=4000.0,
|
||||||
causal=True,
|
causal=True,
|
||||||
chunk_size=(chunk_size,),
|
|
||||||
left_context_frames=(-1,),
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = DecoderDecoder(
|
decoder = Decoder(
|
||||||
embed_dim=max(_to_int_tuple(params.encoder_dim)),
|
embed_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
vocab_size=256, # bytes
|
vocab_size=256, # bytes
|
||||||
)
|
)
|
||||||
@ -532,7 +508,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
|
|
||||||
model = Zipformer2LM(
|
model = SubformerLM(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
@ -698,7 +674,7 @@ def compute_loss(
|
|||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The model for training. It is an instance of Zipformer in our case.
|
The model for training. It is an instance of Subformer in our case.
|
||||||
batch:
|
batch:
|
||||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
for the content in it.
|
for the content in it.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user