mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
[KWS]Remove graph compiler (#1905)
This commit is contained in:
parent
db9fb8ad31
commit
86bd16d496
@ -108,7 +108,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 2: Finetune the model"
|
||||
log "Stage 3: Finetune the model"
|
||||
|
||||
# The following configuration of lr schedule should work well
|
||||
# You may also tune the following parameters to adjust learning rate schedule
|
||||
@ -143,7 +143,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 1: Decode the finetuned model."
|
||||
log "Stage 4: Decode the finetuned model."
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for t in small large; do
|
||||
python ./zipformer/decode.py \
|
||||
@ -170,7 +170,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 2: Export the finetuned model."
|
||||
log "Stage 5: Export the finetuned model."
|
||||
|
||||
python ./zipformer/export.py \
|
||||
--epoch 10 \
|
||||
|
@ -35,7 +35,6 @@ from lhotse.cut import Cut
|
||||
from train import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall import ContextGraph
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
|
@ -90,6 +90,7 @@ from train import (
|
||||
add_training_arguments,
|
||||
compute_validation_loss,
|
||||
display_and_save_batch,
|
||||
encode_text,
|
||||
get_adjusted_batch_count,
|
||||
get_model,
|
||||
get_params,
|
||||
@ -100,7 +101,6 @@ from train import (
|
||||
)
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
@ -110,11 +110,11 @@ from icefall.checkpoint import (
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
num_tokens,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
text_to_pinyin,
|
||||
@ -254,7 +254,6 @@ def load_model_params(
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
@ -289,7 +288,7 @@ def compute_loss(
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = graph_compiler.texts_to_ids(texts, sep="/")
|
||||
y = [c.supervisions[0].tokens for c in supervisions["cut"]]
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
@ -347,7 +346,6 @@ def train_one_epoch(
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
@ -418,7 +416,6 @@ def train_one_epoch(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
@ -436,7 +433,7 @@ def train_one_epoch(
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
@ -523,7 +520,6 @@ def train_one_epoch(
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
@ -576,14 +572,10 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt")
|
||||
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
if not params.use_transducer:
|
||||
params.ctc_loss_scale = 1.0
|
||||
@ -666,17 +658,10 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
train_cuts = wenetspeech.nihaowenwen_train_cuts()
|
||||
|
||||
def encode_text(c: Cut):
|
||||
# Text normalize for each sample
|
||||
text = c.supervisions[0].text
|
||||
text = "/".join(
|
||||
text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
|
||||
)
|
||||
c.supervisions[0].text = text
|
||||
return c
|
||||
_encode_text = partial(encode_text, token_table=token_table, params=params)
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_utt)
|
||||
train_cuts = train_cuts.map(encode_text)
|
||||
train_cuts = train_cuts.map(_encode_text)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
@ -691,7 +676,7 @@ def run(rank, world_size, args):
|
||||
|
||||
valid_cuts = wenetspeech.nihaowenwen_dev_cuts()
|
||||
valid_cuts = valid_cuts.filter(remove_short_utt)
|
||||
valid_cuts = valid_cuts.map(encode_text)
|
||||
valid_cuts = valid_cuts.map(_encode_text)
|
||||
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics and params.scan_for_oom_batches:
|
||||
@ -699,7 +684,6 @@ def run(rank, world_size, args):
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
)
|
||||
|
||||
@ -724,7 +708,6 @@ def run(rank, world_size, args):
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
graph_compiler=graph_compiler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
@ -760,6 +743,8 @@ def main():
|
||||
WenetSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.return_cuts = True
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
|
@ -53,6 +53,7 @@ import argparse
|
||||
import copy
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
@ -79,7 +80,6 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from zipformer import Zipformer2
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
@ -90,11 +90,11 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
num_tokens,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
text_to_pinyin,
|
||||
@ -776,7 +776,6 @@ def save_checkpoint(
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
@ -811,7 +810,7 @@ def compute_loss(
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
y = graph_compiler.texts_to_ids(texts, sep="/")
|
||||
y = [c.supervisions[0].tokens for c in supervisions["cut"]]
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
@ -859,7 +858,6 @@ def compute_loss(
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
@ -872,7 +870,6 @@ def compute_validation_loss(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
@ -895,7 +892,6 @@ def train_one_epoch(
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
@ -971,7 +967,6 @@ def train_one_epoch(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
@ -988,7 +983,7 @@ def train_one_epoch(
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
save_bad_model()
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
@ -1077,7 +1072,6 @@ def train_one_epoch(
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
@ -1098,6 +1092,20 @@ def train_one_epoch(
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def encode_text(c: Cut, token_table: k2.SymbolTable, params: AttributeDict):
|
||||
text = c.supervisions[0].text
|
||||
tokens = text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
|
||||
ids = []
|
||||
for t in tokens:
|
||||
if t in token_table:
|
||||
ids.append(token_table[t])
|
||||
else:
|
||||
logging.warning(f"Text : {text} has OOV token : {t} , encode to <unk>")
|
||||
ids.append(token_table["<unk>"])
|
||||
c.supervisions[0].tokens = ids
|
||||
return c
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
@ -1130,14 +1138,10 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = CharCtcTrainingGraphCompiler(
|
||||
lexicon=lexicon,
|
||||
device=device,
|
||||
)
|
||||
token_table = k2.SymbolTable.from_file(params.lang_dir / "tokens.txt")
|
||||
|
||||
params.blank_id = lexicon.token_table["<blk>"]
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
params.blank_id = token_table["<blk>"]
|
||||
params.vocab_size = num_tokens(token_table) + 1
|
||||
|
||||
if not params.use_transducer:
|
||||
params.ctc_loss_scale = 1.0
|
||||
@ -1216,17 +1220,10 @@ def run(rank, world_size, args):
|
||||
|
||||
return True
|
||||
|
||||
def encode_text(c: Cut):
|
||||
# Text normalize for each sample
|
||||
text = c.supervisions[0].text
|
||||
text = "/".join(
|
||||
text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors)
|
||||
)
|
||||
c.supervisions[0].text = text
|
||||
return c
|
||||
_encode_text = partial(encode_text, token_table=token_table, params=params)
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_cuts = train_cuts.map(encode_text)
|
||||
train_cuts = train_cuts.map(_encode_text)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
@ -1240,7 +1237,7 @@ def run(rank, world_size, args):
|
||||
)
|
||||
|
||||
valid_cuts = wenetspeech.valid_cuts()
|
||||
valid_cuts = valid_cuts.map(encode_text)
|
||||
valid_cuts = valid_cuts.map(_encode_text)
|
||||
valid_dl = wenetspeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics and params.scan_for_oom_batches:
|
||||
@ -1248,7 +1245,6 @@ def run(rank, world_size, args):
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
params=params,
|
||||
)
|
||||
|
||||
@ -1273,7 +1269,6 @@ def run(rank, world_size, args):
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
graph_compiler=graph_compiler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
@ -1307,7 +1302,6 @@ def run(rank, world_size, args):
|
||||
def display_and_save_batch(
|
||||
batch: dict,
|
||||
params: AttributeDict,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
) -> None:
|
||||
"""Display the batch statistics and save the batch into disk.
|
||||
|
||||
@ -1317,8 +1311,6 @@ def display_and_save_batch(
|
||||
for the content in it.
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
graph_compiler:
|
||||
The compiler to encode texts to ids.
|
||||
"""
|
||||
from lhotse.utils import uuid4
|
||||
|
||||
@ -1332,8 +1324,8 @@ def display_and_save_batch(
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
texts = supervisions["text"]
|
||||
y = graph_compiler.texts_to_ids(texts)
|
||||
num_tokens = sum(len(i) for i in y)
|
||||
tokens = [c.supervisions[0].tokens for c in supervisions["cut"]]
|
||||
num_tokens = sum(len(i) for i in tokens)
|
||||
logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
@ -1341,7 +1333,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
@ -1357,7 +1348,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
@ -1372,7 +1362,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
@ -1385,6 +1375,7 @@ def main():
|
||||
args = parser.parse_args()
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.return_cuts = True
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user