update some documentation for cross-attention zipformer

This commit is contained in:
marcoyang1998 2023-09-19 14:53:33 +08:00
parent 58dc0430be
commit 8401f26342
3 changed files with 954 additions and 725 deletions

View File

@ -1,141 +0,0 @@
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from scaling import ScaledLinear, softmax
from icefall.utils import make_pad_mask
class ContextFuser(nn.Module):
def __init__(
self,
embed_dim: int = 256,
):
super().__init__()
self.embed_dim = embed_dim
def forward(
self,
context: torch.Tensor,
context_lens: torch.Tensor=None,
padding_mask: torch.Tensor=None
) -> torch.Tensor:
"""A module fusing the context embedding vectors
Args:
context (torch.Tensor): The context embeddings, (B,W,C)
context_lens (torch.Tensor): The length of context embeddings, (B,)
Returns:
torch.Tensor: The fused context embeddings, (B,C)
"""
batch_size = context.size(0)
if padding_mask is None:
assert context_lens is not None
padding_mask = make_pad_mask(context_lens).unsqueeze(-1)
else:
if padding_mask.ndim != 3:
padding_mask = padding_mask.unsqueeze(-1)
context.masked_fill_(padding_mask, 0)
if context_lens is None:
max_len = padding_mask.size(1)
context_lens = max_len - padding_mask.sum(dim=1) + 1e-5 # to prevent 0
# by a small probability, dropout the context of a few samples
context_dropout_rate = 0.05
m = torch.rand((batch_size, 1, 1), device=context.device) > context_dropout_rate
context = context * m
# average the context
context = context.sum(dim=1)/context_lens
return context
class SelfAttContextFuser(nn.Module):
def __init__(
self,
embed_dim: int = 384,
query_head_dim: int = 128,
nhead: int=4,
context_dropout_rate: float=0.05,
):
"""ContextFuser with multi-head self-attention
Args:
embed_dim (int, optional): The input embedding dim. Defaults to 256.
nhead (int, optional): The number of heads. Defaults to 4.
"""
super().__init__()
self.embed_dim = embed_dim
self.nhead = nhead
self.query_head_dim = query_head_dim
self.in_proj = ScaledLinear(embed_dim, nhead * query_head_dim)
self.weight_proj = ScaledLinear(nhead * query_head_dim, nhead)
self.context_dropout_rate = context_dropout_rate
def forward(
self,
context: torch.Tensor,
context_lens: torch.Tensor=None,
padding_mask: torch.Tensor=None,
) -> torch.Tensor:
"""A module fusing the context embedding vectors
Args:
context (torch.Tensor): The context embeddings, (B,W,C)
context_lens (torch.Tensor): The length of context embeddings, (B,)
padding_mask (torch.Tensor): A padding mask (B,W)
Returns:
torch.Tensor: The fused context embeddings, (B,C)
"""
batch_size = context.size(0)
if padding_mask is None:
assert context_lens is not None
padding_mask = make_pad_mask(context_lens).unsqueeze(-1)
else:
if padding_mask.ndim != 3:
padding_mask = padding_mask.unsqueeze(-1)
# context.masked_fill_(padding_mask, 0)
if context_lens is None:
max_len = padding_mask.size(1)
context_lens = max_len - padding_mask.sum(dim=1) + 1e-5 # to prevent 0
k = self.in_proj(context) # (B,W,C)
w = self.weight_proj(torch.tanh(k)) # (B,W,num_heads)
w.masked_fill_(padding_mask, -1000)
w = softmax(w, dim=1) # (B,W,num_heads)
w = w.permute(0,2,1) # (B,num_heads, W)
# reweight and concat the context embeddings
context = torch.matmul(w, context).view(batch_size, -1) # (B, num_heads * C)
# by a small probability, dropout the context of a few samples
if self.training:
m = torch.rand((batch_size, 1), device=context.device) > self.context_dropout_rate
context = context * m
#context = context * 0.0
return context

View File

@ -20,26 +20,29 @@
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless7/train.py \
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 550
--exp-dir zipformer/exp \
--max-duration 1000
# To train a streaming model
./zipformer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--causal 1
--exp-dir zipformer/exp \
--max-duration 1000
"""
@ -54,7 +57,6 @@ from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import k2
import numpy
import optim
import sentencepiece as spm
import torch
@ -70,11 +72,11 @@ from model_baseline import Transducer
from optim import Eden, ScaledAdam
from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling
from text_normalization import train_text_normalization, upper_only_alpha
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from text_normalization import train_text_normalization, upper_only_alpha
from zipformer import Zipformer2
from icefall import diagnostics
@ -95,41 +97,28 @@ from icefall.utils import (
str2bool,
)
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
def random_sampling(texts: List[str]) -> str:
return random.choice(texts)
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
i = random.randint(0, 1)
out = {
"text": texts[i],
"pre_text": pre_texts[i]
}
return out
def get_first(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
) -> str:
out = {
"text": texts[0],
"pre_text": pre_texts[0]
}
# Always get the first one, which is the mixed-cased text with punc
out = {"text": texts[0], "pre_text": pre_texts[0]}
return out
def get_upper_only_alpha(
texts: List[str],
pre_texts: List[str],
context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None,
) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
# Always get the first one, which is the mixed-cased text with punc,
# but with upper case it and remove punctuation
out = {
"text": upper_only_alpha(texts[0]),
"pre_text": upper_only_alpha(pre_texts[0]),
@ -252,13 +241,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
default=512,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
@ -413,8 +401,7 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
help="The scale to smooth the loss with am (output of encoder network)" "part.",
)
parser.add_argument(
@ -582,7 +569,8 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
)
return encoder_embed
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer2(
output_downsampling_factor=2,
@ -789,11 +777,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = (
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
@ -809,7 +793,7 @@ def compute_loss(
texts = [train_text_normalization(t) for t in texts]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
if random.random() < 0.02:
logging.info(f"Ref texts: {texts[0]}")
@ -844,9 +828,7 @@ def compute_loss(
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -1035,9 +1017,7 @@ def train_one_epoch(
# behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or (
cur_grad_scale < 32.0 and batch_idx % 400 == 0
):
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01:
if not saved_bad_model:
@ -1059,11 +1039,7 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
)
if tb_writer is not None:
@ -1074,9 +1050,7 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
tb_writer.add_scalar(
"train/grad_scale",
@ -1084,10 +1058,7 @@ def train_one_epoch(
params.batch_idx_train,
)
if (
batch_idx % params.valid_interval == 0
and not params.print_diagnostics
):
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
@ -1177,9 +1148,7 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(
model, lr=params.base_lr, include_names=True
),
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)
@ -1200,7 +1169,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1210,7 +1179,7 @@ def run(rank, world_size, args):
libriheavy = LibriHeavyAsrDataModule(args)
train_cuts = libriheavy.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
@ -1265,14 +1234,14 @@ def run(rank, world_size, args):
valid_cuts = libriheavy.dev_cuts()
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:

File diff suppressed because it is too large Load Diff