[KWS]Remove graph compiler (#1905)

This commit is contained in:
Wei Kang 2025-04-02 22:10:06 +08:00 committed by GitHub
parent db9fb8ad31
commit 86bd16d496
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 43 additions and 68 deletions

View File

@ -108,7 +108,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
fi fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 2: Finetune the model" log "Stage 3: Finetune the model"
# The following configuration of lr schedule should work well # The following configuration of lr schedule should work well
# You may also tune the following parameters to adjust learning rate schedule # You may also tune the following parameters to adjust learning rate schedule
@ -143,7 +143,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
fi fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 1: Decode the finetuned model." log "Stage 4: Decode the finetuned model."
export CUDA_VISIBLE_DEVICES="0" export CUDA_VISIBLE_DEVICES="0"
for t in small large; do for t in small large; do
python ./zipformer/decode.py \ python ./zipformer/decode.py \
@ -170,7 +170,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
fi fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 2: Export the finetuned model." log "Stage 5: Export the finetuned model."
python ./zipformer/export.py \ python ./zipformer/export.py \
--epoch 10 \ --epoch 10 \

View File

@ -35,7 +35,6 @@ from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph from icefall import ContextGraph
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model, average_checkpoints_with_averaged_model,

View File

@ -90,6 +90,7 @@ from train import (
add_training_arguments, add_training_arguments,
compute_validation_loss, compute_validation_loss,
display_and_save_batch, display_and_save_batch,
encode_text,
get_adjusted_batch_count, get_adjusted_batch_count,
get_model, get_model,
get_params, get_params,
@ -100,7 +101,6 @@ from train import (
) )
from icefall import diagnostics from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -110,11 +110,11 @@ from icefall.checkpoint import (
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
num_tokens,
setup_logger, setup_logger,
str2bool, str2bool,
text_to_pinyin, text_to_pinyin,
@ -254,7 +254,6 @@ def load_model_params(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
@ -289,7 +288,7 @@ def compute_loss(
warm_step = params.warm_step warm_step = params.warm_step
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts, sep="/") y = [c.supervisions[0].tokens for c in supervisions["cut"]]
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
@ -347,7 +346,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
@ -418,7 +416,6 @@ def train_one_epoch(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
@ -436,7 +433,7 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
save_bad_model() save_bad_model()
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(batch, params=params)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -523,7 +520,6 @@ def train_one_epoch(
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
@ -576,14 +572,10 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir) token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt")
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = num_tokens(token_table) + 1
if not params.use_transducer: if not params.use_transducer:
params.ctc_loss_scale = 1.0 params.ctc_loss_scale = 1.0
@ -666,17 +658,10 @@ def run(rank, world_size, args):
else: else:
train_cuts = wenetspeech.nihaowenwen_train_cuts() train_cuts = wenetspeech.nihaowenwen_train_cuts()
def encode_text(c: Cut): _encode_text = partial(encode_text, token_table=token_table, params=params)
# Text normalize for each sample
text = c.supervisions[0].text
text = "/".join(
text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
)
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_utt) train_cuts = train_cuts.filter(remove_short_utt)
train_cuts = train_cuts.map(encode_text) train_cuts = train_cuts.map(_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
@ -691,7 +676,7 @@ def run(rank, world_size, args):
valid_cuts = wenetspeech.nihaowenwen_dev_cuts() valid_cuts = wenetspeech.nihaowenwen_dev_cuts()
valid_cuts = valid_cuts.filter(remove_short_utt) valid_cuts = valid_cuts.filter(remove_short_utt)
valid_cuts = valid_cuts.map(encode_text) valid_cuts = valid_cuts.map(_encode_text)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts) valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches: if not params.print_diagnostics and params.scan_for_oom_batches:
@ -699,7 +684,6 @@ def run(rank, world_size, args):
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
optimizer=optimizer, optimizer=optimizer,
graph_compiler=graph_compiler,
params=params, params=params,
) )
@ -724,7 +708,6 @@ def run(rank, world_size, args):
model_avg=model_avg, model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
graph_compiler=graph_compiler,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
@ -760,6 +743,8 @@ def main():
WenetSpeechAsrDataModule.add_arguments(parser) WenetSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
args.return_cuts = True
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1

View File

@ -53,6 +53,7 @@ import argparse
import copy import copy
import logging import logging
import warnings import warnings
from functools import partial
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, Optional, Tuple, Union
@ -79,7 +80,6 @@ from torch.utils.tensorboard import SummaryWriter
from zipformer import Zipformer2 from zipformer import Zipformer2
from icefall import diagnostics from icefall import diagnostics
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import (
@ -90,11 +90,11 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.err import raise_grad_scale_is_too_small_error from icefall.err import raise_grad_scale_is_too_small_error
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs, get_parameter_groups_with_lrs,
num_tokens,
setup_logger, setup_logger,
str2bool, str2bool,
text_to_pinyin, text_to_pinyin,
@ -776,7 +776,6 @@ def save_checkpoint(
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
@ -811,7 +810,7 @@ def compute_loss(
warm_step = params.warm_step warm_step = params.warm_step
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
y = graph_compiler.texts_to_ids(texts, sep="/") y = [c.supervisions[0].tokens for c in supervisions["cut"]]
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
@ -859,7 +858,6 @@ def compute_loss(
def compute_validation_loss( def compute_validation_loss(
params: AttributeDict, params: AttributeDict,
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
graph_compiler: CharCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
world_size: int = 1, world_size: int = 1,
) -> MetricsTracker: ) -> MetricsTracker:
@ -872,7 +870,6 @@ def compute_validation_loss(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=False, is_training=False,
) )
@ -895,7 +892,6 @@ def train_one_epoch(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType, scheduler: LRSchedulerType,
graph_compiler: CharCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
@ -971,7 +967,6 @@ def train_one_epoch(
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
@ -988,7 +983,7 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa except: # noqa
save_bad_model() save_bad_model()
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(batch, params=params)
raise raise
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -1077,7 +1072,6 @@ def train_one_epoch(
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
@ -1098,6 +1092,20 @@ def train_one_epoch(
params.best_train_loss = params.train_loss params.best_train_loss = params.train_loss
def encode_text(c: Cut, token_table: k2.SymbolTable, params: AttributeDict):
text = c.supervisions[0].text
tokens = text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
ids = []
for t in tokens:
if t in token_table:
ids.append(token_table[t])
else:
logging.warning(f"Text : {text} has OOV token : {t} , encode to <unk>")
ids.append(token_table["<unk>"])
c.supervisions[0].tokens = ids
return c
def run(rank, world_size, args): def run(rank, world_size, args):
""" """
Args: Args:
@ -1130,14 +1138,10 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir) token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt")
graph_compiler = CharCtcTrainingGraphCompiler(
lexicon=lexicon,
device=device,
)
params.blank_id = lexicon.token_table["<blk>"] params.blank_id = token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1 params.vocab_size = num_tokens(token_table) + 1
if not params.use_transducer: if not params.use_transducer:
params.ctc_loss_scale = 1.0 params.ctc_loss_scale = 1.0
@ -1216,17 +1220,10 @@ def run(rank, world_size, args):
return True return True
def encode_text(c: Cut): _encode_text = partial(encode_text, token_table=token_table, params=params)
# Text normalize for each sample
text = c.supervisions[0].text
text = "/".join(
text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
)
c.supervisions[0].text = text
return c
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
train_cuts = train_cuts.map(encode_text) train_cuts = train_cuts.map(_encode_text)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
@ -1240,7 +1237,7 @@ def run(rank, world_size, args):
) )
valid_cuts = wenetspeech.valid_cuts() valid_cuts = wenetspeech.valid_cuts()
valid_cuts = valid_cuts.map(encode_text) valid_cuts = valid_cuts.map(_encode_text)
valid_dl = wenetspeech.valid_dataloaders(valid_cuts) valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics and params.scan_for_oom_batches: if not params.print_diagnostics and params.scan_for_oom_batches:
@ -1248,7 +1245,6 @@ def run(rank, world_size, args):
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
optimizer=optimizer, optimizer=optimizer,
graph_compiler=graph_compiler,
params=params, params=params,
) )
@ -1273,7 +1269,6 @@ def run(rank, world_size, args):
model_avg=model_avg, model_avg=model_avg,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
graph_compiler=graph_compiler,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
@ -1307,7 +1302,6 @@ def run(rank, world_size, args):
def display_and_save_batch( def display_and_save_batch(
batch: dict, batch: dict,
params: AttributeDict, params: AttributeDict,
graph_compiler: CharCtcTrainingGraphCompiler,
) -> None: ) -> None:
"""Display the batch statistics and save the batch into disk. """Display the batch statistics and save the batch into disk.
@ -1317,8 +1311,6 @@ def display_and_save_batch(
for the content in it. for the content in it.
params: params:
Parameters for training. See :func:`get_params`. Parameters for training. See :func:`get_params`.
graph_compiler:
The compiler to encode texts to ids.
""" """
from lhotse.utils import uuid4 from lhotse.utils import uuid4
@ -1332,8 +1324,8 @@ def display_and_save_batch(
logging.info(f"features shape: {features.shape}") logging.info(f"features shape: {features.shape}")
texts = supervisions["text"] texts = supervisions["text"]
y = graph_compiler.texts_to_ids(texts) tokens = [c.supervisions[0].tokens for c in supervisions["cut"]]
num_tokens = sum(len(i) for i in y) num_tokens = sum(len(i) for i in tokens)
logging.info(f"num tokens: {num_tokens}") logging.info(f"num tokens: {num_tokens}")
@ -1341,7 +1333,6 @@ def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1357,7 +1348,6 @@ def scan_pessimistic_batches_for_oom(
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,
graph_compiler=graph_compiler,
batch=batch, batch=batch,
is_training=True, is_training=True,
) )
@ -1372,7 +1362,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) display_and_save_batch(batch, params=params)
raise raise
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -1385,6 +1375,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
args.lang_dir = Path(args.lang_dir) args.lang_dir = Path(args.lang_dir)
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.return_cuts = True
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1