update some documentation for cross-attention zipformer

This commit is contained in:
marcoyang1998 2023-09-19 14:53:33 +08:00
parent 58dc0430be
commit 8401f26342
3 changed files with 954 additions and 725 deletions

View File

@ -1,141 +0,0 @@
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from scaling import ScaledLinear, softmax
from icefall.utils import make_pad_mask
class ContextFuser(nn.Module):
def __init__(
self,
embed_dim: int = 256,
):
super().__init__()
self.embed_dim = embed_dim
def forward(
self,
context: torch.Tensor,
context_lens: torch.Tensor=None,
padding_mask: torch.Tensor=None
) -> torch.Tensor:
"""A module fusing the context embedding vectors
Args:
context (torch.Tensor): The context embeddings, (B,W,C)
context_lens (torch.Tensor): The length of context embeddings, (B,)
Returns:
torch.Tensor: The fused context embeddings, (B,C)
"""
batch_size = context.size(0)
if padding_mask is None:
assert context_lens is not None
padding_mask = make_pad_mask(context_lens).unsqueeze(-1)
else:
if padding_mask.ndim != 3:
padding_mask = padding_mask.unsqueeze(-1)
context.masked_fill_(padding_mask, 0)
if context_lens is None:
max_len = padding_mask.size(1)
context_lens = max_len - padding_mask.sum(dim=1) + 1e-5 # to prevent 0
# by a small probability, dropout the context of a few samples
context_dropout_rate = 0.05
m = torch.rand((batch_size, 1, 1), device=context.device) > context_dropout_rate
context = context * m
# average the context
context = context.sum(dim=1)/context_lens
return context
class SelfAttContextFuser(nn.Module):
def __init__(
self,
embed_dim: int = 384,
query_head_dim: int = 128,
nhead: int=4,
context_dropout_rate: float=0.05,
):
"""ContextFuser with multi-head self-attention
Args:
embed_dim (int, optional): The input embedding dim. Defaults to 256.
nhead (int, optional): The number of heads. Defaults to 4.
"""
super().__init__()
self.embed_dim = embed_dim
self.nhead = nhead
self.query_head_dim = query_head_dim
self.in_proj = ScaledLinear(embed_dim, nhead * query_head_dim)
self.weight_proj = ScaledLinear(nhead * query_head_dim, nhead)
self.context_dropout_rate = context_dropout_rate
def forward(
self,
context: torch.Tensor,
context_lens: torch.Tensor=None,
padding_mask: torch.Tensor=None,
) -> torch.Tensor:
"""A module fusing the context embedding vectors
Args:
context (torch.Tensor): The context embeddings, (B,W,C)
context_lens (torch.Tensor): The length of context embeddings, (B,)
padding_mask (torch.Tensor): A padding mask (B,W)
Returns:
torch.Tensor: The fused context embeddings, (B,C)
"""
batch_size = context.size(0)
if padding_mask is None:
assert context_lens is not None
padding_mask = make_pad_mask(context_lens).unsqueeze(-1)
else:
if padding_mask.ndim != 3:
padding_mask = padding_mask.unsqueeze(-1)
# context.masked_fill_(padding_mask, 0)
if context_lens is None:
max_len = padding_mask.size(1)
context_lens = max_len - padding_mask.sum(dim=1) + 1e-5 # to prevent 0
k = self.in_proj(context) # (B,W,C)
w = self.weight_proj(torch.tanh(k)) # (B,W,num_heads)
w.masked_fill_(padding_mask, -1000)
w = softmax(w, dim=1) # (B,W,num_heads)
w = w.permute(0,2,1) # (B,num_heads, W)
# reweight and concat the context embeddings
context = torch.matmul(w, context).view(batch_size, -1) # (B, num_heads * C)
# by a small probability, dropout the context of a few samples
if self.training:
m = torch.rand((batch_size, 1), device=context.device) > self.context_dropout_rate
context = context * m
#context = context * 0.0
return context

View File

