Merge branch 'k2-fsa:master' into dev/k2ssl

This commit is contained in:
Yifan Yang 2025-05-11 00:33:35 +08:00 committed by GitHub
commit 260d37b65a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 74 additions and 166 deletions

View File

@ -42,8 +42,8 @@ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishel
# For multi-hans fine-tuned whisper model # For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt # huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
# huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct # huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct huggingface-cli download --local-dir models/qwen Qwen/Qwen2-1.5B-Instruct
# First, we only train the projector and freeze other modules. # First, we only train the projector and freeze other modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
@ -55,9 +55,10 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--use-lora False --unfreeze-llm False --use-lora False \
--unfreeze-llm False
# Then we jointly train the projector and LLM LoRA modules. # Then, we jointly train the projector and LLM LoRA modules.
torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir ./whisper_llm_zh/exp_test \ --exp-dir ./whisper_llm_zh/exp_test \
@ -67,7 +68,8 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \ --deepspeed_config ./whisper_llm_zh/ds_config_zero1.json \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --unfreeze-llm True --use-lora True \
--unfreeze-llm True \
--pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt --pretrained-model-path ./whisper_llm_zh/exp_test/epoch-3.pt
``` ```
@ -81,7 +83,7 @@ huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_aishel
# For multi-hans fine-tuned whisper model # For multi-hans fine-tuned whisper model
# huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt # huggingface-cli download --local-dir models/whisper yuekai/icefall_asr_multi-hans-zh_whisper v1.1/whisper-large-v2-multi-hans-zh-epoch-3-avg-10.pt
huggingface-clie download --local-dir models/qwen Qwen/Qwen2-7B-Instruct huggingface-cli download --local-dir models/qwen Qwen/Qwen2-7B-Instruct
mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B mkdir -p whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B
ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt ln -s models/checkpoint/epoch-10-avg-5.pt whisper_llm_zh/exp_aishell_whisper_qwen2_1.5B/epoch-999.pt
@ -94,5 +96,6 @@ python3 ./whisper_llm_zh/decode.py \
--epoch 999 --avg 1 \ --epoch 999 --avg 1 \
--manifest-dir data/fbank \ --manifest-dir data/fbank \
--use-flash-attn True \ --use-flash-attn True \
--use-lora True --dataset aishell --use-lora True \
--dataset aishell
``` ```

13
egs/speech_llm/ASR_LLM/prepare.sh Normal file → Executable file
View File

