mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 17:44:20 +00:00
update some documentation for cross-attention zipformer
This commit is contained in:
parent
58dc0430be
commit
8401f26342
@ -1,141 +0,0 @@
|
|||||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
|
||||||
#
|
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from scaling import ScaledLinear, softmax
|
|
||||||
|
|
||||||
from icefall.utils import make_pad_mask
|
|
||||||
|
|
||||||
class ContextFuser(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim: int = 256,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
context: torch.Tensor,
|
|
||||||
context_lens: torch.Tensor=None,
|
|
||||||
padding_mask: torch.Tensor=None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""A module fusing the context embedding vectors
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context (torch.Tensor): The context embeddings, (B,W,C)
|
|
||||||
context_lens (torch.Tensor): The length of context embeddings, (B,)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The fused context embeddings, (B,C)
|
|
||||||
"""
|
|
||||||
batch_size = context.size(0)
|
|
||||||
if padding_mask is None:
|
|
||||||
assert context_lens is not None
|
|
||||||
padding_mask = make_pad_mask(context_lens).unsqueeze(-1)
|
|
||||||
else:
|
|
||||||
if padding_mask.ndim != 3:
|
|
||||||
padding_mask = padding_mask.unsqueeze(-1)
|
|
||||||
context.masked_fill_(padding_mask, 0)
|
|
||||||
|
|
||||||
if context_lens is None:
|
|
||||||
max_len = padding_mask.size(1)
|
|
||||||
context_lens = max_len - padding_mask.sum(dim=1) + 1e-5 # to prevent 0
|
|
||||||
|
|
||||||
# by a small probability, dropout the context of a few samples
|
|
||||||
context_dropout_rate = 0.05
|
|
||||||
m = torch.rand((batch_size, 1, 1), device=context.device) > context_dropout_rate
|
|
||||||
context = context * m
|
|
||||||
|
|
||||||
# average the context
|
|
||||||
context = context.sum(dim=1)/context_lens
|
|
||||||
|
|
||||||
return context
|
|
||||||
|
|
||||||
class SelfAttContextFuser(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
embed_dim: int = 384,
|
|
||||||
query_head_dim: int = 128,
|
|
||||||
nhead: int=4,
|
|
||||||
context_dropout_rate: float=0.05,
|
|
||||||
):
|
|
||||||
"""ContextFuser with multi-head self-attention
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embed_dim (int, optional): The input embedding dim. Defaults to 256.
|
|
||||||
nhead (int, optional): The number of heads. Defaults to 4.
|
|
||||||
"""
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.embed_dim = embed_dim
|
|
||||||
self.nhead = nhead
|
|
||||||
self.query_head_dim = query_head_dim
|
|
||||||
|
|
||||||
self.in_proj = ScaledLinear(embed_dim, nhead * query_head_dim)
|
|
||||||
self.weight_proj = ScaledLinear(nhead * query_head_dim, nhead)
|
|
||||||
self.context_dropout_rate = context_dropout_rate
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
context: torch.Tensor,
|
|
||||||
context_lens: torch.Tensor=None,
|
|
||||||
padding_mask: torch.Tensor=None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""A module fusing the context embedding vectors
|
|
||||||
|
|
||||||
Args:
|
|
||||||
context (torch.Tensor): The context embeddings, (B,W,C)
|
|
||||||
context_lens (torch.Tensor): The length of context embeddings, (B,)
|
|
||||||
padding_mask (torch.Tensor): A padding mask (B,W)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The fused context embeddings, (B,C)
|
|
||||||
"""
|
|
||||||
batch_size = context.size(0)
|
|
||||||
if padding_mask is None:
|
|
||||||
assert context_lens is not None
|
|
||||||
padding_mask = make_pad_mask(context_lens).unsqueeze(-1)
|
|
||||||
else:
|
|
||||||
if padding_mask.ndim != 3:
|
|
||||||
padding_mask = padding_mask.unsqueeze(-1)
|
|
||||||
# context.masked_fill_(padding_mask, 0)
|
|
||||||
|
|
||||||
if context_lens is None:
|
|
||||||
max_len = padding_mask.size(1)
|
|
||||||
context_lens = max_len - padding_mask.sum(dim=1) + 1e-5 # to prevent 0
|
|
||||||
|
|
||||||
k = self.in_proj(context) # (B,W,C)
|
|
||||||
w = self.weight_proj(torch.tanh(k)) # (B,W,num_heads)
|
|
||||||
|
|
||||||
w.masked_fill_(padding_mask, -1000)
|
|
||||||
w = softmax(w, dim=1) # (B,W,num_heads)
|
|
||||||
w = w.permute(0,2,1) # (B,num_heads, W)
|
|
||||||
|
|
||||||
# reweight and concat the context embeddings
|
|
||||||
context = torch.matmul(w, context).view(batch_size, -1) # (B, num_heads * C)
|
|
||||||
|
|
||||||
# by a small probability, dropout the context of a few samples
|
|
||||||
if self.training:
|
|
||||||
m = torch.rand((batch_size, 1), device=context.device) > self.context_dropout_rate
|
|
||||||
context = context * m
|
|
||||||
#context = context * 0.0
|
|
||||||
|
|
||||||
return context
|
|
@ -20,26 +20,29 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|
||||||
|
|
||||||
./pruned_transducer_stateless7/train.py \
|
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 1 \
|
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 300
|
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless7/train.py \
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
./zipformer/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless7/exp \
|
--exp-dir zipformer/exp \
|
||||||
--full-libri 1 \
|
--max-duration 1000
|
||||||
--max-duration 550
|
|
||||||
|
# To train a streaming model
|
||||||
|
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--causal 1
|
||||||
|
--exp-dir zipformer/exp \
|
||||||
|
--max-duration 1000
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -54,7 +57,6 @@ from shutil import copyfile
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import numpy
|
|
||||||
import optim
|
import optim
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
@ -70,11 +72,11 @@ from model_baseline import Transducer
|
|||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from subsampling import Conv2dSubsampling
|
from subsampling import Conv2dSubsampling
|
||||||
|
from text_normalization import train_text_normalization, upper_only_alpha
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from text_normalization import train_text_normalization, upper_only_alpha
|
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
@ -95,41 +97,28 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
LRSchedulerType = Union[
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def random_sampling(texts: List[str]) -> str:
|
|
||||||
return random.choice(texts)
|
|
||||||
|
|
||||||
def joint_random_sampling(texts: List[str], pre_texts: List[str]) -> str:
|
|
||||||
i = random.randint(0, 1)
|
|
||||||
out = {
|
|
||||||
"text": texts[i],
|
|
||||||
"pre_text": pre_texts[i]
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
|
|
||||||
def get_first(
|
def get_first(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
pre_texts: List[str],
|
pre_texts: List[str],
|
||||||
context_list: Optional[str] = None,
|
context_list: Optional[str] = None,
|
||||||
rare_word_list: Optional[List[str]] = None,
|
rare_word_list: Optional[List[str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
out = {
|
# Always get the first one, which is the mixed-cased text with punc
|
||||||
"text": texts[0],
|
out = {"text": texts[0], "pre_text": pre_texts[0]}
|
||||||
"pre_text": pre_texts[0]
|
|
||||||
}
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_upper_only_alpha(
|
def get_upper_only_alpha(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
pre_texts: List[str],
|
pre_texts: List[str],
|
||||||
context_list: Optional[str] = None,
|
context_list: Optional[str] = None,
|
||||||
rare_word_list: Optional[List[str]] = None,
|
rare_word_list: Optional[List[str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Always get the first one, which is the gt (mixed-cased trans), but with upper_only_alpha
|
# Always get the first one, which is the mixed-cased text with punc,
|
||||||
|
# but with upper case it and remove punctuation
|
||||||
out = {
|
out = {
|
||||||
"text": upper_only_alpha(texts[0]),
|
"text": upper_only_alpha(texts[0]),
|
||||||
"pre_text": upper_only_alpha(pre_texts[0]),
|
"pre_text": upper_only_alpha(pre_texts[0]),
|
||||||
@ -257,8 +246,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; "
|
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||||
"2 means tri-gram",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -413,8 +401,7 @@ def get_parser():
|
|||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)"
|
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||||
"part.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -583,6 +570,7 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|||||||
)
|
)
|
||||||
return encoder_embed
|
return encoder_embed
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Zipformer2(
|
encoder = Zipformer2(
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
@ -789,11 +777,7 @@ def compute_loss(
|
|||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = (
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
model.device
|
|
||||||
if isinstance(model, DDP)
|
|
||||||
else next(model.parameters()).device
|
|
||||||
)
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
@ -844,9 +828,7 @@ def compute_loss(
|
|||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
(feature_lens // params.subsampling_factor).sum().item()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -1035,9 +1017,7 @@ def train_one_epoch(
|
|||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
cur_grad_scale = scaler._scale.item()
|
cur_grad_scale = scaler._scale.item()
|
||||||
|
|
||||||
if cur_grad_scale < 8.0 or (
|
if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0):
|
||||||
cur_grad_scale < 32.0 and batch_idx % 400 == 0
|
|
||||||
):
|
|
||||||
scaler.update(cur_grad_scale * 2.0)
|
scaler.update(cur_grad_scale * 2.0)
|
||||||
if cur_grad_scale < 0.01:
|
if cur_grad_scale < 0.01:
|
||||||
if not saved_bad_model:
|
if not saved_bad_model:
|
||||||
@ -1059,11 +1039,7 @@ def train_one_epoch(
|
|||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||||
f"grad_scale: {scaler._scale.item()}"
|
|
||||||
if params.use_fp16
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -1074,9 +1050,7 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(
|
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||||
tb_writer, "train/tot_", params.batch_idx_train
|
|
||||||
)
|
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
tb_writer.add_scalar(
|
tb_writer.add_scalar(
|
||||||
"train/grad_scale",
|
"train/grad_scale",
|
||||||
@ -1084,10 +1058,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
batch_idx % params.valid_interval == 0
|
|
||||||
and not params.print_diagnostics
|
|
||||||
):
|
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1177,9 +1148,7 @@ def run(rank, world_size, args):
|
|||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
optimizer = ScaledAdam(
|
optimizer = ScaledAdam(
|
||||||
get_parameter_groups_with_lrs(
|
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||||
model, lr=params.base_lr, include_names=True
|
|
||||||
),
|
|
||||||
lr=params.base_lr, # should have no effect
|
lr=params.base_lr, # should have no effect
|
||||||
clipping_scale=2.0,
|
clipping_scale=2.0,
|
||||||
)
|
)
|
||||||
@ -1200,7 +1169,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2 ** 22
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
@ -1265,14 +1234,14 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts = libriheavy.dev_cuts()
|
valid_cuts = libriheavy.dev_cuts()
|
||||||
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
valid_dl = libriheavy.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
# if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
# scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
# model=model,
|
model=model,
|
||||||
# train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
# optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
# sp=sp,
|
sp=sp,
|
||||||
# params=params,
|
params=params,
|
||||||
# )
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user