First version of subformer that runs.

This commit is contained in:
Daniel Povey 2023-05-15 16:03:01 +08:00
parent 1b8be0744f
commit 047c6ffc58
4 changed files with 115 additions and 90 deletions

View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import torch
from torch import nn, Tensor
class Decoder(nn.Module):
"""
"""
def __init__(self,
embed_dim: int,
vocab_size: int):
"""
A 'decoder' that computes the probability of symbols in a language modeling task.
"""
super().__init__()
self.out_proj = nn.Linear(embed_dim,
vocab_size)
def forward(self,
labels: Tensor,
encoder_embed: Tensor) -> Tensor:
"""
Compute log-probs.
Args:
labels: the labels, a Tensor of integer type of shape (batch_size, seq_len);
encoder_embed: the embeddings from the encoder, of shape (seq_len, batch_size, embed_dim)
Returns:
returns the log-probs for each symbol, in a Tensor of shape (batch_size, seq_len).
"""
(batch_size, seq_len) = labels.shape
(num_chunks, _batch_size, embed_dim) = encoder_embed.shape
assert batch_size == _batch_size
x = self.out_proj(encoder_embed)
x = x.transpose(0, 1)
# x: (batch_size, seq_len, vocab_size)
x = x.log_softmax(dim=-1)
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
return logprobs

View File

@ -19,14 +19,14 @@
import torch
from torch import nn, Tensor
from zipformer import Zipformer2
from subformer import Subformer
class Zipformer2LM(nn.Module):
class SubformerLM(nn.Module):
def __init__(self,
encoder_embed: nn.Module,
encoder: Zipformer2,
encoder: Subformer,
decoder: nn.Module):
super().__init__()
self.encoder_embed = encoder_embed

View File