@ -7,6 +7,9 @@ set -eou pipefail
stage=0 stage=0
stop_stage=0 stop_stage=0
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data". # All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it. # You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data mkdir -p data
@ -23,7 +26,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# pip install huggingface_hub['cli'] # pip install huggingface_hub['cli']
# for aishell 1 # for aishell 1
huggingface-cli download --local-dir data yuekai/aishell_whisper_fbank_lhotse huggingface-cli download --repo-type dataset --local-dir data yuekai/aishell_whisper_fbank_lhotse
fi fi
@ -31,9 +34,9 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface" log "stage 1: Download whisper-large-v2 multi-hans-zh fbank feature from huggingface"
# for multi-hans-zh # for multi-hans-zh
huggingface-cli download --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/wenetspeech_whisper_fbank_lhotse
huggingface-cli download --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/multi_hans_zh_whisper_fbank_lhotse
huggingface-cli download --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse huggingface-cli download --repo-type dataset --local-dir data/fbank yuekai/alimeeting_aishell4_training_whisper_fbank_lhotse
fi fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
@ -41,6 +44,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# for speechio test sets # for speechio test sets
mkdir data_speechio mkdir data_speechio
huggingface-cli download --local-dir data_speechio yuekai/icefall_asr_speechio huggingface-cli download --repo-type model --local-dir data_speechio yuekai/icefall_asr_speechio
mv data_speechio/fbank/* data/fbank mv data_speechio/fbank/* data/fbank
fi fi

View File

@ -0,0 +1 @@
../../../icefall/shared

View File

@ -66,7 +66,7 @@ from train import DEFAULT_SPEECH_TOKEN
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.checkpoint import load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
@ -357,43 +357,6 @@ def decode_dataset(
Returns: Returns:
Return a dict, whose key may be "beam-search". Return a dict, whose key may be "beam-search".
""" """
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
results = [] results = []
num_cuts = 0 num_cuts = 0
@ -406,6 +369,7 @@ def decode_dataset(
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
@ -418,12 +382,8 @@ def decode_dataset(
for lm_scale, hyps in hyps_dict.items(): for lm_scale, hyps in hyps_dict.items():
this_batch = [] this_batch = []
assert len(hyps) == len(texts) assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): for cut_id, hyp_text, ref_text in zip(cut_ids, hyps, texts):
ref_text = normalize_text_alimeeting(ref_text) this_batch.append((cut_id, ref_text, hyp_text))
ref_words = ref_text.split()
print(f"ref: {ref_text}")
print(f"hyp: {''.join(hyp_words)}")
this_batch.append((cut_id, ref_words, hyp_words))
results[lm_scale].extend(this_batch) results[lm_scale].extend(this_batch)
@ -439,40 +399,38 @@ def decode_dataset(
def save_results( def save_results(
params: AttributeDict, params: AttributeDict,
test_set_name: str, test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
): ):
enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = ( recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
) )
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results, char_level=True)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out CERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = ( errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
) )
# we compute CER for aishell dataset.
results_char = []
for res in results:
results_char.append((res[0], list("".join(res[1])), list("".join(res[2]))))
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results_char, enable_log=enable_log f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
) )
test_set_wers[key] = wer test_set_wers[key] = wer
if enable_log:
logging.info("Wrote detailed error stats to {}".format(errs_filename)) logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"cer-summary-{test_set_name}-{params.suffix}.txt" errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tCER", file=f) print("settings\tCER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:
@ -495,9 +453,13 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.res_dir = params.exp_dir / f"{params.method}"
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger( setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" params.res_dir
/ f"log-decode-{params.method}-beam{params.beam_size}-{params.suffix}"
) )
logging.info("Decoding started") logging.info("Decoding started")
@ -574,23 +536,20 @@ def main():
if params.avg > 1: if params.avg > 1:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
assert start >= 1, start assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
assert "model" not in checkpoint
# deepspeed converted checkpoint only contains model state_dict # deepspeed converted checkpoint only contains model state_dict
filenames = [ filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt" f"{params.exp_dir}/epoch-{epoch}/pytorch_model.bin"
for epoch in range(start, params.epoch + 1) for epoch in range(start, params.epoch + 1)
] ]
avg_checkpoint = average_checkpoints(filenames) avg_checkpoint = average_checkpoints(filenames)
model.load_state_dict(avg_checkpoint, strict=False) model.load_state_dict(avg_checkpoint, strict=False)
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" # filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(avg_checkpoint, filename) # torch.save(avg_checkpoint, filename)
else: else:
checkpoint = torch.load( checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" f"{params.exp_dir}/epoch-{params.epoch}/pytorch_model.bin",
map_location="cpu",
) )
model.load_state_dict(checkpoint, strict=False) model.load_state_dict(checkpoint, strict=False)
@ -643,8 +602,7 @@ def main():
logging.info("Done!") logging.info("Done!")
if __name__ == "__main__":
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
if __name__ == "__main__":
main() main()

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang # 2024 Yuekai Zhang
# 2025 Yifan Yang
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -42,47 +43,32 @@ torchrun --nproc_per_node 8 ./whisper_llm_zh/train.py \
""" """
import argparse import argparse
import copy
import logging import logging
import os import os
import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from typing import Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed import deepspeed
import k2
import torch import torch
import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import transformers import transformers
import whisper import whisper
from asr_datamodule import AsrDataModule from asr_datamodule import AsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector from model import IGNORE_TOKEN_ID, SPEECH_LLM, EncoderProjector
from multi_dataset import MultiDataset from multi_dataset import MultiDataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from peft import LoraConfig, get_peft_model
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.dist import get_rank, get_world_size from icefall.dist import get_rank, get_world_size
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
AttributeDict,
MetricsTracker,
filter_uneven_sized_batch,
setup_logger,
str2bool,
)
DEFAULT_SPEECH_TOKEN = "<speech>" DEFAULT_SPEECH_TOKEN = "<speech>"
@ -286,13 +272,6 @@ def compute_loss(
Returns: Returns:
Return a tuple of two elements. The first element is the loss tensor. Return a tuple of two elements. The first element is the loss tensor.
""" """
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
# we simply drop the last few shortest samples, so that the retained total frames
# (after padding) would not exceed `allowed_max_frames`:
# `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
# where `max_frames = max_duration * 1000 // frame_shift_ms`.
# We set allowed_excess_duration_ratio=0.1.
def preprocess( def preprocess(
messages, messages,
@ -347,46 +326,6 @@ def compute_loss(
return input_ids, attention_mask, target_ids return input_ids, attention_mask, target_ids
def normalize_text_alimeeting(text: str, normalize: str = "m2met") -> str:
"""
Text normalization similar to M2MeT challenge baseline.
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
"""
if normalize == "none":
return text
elif normalize == "m2met":
import re
text = text.replace(" ", "")
text = text.replace("<sil>", "")
text = text.replace("<%>", "")
text = text.replace("<->", "")
text = text.replace("<$>", "")
text = text.replace("<#>", "")
text = text.replace("<_>", "")
text = text.replace("<space>", "")
text = text.replace("`", "")
text = text.replace("&", "")
text = text.replace(",", "")
if re.search("[a-zA-Z]", text):
text = text.upper()
text = text.replace("", "A")
text = text.replace("", "A")
text = text.replace("", "B")
text = text.replace("", "C")
text = text.replace("", "K")
text = text.replace("", "T")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
text = text.replace("", "")
return text
max_frames = params.max_duration * 1000 // params.frame_shift_ms
allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
batch = filter_uneven_sized_batch(batch, allowed_max_frames)
device = next(model.parameters()).device device = next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
@ -397,11 +336,10 @@ def compute_loss(
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
# remove spaces in texts
texts = [normalize_text_alimeeting(text) for text in texts]
messages = [] messages = []
for i, text in enumerate(texts): for i, text in enumerate(texts):
text = text.replace(" ", "")
message = [ message = [
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"}, {"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
{"role": "assistant", "content": text}, {"role": "assistant", "content": text},
@ -516,14 +454,17 @@ def train_one_epoch(
The rank of the node in DDP training. If no DDP is used, it should The rank of the node in DDP training. If no DDP is used, it should
be set to 0. be set to 0.
""" """
model.encoder_projector.train() model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: if batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
valid_info = compute_validation_loss( valid_info = compute_validation_loss(
params=params, params=params,
@ -533,6 +474,9 @@ def train_one_epoch(
world_size=world_size, world_size=world_size,
) )
model.train() model.train()
model.encoder.eval()
if not params.unfreeze_llm:
model.llm.eval()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
logging.info( logging.info(
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
@ -648,7 +592,6 @@ def run(rank, world_size, args):
speech_encoder_dim = whisper_model.dims.n_audio_state speech_encoder_dim = whisper_model.dims.n_audio_state
for name, param in speech_encoder.named_parameters(): for name, param in speech_encoder.named_parameters():
param.requires_grad = False param.requires_grad = False
speech_encoder.eval()
tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name) tokenizer = AutoTokenizer.from_pretrained(params.llm_path_or_name)
if params.use_flash_attn: if params.use_flash_attn:
@ -671,7 +614,6 @@ def run(rank, world_size, args):
if not params.unfreeze_llm: if not params.unfreeze_llm:
for name, param in llm.named_parameters(): for name, param in llm.named_parameters():
param.requires_grad = False param.requires_grad = False
llm.eval()
else: else:
if params.use_lora: if params.use_lora:
lora_config = LoraConfig( lora_config = LoraConfig(
@ -728,7 +670,7 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
model.to(device) model.to(device)
assert params.deepspeed and world_size > 1 assert params.deepspeed
logging.info("Using DeepSpeed") logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize( model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters() args=params, model=model, model_parameters=model.parameters()
@ -764,7 +706,7 @@ def run(rank, world_size, args):
if params.sampler_state_dict_path: if params.sampler_state_dict_path:
sampler_state_dict = torch.load(params.sampler_state_dict_path) sampler_state_dict = torch.load(params.sampler_state_dict_path)
sampler_state_dict["max_duration"] = params.max_duration sampler_state_dict["max_duration"] = params.max_duration
# TODO: load sampler state dict
train_dl = data_module.train_dataloaders( train_dl = data_module.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict train_cuts, sampler_state_dict=sampler_state_dict
) )
@ -806,15 +748,15 @@ def run(rank, world_size, args):
model.save_checkpoint( model.save_checkpoint(
save_dir=params.exp_dir, save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}", tag=f"zero-epoch-{params.cur_epoch}",
client_state={}, client_state={},
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}",
tag=f"epoch-{params.cur_epoch}", tag=f"zero-epoch-{params.cur_epoch}",
exclude_frozen_parameters=True, exclude_frozen_parameters=True,
) )
# save sampler state dict into checkpoint # save sampler state dict into checkpoint
@ -824,7 +766,7 @@ def run(rank, world_size, args):
f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt", f"{params.exp_dir}/epoch-{params.cur_epoch}-sampler.pt",
) )
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}") os.system(f"rm -rf {params.exp_dir}/zero-epoch-{params.cur_epoch}")
logging.info("Done!") logging.info("Done!")
@ -865,6 +807,7 @@ def main():
torch.set_num_threads(1) torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
warnings.filterwarnings("ignore", category=FutureWarning)
run(rank=rank, world_size=world_size, args=args) run(rank=rank, world_size=world_size, args=args)

View File

@ -422,7 +422,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
unk_id = params.unk_id unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp) y = convert_texts_into_ids(texts, sp=sp)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):

View File

@ -397,7 +397,7 @@ def compute_loss(
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
unk_id = params.unk_id unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp) y = convert_texts_into_ids(texts, sp=sp)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):