mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
add speechio results
This commit is contained in:
parent
b422e7a97f
commit
a00c0c5279
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_aishell2(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -69,7 +78,9 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
dataset_parts,
|
||||
)
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -84,7 +95,7 @@ def compute_fbank_aishell2(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -129,5 +140,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell2(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
nj=30
|
||||
stage=1
|
||||
stop_stage=1
|
||||
stage=0
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import LilcomChunkyWriter, CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_aishell4(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/aishell4")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -71,7 +80,9 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -87,7 +98,7 @@ def compute_fbank_aishell4(num_mel_bins: int = 80, perturb_speed: bool = False,
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -140,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell4(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=20
|
||||
stop_stage=20
|
||||
stage=-1
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_alimeeting(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/alimeeting")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -71,7 +80,9 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -86,7 +97,7 @@ def compute_fbank_alimeeting(num_mel_bins: int = 80, perturb_speed: bool = False
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -140,5 +151,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_alimeeting(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=20
|
||||
stop_stage=20
|
||||
stage=-1
|
||||
stop_stage=7
|
||||
perturb_speed=true
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
|
@ -17,14 +17,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -32,6 +40,7 @@ from icefall.utils import str2bool
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -52,6 +61,7 @@ def get_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_kespeech_dev_test(args):
|
||||
in_out_dir = Path("data/fbank/kespeech")
|
||||
# number of workers in dataloader
|
||||
@ -70,7 +80,9 @@ def compute_fbank_kespeech_dev_test(args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device=device))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device=device)
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
|
@ -25,16 +25,17 @@ from pathlib import Path
|
||||
import torch
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -129,7 +130,9 @@ def compute_fbank_kespeech_splits(args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
logging.info(f"device: {device}")
|
||||
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,6 +49,7 @@ from icefall.utils import get_executor, str2bool
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -62,7 +70,10 @@ def get_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
|
||||
def compute_fbank_magicdata(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/magicdata")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -86,7 +97,9 @@ def compute_fbank_magicdata(num_mel_bins: int = 80, speed_perturb: bool = False,
|
||||
)
|
||||
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda"))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -145,5 +158,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_magicdata(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_primewords(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/primewords")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -67,7 +76,9 @@ def compute_fbank_primewords(num_mel_bins: int = 80, speed_perturb: bool = False
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -128,5 +139,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_primewords(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_stcmds(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/stcmds")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -67,7 +76,9 @@ def compute_fbank_stcmds(num_mel_bins: int = 80, speed_perturb: bool = False, wh
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -126,5 +137,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_stcmds(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -43,7 +50,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_thchs30(
|
||||
num_mel_bins: int = 80, speed_perturb: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests/thchs30")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -71,7 +80,9 @@ def compute_fbank_thchs30(num_mel_bins: int = 80, speed_perturb: bool = False, w
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -132,5 +143,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_thchs30(
|
||||
num_mel_bins=args.num_mel_bins, speed_perturb=args.speed_perturb, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
speed_perturb=args.speed_perturb,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -5,8 +5,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=121
|
||||
stop_stage=121
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
num_splits=100
|
||||
|
||||
dl_dir=$PWD/download
|
||||
@ -482,5 +482,3 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then
|
||||
python ./local/compile_lg.py --lang-dir $lang_dir
|
||||
done
|
||||
fi
|
||||
|
||||
|
||||
|
@ -52,14 +52,14 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
|
@ -34,10 +34,10 @@ torchrun --nproc-per-node 8 ./whisper/train.py \
|
||||
--model-name medium
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@ -52,13 +52,13 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import AsrDataModule
|
||||
from multi_dataset import MultiDataset
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from multi_dataset import MultiDataset
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
@ -626,7 +626,9 @@ def train_one_epoch(
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}-checkpoint-{batch_idx}",
|
||||
)
|
||||
os.system(f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}")
|
||||
os.system(
|
||||
f"rm -rf {params.exp_dir}/epoch-{params.cur_epoch}-checkpoint-{batch_idx}"
|
||||
)
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
@ -761,9 +763,7 @@ def run(rank, world_size, args):
|
||||
del model.alignment_heads
|
||||
|
||||
if params.pretrained_model_path:
|
||||
checkpoint = torch.load(
|
||||
params.pretrained_model_path, map_location="cpu"
|
||||
)
|
||||
checkpoint = torch.load(params.pretrained_model_path, map_location="cpu")
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
|
15
egs/speechio/ASR/README.md
Normal file
15
egs/speechio/ASR/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
|
||||
# Introduction
|
||||
|
||||
This recipe includes some different pretrained ASR models' decoding results with [SpeechIO](https://github.com/SpeechColab/Leaderboard) test sets.
|
||||
|
||||
[./RESULTS.md](./RESULTS.md) contains the latest results.
|
||||
|
||||
# Pretrained Models
|
||||
|
||||
The following table lists the pretrained models.
|
||||
|
||||
| | Huggingface | Comment |
|
||||
|---------------------------------------|--------------------|-----------------------------|
|
||||
| `zipformer` | zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24 | Using [multi-hans-zh recipe](../../multi_zh-hans/ASR/zipformer/) training | |
|
||||
| `whisper` | yuekai/icefall_asr_wenetspeech_whisper | Using [wenetspeech recipe](../../wenetspeech/ASR/whisper/) training |
|
92
egs/speechio/ASR/RESULTS.md
Normal file
92
egs/speechio/ASR/RESULTS.md
Normal file
@ -0,0 +1,92 @@
|
||||
## Results
|
||||
|
||||
### SpeechIO Test Set Decoding Results
|
||||
|
||||
##### Decoding results using pretrained [multi-hans-zh zipformer](https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24), [whipser-large-v2](https://github.com/openai/whisper/blob/main/whisper/__init__.py#L27), [whisper-large-v2-wenetspeech-ft](https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper).
|
||||
|
||||
| | zipformer_transducer | zipformer_transducer_blank_penalty_2 | whisper_large_v2 | whisper_large_v2_wenetspeech | whisper_large_v2_wenetspeech_zipformer_fusion |
|
||||
|------------------------|----------------------|--------------------------------------|------------------|------------------------------|-----------------------------------------------|
|
||||
| SPEECHIO_ASR_ZH00000 | 10.04 | 8.04 | 11.4 | 9.88 | 7.78 |
|
||||
| SPEECHIO_ASR_ZH00001 | 1.67 | 1.51 | 2.49 | 1.57 | 1.38 |
|
||||
| SPEECHIO_ASR_ZH00002 | 5.89 | 5.27 | 7.89 | 5.65 | 4.99 |
|
||||
| SPEECHIO_ASR_ZH00003 | 2.66 | 2.79 | 5.94 | 2.27 | 2.33 |
|
||||
| SPEECHIO_ASR_ZH00004 | 3.6 | 3.34 | 4.57 | 3.62 | 3.26 |
|
||||
| SPEECHIO_ASR_ZH00005 | 7.54 | 5.81 | 8.39 | 7.26 | 5.43 |
|
||||
| SPEECHIO_ASR_ZH00006 | 15.59 | 13.34 | 19.07 | 13.64 | 11.96 |
|
||||
| SPEECHIO_ASR_ZH00007 | 15.9 | 15.05 | 16.7 | 14.06 | 13.73 |
|
||||
| SPEECHIO_ASR_ZH00008 | 11.07 | 9.68 | 14.69 | 10.34 | 8.87 |
|
||||
| SPEECHIO_ASR_ZH00009 | 7.38 | 6.23 | 8.32 | 6.74 | 5.96 |
|
||||
| SPEECHIO_ASR_ZH00010 | 9.19 | 7.33 | 11.2 | 8.85 | 6.97 |
|
||||
| SPEECHIO_ASR_ZH00011 | 4.16 | 3.84 | 54.56 | 4.09 | 3.72 |
|
||||
| SPEECHIO_ASR_ZH00012 | 7.61 | 6.58 | 10.53 | 8.35 | 6.27 |
|
||||
| SPEECHIO_ASR_ZH00013 | 8.72 | 7.66 | 9.32 | 7.26 | 6.7 |
|
||||
| SPEECHIO_ASR_ZH00014 | 9.69 | 8.71 | 9.03 | 7.03 | 6.59 |
|
||||
| SPEECHIO_ASR_ZH00015 | 11.94 | 11.37 | 16.58 | 12.02 | 11.11 |
|
||||
| SPEECHIO_ASR_ZH00016 | 9.79 | 8.79 | 14.1 | 10.19 | 8.15 |
|
||||
| SPEECHIO_ASR_ZH00017 | 8 | 6.72 | 9.04 | 8.9 | 6.44 |
|
||||
| SPEECHIO_ASR_ZH00018 | 5.42 | 5.02 | 6.06 | 4.86 | 4.4 |
|
||||
| SPEECHIO_ASR_ZH00019 | 11.26 | 9.06 | 14.8 | 9.83 | 8.22 |
|
||||
| SPEECHIO_ASR_ZH00020 | 4.37 | 4.23 | 5.97 | 4.23 | 4.13 |
|
||||
| SPEECHIO_ASR_ZH00021 | 7.81 | 6.34 | 8.53 | 7.08 | 5.88 |
|
||||
| SPEECHIO_ASR_ZH00022 | 9.11 | 8.54 | 9.7 | 8.97 | 8.02 |
|
||||
| SPEECHIO_ASR_ZH00023 | 9.98 | 8.98 | 6.31 | 9.44 | 8.57 |
|
||||
| SPEECHIO_ASR_ZH00024 | 16.15 | 12.95 | 20.54 | 15.92 | 12.28 |
|
||||
| SPEECHIO_ASR_ZH00025 | 10.38 | 9.82 | 11.4 | 10.26 | 9.27 |
|
||||
| SPEECHIO_ASR_ZH00026 | 5.69 | 5.63 | 9.09 | 5.95 | 5.51 |
|
||||
| Average WER (001-026) | 8.48 | 7.48 | 12.11 | 8.01 | 6.93 |
|
||||
|
||||
|
||||
|
||||
|
||||
Command for decoding using fine-tuned whisper:
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://huggingface.co/yuekai/icefall_asr_wenetspeech_whisper
|
||||
ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-4-avg3.pt whisper/exp_large_v2_wenetspeech/epoch-999.pt
|
||||
|
||||
python3 ./whisper/decode.py \
|
||||
--exp-dir whisper/exp_large_v2_wenetspeech \
|
||||
--model-name large-v2 \
|
||||
--epoch 999 --avg 1 \
|
||||
--start-index 0 --end-index 26 \
|
||||
--remove-whisper-encoder-input-length-restriction True \
|
||||
--manifest-dir data/fbank \
|
||||
--beam-size 1 --max-duration 50
|
||||
```
|
||||
Command for decoding using pretrained zipformer:
|
||||
```bash
|
||||
git lfs install
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/zrjin/icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
|
||||
cd icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24
|
||||
git lfs pull --include "exp/pretrained.pt"
|
||||
git lfs pull --include "data/lang_bpe_2000/*"
|
||||
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/exp/pretrained.pt zipformer/exp_pretrain/epoch-999.pt
|
||||
ln -s ../icefall-asr-multi-zh-hans-zipformer-ctc-2023-10-24/data/lang_bpe_2000/ ./data
|
||||
wget https://huggingface.co/pkufool/icefall-asr-zipformer-wenetspeech-20230615/resolve/main/data/lang_char/words.txt
|
||||
mv words.txt ./data/lang_bpe_2000/
|
||||
|
||||
./zipformer/decode.py \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--blank-penalty 2.0 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./zipformer/exp_pretrain \
|
||||
--max-duration 600 \
|
||||
--start-index 0 --end-index 26 \
|
||||
--manifest-dir data/fbank_kaldi \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
Command for fusion the above decoding results from whisper and zipformer:
|
||||
```bash
|
||||
python local/whisper_zipformer_fusion.py \
|
||||
--whisper-log-dir ./whisper/exp_large_v2_wenetspeech \
|
||||
--zipformer-log-dir ./zipformer/exp_pretrain/greedy_search \
|
||||
--output-log-dir ./results_fusion
|
||||
|
||||
```
|
||||
|
||||
See why the fusion helps [here](./local/whisper_zipformer_fusion.py).
|
||||
|
||||
SpeechIO fbank features, decoding scripts, logs, and decoding results
|
||||
are available at
|
||||
<https://huggingface.co/yuekai/icefall_asr_speechio>
|
@ -30,7 +30,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, Fbank, FbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -44,7 +51,13 @@ torch.set_num_interop_threads(1)
|
||||
|
||||
SPEECHIO_TESTSET_INDEX = 26 # Currently, from 0 - 26 test sets are open source.
|
||||
|
||||
def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False, fbank_dir: str = "data/fbank", whisper_fbank: bool = False):
|
||||
|
||||
def compute_fbank_speechio(
|
||||
num_mel_bins: int = 80,
|
||||
speed_perturb: bool = False,
|
||||
fbank_dir: str = "data/fbank",
|
||||
whisper_fbank: bool = False,
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path(fbank_dir)
|
||||
num_jobs = min(8, os.cpu_count())
|
||||
@ -72,7 +85,9 @@ def compute_fbank_speechio(num_mel_bins: int = 80, speed_perturb: bool = False,
|
||||
)
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -127,5 +142,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_speechio(
|
||||
num_mel_bins=args.num_mel_bins, fbank_dir=args.fbank_dir, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
fbank_dir=args.fbank_dir,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -41,9 +41,12 @@ def main():
|
||||
for partition in dataset_parts:
|
||||
path = f"./data/fbank/{prefix}_cuts_{partition}.{suffix}"
|
||||
cuts = load_manifest_lazy(path)
|
||||
print(f"===================Duration statistics of {partition}===================")
|
||||
print(
|
||||
f"===================Duration statistics of {partition}==================="
|
||||
)
|
||||
cuts.describe()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
217
egs/speechio/ASR/local/whisper_zipformer_fusion.py
Normal file
217
egs/speechio/ASR/local/whisper_zipformer_fusion.py
Normal file
@ -0,0 +1,217 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2024 Author: Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This file uses whisper and zipformer decoding results to generate fusion decoding results.
|
||||
Since whisper model is more likely to make deletion errors and zipformer model is more likely to make substitution and insertion errors,
|
||||
we trust whisper model when it makes substitution and insertion errors and trust zipformer model when it makes deletion errors.
|
||||
|
||||
Usage:
|
||||
python whisper_zipformer_fusion.py --whisper-log-dir ./whisper_decoding_log_dir --zipformer-log-dir ./zipformer_decoding_log_dir --output-log-dir ./results_fusion
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import kaldialign
|
||||
|
||||
from icefall.utils import store_transcripts, write_error_stats
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
"--whisper-log-dir",
|
||||
type=str,
|
||||
default="./recogs_whisper",
|
||||
help="The directory to store the whisper logs: e.g. recogs-SPEECHIO_ASR_ZH00014-beam-search-epoch--1-avg-1.txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zipformer-log-dir",
|
||||
type=str,
|
||||
default="./recogs_zipformer",
|
||||
help="The directory to store the zipformer logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-log-dir",
|
||||
type=str,
|
||||
default="./results_fusion",
|
||||
help="The directory to store the fusion logs",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def save_results(
|
||||
res_dir: Path,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
|
||||
suffix = "epoch-999-avg-1"
|
||||
|
||||
for key, results in results_dict.items():
|
||||
recog_path = res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
print(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
print("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
print(s)
|
||||
|
||||
|
||||
def extract_hyp_ref_wavname(filename):
|
||||
"""
|
||||
0Phqz8RWYuE_0007-5: ref=['R', 'Y', 'Y', 'B', '它最大的优势就是进光量或者说是对光线利用率的提升']
|
||||
0Phqz8RWYuE_0007-5: hyp=而YB它最大的优势是近光量或者说是对光线利用率的提升
|
||||
"""
|
||||
hyps, refs, wav_name = [], [], []
|
||||
with open(filename, "r") as f:
|
||||
for line in f:
|
||||
if "ref" in line:
|
||||
ref = line.split("ref=")[1].strip()
|
||||
ref = ref[2:-2]
|
||||
list_elements = ref.split("', '")
|
||||
ref = "".join(list_elements)
|
||||
refs.append(ref)
|
||||
elif "hyp" in line:
|
||||
hyp = line.split("hyp=")[1].strip()
|
||||
hyps.append(hyp)
|
||||
wav_name.append(line.split(":")[0])
|
||||
return hyps, refs, wav_name
|
||||
|
||||
|
||||
def get_pair_filenames(
|
||||
whisper_log_dir,
|
||||
zipformer_log_dir,
|
||||
whisper_suffix="beam-search-epoch-999-avg-1",
|
||||
zipformer_suffix="greedy_search_blank_penalty_2.0-epoch-999-avg-1-context-2-max-sym-per-frame-1-blank-penalty-2.0",
|
||||
):
|
||||
results = []
|
||||
start_index, end_index = 0, 26
|
||||
dataset_parts = []
|
||||
for i in range(start_index, end_index + 1):
|
||||
idx = f"{i}".zfill(2)
|
||||
dataset_parts.append(f"SPEECHIO_ASR_ZH000{idx}")
|
||||
for partition in dataset_parts:
|
||||
whisper_filename = f"{whisper_log_dir}/recogs-{partition}-{whisper_suffix}.txt"
|
||||
zipformer_filename = (
|
||||
f"{zipformer_log_dir}/recogs-{partition}-{zipformer_suffix}.txt"
|
||||
)
|
||||
results.append((whisper_filename, zipformer_filename))
|
||||
return results
|
||||
|
||||
|
||||
def fusion_hyps_trust_substituion_insertion(
|
||||
hyps_whisper, hyps_zipformer, refs, ERR="*"
|
||||
):
|
||||
"""
|
||||
alignment example:
|
||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||
left is whisper, right is zipformer
|
||||
for whisper substitution, use left
|
||||
for whisper insertion, use left
|
||||
for whisper deletion, use right
|
||||
"""
|
||||
hyps_fusion = []
|
||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||
hyp_f = ""
|
||||
for a in ali:
|
||||
if a[0] == ERR:
|
||||
hyp_f += a[1]
|
||||
else:
|
||||
hyp_f += a[0]
|
||||
hyps_fusion.append(hyp_f)
|
||||
return hyps_fusion
|
||||
|
||||
|
||||
def fusion_hyps_trust_substituion(hyps_whisper, hyps_zipformer, refs, ERR="*"):
|
||||
"""
|
||||
alignment example:
|
||||
[('我', '你'), ('在', '*'), ('任', '任'), ('的', '的'), ('时', '时'), ('候', '候'), ('*', '呢')]
|
||||
left is whisper, right is zipformer
|
||||
for whisper substitution, use left
|
||||
for whisper insertion, use right
|
||||
for whisper deletion, use right
|
||||
"""
|
||||
hyps_fusion = []
|
||||
for hyp_w, hyp_z, ref in zip(hyps_whisper, hyps_zipformer, refs):
|
||||
ali = kaldialign.align(hyp_w, hyp_z, ERR)
|
||||
hyp_f = ""
|
||||
for a in ali:
|
||||
if a[0] == ERR:
|
||||
hyp_f += a[1]
|
||||
elif a[1] == ERR:
|
||||
pass
|
||||
else:
|
||||
hyp_f += a[0]
|
||||
hyps_fusion.append(hyp_f)
|
||||
return hyps_fusion
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
# mkdir output_log_dir
|
||||
Path(args.output_log_dir).mkdir(parents=True, exist_ok=True)
|
||||
pair_logs = get_pair_filenames(args.whisper_log_dir, args.zipformer_log_dir)
|
||||
for pair in pair_logs:
|
||||
hyps_whisper, refs, wav_name = extract_hyp_ref_wavname(pair[0])
|
||||
hyps_zipformer, _, _ = extract_hyp_ref_wavname(pair[1])
|
||||
|
||||
hyps_fusion = fusion_hyps_trust_substituion_insertion(
|
||||
hyps_whisper, hyps_zipformer, refs
|
||||
)
|
||||
|
||||
partition_name = pair[0].split("/")[-1].split("-")[1]
|
||||
save_results(
|
||||
Path(args.output_log_dir),
|
||||
partition_name,
|
||||
{"fusion": list(zip(wav_name, refs, hyps_fusion))},
|
||||
)
|
||||
|
||||
print(f"Processed {partition_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -34,9 +34,7 @@ from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
@ -53,14 +53,14 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
|
||||
from asr_datamodule import AsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
from lhotse.cut import Cut
|
||||
from multi_dataset import MultiDataset
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
|
@ -53,9 +53,7 @@ class MultiDataset:
|
||||
path = f"{prefix}_cuts_{partition}.{suffix}"
|
||||
|
||||
logging.info(f"Loading {path} set in lazy mode")
|
||||
test_cuts = load_manifest_lazy(
|
||||
self.fbank_dir / path
|
||||
)
|
||||
test_cuts = load_manifest_lazy(self.fbank_dir / path)
|
||||
results_dict[partition] = test_cuts
|
||||
|
||||
return results_dict
|
@ -303,6 +303,17 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
@ -431,6 +442,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -455,6 +467,7 @@ def decode_one_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
blank_penalty=params.blank_penalty,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
@ -468,8 +481,9 @@ def decode_one_batch(
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
key = f"blank_penalty_{params.blank_penalty}"
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
return {"greedy_search_" + key: hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
@ -657,6 +671,7 @@ def main():
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
params.suffix += f"-blank-penalty-{params.blank_penalty}"
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
1
egs/speechio/ASR/zipformer/train.py
Symbolic link
1
egs/speechio/ASR/zipformer/train.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../multi_zh-hans/ASR/zipformer/train.py
|
@ -16,11 +16,19 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, WhisperFbank, WhisperFbankConfig, KaldifeatFbank, KaldifeatFbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -32,6 +40,7 @@ torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -52,6 +61,7 @@ def get_parser():
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def compute_fbank_wenetspeech_dev_test(args):
|
||||
in_out_dir = Path("data/fbank")
|
||||
# number of workers in dataloader
|
||||
@ -66,7 +76,9 @@ def compute_fbank_wenetspeech_dev_test(args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
if args.whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=args.num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=args.num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = KaldifeatFbank(KaldifeatFbankConfig(device=device))
|
||||
|
||||
|
@ -22,20 +22,19 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import (
|
||||
from lhotse import ( # KaldifeatWhisperFbank,; KaldifeatWhisperFbankConfig,
|
||||
CutSet,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
# KaldifeatWhisperFbank,
|
||||
# KaldifeatWhisperFbankConfig,
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
set_audio_duration_mismatch_tolerance,
|
||||
set_caching_enabled,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool, get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
@ -177,13 +176,6 @@ def compute_fbank_wenetspeech_splits(args):
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
# cut_set = cut_set.compute_and_store_features(
|
||||
# extractor=extractor,
|
||||
# storage_path=f"{output_dir}/feats_{subset}_{idx}",
|
||||
# num_jobs=args.num_workers,
|
||||
# executor=ex,
|
||||
# storage_type=LilcomChunkyWriter,
|
||||
# )
|
||||
logging.info(f"Saving to {cuts_path}")
|
||||
cut_set.to_file(cuts_path)
|
||||
|
||||
|
@ -6,8 +6,8 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
set -eou pipefail
|
||||
|
||||
nj=15
|
||||
stage=131
|
||||
stop_stage=131
|
||||
stage=0
|
||||
stop_stage=100
|
||||
|
||||
# Split L subset to this number of pieces
|
||||
# This is to avoid OOM during feature extraction.
|
||||
|
@ -74,10 +74,10 @@ It will generate the following 3 files inside $repo/exp:
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from icefall import is_module_available
|
||||
import torch
|
||||
from onnx_pretrained import OnnxModel
|
||||
|
||||
import torch
|
||||
from icefall import is_module_available
|
||||
|
||||
|
||||
def get_parser():
|
||||
|
@ -52,13 +52,13 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
|
||||
from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from lhotse.cut import Cut
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from zhconv import convert
|
||||
from lhotse.cut import Cut
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import (
|
||||
|
@ -834,6 +834,7 @@ def run(rank, world_size, args):
|
||||
# )
|
||||
return False
|
||||
return True
|
||||
|
||||
train_cuts = wenetspeech.train_cuts()
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
train_dl = wenetspeech.train_dataloaders(train_cuts)
|
||||
|
Loading…
x
Reference in New Issue
Block a user