mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
add manifest dir option
This commit is contained in:
parent
46605eaef2
commit
fd4ebf3bfe
@ -50,10 +50,13 @@ torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
num_mel_bins: int = 80,
|
||||
perturb_speed: bool = False,
|
||||
whisper_fbank: bool = False,
|
||||
output_dir: str = "data/fbank",
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
output_dir = Path(output_dir)
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
@ -130,6 +133,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="Output directory. Default: data/fbank.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -143,4 +152,5 @@ if __name__ == "__main__":
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
|
@ -379,12 +379,13 @@ fi
|
||||
|
||||
# whisper large-v3 using 128 mel bins, others using 80 mel bins
|
||||
whisper_mel_bins=80
|
||||
output_dir=data/fbank_whisper
|
||||
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||
log "Stage 30: Compute ${whisper_mel_bins} dim fbank for whisper model fine-tuning"
|
||||
if [ ! -f data/fbank/.aishell.whisper.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.aishell.whisper.done
|
||||
if [ ! -f $output_dir/.aishell.whisper.done ]; then
|
||||
mkdir -p $output_dir
|
||||
./local/compute_fbank_aishell.py --perturb-speed ${perturb_speed} --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
|
||||
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true --output-dir $output_dir
|
||||
touch $output_dir/.aishell.whisper.done
|
||||
fi
|
||||
fi
|
||||
|
@ -28,6 +28,7 @@ python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--manifest-dir data/fbank_whisper \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
# Command for decoding using pretrained models (before fine-tuning):
|
||||
@ -36,6 +37,7 @@ python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--epoch -1 --avg 1 \
|
||||
--manifest-dir data/fbank_whisper \
|
||||
--remove-whisper-encoder-input-length-restriction False \
|
||||
--beam-size 10 --max-duration 50
|
||||
|
||||
|
@ -23,6 +23,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--manifest-dir data/fbank_whisper \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
|
||||
@ -30,6 +31,7 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_medium \
|
||||
--manifest-dir data/fbank_whisper \
|
||||
--base-lr 1e-5 \
|
||||
--model-name medium
|
||||
"""
|
||||
@ -253,6 +255,7 @@ def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"frame_shift_ms": 10.0,
|
||||
"subsampling_factor": 2,
|
||||
"allowed_excess_duration_ratio": 0.1,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
|
@ -54,9 +54,11 @@ def is_cut_long(c: MonoCut) -> bool:
|
||||
return c.duration > 5
|
||||
|
||||
|
||||
def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
|
||||
def compute_fbank_musan(
|
||||
num_mel_bins: int = 80, whisper_fbank: bool = False, output_dir: str = "data/fbank"
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
output_dir = Path(output_dir)
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
|
||||
dataset_parts = (
|
||||
@ -129,6 +131,12 @@ def get_args():
|
||||
default=False,
|
||||
help="Use WhisperFbank instead of Fbank. Default: False.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="data/fbank",
|
||||
help="Output directory. Default: data/fbank.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -138,5 +146,7 @@ if __name__ == "__main__":
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
compute_fbank_musan(
|
||||
num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
output_dir=args.output_dir,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user