@ -20,26 +20,29 @@
""" """
Usage: Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless7/train.py \ export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \ --exp-dir zipformer/exp \
--full-libri 1 \ --max-duration 1000
--max-duration 550
# To train a streaming model
./zipformer/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--causal 1
--exp-dir zipformer/exp \
--max-duration 1000
""" """
@ -54,7 +57,6 @@ from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import k2 import k2
import numpy
import optim import optim
import sentencepiece as spm import sentencepiece as spm
import torch import torch
@ -70,11 +72,11 @@ from model_baseline import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling import Conv2dSubsampling
from text_normalization import train_text_normalization, upper_only_alpha
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from text_normalization import train_text_normalization, upper_only_alpha
from zipformer import Zipformer2 from zipformer import Zipformer2
from icefall import diagnostics from icefall import diagnostics
@ -95,41 +97,28 @@ from icefall.utils import (
str2bool, str2bool,
) )
LRSchedulerType = Union[ LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def random_sampling(texts: List[str]) -> str:
return random.choice(texts)
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
i = random.randint(0, 1)
out = {
"text": texts[i],
"pre_text": pre_texts[i]
}
return out
def get_first( def get_first(
texts: List[str], texts: List[str],
pre_texts: List[str], pre_texts: List[str],
context_list: Optional[str] = None, context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None, rare_word_list: Optional[List[str]] = None,
) -> str: ) -> str:
out = { # Always get the first one, which is the mixed-cased text with punc
"text": texts[0], out = {"text": texts[0], "pre_text": pre_texts[0]}
"pre_text": pre_texts[0]
}
return out return out
def get_upper_only_alpha( def get_upper_only_alpha(
texts: List[str], texts: List[str],
pre_texts: List[str], pre_texts: List[str],
context_list: Optional[str] = None, context_list: Optional[str] = None,
rare_word_list: Optional[List[str]] = None, rare_word_list: Optional[List[str]] = None,
) -> str: ) -> str:
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha # Always get the first one, which is the mixed-cased text with punc,
# but with upper case it and remove punctuation
out = { out = {
"text": upper_only_alpha(texts[0]), "text": upper_only_alpha(texts[0]),
"pre_text": upper_only_alpha(pre_texts[0]), "pre_text": upper_only_alpha(pre_texts[0]),
@ -257,8 +246,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -413,8 +401,7 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" help="The scale to smooth the loss with am (output of encoder network)" "part.",
"part.",
) )
parser.add_argument( parser.add_argument(
@ -583,6 +570,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
) )
return encoder_embed return encoder_embed
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer2( encoder = Zipformer2(
output_downsampling_factor=2, output_downsampling_factor=2,
@ -789,11 +777,7 @@ def compute_loss(
warmup: a floating point value which increases throughout training; warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present. values >= 1.0 are fully warmed up and have all modules present.
""" """
device = ( device = model.device if isinstance(model, DDP) else next(model.parameters()).device
model.device
if isinstance(model, DDP)
else next(model.parameters()).device
)
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
@ -844,9 +828,7 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -1035,9 +1017,7 @@ def train_one_epoch(
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
cur_grad_scale = scaler._scale.item() cur_grad_scale = scaler._scale.item()
if cur_grad_scale < 8.0 or ( if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
cur_grad_scale < 32.0 and batch_idx % 400 == 0
):
scaler.update(cur_grad_scale * 2.0) scaler.update(cur_grad_scale * 2.0)
if cur_grad_scale < 0.01: if cur_grad_scale < 0.01:
if not saved_bad_model: if not saved_bad_model:
@ -1059,11 +1039,7 @@ def train_one_epoch(
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " f"lr: {cur_lr:.2e}, "
+ ( + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
f"grad_scale: {scaler._scale.item()}"
if params.use_fp16
else ""
)
) )
if tb_writer is not None: if tb_writer is not None:
@ -1074,9 +1050,7 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary( tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tb_writer, "train/tot_", params.batch_idx_train
)
if params.use_fp16: if params.use_fp16:
tb_writer.add_scalar( tb_writer.add_scalar(
"train/grad_scale", "train/grad_scale",
@ -1084,10 +1058,7 @@ def train_one_epoch(
params.batch_idx_train, params.batch_idx_train,
) )
if ( if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
batch_idx % params.valid_interval == 0
and not params.print_diagnostics
):
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
@ -1177,9 +1148,7 @@ def run(rank, world_size, args):
model = DDP(model, device_ids=[rank], find_unused_parameters=True) model = DDP(model, device_ids=[rank], find_unused_parameters=True)
optimizer = ScaledAdam( optimizer = ScaledAdam(
get_parameter_groups_with_lrs( get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
model, lr=params.base_lr, include_names=True
),
lr=params.base_lr, # should have no effect lr=params.base_lr, # should have no effect
clipping_scale=2.0, clipping_scale=2.0,
) )
@ -1200,7 +1169,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2 ** 22 2**22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -1265,14 +1234,14 @@ def run(rank, world_size, args):
valid_cuts = libriheavy.dev_cuts() valid_cuts = libriheavy.dev_cuts()
valid_dl = libriheavy.valid_dataloaders(valid_cuts) valid_dl = libriheavy.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics: if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
# model=model, model=model,
# train_dl=train_dl, train_dl=train_dl,
# optimizer=optimizer, optimizer=optimizer,
# sp=sp, sp=sp,
# params=params, params=params,
# ) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:

File diff suppressed because it is too large Load Diff