[KWS]Remove graph compiler (#1905)

This commit is contained in:
Wei Kang 2025-04-02 22:10:06 +08:00 committed by GitHub
parent db9fb8ad31
commit 86bd16d496
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 68 deletions

View File

@ -108,7 +108,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 2: Finetune the model"
log "Stage 3: Finetune the model"
# The following configuration of lr schedule should work well
# You may also tune the following parameters to adjust learning rate schedule
@ -143,7 +143,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 1: Decode the finetuned model."
log "Stage 4: Decode the finetuned model."
export CUDA_VISIBLE_DEVICES="0"
for t in small large; do
python ./zipformer/decode.py \
@ -170,7 +170,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 2: Export the finetuned model."
log "Stage 5: Export the finetuned model."
python ./zipformer/export.py \
--epoch 10 \

View File

@ -35,7 +35,6 @@ from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,

View File

@ -90,6 +90,7 @@ from train import (
add_training_arguments,
compute_validation_loss,
display_and_save_batch,
encode_text,
get_adjusted_batch_count,
get_model,
get_params,
@ -100,7 +101,6 @@ from train import (
)
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -110,11 +110,11 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
num_tokens,
setup_logger,
str2bool,
text_to_pinyin,
@ -254,7 +254,6 @@ def load_model_params(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
@ -289,7 +288,7 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts, sep="/")
y = [c.supervisions[0].tokens for c in supervisions["cut"]]
y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training):
@ -347,7 +346,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@ -418,7 +416,6 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@ -436,7 +433,7 @@ def train_one_epoch(
optimizer.zero_grad()
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(batch, params=params)
raise
if params.print_diagnostics and batch_idx == 5:
@ -523,7 +520,6 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
@ -576,14 +572,10 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt")
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
if not params.use_transducer:
params.ctc_loss_scale = 1.0
@ -666,17 +658,10 @@ def run(rank, world_size, args):
else:
train_cuts = wenetspeech.nihaowenwen_train_cuts()
def encode_text(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = "/".join(
text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
)
c.supervisions[0].text = text
return c
_encode_text = partial(encode_text, token_table=token_table, params=params)
train_cuts = train_cuts.filter(remove_short_utt)
train_cuts = train_cuts.map(encode_text)
train_cuts = train_cuts.map(_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -691,7 +676,7 @@ def run(rank, world_size, args):
valid_cuts = wenetspeech.nihaowenwen_dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt)
valid_cuts = valid_cuts.map(encode_text)
valid_cuts = valid_cuts.map(_encode_text)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches:
@ -699,7 +684,6 @@ def run(rank, world_size, args):
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
@ -724,7 +708,6 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@ -760,6 +743,8 @@ def main():
WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.return_cuts = True
world_size = args.world_size
assert world_size >= 1

View File

@ -53,6 +53,7 @@ import argparse
import copy
import logging
import warnings
from functools import partial
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
@ -79,7 +80,6 @@ from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2
from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
@ -90,11 +90,11 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
num_tokens,
setup_logger,
str2bool,
text_to_pinyin,
@ -776,7 +776,6 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
@ -811,7 +810,7 @@ def compute_loss(
warm_step = params.warm_step
texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts, sep="/")
y = [c.supervisions[0].tokens for c in supervisions["cut"]]
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
@ -859,7 +858,6 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@ -872,7 +870,6 @@ def compute_validation_loss(
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=False,
)
@ -895,7 +892,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@ -971,7 +967,6 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@ -988,7 +983,7 @@ def train_one_epoch(
optimizer.zero_grad()
except: # noqa
save_bad_model()
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(batch, params=params)
raise
if params.print_diagnostics and batch_idx == 5:
@ -1077,7 +1072,6 @@ def train_one_epoch(
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
@ -1098,6 +1092,20 @@ def train_one_epoch(
params.best_train_loss = params.train_loss
def encode_text(c: Cut, token_table: k2.SymbolTable, params: AttributeDict):
text = c.supervisions[0].text
tokens = text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
ids = []
for t in tokens:
if t in token_table:
ids.append(token_table[t])
else:
logging.warning(f"Text : {text} has OOV token : {t} , encode to <unk>")
ids.append(token_table["<unk>"])
c.supervisions[0].tokens = ids
return c
def run(rank, world_size, args):
"""
Args:
@ -1130,14 +1138,10 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt")
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1
if not params.use_transducer:
params.ctc_loss_scale = 1.0
@ -1216,17 +1220,10 @@ def run(rank, world_size, args):
return True
def encode_text(c: Cut):
# Text normalize for each sample
text = c.supervisions[0].text
text = "/".join(
text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
)
c.supervisions[0].text = text
return c
_encode_text = partial(encode_text, token_table=token_table, params=params)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(encode_text)
train_cuts = train_cuts.map(_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
@ -1240,7 +1237,7 @@ def run(rank, world_size, args):
)
valid_cuts = wenetspeech.valid_cuts()
valid_cuts = valid_cuts.map(encode_text)
valid_cuts = valid_cuts.map(_encode_text)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches:
@ -1248,7 +1245,6 @@ def run(rank, world_size, args):
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
@ -1273,7 +1269,6 @@ def run(rank, world_size, args):
model_avg=model_avg,
optimizer=optimizer,
scheduler=scheduler,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@ -1307,7 +1302,6 @@ def run(rank, world_size, args):
def display_and_save_batch(
batch: dict,
params: AttributeDict,
graph_compiler: CharCtcTrainingGraphCompiler,
) -> None:
"""Display the batch statistics and save the batch into disk.
@ -1317,8 +1311,6 @@ def display_and_save_batch(
for the content in it.
params:
Parameters for training. See :func:`get_params`.
graph_compiler:
The compiler to encode texts to ids.
"""
from lhotse.utils import uuid4
@ -1332,8 +1324,8 @@ def display_and_save_batch(
logging.info(f"features shape: {features.shape}")
texts = supervisions["text"]
y = graph_compiler.texts_to_ids(texts)
num_tokens = sum(len(i) for i in y)
tokens = [c.supervisions[0].tokens for c in supervisions["cut"]]
num_tokens = sum(len(i) for i in tokens)
logging.info(f"num tokens: {num_tokens}")
@ -1341,7 +1333,6 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -1357,7 +1348,6 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
)
@ -1372,7 +1362,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
display_and_save_batch(batch, params=params)
raise
logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -1385,6 +1375,7 @@ def main():
args = parser.parse_args()
args.lang_dir = Path(args.lang_dir)
args.exp_dir = Path(args.exp_dir)
args.return_cuts = True
world_size = args.world_size
assert world_size >= 1