train and inference

This commit is contained in:
glynpu 2023-03-16 12:31:22 +08:00
parent b55ae4fd53
commit a49817385a
6 changed files with 1080 additions and 1 deletions

View File

@ -0,0 +1,49 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang,
# Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]):
"""
A graph starts with blank/unknown and follwoing by wakeup word.
Args:
wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
It should not contain 0 and 1.
We assume 0 is for blank and 1 is for unknown.
"""
assert 0 not in wakeup_word_tokens
assert 1 not in wakeup_word_tokens
assert len(wakeup_word_tokens) >= 2
keyword_ilabel_start = wakeup_word_tokens[0]
fst_graph = ""
for non_wake_word_token in range(keyword_ilabel_start):
fst_graph += f"0 0 {non_wake_word_token} 0\n"
cur_state = 1
for token_idx in range(len(wakeup_word_tokens) - 1):
token = wakeup_word_tokens[token_idx]
fst_graph += f"{cur_state - 1} {cur_state} {token} 0\n"
fst_graph += f"{cur_state} {cur_state} {token} 0\n"
cur_state += 1
token = wakeup_word_tokens[-1]
fst_graph += f"{cur_state - 1} {cur_state} {token} 1\n"
fst_graph += f"{cur_state} {cur_state} {token} 0\n"
fst_graph += f"{cur_state}\n"
return fst_graph

View File