@ -41,7 +41,7 @@ from scaling import (
from torch import Tensor, nn
class Subformer2(EncoderInterface):
class Subformer(EncoderInterface):
"""
Args:
@ -94,11 +94,11 @@ class Subformer2(EncoderInterface):
feedforward_dim: Union[int, Tuple[int]] = 1536,
memory_dim: int = -1,
pos_dim: int = 4,
dropout: FloatLike = None, # see code below for default
dropout: Optional[FloatLike] = None, # see code below for default
warmup_batches: float = 4000.0,
causal: bool = False,
) -> None:
super(Subformer2, self).__init__()
super(Subformer, self).__init__()
if dropout is None:
dropout = ScheduledFloat((0.0, 0.3),
@ -129,13 +129,13 @@ class Subformer2(EncoderInterface):
for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
# each one will be Subformer2Encoder or DownsampledSubformer2Encoder
# each one will be SubformerEncoder or DownsampledSubformerEncoder
encoders = []
num_encoders = len(downsampling_factor)
for i in range(num_encoders):
encoder_layer = Subformer2EncoderLayer(
encoder_layer = SubformerEncoderLayer(
embed_dim=encoder_dim[i],
pos_dim=pos_dim,
num_heads=num_heads[i],
@ -149,7 +149,7 @@ class Subformer2(EncoderInterface):
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = Subformer2Encoder(
encoder = SubformerEncoder(
encoder_layer,
num_encoder_layers[i],
dropout=dropout,
@ -159,7 +159,7 @@ class Subformer2(EncoderInterface):
)
if downsampling_factor[i] != 1:
encoder = DownsampledSubformer2Encoder(
encoder = DownsampledSubformerEncoder(
encoder,
dim=encoder_dim[i],
downsample=downsampling_factor[i],
@ -359,7 +359,7 @@ def _balancer_schedule(min_prob: float):
class Subformer2EncoderLayer(nn.Module):
class SubformerEncoderLayer(nn.Module):
"""
Args:
embed_dim: the number of expected features in the input (required).
@ -368,7 +368,7 @@ class Subformer2EncoderLayer(nn.Module):
dropout: the dropout value (default=0.1).
Examples::
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
>>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb)
@ -391,7 +391,7 @@ class Subformer2EncoderLayer(nn.Module):
ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)),
bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0),
) -> None:
super(Subformer2EncoderLayer, self).__init__()
super(SubformerEncoderLayer, self).__init__()
self.embed_dim = embed_dim
# self.bypass implements layer skipping as well as bypass; see its default values.
@ -508,25 +508,6 @@ class Subformer2EncoderLayer(nn.Module):
)
def get_bypass_scale(self, batch_size: int):
# returns bypass-scale of shape (num_channels,),
# or (batch_size, num_channels,). This is actually the
# scale on the non-residual term, so 0 correponds to bypassing
# this module.
if torch.jit.is_scripting() or not self.training:
return self.bypass_scale
else:
ans = limit_param_value(self.bypass_scale,
min=float(self.bypass_min),
max=float(self.bypass_max))
layer_skip_rate = float(self.layer_skip_rate)
if layer_skip_rate != 0.0:
mask = torch.rand((batch_size, 1), device=ans.device) > layer_skip_rate
ans = ans * mask
# now ans is of shape (batch_size, num_channels), and is zero for sequences
# on which we have randomly chosen to do layer-skipping.
return ans
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
return None
@ -645,16 +626,16 @@ class Subformer2EncoderLayer(nn.Module):
return src
class Subformer2Encoder(nn.Module):
r"""Subformer2Encoder is a stack of N encoder layers
class SubformerEncoder(nn.Module):
r"""SubformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
encoder_layer: an instance of the SubformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
Examples::
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
>>> zipformer_encoder = Subformer2Encoder(encoder_layer, num_layers=6)
>>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8)
>>> zipformer_encoder = SubformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = zipformer_encoder(src)
"""
@ -824,7 +805,7 @@ class LearnedDownsamplingModule(nn.Module):
def forward(self,
x: Tensor) -> Tuple[Tensor, Tensor]:
x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
x: a Tensor of shape (seq_len, batch_size, embed_dim)
@ -853,17 +834,19 @@ class LearnedDownsamplingModule(nn.Module):
# TODO: if seq_len / downsampling_factor <= 2, do something special.
intermediate_rate = float(self.intermediate_rate)
# 'right' is the rightmost of the 2 limits; we want the scores indexed
# 'upper' to be mapped to around 0.0
right = seq_len_reduced
# we want scores around 'left' to be mapped to around 1.0.
left = int(seq_len_reduced * (1.0 - self.intermediate_rate))
left = int(seq_len_reduced * (1.0 - intermediate_rate))
# 'collar' determines the range of positions in the sorted list that we use to
# compute the average. We could let collar be 0.0, which would more exactly
# accomplish what we want; but we don't, because this would cause too-noisy
# gradients, with too much gradient going to one frame.
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
collar = max(1, int(seq_len_reduced * 0.5 * intermediate_rate))
# right_avg: shape (batch_size,), this is to be mapped to 0.0
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
@ -1009,9 +992,9 @@ class LearnedDownsamplingModule(nn.Module):
return ans + x_orig * not_kept.t().unsqueeze(-1)
class DownsampledSubformer2Encoder(nn.Module):
class DownsampledSubformerEncoder(nn.Module):
"""
DownsampledSubformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
DownsampledSubformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input.
"""
@ -1020,7 +1003,7 @@ class DownsampledSubformer2Encoder(nn.Module):
dim: int,
downsample: int,
dropout: FloatLike):
super(DownsampledSubformer2Encoder, self).__init__()
super(DownsampledSubformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsampler = LearnedDownsamplingModule(dim,
downsample)
@ -1028,12 +1011,11 @@ class DownsampledSubformer2Encoder(nn.Module):
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
def forward(self,
src: Tensor,
pos_emb: Tensor,
attn_offset: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
attn_offset: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
@ -1638,7 +1620,7 @@ class MultiheadAttentionWeights(nn.Module):
class FeedforwardModule(nn.Module):
"""Feedforward module in Subformer2 model.
"""Feedforward module in Subformer model.
"""
def __init__(self,
embed_dim: int,
@ -1795,7 +1777,7 @@ def _test_zipformer_main(causal: bool = False):
# Just make sure the forward pass runs.
memory_dim = 100
c = Subformer2(
c = Subformer(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
causal=causal,
memory_dim=memory_dim,

View File

@ -60,11 +60,11 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from lm_datamodule import LmDataset, LmDataloader
from zipformer import Zipformer2
from subformer import Subformer
from scaling import ScheduledFloat
from lhotse.utils import fix_random_seed
from decoder import Decoder
from model import Zipformer2LM
from model import SubformerLM
from optim import Eden, ScaledAdam
from torch import Tensor
from torch import nn
@ -121,15 +121,15 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=str,
default="2,4,8",
help="Number of zipformer encoder layers per stack, comma separated.",
default="2,4,8,4,2",
help="Number of subformer encoder layers per stack, comma separated.",
)
parser.add_argument(
"--downsampling-factor",
type=str,
default="1,2,4",
default="1,2,4,2,1",
help="Downsampling factor for each stack of encoder layers.",
)
@ -137,21 +137,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--feedforward-dim",
type=str,
default="768,1024,1536",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
default="512,768,1024,768,512",
help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.",
)
parser.add_argument(
"--num-heads",
type=str,
default="4,4,8",
help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.",
default="4,4,8,4,4",
help="Number of attention heads in the subformer encoder layers: a single int or comma-separated list.",
)
parser.add_argument(
"--encoder-dim",
type=str,
default="256,384,512",
default="256,256,384,256,256",
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
)
@ -170,42 +170,20 @@ def add_model_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--pos-head-dim",
"--pos-dim",
type=str,
default="4",
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
default="192,192,256",
default="192,192,256,192,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
)
parser.add_argument(
"--cnn-module-kernel",
type=str,
default="31,31,15",
help="Sizes of convolutional kernels in convolution modules in each encoder stack: "
"a single int or comma-separated list.",
)
parser.add_argument(
"--decoder-hidden-size",
type=int,
default=768,
help="LSTM dimension in decoder",
)
parser.add_argument(
"--decoder-num-layers",
type=int,
default=2,
help="Number of LSTM layers in decoder",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -474,7 +452,7 @@ def get_params() -> AttributeDict:
"warm_step": 2000,
"env_info": get_env_info(),
"bytes_per_segment": 2048,
"batch_size": 32,
"batch_size": 16,
"train_file_list": "train.txt",
"valid_file_list": "valid.txt",
"num_workers": 4,
@ -499,28 +477,26 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
def get_encoder_model(params: AttributeDict) -> nn.Module:
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
encoder = Zipformer2(
encoder = Subformer(
#output_downsampling_factor=chunk_size,
downsampling_factor=_to_int_tuple(params.downsampling_factor),
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(params.query_head_dim),
pos_head_dim=_to_int_tuple(params.pos_head_dim),
pos_dim=int(params.pos_dim),
value_head_dim=_to_int_tuple(params.value_head_dim),
num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=True,
chunk_size=(chunk_size,),
left_context_frames=(-1,),
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = DecoderDecoder(
decoder = Decoder(
embed_dim=max(_to_int_tuple(params.encoder_dim)),
vocab_size=256, # bytes
)
@ -532,7 +508,7 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
model = Zipformer2LM(
model = SubformerLM(
encoder_embed=encoder_embed,
encoder=encoder,
decoder=decoder,
@ -698,7 +674,7 @@ def compute_loss(
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Zipformer in our case.
The model for training. It is an instance of Subformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.