mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 09:34:39 +00:00
update some documentation for cross-attention zipformer
This commit is contained in:
parent
58dc0430be
commit
8401f26342
@ -1,141 +0,0 @@
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from scaling import ScaledLinear, softmax
|
||||
|
||||
from icefall.utils import make_pad_mask
|
||||
|
||||
class ContextFuser(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 256,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
context_lens: torch.Tensor=None,
|
||||
padding_mask: torch.Tensor=None
|
||||
) -> torch.Tensor:
|
||||
"""A module fusing the context embedding vectors
|
||||
|
||||
Args:
|
||||
context (torch.Tensor): The context embeddings, (B,W,C)
|
||||
context_lens (torch.Tensor): The length of context embeddings, (B,)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The fused context embeddings, (B,C)
|
||||
"""
|
||||
batch_size = context.size(0)
|
||||
if padding_mask is None:
|
||||
assert context_lens is not None
|
||||
padding_mask = make_pad_mask(context_lens).unsqueeze(-1)
|
||||
else:
|
||||
if padding_mask.ndim != 3:
|
||||
padding_mask = padding_mask.unsqueeze(-1)
|
||||
context.masked_fill_(padding_mask, 0)
|
||||
|
||||
if context_lens is None:
|
||||
max_len = padding_mask.size(1)
|
||||
context_lens = max_len - padding_mask.sum(dim=1) + 1e-5 # to prevent 0
|
||||
|
||||
# by a small probability, dropout the context of a few samples
|
||||
context_dropout_rate = 0.05
|
||||
m = torch.rand((batch_size, 1, 1), device=context.device) > context_dropout_rate
|
||||
context = context * m
|
||||
|
||||
# average the context
|
||||
context = context.sum(dim=1)/context_lens
|
||||
|
||||
return context
|
||||
|
||||
class SelfAttContextFuser(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 384,
|
||||
query_head_dim: int = 128,
|
||||
nhead: int=4,
|
||||
context_dropout_rate: float=0.05,
|
||||
):
|
||||
"""ContextFuser with multi-head self-attention
|
||||
|
||||
Args:
|
||||
embed_dim (int, optional): The input embedding dim. Defaults to 256.
|
||||
nhead (int, optional): The number of heads. Defaults to 4.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.nhead = nhead
|
||||
self.query_head_dim = query_head_dim
|
||||
|
||||
self.in_proj = ScaledLinear(embed_dim, nhead * query_head_dim)
|
||||
self.weight_proj = ScaledLinear(nhead * query_head_dim, nhead)
|
||||
self.context_dropout_rate = context_dropout_rate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
context_lens: torch.Tensor=None,
|
||||
padding_mask: torch.Tensor=None,
|
||||
) -> torch.Tensor:
|
||||
"""A module fusing the context embedding vectors
|
||||
|
||||
Args:
|
||||
context (torch.Tensor): The context embeddings, (B,W,C)
|
||||
context_lens (torch.Tensor): The length of context embeddings, (B,)
|
||||
padding_mask (torch.Tensor): A padding mask (B,W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The fused context embeddings, (B,C)
|
||||
"""
|
||||
batch_size = context.size(0)
|
||||
if padding_mask is None:
|
||||
assert context_lens is not None
|
||||
padding_mask = make_pad_mask(context_lens).unsqueeze(-1)
|
||||
else:
|
||||
if padding_mask.ndim != 3:
|
||||
padding_mask = padding_mask.unsqueeze(-1)
|
||||
# context.masked_fill_(padding_mask, 0)
|
||||
|
||||
if context_lens is None:
|
||||
max_len = padding_mask.size(1)
|
||||
context_lens = max_len - padding_mask.sum(dim=1) + 1e-5 # to prevent 0
|
||||
|
||||
k = self.in_proj(context) # (B,W,C)
|
||||
w = self.weight_proj(torch.tanh(k)) # (B,W,num_heads)
|
||||
|
||||
w.masked_fill_(padding_mask, -1000)
|
||||
w = softmax(w, dim=1) # (B,W,num_heads)
|
||||
w = w.permute(0,2,1) # (B,num_heads, W)
|
||||
|
||||
# reweight and concat the context embeddings
|
||||
context = torch.matmul(w, context).view(batch_size, -1) # (B, num_heads * C)
|
||||
|
||||
# by a small probability, dropout the context of a few samples
|
||||
if self.training:
|
||||
m = torch.rand((batch_size, 1), device=context.device) > self.context_dropout_rate
|
||||
context = context * m
|
||||
#context = context * 0.0
|
||||
|
||||
return context
|
@ -20,26 +20,29 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless7/train.py \
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless7/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
--exp-dir zipformer/exp \
|
||||
--max-duration 1000
|
||||
|
||||
# To train a streaming model
|
||||
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--causal 1
|
||||
--exp-dir zipformer/exp \
|
||||
--max-duration 1000
|
||||
|
||||
"""
|
||||
|
||||
@ -54,7 +57,6 @@ from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import numpy
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
@ -70,11 +72,11 @@ from model_baseline import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from scaling import ScheduledFloat
|
||||
from subsampling import Conv2dSubsampling
|
||||
from text_normalization import train_text_normalization, upper_only_alpha
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from text_normalization import train_text_normalization, upper_only_alpha
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall import diagnostics
|
||||
@ -95,41 +97,28 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
def random_sampling(texts: List[str]) -> str:
|
||||
return random.choice(texts)
|
||||
|
||||
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
||||
i = random.randint(0, 1)
|
||||
out = {
|
||||
"text": texts[i],
|
||||
"pre_text": pre_texts[i]
|
||||
}
|
||||
return out
|
||||
|
||||
def get_first(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
out = {
|
||||
"text": texts[0],
|
||||
"pre_text": pre_texts[0]
|
||||
}
|
||||
# Always get the first one, which is the mixed-cased text with punc
|
||||
out = {"text": texts[0], "pre_text": pre_texts[0]}
|
||||
return out
|
||||
|
||||
|
||||
def get_upper_only_alpha(
|
||||
texts: List[str],
|
||||
pre_texts: List[str],
|
||||
context_list: Optional[str] = None,
|
||||
rare_word_list: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
||||
# Always get the first one, which is the mixed-cased text with punc,
|
||||
# but with upper case it and remove punctuation
|
||||
out = {
|
||||
"text": upper_only_alpha(texts[0]),
|
||||
"pre_text": upper_only_alpha(pre_texts[0]),
|
||||
@ -252,13 +241,12 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
default=512,
|
||||
help="Embedding dimension in the decoder model.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -413,8 +401,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -582,7 +569,8 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
)
|
||||
return encoder_embed
|
||||
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Zipformer2(
|
||||
output_downsampling_factor=2,
|
||||
@ -789,11 +777,7 @@ def compute_loss(
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = (
|
||||
model.device
|
||||
if isinstance(model, DDP)
|
||||
else next(model.parameters()).device
|
||||
)
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
@ -809,7 +793,7 @@ def compute_loss(
|
||||
texts = [train_text_normalization(t) for t in texts]
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
|
||||
if random.random() < 0.02:
|
||||
logging.info(f"Ref texts: {texts[0]}")
|
||||
|
||||
@ -844,9 +828,7 @@ def compute_loss(
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -1035,9 +1017,7 @@ def train_one_epoch(
|
||||
# behavior depending on the current grad scale.
|
||||
cur_grad_scale = scaler._scale.item()
|
||||
|
||||
if cur_grad_scale < 8.0 or (
|
||||
cur_grad_scale < 32.0 and batch_idx % 400 == 0
|
||||
):
|
||||
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||
scaler.update(cur_grad_scale * 2.0)
|
||||
if cur_grad_scale < 0.01:
|
||||
if not saved_bad_model:
|
||||
@ -1059,11 +1039,7 @@ def train_one_epoch(
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (
|
||||
f"grad_scale: {scaler._scale.item()}"
|
||||
if params.use_fp16
|
||||
else ""
|
||||
)
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -1074,9 +1050,7 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
if params.use_fp16:
|
||||
tb_writer.add_scalar(
|
||||
"train/grad_scale",
|
||||
@ -1084,10 +1058,7 @@ def train_one_epoch(
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
if (
|
||||
batch_idx % params.valid_interval == 0
|
||||
and not params.print_diagnostics
|
||||
):
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
@ -1177,9 +1148,7 @@ def run(rank, world_size, args):
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(
|
||||
model, lr=params.base_lr, include_names=True
|
||||
),
|
||||
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
@ -1200,7 +1169,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
@ -1210,7 +1179,7 @@ def run(rank, world_size, args):
|
||||
libriheavy = LibriHeavyAsrDataModule(args)
|
||||
|
||||
train_cuts = libriheavy.train_cuts()
|
||||
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
@ -1265,14 +1234,14 @@ def run(rank, world_size, args):
|
||||
valid_cuts = libriheavy.dev_cuts()
|
||||
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
||||
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# params=params,
|
||||
# )
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user