@ -0,0 +1,203 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
import torch
from lhotse.features.io import NumpyHdf5Writer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
)
from asr_datamodule import HiMiaWuwDataModule
from tdnn import Tdnn
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=10,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 1.",
)
parser.add_argument(
"--avg",
type=int,
default=1,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="ctc_tdnn/exp",
help="The experiment dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"env_info": get_env_info(),
"feature_dim": 80,
"number_class": 9,
}
)
return params
def inference_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: torch.nn.Module,
test_set: str,
):
"""Compute and save model output of each utterance.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
test_set:
Name of test set.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
writer = NumpyHdf5Writer(f"{params.out_dir}/{test_set}")
for batch_idx, batch in enumerate(dl):
device = params.device
feature = batch["inputs"]
assert feature.ndim == 3
supervisions = batch["supervisions"]
start_frames = supervisions["start_frame"]
end_frames = start_frames + supervisions["num_frames"]
feature = feature.to(device)
# model_output is log_softmax(logit) with shape [N, T, C]
model_output = model(feature)
for i in range(feature.size(0)):
assert start_frames[i] == 0
cut = batch["supervisions"]["cut"][i]
cur_target = model_output[i][start_frames[i] : end_frames[i]]
writer.store_array(key=cut.id, value=cur_target.cpu().numpy())
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
@torch.no_grad()
def main():
parser = get_parser()
HiMiaWuwDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/"
Path(out_dir).mkdir(parents=True, exist_ok=True)
params.out_dir = out_dir
setup_logger(f"{out_dir}/log-decode")
logging.info("Decoding started")
logging.info(params)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Tdnn(params.feature_dim, params.number_class)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(
average_checkpoints(filenames, device=device), strict=True
)
model.to(device)
model.eval()
params.device = device
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
himia = HiMiaWuwDataModule(args)
aishell_test_cuts = himia.aishell_test_cuts()
test_cuts = himia.test_cuts()
cw_test_cuts = himia.cw_test_cuts()
aishell_test_dl = himia.test_dataloaders(aishell_test_cuts)
test_dl = himia.test_dataloaders(test_cuts)
cw_test_dl = himia.test_dataloaders(cw_test_cuts)
test_sets = ["aishell_test", "test", "cw_test"]
test_dls = [aishell_test_dl, test_dl, cw_test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
inference_dataset(
dl=test_dl,
params=params,
model=model,
test_set=test_set,
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,94 @@
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
#
# See ../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import torch
from typing import List, Tuple
class WakeupWordTokenizer(object):
def __init__(
self,
wakeup_word: str = "",
wakeup_word_tokens: List[int] = None,
) -> None:
"""
Args:
wakeup_word: content of positive samples.
A sample will be treated as a negative sample unless its context
is exactly the same to key_words.
wakeup_word_tokens: A list if int represents token ids of wakeup_word.
For example: the pronunciation of "你好米雅" is
"n i h ao m i y a".
Suppose we are using following lexicon:
blk 0
unk 1
n 2
i 3
h 4
ao 5
m 6
y 7
a 8
Then wakeup_word_tokens for "你好米雅" is:
n i h ao m i y a
[2, 3, 4, 5, 6, 3, 7, 8]
"""
super().__init__()
assert wakeup_word is not None
assert wakeup_word_tokens is not None
assert (
0 not in wakeup_word_tokens
), f"0 is kept for blank. Please Remove 0 from {wakeup_word_tokens}"
assert 1 not in wakeup_word_tokens, (
f"1 is kept for unknown and negative samples. "
f" Please Remove 1 from {wakeup_word_tokens}"
)
self.wakeup_word = wakeup_word
self.wakeup_word_tokens = wakeup_word_tokens
self.positive_number_tokens = len(wakeup_word_tokens)
self.negative_word_tokens = [1]
self.negative_number_tokens = 1
def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, int]:
"""Convert a list of texts to a list of k2.Fsa based texts.
Args:
texts:
It is a list of strings.
Returns:
Return a list of k2.Fsa, one for an element in texts.
If the element is `wakeup_word`, a graph for positive samples is appneded
into resulting graph_vec, otherwise, a graph for negative samples is appended.
Number of positive samples is also returned to track its proportion.
"""
batch_token_ids = []
target_lengths = []
number_positive_samples = 0
for utt_text in texts:
if utt_text == self.wakeup_word:
batch_token_ids.append(self.wakeup_word_tokens)
target_lengths.append(self.positive_number_tokens)
number_positive_samples += 1
else:
batch_token_ids.append(self.negative_word_tokens)
target_lengths.append(self.negative_number_tokens)
target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids)))
target_lengths = torch.tensor(target_lengths)
return target, target_lengths, number_positive_samples

678
egs/himia/wuw/ctc_tdnn/train.py Executable file
View File

@ -0,0 +1,678 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./ctc_tdnn/train.py \
--exp-dir ./tdnn/exp \
--world-size 4 \
--max-duration 200 \
--num-epochs 20
"""
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import HiMiaWuwDataModule
from tdnn import Tdnn
from lhotse.cut import Cut
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from tokenizer import WakeupWordTokenizer
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
MetricsTracker,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=20,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=1,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
ctc_tdnn/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="ctc_tdnn/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lr-factor",
type=float,
default=0.001,
help="The lr_factor for optimizer",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="The seed for random generators intended for reproducibility",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- number_class: Numer of classes. Each token will have a token id
from [0, num_class).
In this recipe, 0 is usually kept for blank,
and 1 is usually kept for negative words.
- wakeup_word: Text of wakeup word, i.e. positive samples.
- wakeup_word_tokens: A sequence of token ids corresponding wakeup_word.
- weight_decay: The weight_decay for the optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 5,
"reset_interval": 200,
"valid_interval": 3000,
# parameters for model
"feature_dim": 80,
"number_class": 9,
# parameters for tokenizer
"wakeup_word": "你好米雅",
"wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8],
# parameters for Optimizer
"weight_decay": 1e-6,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is larger than 1, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch > 1:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
tokenizer: WakeupWordTokenizer,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
tokenizer:
For positive samples, map their texts to corresponding token index sequence.
While for negative samples, map their texts to unknown no matter what they are.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = model.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
N, T, C = feature.shape
feature = feature.to(device)
supervisions = batch["supervisions"]
texts = supervisions["text"]
with torch.set_grad_enabled(is_training):
# model_output is log_softmax(logit) with shape [N, T, C]
model_output = model(feature)
assert torch.all(supervisions["start_frame"] == 0)
num_frames = supervisions["num_frames"].to(device)
target, target_lengths, number_positive_samples = tokenizer.texts_to_token_ids(
texts
) # noqa E501
target = target.to(device)
target_lengths = target_lengths.to(device)
ctc_loss = nn.CTCLoss(reduction="sum")
# [N, T, C] --> [T, N, C]
model_output = model_output.transpose(0, 1)
loss = ctc_loss(model_output, target, num_frames, target_lengths)
loss /= num_frames.sum()
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = num_frames.sum().item()
info["loss"] = loss.detach().cpu().item() * info["frames"]
# `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
info["utterances"] = feature.size(0)
# averaged input duration in frames over utterances
info["utt_duration"] = supervisions["num_frames"].sum().item()
# averaged padding proportion over utterances
info["utt_pad_proportion"] = (
((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item()
)
info["number_positive_cuts_ratio"] = (number_positive_samples / N) * info["frames"]
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
tokenizer: WakeupWordTokenizer,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
tokenizer: WakeupWordTokenizer,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
tokenizer:
For positive samples, map their texts to corresponding token index sequence.
While for negative samples, map their texts to unknown no matter what they are.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
tokenizer=tokenizer,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(params.seed)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
tokenizer = WakeupWordTokenizer(
wakeup_word=params.wakeup_word,
wakeup_word_tokens=params.wakeup_word_tokens,
)
logging.info("About to create model")
model = Tdnn(params.feature_dim, params.number_class)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = torch.optim.Adam(
model.parameters(),
lr=params.lr_factor,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
himia = HiMiaWuwDataModule(args)
train_cuts = himia.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 0.5 <= c.duration <= 20.0
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_dl = himia.train_dataloaders(train_cuts)
valid_cuts = himia.dev_cuts()
valid_dl = himia.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
tokenizer=tokenizer,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
# TODO: Support lr scheduler
cur_lr = params.lr_factor
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
tokenizer=tokenizer,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
tokenizer: WakeupWordTokenizer,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
batch=batch,
tokenizer=tokenizer,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
HiMiaWuwDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -2,7 +2,7 @@
set -eou pipefail set -eou pipefail
stage=6 stage=0
stop_stage=6 stop_stage=6
# HI_MIA and aishell dataset are used in this experiment. # HI_MIA and aishell dataset are used in this experiment.

View File

@ -0,0 +1,55 @@
#!/usr/bin/env bash
set -eou pipefail
# You need to execute ./prepare.sh to prepare datasets.
stage=1
stop_stage=2
epoch=10
avg=1
exp_dir=./ctc_tdnn/exp/
epoch_avg=epoch_${epoch}-avg_${avg}
post_dir=${exp_dir}/post/${epoch_avg}
. shared/parse_options.sh || exit 1
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Model training"
python ./ctc_tdnn/train.py \
--num-epochs $epoch
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Get posterior of test sets"
python ctc_tdnn/inference.py \
--avg $avg \
--epoch $epoch \
--exp-dir ${exp_dir}
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Decode and compute area under curve(AUC)"
for test_set in test aishell_test cw_test; do
python ctc_tdnn/decode.py \
--decoding-graph ./data/LG.int \
--post-h5 ${post_dir}/${test_set}.h5 \
--score-file ${post_dir}/fst_${test_set}_pos_h5.txt
done
python ./local/auc.py \
--legend himia_cw \
--positive-score-file ${post_dir}/fst_test_pos_h5.txt \
--negative-score-file ${post_dir}/fst_cw_test_pos_h5.txt
python ./local/auc.py \
--legend himia_aishell \
--positive-score-file ${post_dir}/fst_test_pos_h5.txt \
--negative-score-file ${post_dir}/fst_aishell_test_pos_h5.txt
fi