mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add training code.
This commit is contained in:
parent
232caf51ee
commit
cd5ed7db20
@ -121,7 +121,7 @@ class Transducer(nn.Module):
|
|||||||
logit_lengths=x_lens,
|
logit_lengths=x_lens,
|
||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
blank=blank_id,
|
blank=blank_id,
|
||||||
reduction="mean",
|
reduction="sum",
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -25,6 +25,7 @@ from shutil import copyfile
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -36,21 +37,15 @@ from torch.nn.utils import clip_grad_norm_
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transducer.conformer import Conformer
|
from transducer.conformer import Conformer
|
||||||
from transducer.decoder import Decoder
|
from transducer.decoder import Decoder
|
||||||
|
from transducer.joiner import Joiner
|
||||||
|
from transducer.model import Transducer
|
||||||
from transducer.transformer import Noam
|
from transducer.transformer import Noam
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
MetricsTracker,
|
|
||||||
encode_supervisions,
|
|
||||||
setup_logger,
|
|
||||||
str2bool,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -107,22 +102,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
default="data/lang_bpe_500/bpe.model",
|
||||||
help="""The lang dir
|
help="Path to the BPE model",
|
||||||
It contains language related input files such as
|
|
||||||
"lexicon.txt"
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--att-rate",
|
|
||||||
type=float,
|
|
||||||
default=0.8,
|
|
||||||
help="""The attention rate.
|
|
||||||
The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -178,16 +161,8 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- attention_dim: Hidden dim for multi-head attention model.
|
- attention_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
- head: Number of heads of multi-head attention model.
|
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
- beam_size: It is used in k2.ctc_loss
|
|
||||||
|
|
||||||
- reduction: It is used in k2.ctc_loss
|
|
||||||
|
|
||||||
- use_double_scores: It is used in k2.ctc_loss
|
|
||||||
|
|
||||||
- weight_decay: The weight_decay for the optimizer.
|
- weight_decay: The weight_decay for the optimizer.
|
||||||
|
|
||||||
- warm_step: The warm_step for Noam optimizer.
|
- warm_step: The warm_step for Noam optimizer.
|
||||||
@ -213,16 +188,9 @@ def get_params() -> AttributeDict:
|
|||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
# decoder params
|
# decoder params
|
||||||
"vocab_size": 500, # including blank
|
|
||||||
"decoder_embedding_dim": 1024,
|
"decoder_embedding_dim": 1024,
|
||||||
"blank_id": 0,
|
|
||||||
"sos_id": 1,
|
|
||||||
"num_decoder_layers": 4,
|
"num_decoder_layers": 4,
|
||||||
"decoder_hidden_dim": 512,
|
"decoder_hidden_dim": 512,
|
||||||
# parameters for loss
|
|
||||||
"beam_size": 10,
|
|
||||||
"reduction": "sum",
|
|
||||||
"use_double_scores": True,
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"warm_step": 80000,
|
"warm_step": 80000,
|
||||||
@ -262,6 +230,27 @@ def get_decoder_model(params: AttributeDict):
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
|
def get_joiner_model(params: AttributeDict):
|
||||||
|
joiner = Joiner(
|
||||||
|
input_dim=params.encoder_out_dim,
|
||||||
|
output_dim=params.vocab_size,
|
||||||
|
)
|
||||||
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
|
def get_transducer_model(params: AttributeDict):
|
||||||
|
encoder = get_encoder_model(params)
|
||||||
|
decoder = get_decoder_model(params)
|
||||||
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
|
model = Transducer(
|
||||||
|
encoder=encoder,
|
||||||
|
decoder=decoder,
|
||||||
|
joiner=joiner,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -352,8 +341,8 @@ def save_checkpoint(
|
|||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
@ -367,86 +356,35 @@ def compute_loss(
|
|||||||
batch:
|
batch:
|
||||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
for the content in it.
|
for the content in it.
|
||||||
graph_compiler:
|
|
||||||
It is used to build a decoding graph from a ctc topo and training
|
|
||||||
transcript. The training transcript is contained in the given `batch`,
|
|
||||||
while the ctc topo is built when this compiler is instantiated.
|
|
||||||
is_training:
|
is_training:
|
||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
"""
|
"""
|
||||||
device = graph_compiler.device
|
device = model.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
y = sp.encode(texts, out_type=int)
|
||||||
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
loss = model(x=feature, x_lens=feature_lens, y=y)
|
||||||
# nnet_output is (N, T, C)
|
|
||||||
|
|
||||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
|
||||||
# different duration in decreasing order, required by
|
|
||||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
|
||||||
supervision_segments, texts = encode_supervisions(
|
|
||||||
supervisions, subsampling_factor=params.subsampling_factor
|
|
||||||
)
|
|
||||||
|
|
||||||
token_ids = graph_compiler.texts_to_ids(texts)
|
|
||||||
|
|
||||||
decoding_graph = graph_compiler.compile(token_ids)
|
|
||||||
|
|
||||||
dense_fsa_vec = k2.DenseFsaVec(
|
|
||||||
nnet_output,
|
|
||||||
supervision_segments,
|
|
||||||
allow_truncate=params.subsampling_factor - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctc_loss = k2.ctc_loss(
|
|
||||||
decoding_graph=decoding_graph,
|
|
||||||
dense_fsa_vec=dense_fsa_vec,
|
|
||||||
output_beam=params.beam_size,
|
|
||||||
reduction=params.reduction,
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.att_rate != 0.0:
|
|
||||||
with torch.set_grad_enabled(is_training):
|
|
||||||
mmodel = model.module if hasattr(model, "module") else model
|
|
||||||
# Note: We need to generate an unsorted version of token_ids
|
|
||||||
# `encode_supervisions()` called above sorts text, but
|
|
||||||
# encoder_memory and memory_mask are not sorted, so we
|
|
||||||
# use an unsorted version `supervisions["text"]` to regenerate
|
|
||||||
# the token_ids
|
|
||||||
#
|
|
||||||
# See https://github.com/k2-fsa/icefall/issues/97
|
|
||||||
# for more details
|
|
||||||
unsorted_token_ids = graph_compiler.texts_to_ids(
|
|
||||||
supervisions["text"]
|
|
||||||
)
|
|
||||||
att_loss = mmodel.decoder_forward(
|
|
||||||
encoder_memory,
|
|
||||||
memory_mask,
|
|
||||||
token_ids=unsorted_token_ids,
|
|
||||||
sos_id=graph_compiler.sos_id,
|
|
||||||
eos_id=graph_compiler.eos_id,
|
|
||||||
)
|
|
||||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
|
||||||
else:
|
|
||||||
loss = ctc_loss
|
|
||||||
att_loss = torch.tensor([0])
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
|
||||||
if params.att_rate != 0.0:
|
|
||||||
info["att_loss"] = att_loss.detach().cpu().item()
|
|
||||||
|
|
||||||
info["loss"] = loss.detach().cpu().item()
|
# We use reduction="sum" in computing the loss.
|
||||||
|
# The displayed loss is the average loss over the batch
|
||||||
|
info["loss"] = loss.detach().cpu().item() / feature.size(0)
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -454,7 +392,7 @@ def compute_loss(
|
|||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
sp: spm.SentencePieceProcessor,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
) -> MetricsTracker:
|
) -> MetricsTracker:
|
||||||
@ -467,8 +405,8 @@ def compute_validation_loss(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
|
||||||
is_training=False,
|
is_training=False,
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
@ -489,7 +427,7 @@ def train_one_epoch(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
@ -508,8 +446,6 @@ def train_one_epoch(
|
|||||||
The model for training.
|
The model for training.
|
||||||
optimizer:
|
optimizer:
|
||||||
The optimizer we are using.
|
The optimizer we are using.
|
||||||
graph_compiler:
|
|
||||||
It is used to convert transcripts to FSAs.
|
|
||||||
train_dl:
|
train_dl:
|
||||||
Dataloader for the training dataset.
|
Dataloader for the training dataset.
|
||||||
valid_dl:
|
valid_dl:
|
||||||
@ -530,8 +466,8 @@ def train_one_epoch(
|
|||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
@ -567,7 +503,7 @@ def train_one_epoch(
|
|||||||
valid_info = compute_validation_loss(
|
valid_info = compute_validation_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
graph_compiler=graph_compiler,
|
sp=sp,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
@ -606,50 +542,37 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info("Training started")
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
if args.tensorboard and rank == 0:
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
else:
|
else:
|
||||||
tb_writer = None
|
tb_writer = None
|
||||||
|
|
||||||
model = get_encoder_model(params)
|
|
||||||
model = get_decoder_model(params)
|
|
||||||
print(model)
|
|
||||||
return
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
|
||||||
max_token_id = max(lexicon.tokens)
|
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
sp = spm.SentencePieceProcessor()
|
||||||
params.lang_dir,
|
sp.load(params.bpe_model)
|
||||||
device=device,
|
|
||||||
sos_token="<sos/eos>",
|
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
||||||
eos_token="<sos/eos>",
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
)
|
params.sos_id = sp.piece_to_id("<sos/eos>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = Conformer(
|
model = get_transducer_model(params)
|
||||||
num_features=params.feature_dim,
|
|
||||||
nhead=params.nhead,
|
|
||||||
d_model=params.attention_dim,
|
|
||||||
num_classes=num_classes,
|
|
||||||
subsampling_factor=params.subsampling_factor,
|
|
||||||
num_decoder_layers=params.num_decoder_layers,
|
|
||||||
vgg_frontend=False,
|
|
||||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
logging.info("Using DDP")
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
model.device = device
|
||||||
|
|
||||||
optimizer = Noam(
|
optimizer = Noam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
@ -659,7 +582,8 @@ def run(rank, world_size, args):
|
|||||||
weight_decay=params.weight_decay,
|
weight_decay=params.weight_decay,
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
|
logging.info("Loading optimizer state dict")
|
||||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
@ -678,7 +602,7 @@ def run(rank, world_size, args):
|
|||||||
model=model,
|
model=model,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
graph_compiler=graph_compiler,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -701,7 +625,7 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
graph_compiler=graph_compiler,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
@ -726,7 +650,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
graph_compiler: BpeCtcTrainingGraphCompiler,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
@ -742,8 +666,8 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
graph_compiler=graph_compiler,
|
|
||||||
is_training=True,
|
is_training=True,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -766,7 +690,6 @@ def main():
|
|||||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user